Skip to content

Commit cfd6cb8

Browse files
authored
Merge pull request #79 from SymbolicML/dev
Add DispatchDoctor.jl and fix various type instabilities
2 parents a55f966 + 7eda0fe commit cfd6cb8

22 files changed

+171
-132
lines changed

.github/workflows/CI.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ jobs:
6161
path-to-lcov: lcov.info
6262
flag-name: julia-${{ matrix.julia-version }}-${{ matrix.os }}-main-${{ github.event_name }}
6363

64-
integration_tests:
65-
name: Integration test - ${{ matrix.test_name }} - ${{ matrix.os }}
64+
additional_tests:
65+
name: test ${{ matrix.test_name }} - ${{ matrix.os }}
6666
runs-on: ${{ matrix.os }}
6767
timeout-minutes: 60
6868
strategy:
@@ -74,6 +74,7 @@ jobs:
7474
- "1"
7575
test_name:
7676
- "enzyme"
77+
- "jet"
7778
steps:
7879
- uses: actions/checkout@v2
7980
- uses: julia-actions/setup-julia@v1
@@ -100,7 +101,7 @@ jobs:
100101
runs-on: ubuntu-latest
101102
needs:
102103
- test
103-
- integration_tests
104+
- additional_tests
104105
steps:
105106
- name: Finish
106107
uses: coverallsapp/github-action@v2

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.18.0-alpha"
4+
version = "0.18.0-alpha.1"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
9+
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
910
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1011
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1112
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
@@ -31,6 +32,7 @@ DynamicExpressionsZygoteExt = "Zygote"
3132
Bumper = "0.6"
3233
ChainRulesCore = "1"
3334
Compat = "3.37, 4"
35+
DispatchDoctor = "0.4"
3436
LoopVectorization = "0.12"
3537
MacroTools = "0.4, 0.5"
3638
Optim = "0.19, 1"

src/ChainRules.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ function ChainRulesCore.rrule(
3838
dtree = let X = X, dY = dY, tree = tree, operators = operators
3939
@thunk(
4040
let
41-
_, gradient, complete = eval_grad_tree_array(
41+
_, gradient, complete2 = eval_grad_tree_array(
4242
tree, X, operators; variable=Val(false)
4343
)
44-
if !complete
44+
if !complete2
4545
gradient .= NaN
4646
end
4747

@@ -55,14 +55,14 @@ function ChainRulesCore.rrule(
5555
dX = let X = X, dY = dY, tree = tree, operators = operators
5656
@thunk(
5757
let
58-
_, gradient, complete = eval_grad_tree_array(
58+
_, gradient2, complete3 = eval_grad_tree_array(
5959
tree, X, operators; variable=Val(true)
6060
)
61-
if !complete
62-
gradient .= NaN
61+
if !complete3
62+
gradient2 .= NaN
6363
end
6464

65-
gradient .* reshape(dY, 1, length(dY))
65+
gradient2 .* reshape(dY, 1, length(dY))
6666
end
6767
)
6868
end

src/DynamicExpressions.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
11
module DynamicExpressions
22

3-
include("Utils.jl")
4-
include("ExtensionInterface.jl")
5-
include("OperatorEnum.jl")
6-
include("Node.jl")
7-
include("NodeUtils.jl")
8-
include("Strings.jl")
9-
include("Evaluate.jl")
10-
include("EvaluateDerivative.jl")
11-
include("ChainRules.jl")
12-
include("EvaluationHelpers.jl")
13-
include("Simplify.jl")
14-
include("OperatorEnumConstruction.jl")
15-
include("Random.jl")
16-
include("Expression.jl")
17-
include("Parse.jl")
3+
using DispatchDoctor: @stable, @unstable
4+
5+
@stable default_mode = "disable" begin
6+
include("Utils.jl")
7+
include("ExtensionInterface.jl")
8+
include("OperatorEnum.jl")
9+
include("Node.jl")
10+
include("NodeUtils.jl")
11+
include("Strings.jl")
12+
include("Evaluate.jl")
13+
include("EvaluateDerivative.jl")
14+
include("ChainRules.jl")
15+
include("EvaluationHelpers.jl")
16+
include("Simplify.jl")
17+
include("OperatorEnumConstruction.jl")
18+
include("Random.jl")
19+
include("Expression.jl")
20+
include("Parse.jl")
21+
end
1822

1923
import PackageExtensionCompat: @require_extensions
2024
import Reexport: @reexport
25+
macro ignore(args...) end
26+
2127
@reexport import .NodeModule:
2228
AbstractNode,
2329
AbstractExpressionNode,
@@ -86,11 +92,9 @@ const PACKAGE_VERSION = let d = pkgdir(@__MODULE__)
8692
end
8793
end
8894

89-
macro ignore(args...) end
9095
# To get LanguageServer to register library within tests
9196
@ignore include("../test/runtests.jl")
9297

9398
include("precompile.jl")
9499
do_precompilation(; mode=:precompile)
95-
96100
end

src/Evaluate.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module EvaluateModule
22

3+
using DispatchDoctor: @unstable
4+
35
import ..NodeModule: AbstractExpressionNode, constructorof
46
import ..StringsModule: string_tree
57
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
@@ -657,7 +659,7 @@ function eval(current_node)
657659
A `false` complete means an operator was called on input types
658660
that it was not defined for.
659661
"""
660-
function eval_tree_array(
662+
@unstable function eval_tree_array(
661663
tree::AbstractExpressionNode,
662664
cX::AbstractArray,
663665
operators::GenericOperatorEnum;
@@ -680,7 +682,7 @@ function eval_tree_array(
680682
end
681683
end
682684

683-
function _eval_tree_array_generic(
685+
@unstable function _eval_tree_array_generic(
684686
tree::AbstractExpressionNode{T1},
685687
cX::AbstractArray{T2,N},
686688
operators::GenericOperatorEnum,
@@ -707,7 +709,7 @@ function _eval_tree_array_generic(
707709
end
708710
end
709711

710-
function deg1_eval_generic(
712+
@unstable function deg1_eval_generic(
711713
tree, cX, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
712714
) where {F,throw_errors}
713715
left, complete = eval_tree_array(tree.l, cX, operators)
@@ -716,7 +718,7 @@ function deg1_eval_generic(
716718
return op(left), true
717719
end
718720

719-
function deg2_eval_generic(
721+
@unstable function deg2_eval_generic(
720722
tree, cX, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
721723
) where {F,throw_errors}
722724
left, complete = eval_tree_array(tree.l, cX, operators)

src/Expression.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""This module defines a user-facing `Expression` type"""
22
module ExpressionModule
33

4+
using DispatchDoctor: @unstable
45
using ..NodeModule: AbstractExpressionNode
56
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
67
using ..UtilsModule: Undefined
@@ -12,7 +13,7 @@ end
1213
_data(x::Metadata) = getfield(x, :_data)
1314

1415
Base.propertynames(x::Metadata) = propertynames(_data(x))
15-
@inline Base.getproperty(x::Metadata, f::Symbol) = getproperty(_data(x), f)
16+
@unstable @inline Base.getproperty(x::Metadata, f::Symbol) = getproperty(_data(x), f)
1617
Base.show(io::IO, x::Metadata) = print(io, "Metadata(", _data(x), ")")
1718
@inline _copy(x) = copy(x)
1819
@inline _copy(x::Nothing) = nothing

src/OperatorEnumConstruction.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module OperatorEnumConstructionModule
22

3+
using DispatchDoctor: @unstable
4+
35
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
46
import ..NodeModule: Node, GraphNode, AbstractExpressionNode, constructorof
57
import ..StringsModule: string_tree
@@ -43,7 +45,7 @@ function Base.show(io::IO, tree::AbstractExpressionNode)
4345
return print(io, string_tree(tree, latest_operators; kwargs...))
4446
end
4547
end
46-
function (tree::AbstractExpressionNode)(X; kws...)
48+
@unstable function (tree::AbstractExpressionNode)(X; kws...)
4749
Base.depwarn(
4850
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
4951
:AbstractExpressionNode,
@@ -62,7 +64,7 @@ function (tree::AbstractExpressionNode)(X; kws...)
6264
end
6365
end
6466

65-
function _grad_evaluator(tree::AbstractExpressionNode, X; kws...)
67+
@unstable function _grad_evaluator(tree::AbstractExpressionNode, X; kws...)
6668
Base.depwarn(
6769
"The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.",
6870
:AbstractExpressionNode,
@@ -93,7 +95,7 @@ function set_default_operators!(operators::GenericOperatorEnum)
9395
return LATEST_OPERATORS_TYPE.x = IsGenericOperatorEnum
9496
end
9597

96-
function lookup_op(@nospecialize(f), ::Val{degree}) where {degree}
98+
@unstable function lookup_op(@nospecialize(f), ::Val{degree}) where {degree}
9799
mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
98100
if !haskey(mapping, f)
99101
error(
@@ -364,7 +366,7 @@ redefine operators for `AbstractExpressionNode` types, as well as `show`, `print
364366
are *not* needed for the package to work; they are purely for convenience.
365367
- `empty_old_operators::Bool=true`: Whether to clear the old operators.
366368
"""
367-
function OperatorEnum(;
369+
@unstable function OperatorEnum(;
368370
binary_operators=Function[],
369371
unary_operators=Function[],
370372
define_helper_functions::Bool=true,
@@ -417,7 +419,7 @@ and `(::AbstractExpressionNode)(X)`.
417419
are *not* needed for the package to work; they are purely for convenience.
418420
- `empty_old_operators::Bool=true`: Whether to clear the old operators.
419421
"""
420-
function GenericOperatorEnum(;
422+
@unstable function GenericOperatorEnum(;
421423
binary_operators=Function[],
422424
unary_operators=Function[],
423425
define_helper_functions::Bool=true,

src/Parse.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module ParseModule
22

3+
using DispatchDoctor: @unstable
4+
35
using ..NodeModule: AbstractExpressionNode, Node, constructorof
46
using ..OperatorEnumModule: AbstractOperatorEnum
57
using ..OperatorEnumConstructionModule: OperatorEnum, empty_all_globals!
@@ -97,7 +99,7 @@ macro parse_expression(ex, kws...)
9799
)
98100
end
99101

100-
function _parse_kws(kws)
102+
@unstable function _parse_kws(kws)
101103
# Initialize default values for operators and variable_names
102104
operators = nothing
103105
variable_names = nothing
@@ -189,7 +191,7 @@ function _parse_kws(kws)
189191
end
190192

191193
"""Parse an expression Julia `Expr` object."""
192-
function parse_expression(
194+
@unstable function parse_expression(
193195
ex;
194196
operators::AbstractOperatorEnum,
195197
variable_names::Union{AbstractVector,Nothing}=nothing,
@@ -215,7 +217,7 @@ end
215217
"""An empty module for evaluation without collisions."""
216218
module EmptyModule end
217219

218-
function _parse_expression(
220+
@unstable function _parse_expression(
219221
ex::Expr,
220222
operators::AbstractOperatorEnum,
221223
variable_names::Union{AbstractVector{<:AbstractString},Nothing},
@@ -242,7 +244,7 @@ function _parse_expression(
242244
func, args, operators, variable_names, N, E, evaluate_on; kws...
243245
)
244246
end
245-
function _parse_expression(
247+
@unstable function _parse_expression(
246248
func::F,
247249
args,
248250
operators::AbstractOperatorEnum,
@@ -331,7 +333,7 @@ function _parse_expression(
331333
)
332334
end
333335
end
334-
function _parse_expression(
336+
@unstable function _parse_expression(
335337
ex,
336338
operators::AbstractOperatorEnum,
337339
variable_names::Union{AbstractVector{<:AbstractString},Nothing},
@@ -343,7 +345,7 @@ function _parse_expression(
343345
return parse_leaf(ex, variable_names, node_type, expression_type; kws...)
344346
end
345347

346-
function parse_leaf(
348+
@unstable function parse_leaf(
347349
ex,
348350
variable_names,
349351
node_type::Type{<:AbstractExpressionNode},

src/Simplify.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,9 @@ end
122122

123123
# Simplify tree
124124
function simplify_tree!(tree::AbstractExpressionNode, operators::AbstractOperatorEnum)
125-
tree = tree_mapreduce(
126-
identity,
127-
(p, c...) -> combine_children!(operators, p, c...),
128-
tree,
129-
constructorof(typeof(tree));
125+
return tree_mapreduce(
126+
identity, (p, c...) -> combine_children!(operators, p, c...), tree, typeof(tree);
130127
)
131-
return tree
132128
end
133129

134130
end

src/Strings.jl

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@ const OP_NAMES = Base.ImmutableDict(
1414
"safe_pow" => "^",
1515
)
1616

17-
function dispatch_op_name(::Val{2}, ::Nothing, idx)::Vector{Char}
18-
return vcat(collect("binary_operator["), collect(string(idx)), [']'])
19-
end
20-
function dispatch_op_name(::Val{1}, ::Nothing, idx)::Vector{Char}
21-
return vcat(collect("unary_operator["), collect(string(idx)), [']'])
22-
end
23-
function dispatch_op_name(::Val{2}, operators::AbstractOperatorEnum, idx)::Vector{Char}
24-
return get_op_name(operators.binops[idx])
17+
function dispatch_op_name(::Val{deg}, ::Nothing, idx)::Vector{Char} where {deg}
18+
if deg == 1
19+
return vcat(collect("unary_operator["), collect(string(idx)), [']'])
20+
else
21+
return vcat(collect("binary_operator["), collect(string(idx)), [']'])
22+
end
2523
end
26-
function dispatch_op_name(::Val{1}, operators::AbstractOperatorEnum, idx)::Vector{Char}
27-
return get_op_name(operators.unaops[idx])
24+
function dispatch_op_name(::Val{deg}, operators::AbstractOperatorEnum, idx) where {deg}
25+
if deg == 1
26+
return get_op_name(operators.unaops[idx])::Vector{Char}
27+
else
28+
return get_op_name(operators.binops[idx])::Vector{Char}
29+
end
2830
end
2931

3032
@generated function get_op_name(op::F)::Vector{Char} where {F}
@@ -137,15 +139,22 @@ function string_tree(
137139
)::String where {T,F1<:Function,F2<:Function}
138140
variable_names = deprecate_varmap(variable_names, varMap, :string_tree)
139141
raw_output = tree_mapreduce(
140-
leaf -> if leaf.constant
141-
collect(f_constant(leaf.val))
142-
else
143-
collect(f_variable(leaf.feature, variable_names))
142+
let f_constant = f_constant,
143+
f_variable = f_variable,
144+
variable_names = variable_names
145+
146+
(leaf,) -> if leaf.constant
147+
collect(f_constant(leaf.val))::Vector{Char}
148+
else
149+
collect(f_variable(leaf.feature, variable_names))::Vector{Char}
150+
end
144151
end,
145-
branch -> if branch.degree == 1
146-
dispatch_op_name(Val(1), operators, branch.op)
147-
else
148-
dispatch_op_name(Val(2), operators, branch.op)
152+
let operators = operators
153+
(branch,) -> if branch.degree == 1
154+
dispatch_op_name(Val(1), operators, branch.op)::Vector{Char}
155+
else
156+
dispatch_op_name(Val(2), operators, branch.op)::Vector{Char}
157+
end
149158
end,
150159
combine_op_with_inputs,
151160
tree,

0 commit comments

Comments
 (0)