Skip to content

Commit 27d2d83

Browse files
authored
Merge branch 'master' into compathelper/new_version/2024-05-29-01-06-39-039-00781124978
2 parents d57a4aa + cfd6cb8 commit 27d2d83

34 files changed

+2199
-221
lines changed

.github/workflows/CI.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ jobs:
6161
path-to-lcov: lcov.info
6262
flag-name: julia-${{ matrix.julia-version }}-${{ matrix.os }}-main-${{ github.event_name }}
6363

64-
integration_tests:
65-
name: Integration test - ${{ matrix.test_name }} - ${{ matrix.os }}
64+
additional_tests:
65+
name: test ${{ matrix.test_name }} - ${{ matrix.os }}
6666
runs-on: ${{ matrix.os }}
6767
timeout-minutes: 60
6868
strategy:
@@ -74,6 +74,7 @@ jobs:
7474
- "1"
7575
test_name:
7676
- "enzyme"
77+
- "jet"
7778
steps:
7879
- uses: actions/checkout@v2
7980
- uses: julia-actions/setup-julia@v1
@@ -100,7 +101,7 @@ jobs:
100101
runs-on: ubuntu-latest
101102
needs:
102103
- test
103-
- integration_tests
104+
- additional_tests
104105
steps:
105106
- name: Finish
106107
uses: coverallsapp/github-action@v2

Project.toml

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.17.0"
4+
version = "0.18.0-alpha.1"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
9+
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
910
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1011
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1112
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
@@ -28,11 +29,10 @@ DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
2829
DynamicExpressionsZygoteExt = "Zygote"
2930

3031
[compat]
31-
Aqua = "0.7"
3232
Bumper = "0.6"
3333
ChainRulesCore = "1"
3434
Compat = "3.37, 4"
35-
Enzyme = "^0.11.12"
35+
DispatchDoctor = "0.4"
3636
LoopVectorization = "0.12"
3737
MacroTools = "0.4, 0.5"
3838
Optim = "0.19, 1"
@@ -44,20 +44,8 @@ Zygote = "0.6"
4444
julia = "1.6"
4545

4646
[extras]
47-
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4847
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
49-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
50-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
51-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5248
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
5349
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
54-
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
55-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
56-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
57-
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
5850
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
59-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6051
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
61-
62-
[targets]
63-
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "Suppressor", "SymbolicUtils", "Zygote"]

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ using DynamicExpressions
3030

3131
operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])
3232

33-
x1 = Node(; feature=1)
34-
x2 = Node(; feature=2)
33+
x1 = Node{Float64}(feature=1)
34+
x2 = Node{Float64}(feature=2)
3535

3636
expression = x1 * cos(x2 - 3.2)
3737

src/ChainRules.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ function ChainRulesCore.rrule(
3838
dtree = let X = X, dY = dY, tree = tree, operators = operators
3939
@thunk(
4040
let
41-
_, gradient, complete = eval_grad_tree_array(
41+
_, gradient, complete2 = eval_grad_tree_array(
4242
tree, X, operators; variable=Val(false)
4343
)
44-
if !complete
44+
if !complete2
4545
gradient .= NaN
4646
end
4747

@@ -55,14 +55,14 @@ function ChainRulesCore.rrule(
5555
dX = let X = X, dY = dY, tree = tree, operators = operators
5656
@thunk(
5757
let
58-
_, gradient, complete = eval_grad_tree_array(
58+
_, gradient2, complete3 = eval_grad_tree_array(
5959
tree, X, operators; variable=Val(true)
6060
)
61-
if !complete
62-
gradient .= NaN
61+
if !complete3
62+
gradient2 .= NaN
6363
end
6464

65-
gradient .* reshape(dY, 1, length(dY))
65+
gradient2 .* reshape(dY, 1, length(dY))
6666
end
6767
)
6868
end

src/DynamicExpressions.jl

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
module DynamicExpressions
22

3-
include("Utils.jl")
4-
include("ExtensionInterface.jl")
5-
include("OperatorEnum.jl")
6-
include("Node.jl")
7-
include("NodeUtils.jl")
8-
include("Strings.jl")
9-
include("Evaluate.jl")
10-
include("EvaluateDerivative.jl")
11-
include("ChainRules.jl")
12-
include("EvaluationHelpers.jl")
13-
include("Simplify.jl")
14-
include("OperatorEnumConstruction.jl")
15-
include("Random.jl")
3+
using DispatchDoctor: @stable, @unstable
4+
5+
@stable default_mode = "disable" begin
6+
include("Utils.jl")
7+
include("ExtensionInterface.jl")
8+
include("OperatorEnum.jl")
9+
include("Node.jl")
10+
include("NodeUtils.jl")
11+
include("Strings.jl")
12+
include("Evaluate.jl")
13+
include("EvaluateDerivative.jl")
14+
include("ChainRules.jl")
15+
include("EvaluationHelpers.jl")
16+
include("Simplify.jl")
17+
include("OperatorEnumConstruction.jl")
18+
include("Random.jl")
19+
include("Expression.jl")
20+
include("Parse.jl")
21+
end
1622

1723
import PackageExtensionCompat: @require_extensions
1824
import Reexport: @reexport
25+
macro ignore(args...) end
26+
1927
@reexport import .NodeModule:
2028
AbstractNode,
2129
AbstractExpressionNode,
@@ -26,7 +34,16 @@ import Reexport: @reexport
2634
tree_mapreduce,
2735
filter_map,
2836
filter_map!
29-
import .NodeModule: constructorof, preserve_sharing
37+
import .NodeModule:
38+
constructorof,
39+
with_type_parameters,
40+
preserve_sharing,
41+
leaf_copy,
42+
branch_copy,
43+
leaf_hash,
44+
branch_hash,
45+
leaf_equal,
46+
branch_equal
3047
@reexport import .NodeUtilsModule:
3148
count_nodes,
3249
count_constants,
@@ -48,6 +65,11 @@ import .NodeModule: constructorof, preserve_sharing
4865
@reexport import .EvaluationHelpersModule
4966
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
5067
@reexport import .RandomModule: NodeSampler
68+
@reexport import .ExpressionModule: AbstractExpression, Expression
69+
# Not for export; just for overloading
70+
import .ExpressionModule: get_tree, get_operators, get_variable_names, Metadata
71+
@reexport import .ParseModule: @parse_expression, parse_expression
72+
import .ParseModule: parse_leaf
5173

5274
function __init__()
5375
@require_extensions
@@ -57,16 +79,22 @@ include("deprecated.jl")
5779

5880
import TOML: parsefile
5981

60-
const PACKAGE_VERSION = let
61-
project = parsefile(joinpath(pkgdir(@__MODULE__), "Project.toml"))
62-
VersionNumber(project["version"])
82+
const PACKAGE_VERSION = let d = pkgdir(@__MODULE__)
83+
try
84+
if d isa String
85+
project = parsefile(joinpath(d, "Project.toml"))
86+
VersionNumber(project["version"])
87+
else
88+
v"0.0.0"
89+
end
90+
catch
91+
v"0.0.0"
92+
end
6393
end
6494

65-
macro ignore(args...) end
6695
# To get LanguageServer to register library within tests
6796
@ignore include("../test/runtests.jl")
6897

6998
include("precompile.jl")
7099
do_precompilation(; mode=:precompile)
71-
72100
end

src/Evaluate.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module EvaluateModule
22

3+
using DispatchDoctor: @unstable
4+
35
import ..NodeModule: AbstractExpressionNode, constructorof
46
import ..StringsModule: string_tree
57
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
@@ -657,7 +659,7 @@ function eval(current_node)
657659
A `false` complete means an operator was called on input types
658660
that it was not defined for.
659661
"""
660-
function eval_tree_array(
662+
@unstable function eval_tree_array(
661663
tree::AbstractExpressionNode,
662664
cX::AbstractArray,
663665
operators::GenericOperatorEnum;
@@ -680,7 +682,7 @@ function eval_tree_array(
680682
end
681683
end
682684

683-
function _eval_tree_array_generic(
685+
@unstable function _eval_tree_array_generic(
684686
tree::AbstractExpressionNode{T1},
685687
cX::AbstractArray{T2,N},
686688
operators::GenericOperatorEnum,
@@ -707,7 +709,7 @@ function _eval_tree_array_generic(
707709
end
708710
end
709711

710-
function deg1_eval_generic(
712+
@unstable function deg1_eval_generic(
711713
tree, cX, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
712714
) where {F,throw_errors}
713715
left, complete = eval_tree_array(tree.l, cX, operators)
@@ -716,7 +718,7 @@ function deg1_eval_generic(
716718
return op(left), true
717719
end
718720

719-
function deg2_eval_generic(
721+
@unstable function deg2_eval_generic(
720722
tree, cX, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
721723
) where {F,throw_errors}
722724
left, complete = eval_tree_array(tree.l, cX, operators)

src/EvaluateDerivative.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ end
8080
end
8181
end
8282
deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
83-
diff_deg2_eval(tree, cX, operators.binops[op_idx], operators, direction)
83+
quote
84+
diff_deg2_eval(tree, cX, operators.binops[op_idx], operators, direction)
85+
end
8486
else
8587
quote
8688
Base.Cartesian.@nif(

0 commit comments

Comments
 (0)