11using DynamicExpressions, BenchmarkTools, Random
22using DynamicExpressions. EquationUtilsModule: is_constant
3+ using Zygote
4+ if PACKAGE_VERSION < v " 0.14.0"
5+ @eval using DynamicExpressions: Node as GraphNode
6+ else
7+ @eval using DynamicExpressions: GraphNode
8+ end
39
410include (" benchmark_utils.jl" )
511
6571
6672# These macros make the benchmarks work on older versions:
6773# ! format: off
68- @generated function _convert (:: Type{N} , t; preserve_sharing) where {N<: Node }
74+ @generated function _convert (:: Type{N} , t; preserve_sharing) where {N}
6975 PACKAGE_VERSION < v " 0.7.0" && return :(convert (N, t))
70- return :(convert (N, t; preserve_sharing= preserve_sharing))
76+ PACKAGE_VERSION < v " 0.14.0" && return :(convert (N, t; preserve_sharing= preserve_sharing))
77+ return :(convert (N, t)) # Assume type used to infer sharing
7178end
7279@generated function _copy_node (t; preserve_sharing)
7380 PACKAGE_VERSION < v " 0.7.0" && return :(copy_node (t; preserve_topology= preserve_sharing))
74- return :(copy_node (t; preserve_sharing= preserve_sharing))
81+ PACKAGE_VERSION < v " 0.14.0" && return :(copy_node (t; preserve_sharing= preserve_sharing))
82+ return :(copy_node (t)) # Assume type used to infer sharing
7583end
7684@generated function get_set_constants! (tree)
7785 ! (@isdefined set_constants!) && return :(set_constants (tree, get_constants (tree)))
@@ -98,14 +106,42 @@ function benchmark_utilities()
98106 :is_constant ,
99107 :get_set_constants! ,
100108 :index_constants ,
109+ :string_tree ,
110+ :hash ,
101111 )
112+ has_both_modes = [:copy , :convert ]
113+ if PACKAGE_VERSION >= v " 0.14.0"
114+ append! (
115+ has_both_modes,
116+ [
117+ :simplify_tree ,
118+ :count_nodes ,
119+ :count_constants ,
120+ :get_set_constants! ,
121+ :index_constants ,
122+ :string_tree ,
123+ ],
124+ )
125+ end
126+ if PACKAGE_VERSION >= v " 0.14.1"
127+ append! (has_both_modes, [:hash ])
128+ end
102129
103130 operators = OperatorEnum (; binary_operators= [+ , - , / , * ], unary_operators= [cos, exp])
104-
105131 for func_k in all_funcs
106132 suite[func_k] = let s = BenchmarkGroup ()
107- for k in (:break_sharing , :preserve_sharing )
108- k == :preserve_sharing && ! (func_k in (:copy , :convert )) && continue
133+ for k in (
134+ if func_k in has_both_modes
135+ [:break_sharing , :preserve_sharing ]
136+ else
137+ [:break_sharing ]
138+ end
139+ )
140+ preprocess = if k == :preserve_sharing && PACKAGE_VERSION >= v " 0.14.0"
141+ tree -> GraphNode (tree)
142+ else
143+ identity
144+ end
109145
110146 f = if func_k == :copy
111147 tree -> _copy_node (tree; preserve_sharing= (k == :preserve_sharing ))
@@ -115,7 +151,7 @@ function benchmark_utilities()
115151 tree;
116152 preserve_sharing= (k == :preserve_sharing ),
117153 )
118- elseif func_k in (:simplify_tree , :combine_operators )
154+ elseif func_k in (:simplify_tree , :combine_operators , :string_tree )
119155 g = getfield (@__MODULE__ , func_k)
120156 tree -> f_tree_op (g, tree, operators)
121157 else
@@ -130,7 +166,7 @@ function benchmark_utilities()
130166 setup= (
131167 ntrees= 100 ;
132168 n= 20 ;
133- trees= [gen_random_tree_fixed_size (n, $ operators, 5 , Float32) for _ in 1 : ntrees]
169+ trees= [$ preprocess ( gen_random_tree_fixed_size (n, $ operators, 5 , Float32) ) for _ in 1 : ntrees]
134170 )
135171 )
136172 # ! format: on
0 commit comments