Skip to content

Commit e98527c

Browse files
authored
Merge branch 'master' into dev
2 parents c850e97 + a55f966 commit e98527c

File tree

12 files changed

+169
-93
lines changed

12 files changed

+169
-93
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.17.0"
4+
version = "0.18.0-alpha"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ using DynamicExpressions
3030

3131
operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])
3232

33-
x1 = Node(; feature=1)
34-
x2 = Node(; feature=2)
33+
x1 = Node{Float64}(feature=1)
34+
x2 = Node{Float64}(feature=2)
3535

3636
expression = x1 * cos(x2 - 3.2)
3737

docs/src/utils.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
# Node utilities
22

3-
## Creating trees
4-
5-
```@docs
6-
@parse_expression
7-
```
8-
93
## `Base`
104

115
Various functions in `Base` are overloaded to treat an `AbstractNode` as a

src/DynamicExpressions.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,17 @@ include("deprecated.jl")
7979

8080
import TOML: parsefile
8181

82-
const PACKAGE_VERSION = let
83-
project = parsefile(joinpath(pkgdir(@__MODULE__), "Project.toml"))
84-
VersionNumber(project["version"])
82+
const PACKAGE_VERSION = let d = pkgdir(@__MODULE__)
83+
try
84+
if d isa String
85+
project = parsefile(joinpath(d, "Project.toml"))
86+
VersionNumber(project["version"])
87+
else
88+
v"0.0.0"
89+
end
90+
catch
91+
v"0.0.0"
92+
end
8593
end
8694

8795
# To get LanguageServer to register library within tests

src/EvaluateDerivative.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ end
8080
end
8181
end
8282
deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
83-
diff_deg2_eval(tree, cX, operators.binops[op_idx], operators, direction)
83+
quote
84+
diff_deg2_eval(tree, cX, operators.binops[op_idx], operators, direction)
85+
end
8486
else
8587
quote
8688
Base.Cartesian.@nif(

src/Expression.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828
"""
2929
AbstractExpression{T}
3030
31-
Abstract type for user-facing expression types, which contain
31+
(Experimental) Abstract type for user-facing expression types, which contain
3232
both the raw expression tree operating on a value type of `T`,
3333
as well as associated metadata to evaluate and render the expression.
3434
@@ -74,7 +74,7 @@ abstract type AbstractExpression{T} end
7474
"""
7575
Expression{T, N, D} <: AbstractExpression{T}
7676
77-
Defines a high level, user-facing, expression type that encapsulates an
77+
(Experimental) Defines a high-level, user-facing, expression type that encapsulates an
7878
expression tree (like `Node`) along with associated metadata for evaluation and rendering.
7979
8080
# Fields
@@ -214,8 +214,8 @@ import ..NodeUtilsModule:
214214
set_constants!
215215

216216
#! format: off
217-
count_constants(ex::AbstractExpression; kws...) = count_constants(get_tree(ex); kws...)
218-
count_depth(ex::AbstractExpression; kws...) = count_depth(get_tree(ex); kws...)
217+
count_constants(ex::AbstractExpression) = count_constants(get_tree(ex))
218+
count_depth(ex::AbstractExpression) = count_depth(get_tree(ex))
219219
index_constants(ex::AbstractExpression, ::Type{T}=UInt16) where {T} = index_constants(get_tree(ex), T)
220220
has_operators(ex::AbstractExpression) = has_operators(get_tree(ex))
221221
has_constants(ex::AbstractExpression) = has_constants(get_tree(ex))
@@ -311,14 +311,14 @@ end
311311
import ..SimplifyModule: combine_operators, simplify_tree!
312312

313313
# Avoid implementing a generic version for these, as it is less likely to generalize
314-
function combine_operators(ex::Expression, operators=nothing; kws...)
314+
function combine_operators(ex::Expression, operators=nothing)
315315
return Expression(
316-
combine_operators(get_tree(ex), get_operators(ex, operators); kws...), ex.metadata
316+
combine_operators(get_tree(ex), get_operators(ex, operators)), ex.metadata
317317
)
318318
end
319-
function simplify_tree!(ex::Expression, operators=nothing; kws...)
319+
function simplify_tree!(ex::Expression, operators=nothing)
320320
return Expression(
321-
simplify_tree!(get_tree(ex), get_operators(ex, operators); kws...), ex.metadata
321+
simplify_tree!(get_tree(ex), get_operators(ex, operators)), ex.metadata
322322
)
323323
end
324324

src/OperatorEnumConstruction.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function set_default_variable_names!(variable_names::Vector{String})
8484
return LATEST_VARIABLE_NAMES.x = copy(variable_names)
8585
end
8686

87-
Base.@deprecate create_evaluation_helpers! set_default_operators!
87+
Base.@deprecate create_evaluation_helpers!(operators) set_default_operators!(operators)
8888

8989
function set_default_operators!(operators::OperatorEnum)
9090
LATEST_OPERATORS.x = operators
@@ -374,7 +374,6 @@ redefine operators for `AbstractExpressionNode` types, as well as `show`, `print
374374
# Deprecated:
375375
enable_autodiff=nothing,
376376
)
377-
@assert length(binary_operators) > 0 || length(unary_operators) > 0
378377
enable_autodiff !== nothing && Base.depwarn(
379378
"The option `enable_autodiff` has been deprecated. " *
380379
"Differential operators are now automatically computed within the gradient call.",

src/Parse.jl

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,27 @@ using DispatchDoctor: @unstable
44

55
using ..NodeModule: AbstractExpressionNode, Node, constructorof
66
using ..OperatorEnumModule: AbstractOperatorEnum
7-
using ..OperatorEnumConstructionModule: empty_all_globals!
7+
using ..OperatorEnumConstructionModule: OperatorEnum, empty_all_globals!
88
using ..ExpressionModule: AbstractExpression, Expression
99

1010
"""
1111
@parse_expression(expr; operators, variable_names, node_type=Node, evaluate_on=[])
1212
13-
Parse a symbolic expression `expr` into a computational graph where nodes represent operations or variables.
13+
(Experimental) Parse a symbolic expression `expr` into a computational graph where nodes represent operations or variables.
1414
1515
## Arguments
1616
1717
- `expr`: An expression to parse into an `AbstractExpression`.
1818
1919
## Keyword Arguments
2020
21-
- `operators`: An instance of `OperatorEnum` specifying the available unary and binary operators.
21+
- `operators`: An instance of `AbstractOperatorEnum` specifying the available unary and binary operators.
2222
- `variable_names`: A list of variable names as strings or symbols that are allowed in the expression.
2323
- `evaluate_on`: A list of external functions to evaluate explicitly when encountered.
2424
- `node_type`: The type of the nodes in the resulting expression tree. Defaults to `Node`.
2525
- `expression_type`: The type of the resulting expression. Defaults to `Expression`.
26+
- `binary_operators`: Convenience syntax for creating an `OperatorEnum`.
27+
- `unary_operators`: Convenience syntax for creating an `OperatorEnum`.
2628
2729
## Usage
2830
@@ -105,6 +107,8 @@ end
105107
expression_type = Expression
106108
evaluate_on = nothing
107109
extra_metadata = ()
110+
binops = nothing
111+
unaops = nothing
108112

109113
# Iterate over keyword arguments to extract operators and variable_names
110114
for kw in kws
@@ -127,6 +131,12 @@ end
127131
elseif kw == :extra_metadata
128132
extra_metadata = kw
129133
continue
134+
elseif kw == :binary_operators
135+
binops = kw
136+
continue
137+
elseif kw == :unary_operators
138+
unaops = kw
139+
continue
130140
end
131141
elseif kw isa Expr && kw.head == :(=)
132142
if kw.args[1] == :operators
@@ -147,6 +157,12 @@ end
147157
elseif kw.args[1] == :extra_metadata
148158
extra_metadata = kw.args[2]
149159
continue
160+
elseif kw.args[1] == :binary_operators
161+
binops = kw.args[2]
162+
continue
163+
elseif kw.args[1] == :unary_operators
164+
unaops = kw.args[2]
165+
continue
150166
end
151167
end
152168
throw(
@@ -156,8 +172,19 @@ end
156172
)
157173
end
158174

159-
# Ensure that operators are provided
160-
@assert operators !== nothing "The 'operators' keyword argument must be provided."
175+
if operators === nothing
176+
@assert(
177+
binops !== nothing || unaops !== nothing,
178+
"You must specify the operators using either `operators`, or `binary_operators` and `unary_operators`"
179+
)
180+
operators = :($(OperatorEnum)(;
181+
binary_operators=$(binops === nothing ? :(Function[]) : binops),
182+
unary_operators=$(unaops === nothing ? :(Function[]) : unaops),
183+
))
184+
else
185+
@assert (binops === nothing && unaops === nothing)
186+
end
187+
161188
return (;
162189
operators, variable_names, node_type, expression_type, evaluate_on, extra_metadata
163190
)

src/deprecated.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import Base: @deprecate
22
import .NodeModule: Node, GraphNode
33

4-
@deprecate set_constants set_constants!
5-
@deprecate simplify_tree simplify_tree!
4+
@deprecate set_constants(tree, constants) set_constants!(tree, constants)
5+
@deprecate simplify_tree(tree, operators) simplify_tree!(tree, operators)
66

77
for N in (:Node, :GraphNode)
88
@eval begin

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
66
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
77
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1112
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

0 commit comments

Comments
 (0)