Skip to content

Commit 10d9e76

Browse files
committed
refactor: remove calling module from parse expression
1 parent ec10651 commit 10d9e76

File tree

4 files changed

+45
-105
lines changed

4 files changed

+45
-105
lines changed

src/Parse.jl

Lines changed: 25 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ typeof(x) = Node{Float32}
8282
"""
8383
macro parse_expression(ex, kws...)
8484
parsed_kws = _parse_kws(kws)
85-
calling_module = __module__
8685
return esc(
8786
:($(parse_expression)(
8887
$(Meta.quot(ex));
@@ -91,7 +90,6 @@ macro parse_expression(ex, kws...)
9190
node_type=$(parsed_kws.node_type),
9291
expression_type=$(parsed_kws.expression_type),
9392
evaluate_on=$(parsed_kws.evaluate_on),
94-
calling_module=$calling_module,
9593
$(parsed_kws.extra_metadata)...,
9694
)),
9795
)
@@ -166,7 +164,6 @@ end
166164
"""Parse an expression Julia `Expr` object."""
167165
function parse_expression(
168166
ex;
169-
calling_module,
170167
operators::AbstractOperatorEnum,
171168
variable_names::Union{AbstractVector,Nothing}=nothing,
172169
node_type::Type{N}=Node,
@@ -182,22 +179,22 @@ function parse_expression(
182179
else
183180
string.(variable_names)
184181
end
185-
tree = _parse_expression(
186-
ex, operators, variable_names, N, E, evaluate_on, calling_module; kws...
187-
)
182+
tree = _parse_expression(ex, operators, variable_names, N, E, evaluate_on; kws...)
188183

189184
return constructorof(E)(tree; operators, variable_names, kws...)
190185
end
191186
end
192187

188+
"""An empty module for evaluation without collisions."""
189+
module EmptyModule end
190+
193191
function _parse_expression(
194192
ex::Expr,
195193
operators::AbstractOperatorEnum,
196194
variable_names::Union{AbstractVector{<:AbstractString},Nothing},
197195
::Type{N},
198196
::Type{E},
199-
evaluate_on::Union{Nothing,AbstractVector},
200-
calling_module;
197+
evaluate_on::Union{Nothing,AbstractVector};
201198
kws...,
202199
) where {N<:AbstractExpressionNode,E<:AbstractExpression}
203200
ex.head != :call && throw(
@@ -208,17 +205,14 @@ function _parse_expression(
208205
)
209206
args = ex.args
210207
func = try
211-
Core.eval(calling_module, first(ex.args))
208+
Core.eval(EmptyModule, first(ex.args))
212209
catch
213210
throw(
214-
ArgumentError(
215-
"Failed to evaluate function `$(first(ex.args))` within `$(calling_module)`. " *
216-
"Make sure the function is defined in that module.",
217-
),
211+
ArgumentError("Tried to interpolate function `$(first(ex.args))` but failed."),
218212
)
219213
end::Function
220214
return _parse_expression(
221-
func, args, operators, variable_names, N, E, evaluate_on, calling_module; kws...
215+
func, args, operators, variable_names, N, E, evaluate_on; kws...
222216
)
223217
end
224218
function _parse_expression(
@@ -228,8 +222,7 @@ function _parse_expression(
228222
variable_names::Union{AbstractVector{<:AbstractString},Nothing},
229223
::Type{N},
230224
::Type{E},
231-
evaluate_on::Union{Nothing,AbstractVector},
232-
calling_module;
225+
evaluate_on::Union{Nothing,AbstractVector};
233226
kws...,
234227
)::N where {F<:Function,N<:AbstractExpressionNode,E<:AbstractExpression}
235228
if length(args) == 2 && func operators.unaops
@@ -238,14 +231,7 @@ function _parse_expression(
238231
return N(;
239232
op=op::Int,
240233
l=_parse_expression(
241-
args[2],
242-
operators,
243-
variable_names,
244-
N,
245-
E,
246-
evaluate_on,
247-
calling_module;
248-
kws...,
234+
args[2], operators, variable_names, N, E, evaluate_on; kws...
249235
),
250236
)
251237
elseif length(args) == 3 && func operators.binops
@@ -254,24 +240,10 @@ function _parse_expression(
254240
return N(;
255241
op=op::Int,
256242
l=_parse_expression(
257-
args[2],
258-
operators,
259-
variable_names,
260-
N,
261-
E,
262-
evaluate_on,
263-
calling_module;
264-
kws...,
243+
args[2], operators, variable_names, N, E, evaluate_on; kws...
265244
),
266245
r=_parse_expression(
267-
args[3],
268-
operators,
269-
variable_names,
270-
N,
271-
E,
272-
evaluate_on,
273-
calling_module;
274-
kws...,
246+
args[3], operators, variable_names, N, E, evaluate_on; kws...
275247
),
276248
)
277249
elseif length(args) > 3 && func in (+, -, *) && func operators.binops
@@ -280,39 +252,18 @@ function _parse_expression(
280252
inner = N(;
281253
op=op::Int,
282254
l=_parse_expression(
283-
args[2],
284-
operators,
285-
variable_names,
286-
N,
287-
E,
288-
evaluate_on,
289-
calling_module;
290-
kws...,
255+
args[2], operators, variable_names, N, E, evaluate_on; kws...
291256
),
292257
r=_parse_expression(
293-
args[3],
294-
operators,
295-
variable_names,
296-
N,
297-
E,
298-
evaluate_on,
299-
calling_module;
300-
kws...,
258+
args[3], operators, variable_names, N, E, evaluate_on; kws...
301259
),
302260
)
303261
for arg in args[4:end]
304262
inner = N(;
305263
op=op::Int,
306264
l=inner,
307265
r=_parse_expression(
308-
arg,
309-
operators,
310-
variable_names,
311-
N,
312-
E,
313-
evaluate_on,
314-
calling_module;
315-
kws...,
266+
arg, operators, variable_names, N, E, evaluate_on; kws...
316267
),
317268
)
318269
end
@@ -322,14 +273,7 @@ function _parse_expression(
322273
func(
323274
map(
324275
arg -> _parse_expression(
325-
arg,
326-
operators,
327-
variable_names,
328-
N,
329-
E,
330-
evaluate_on,
331-
calling_module;
332-
kws...,
276+
arg, operators, variable_names, N, E, evaluate_on; kws...
333277
),
334278
args[2:end],
335279
)...,
@@ -366,16 +310,19 @@ function _parse_expression(
366310
variable_names::Union{AbstractVector{<:AbstractString},Nothing},
367311
node_type::Type{<:AbstractExpressionNode},
368312
expression_type::Type{<:AbstractExpression},
369-
evaluate_on::Union{Nothing,AbstractVector},
370-
calling_module;
313+
evaluate_on::Union{Nothing,AbstractVector};
371314
kws...,
372315
)
373-
return parse_leaf(
374-
ex, variable_names, node_type, expression_type, calling_module; kws...
375-
)
316+
return parse_leaf(ex, variable_names, node_type, expression_type; kws...)
376317
end
377318

378-
function parse_leaf(ex, variable_names, node_type, expression_type, calling_module; kws...)
319+
function parse_leaf(
320+
ex,
321+
variable_names,
322+
node_type::Type{<:AbstractExpressionNode},
323+
expression_type::Type{<:AbstractExpression};
324+
kws...,
325+
)
379326
if ex isa AbstractExpression
380327
throw(
381328
ArgumentError(

test/test_extra_node_fields.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ freeze!(n) = (n.frozen = true; n)
8282
thaw!(n) = (n.frozen = false; n)
8383

8484
ex = parse_expression(
85-
:(x + freeze!(sin(thaw!(y + 2.1))));
85+
:(x + $freeze!(sin($thaw!(y + 2.1))));
8686
operators=OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin]),
8787
variable_names=[:x, :y],
8888
evaluate_on=[freeze!, thaw!],

test/test_parametric_expression.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,7 @@ function DE.parse_leaf(
252252
ex,
253253
variable_names,
254254
node_type::Type{<:ParametricNode},
255-
expression_type::Type{<:ParametricExpression},
256-
evaluate_on;
255+
expression_type::Type{<:ParametricExpression};
257256
parameter_names,
258257
kws...,
259258
)

test/test_parse.jl

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ operators = OperatorEnum(;
1818
OperatorEnum(; binary_operators=[/])
1919

2020
let ex = @parse_expression(
21-
my_custom_op(x, sin(y) + 0.3), operators = operators, variable_names = ["x", "y"],
21+
$(my_custom_op)(x, sin(y) + 0.3),
22+
operators = operators,
23+
variable_names = ["x", "y"],
2224
)
2325
@test typeof(ex) <: Expression
2426

@@ -181,7 +183,7 @@ show_type(x) = (show(typeof(x)); x)
181183
let
182184
logged_out = @capture_out begin
183185
ex = @parse_expression(
184-
x * 2.5 - show_type(cos(y)),
186+
x * 2.5 - $(show_type)(cos(y)),
185187
operators = OperatorEnum(; binary_operators=[*, -], unary_operators=[cos]),
186188
variable_names = [:x, :y],
187189
evaluate_on = [show_type],
@@ -238,7 +240,7 @@ let operators = OperatorEnum(; unary_operators=[sin])
238240
@eval blah(x...) = first(x)
239241
if VERSION >= v"1.9"
240242
@test_throws "Unrecognized operator: `blah` with no matches in `[show]`." parse_expression(
241-
:(blah(x, x, y));
243+
:($blah(x, x, y));
242244
operators=operators,
243245
variable_names=[:x, :y],
244246
evaluate_on=[show],
@@ -248,34 +250,24 @@ let operators = OperatorEnum(; unary_operators=[sin])
248250
end
249251

250252
# Helpful error for missing function in scope
251-
let my_badly_scoped_function(x) = x
252-
@test_throws ArgumentError begin
253+
my_badly_scoped_function(x) = x
254+
@test_throws ArgumentError begin
255+
ex = @parse_expression(
256+
my_badly_scoped_function(x),
257+
operators = operators,
258+
variable_names = ["x"],
259+
evaluate_on = [my_badly_scoped_function]
260+
)
261+
end
262+
if VERSION >= v"1.9"
263+
@test_throws "Tried to interpolate function `my_badly_scoped_function` but failed." begin
253264
ex = @parse_expression(
254265
my_badly_scoped_function(x),
255266
operators = operators,
256267
variable_names = ["x"],
257268
evaluate_on = [my_badly_scoped_function]
258269
)
259270
end
260-
@test_throws ArgumentError begin
261-
ex = parse_expression(
262-
:(my_badly_scoped_function(x));
263-
operators,
264-
variable_names=["x"],
265-
evaluate_on=[my_badly_scoped_function],
266-
calling_module=@__MODULE__,
267-
)
268-
end
269-
if VERSION >= v"1.9"
270-
@test_throws "Make sure the function is defined in that module." begin
271-
ex = @parse_expression(
272-
my_badly_scoped_function(x),
273-
operators = operators,
274-
variable_names = ["x"],
275-
evaluate_on = [my_badly_scoped_function]
276-
)
277-
end
278-
end
279271
end
280272

281273
# Helpful error for missing variable name
@@ -348,3 +340,5 @@ let
348340
)
349341
@test string_tree(ex) == "x"
350342
end
343+
344+
# TODO: Test parsing with custom operators

0 commit comments

Comments
 (0)