Skip to content

Commit 5f9fd60

Browse files
authored
Merge pull request #3 from SymbolicML/generic-operators
Extend to have generic operators
2 parents fe8e6df + 2c6a4aa commit 5f9fd60

File tree

12 files changed

+375
-122
lines changed

12 files changed

+375
-122
lines changed

README.md

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,15 @@ DynamicExpressions.jl is the backbone of
1414

1515
## Summary
1616

17-
A dynamic expression is a snippet of code that can change throughout
18-
runtime - compilation is not possible!
17+
A dynamic expression is a snippet of code that can change throughout runtime - compilation is not possible! DynamicExpressions.jl does the following:
18+
1. Defines an enum over user-specified operators.
19+
2. Using this enum, it defines a [very lightweight and type-stable data structure](https://symbolicml.org/DynamicExpressions.jl/dev/types/#DynamicExpressions.EquationModule.Node) for arbitrary expressions.
20+
3. It then generates specialized [evaluation kernels](https://github.com/SymbolicML/DynamicExpressions.jl/blob/fe8e6dfa160d12485fb77c226d22776dd6ed697a/src/EvaluateEquation.jl#L29-L66) for the space of potential operators.
21+
4. It also generates kernels for the [first-order derivatives](https://github.com/SymbolicML/DynamicExpressions.jl/blob/fe8e6dfa160d12485fb77c226d22776dd6ed697a/src/EvaluateEquationDerivative.jl#L139-L175), using [Zygote.jl](https://github.com/FluxML/Zygote.jl).
22+
5. It can also operate on arbitrary other types (vectors, tensors, symbols, strings, etc.) - see last part below.
23+
24+
It also has import and export functionality with [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl), so you can move your runtime expression into a CAS!
1925

20-
DynamicExpressions.jl:
21-
1. Defines an enum over user-specified scalar operators.
22-
2. Using this enum, it defines a very lightweight
23-
and type-stable data structure for arbitrary expressions.
24-
3. It then generates specialized evaluation kernels for
25-
the space of potential operators.
26-
4. It also generates kernels for the first-order derivatives, using [Zygote.jl](https://github.com/FluxML/Zygote.jl).
2726

2827
## Example
2928

@@ -41,29 +40,27 @@ X = randn(Float64, 2, 100);
4140
expression(X) # 100-element Vector{Float64}
4241
```
4342

44-
### Speed
43+
(We can construct this expression with normal operators, since calling `OperatorEnum()` will `@eval` new functions on `Node` that use the specified enum.)
44+
45+
## Speed
4546

46-
First, what happens if we naively use Julia symbols to define
47-
and then evaluate this expression?
47+
First, what happens if we naively use Julia symbols to define and then evaluate this expression?
4848

4949
```julia
5050
@btime eval(:(X[1, :] .* cos.(X[2, :] .- 3.2)))
5151
# 117,000 ns
5252
```
5353

54-
This is quite slow, meaning it will be hard to
55-
quickly search over the space of expressions.
56-
Let's see how DynamicExpressions.jl compares:
54+
This is quite slow, meaning it will be hard to quickly search over the space of expressions. Let's see how DynamicExpressions.jl compares:
5755

5856
```julia
5957
@btime expression(X)
6058
# 693 ns
6159
```
6260

63-
Much faster!
64-
And we didn't even need to compile it.
65-
If we change `expression` dynamically with a random number generator,
66-
it will have the same performance:
61+
Much faster! And we didn't even need to compile it. (Internally, this is calling `eval_tree_array(expression, X, operators)` - where `operators` has been pre-defined when we called `OperatorEnum()`).
62+
63+
If we change `expression` dynamically with a random number generator, it will have the same performance:
6764

6865
```julia
6966
@btime begin
@@ -72,7 +69,6 @@ it will have the same performance:
7269
end
7370
# 842 ns
7471
```
75-
7672
Now, let's see the performance if we had hard-coded these expressions:
7773

7874
```julia
@@ -81,14 +77,12 @@ f(X) = X[1, :] .* cos.(X[2, :] .- 3.2)
8177
# 708 ns
8278
```
8379

84-
So, our dynamic expression evaluation is about the same (or even a bit faster)
85-
as evaluating a basic hard-coded expression!
86-
Let's see if we can optimize the hard-coded version:
80+
So, our dynamic expression evaluation is about the same (or even a bit faster) as evaluating a basic hard-coded expression! Let's see if we can optimize the speed of the hard-coded version:
8781

8882
```julia
8983
f_optimized(X) = begin
9084
y = Vector{Float64}(undef, 100)
91-
@inbounds @simd for i=1:100;
85+
@inbounds @simd for i=1:100
9286
y[i] = X[1, i] * cos(X[2, i] - 3.2)
9387
end
9488
y
@@ -97,14 +91,9 @@ end
9791
# 526 ns
9892
```
9993

100-
The `DynamicExpressions.jl` version is only 25% slower than one which
101-
has been optimized by hand into a single SIMD kernel! Not bad at all.
94+
The `DynamicExpressions.jl` version is only 25% slower than one which has been optimized by hand into a single SIMD kernel! Not bad at all.
10295

103-
More importantly: we can change `expression` throughout runtime,
104-
and expect the same performance.
105-
This makes this data structure ideal for symbolic
106-
regression and other evaluation-based searches
107-
over expression trees.
96+
More importantly: we can change `expression` throughout runtime, and expect the same performance. This makes this data structure ideal for symbolic regression and other evaluation-based searches over expression trees.
10897

10998

11099
## Derivatives
@@ -122,8 +111,7 @@ x2 = Node(; feature=2)
122111
expression = x1 * cos(x2 - 3.2)
123112
```
124113

125-
We can take the gradient with respect to inputs
126-
with simply the `'` character:
114+
We can take the gradient with respect to inputs with simply the `'` character:
127115

128116
```julia
129117
grad = expression'(X)
@@ -133,13 +121,20 @@ This is quite fast:
133121

134122
```julia
135123
@btime expression'(X)
136-
# 2.894 us
124+
# 2894 ns
137125
```
138126

139-
Internally, this is calling the `eval_grad_tree_array` function,
140-
which performs forward-mode automatic differentiation
141-
on the expression tree with Zygote-compiled kernels.
142-
We can also compute the derivative with respect to constants:
127+
and again, we can change this expression at runtime, without loss in performance!
128+
129+
```julia
130+
@btime begin
131+
expression.op = rand(1:3)
132+
expression'(X)
133+
end
134+
# 3198 ns
135+
```
136+
137+
Internally, this is calling the `eval_grad_tree_array` function, which performs forward-mode automatic differentiation on the expression tree with Zygote-compiled kernels. We can also compute the derivative with respect to constants:
143138

144139
```julia
145140
result, grad, did_finish = eval_grad_tree_array(expression, X, operators; variable=false)
@@ -151,3 +146,33 @@ or with respect to variables, and only in a single direction:
151146
feature = 2
152147
result, grad, did_finish = eval_diff_tree_array(expression, X, operators, feature)
153148
```
149+
150+
## Generic types
151+
152+
> Does this work for only scalar operators on real numbers, or will it work for `MyCrazyType`?
153+
154+
I'm so glad you asked. `DynamicExpressions.jl` actually will work for **arbitrary types**! However, to work on operators other than real scalars, you need to use the `GenericOperatorEnum` instead of the normal `OperatorEnum`. Let's try it with strings!
155+
156+
```julia
157+
x1 = Node(String; feature=1)
158+
```
159+
This node, will be used to index input data (whatever it may be) with `selectdim(data, 1, feature)`. Let's now define some operators to use:
160+
```julia
161+
my_string_func(x::String) = "Hello $x"
162+
163+
operators = GenericOperatorEnum(;
164+
binary_operators=[*],
165+
unary_operators=[my_string_func],
166+
extend_user_operators=true)
167+
```
168+
Now, let's create an expression:
169+
```julia
170+
tree = x1 * " World!"
171+
tree(["Hello", "Me?"])
172+
# Hello World!
173+
```
174+
So indeed it works for arbitrary types. It is a bit slower due to the potential for type instability, but it's not too bad:
175+
```julia
176+
@btime tree(["Hello", "Me?"]
177+
# 1738 ns
178+
```

docs/src/eval.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ It also re-defines `print`, `show`, and the various operators, to work with the
2727
Thus, if you define an expression with one `OperatorEnum`, and then try to
2828
evaluate it or print it with a different `OperatorEnum`, you will get undefined behavior!
2929

30+
You can also work with arbitrary types, by defining a `GenericOperatorEnum` instead.
31+
The notation is the same for `eval_tree_array`, though it will return `nothing`
32+
when it can't find a method, and not do any NaN checks:
33+
```@docs
34+
eval_tree_array(tree, cX::AbstractArray{T,N}, operators::GenericOperatorEnum) where {T,N}
35+
```
36+
3037
## Derivatives
3138

3239
`DynamicExpressions.jl` can efficiently compute first-order derivatives
@@ -53,7 +60,7 @@ differentiable_eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::
5360
You can also print a tree as follows:
5461

5562
```@docs
56-
string_tree(tree::Node, operators::OperatorEnum)
63+
string_tree(tree::Node, operators::AbstractOperatorEnum)
5764
```
5865

5966
When you define an `OperatorEnum`, the standard `show` and `print` methods

docs/src/types.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@ OperatorEnum
1313
Construct this operator specification as follows:
1414

1515
```@docs
16-
OperatorEnum(; binary_operators, unary_operators)
16+
OperatorEnum(; binary_operators, unary_operators, enable_autodiff)
17+
```
18+
19+
This is just for scalar real operators. However, you can use
20+
the following for more general operators:
21+
22+
```@docs
23+
GenericOperatorEnum(; binary_operators=[], unary_operators=[], extend_user_operators::Bool=false)
1724
```
1825

1926
## Equations
@@ -22,13 +29,14 @@ Equations are specified as binary trees with the `Node` type, defined
2229
as follows:
2330

2431
```@docs
25-
Node{T<:Real}
32+
Node{T}
2633
```
2734

2835
There are a variety of constructors for `Node` objects, including:
2936

3037
```@docs
31-
Node(; val::Real=nothing, feature::Integer=nothing)
38+
Node(; val=nothing, feature::Integer=nothing)
39+
Node(::Type{T}; val=nothing, feature::Integer=nothing) where {T}
3240
Node(op::Int, l::Node)
3341
Node(op::Int, l::Node, r::Node)
3442
Node(var_string::String)

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using Reexport
2121
has_constants,
2222
get_constants,
2323
set_constants
24-
@reexport import .OperatorEnumConstructionModule: OperatorEnum
24+
@reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum
2525
@reexport import .EvaluateEquationModule: eval_tree_array, differentiable_eval_tree_array
2626
@reexport import .EvaluateEquationDerivativeModule:
2727
eval_diff_tree_array, eval_grad_tree_array

0 commit comments

Comments
 (0)