@@ -3,7 +3,7 @@ module EvaluateEquationModule
33import LoopVectorization: @turbo , indices
44import .. EquationModule: Node, string_tree
55import .. OperatorEnumModule: OperatorEnum, GenericOperatorEnum
6- import .. UtilsModule: @return_on_false , @maybe_turbo , is_bad_array, vals
6+ import .. UtilsModule: @return_on_false , @maybe_turbo , is_bad_array
77import .. EquationUtilsModule: is_constant
88
99macro return_on_check (val, T, n)
@@ -28,7 +28,7 @@ macro return_on_nonfinite_array(array, T, n)
2828end
2929
3030"""
31- eval_tree_array(tree::Node, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool)
31+ eval_tree_array(tree::Node, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false )
3232
3333Evaluate a binary tree (equation) over a given input data matrix. The
3434operators contain all of the operators used. This function fuses doublets
8888function _eval_tree_array (
8989 tree:: Node{T} , cX:: AbstractMatrix{T} , operators:: OperatorEnum , :: Val{turbo}
9090):: Tuple{AbstractVector{T},Bool} where {T<: Real ,turbo}
91+ n = size (cX, 2 )
9192 # First, we see if there are only constants in the tree - meaning
9293 # we can just return the constant result.
9394 if tree. degree == 0
@@ -98,104 +99,88 @@ function _eval_tree_array(
9899 ! flag && return Array {T,1} (undef, size (cX, 2 )), false
99100 return fill (result, size (cX, 2 )), true
100101 elseif tree. degree == 1
102+ op = operators. unaops[tree. op]
101103 if tree. l. degree == 2 && tree. l. l. degree == 0 && tree. l. r. degree == 0
102104 # op(op2(x, y)), where x, y, z are constants or variables.
103- return deg1_l2_ll0_lr0_eval (
104- tree, cX, vals[tree. op], vals[tree. l. op], operators, Val (turbo)
105- )
105+ op_l = operators. binops[tree. l. op]
106+ return deg1_l2_ll0_lr0_eval (tree, cX, op, op_l, Val (turbo))
106107 elseif tree. l. degree == 1 && tree. l. l. degree == 0
107108 # op(op2(x)), where x is a constant or variable.
108- return deg1_l1_ll0_eval (
109- tree, cX, vals[tree. op], vals[tree. l. op], operators, Val (turbo)
110- )
111- else
112- # op(x), for any x.
113- return deg1_eval (tree, cX, vals[tree. op], operators, Val (turbo))
109+ op_l = operators. unaops[tree. l. op]
110+ return deg1_l1_ll0_eval (tree, cX, op, op_l, Val (turbo))
114111 end
112+
113+ # op(x), for any x.
114+ (cumulator, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
115+ @return_on_false complete cumulator
116+ @return_on_nonfinite_array cumulator T n
117+ return deg1_eval (cumulator, op, Val (turbo))
118+
115119 elseif tree. degree == 2
120+ op = operators. binops[tree. op]
116121 # TODO - add op(op2(x, y), z) and op(x, op2(y, z))
122+ # op(x, y), where x, y are constants or variables.
117123 if tree. l. degree == 0 && tree. r. degree == 0
118- # op(x, y), where x, y are constants or variables.
119- return deg2_l0_r0_eval (tree, cX, vals[tree. op], operators, Val (turbo))
120- elseif tree. l. degree == 0
121- # op(x, y), where x is a constant or variable but y is not.
122- return deg2_l0_eval (tree, cX, vals[tree. op], operators, Val (turbo))
124+ return deg2_l0_r0_eval (tree, cX, op, Val (turbo))
123125 elseif tree. r. degree == 0
126+ (cumulator_l, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
127+ @return_on_false complete cumulator_l
128+ @return_on_nonfinite_array cumulator_l T n
124129 # op(x, y), where y is a constant or variable but x is not.
125- return deg2_r0_eval (tree, cX, vals[tree. op], operators, Val (turbo))
126- else
127- # op(x, y), for any x or y
128- return deg2_eval (tree, cX, vals[tree. op], operators, Val (turbo))
130+ return deg2_r0_eval (tree, cumulator_l, cX, op, Val (turbo))
131+ elseif tree. l. degree == 0
132+ (cumulator_r, complete) = _eval_tree_array (tree. r, cX, operators, Val (turbo))
133+ @return_on_false complete cumulator_r
134+ @return_on_nonfinite_array cumulator_r T n
135+ # op(x, y), where x is a constant or variable but y is not.
136+ return deg2_l0_eval (tree, cumulator_r, cX, op, Val (turbo))
129137 end
138+ (cumulator_l, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
139+ @return_on_false complete cumulator_l
140+ @return_on_nonfinite_array cumulator_l T n
141+ (cumulator_r, complete) = _eval_tree_array (tree. r, cX, operators, Val (turbo))
142+ @return_on_false complete cumulator_r
143+ @return_on_nonfinite_array cumulator_r T n
144+ # op(x, y), for any x or y
145+ return deg2_eval (cumulator_l, cumulator_r, op, Val (turbo))
130146 end
131147end
132148
133149function deg2_eval (
134- tree:: Node{T} ,
135- cX:: AbstractMatrix{T} ,
136- :: Val{op_idx} ,
137- operators:: OperatorEnum ,
138- :: Val{turbo} ,
139- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,turbo}
140- n = size (cX, 2 )
141- (cumulator, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
142- @return_on_false complete cumulator
143- @return_on_nonfinite_array cumulator T n
144- (array2, complete2) = _eval_tree_array (tree. r, cX, operators, Val (turbo))
145- @return_on_false complete2 cumulator
146- @return_on_nonfinite_array array2 T n
147- op = operators. binops[op_idx]
148-
149- # We check inputs (and intermediates), not outputs.
150- @maybe_turbo turbo for j in indices (cumulator)
151- x = op (cumulator[j], array2[j]):: T
152- cumulator[j] = x
150+ cumulator_l:: AbstractVector{T} , cumulator_r:: AbstractVector{T} , op:: F , :: Val{turbo}
151+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,turbo}
152+ @maybe_turbo turbo for j in indices (cumulator_l)
153+ x = op (cumulator_l[j], cumulator_r[j]):: T
154+ cumulator_l[j] = x
153155 end
154- # return (cumulator, finished_loop) #
155- return (cumulator, true )
156+ return (cumulator_l, true )
156157end
157158
158159function deg1_eval (
159- tree:: Node{T} ,
160- cX:: AbstractMatrix{T} ,
161- :: Val{op_idx} ,
162- operators:: OperatorEnum ,
163- :: Val{turbo} ,
164- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,turbo}
165- n = size (cX, 2 )
166- (cumulator, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
167- @return_on_false complete cumulator
168- @return_on_nonfinite_array cumulator T n
169- op = operators. unaops[op_idx]
160+ cumulator:: AbstractVector{T} , op:: F , :: Val{turbo}
161+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,turbo}
170162 @maybe_turbo turbo for j in indices (cumulator)
171163 x = op (cumulator[j]):: T
172164 cumulator[j] = x
173165 end
174- return (cumulator, true ) #
166+ return (cumulator, true )
175167end
176168
177169function deg0_eval (
178170 tree:: Node{T} , cX:: AbstractMatrix{T}
179171):: Tuple{AbstractVector{T},Bool} where {T<: Real }
180- n = size (cX, 2 )
181172 if tree. constant
173+ n = size (cX, 2 )
182174 return (fill (tree. val:: T , n), true )
183175 else
184176 return (cX[tree. feature, :], true )
185177 end
186178end
187179
188180function deg1_l2_ll0_lr0_eval (
189- tree:: Node{T} ,
190- cX:: AbstractMatrix{T} ,
191- :: Val{op_idx} ,
192- :: Val{op_l_idx} ,
193- operators:: OperatorEnum ,
194- :: Val{turbo} ,
195- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,op_l_idx,turbo}
181+ tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , op_l:: F2 , :: Val{turbo}
182+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,F2,turbo}
196183 n = size (cX, 2 )
197- op = operators. unaops[op_idx]
198- op_l = operators. binops[op_l_idx]
199184 if tree. l. l. constant && tree. l. r. constant
200185 val_ll = tree. l. l. val:: T
201186 val_lr = tree. l. r. val:: T
243228
244229# op(op2(x)) for x variable or constant
245230function deg1_l1_ll0_eval (
246- tree:: Node{T} ,
247- cX:: AbstractMatrix{T} ,
248- :: Val{op_idx} ,
249- :: Val{op_l_idx} ,
250- operators:: OperatorEnum ,
251- :: Val{turbo} ,
252- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,op_l_idx,turbo}
231+ tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , op_l:: F2 , :: Val{turbo}
232+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,F2,turbo}
253233 n = size (cX, 2 )
254- op = operators. unaops[op_idx]
255- op_l = operators. unaops[op_l_idx]
256234 if tree. l. l. constant
257235 val_ll = tree. l. l. val:: T
258236 @return_on_check val_ll T n
275253
276254# op(x, y) for x and y variable/constant
277255function deg2_l0_r0_eval (
278- tree:: Node{T} ,
279- cX:: AbstractMatrix{T} ,
280- :: Val{op_idx} ,
281- operators:: OperatorEnum ,
282- :: Val{turbo} ,
283- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,turbo}
256+ tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , :: Val{turbo}
257+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,turbo}
284258 n = size (cX, 2 )
285- op = operators. binops[op_idx]
286259 if tree. l. constant && tree. r. constant
287260 val_l = tree. l. val:: T
288261 @return_on_check val_l T n
323296
324297# op(x, y) for x variable/constant, y arbitrary
325298function deg2_l0_eval (
326- tree:: Node{T} ,
327- cX:: AbstractMatrix{T} ,
328- :: Val{op_idx} ,
329- operators:: OperatorEnum ,
330- :: Val{turbo} ,
331- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,turbo}
299+ tree:: Node{T} , cumulator:: AbstractVector{T} , cX:: AbstractArray{T} , op:: F , :: Val{turbo}
300+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,turbo}
332301 n = size (cX, 2 )
333- (cumulator, complete) = _eval_tree_array (tree. r, cX, operators, Val (turbo))
334- @return_on_false complete cumulator
335- @return_on_nonfinite_array cumulator T n
336- op = operators. binops[op_idx]
337302 if tree. l. constant
338303 val = tree. l. val:: T
339304 @return_on_check val T n
353318
354319# op(x, y) for x arbitrary, y variable/constant
355320function deg2_r0_eval (
356- tree:: Node{T} ,
357- cX:: AbstractMatrix{T} ,
358- :: Val{op_idx} ,
359- operators:: OperatorEnum ,
360- :: Val{turbo} ,
361- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,turbo}
321+ tree:: Node{T} , cumulator:: AbstractVector{T} , cX:: AbstractArray{T} , op:: F , :: Val{turbo}
322+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,turbo}
362323 n = size (cX, 2 )
363- (cumulator, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
364- @return_on_false complete cumulator
365- @return_on_nonfinite_array cumulator T n
366- op = operators. binops[op_idx]
367324 if tree. r. constant
368325 val = tree. r. val:: T
369326 @return_on_check val T n
@@ -394,9 +351,9 @@ function _eval_constant_tree(
394351 if tree. degree == 0
395352 return deg0_eval_constant (tree)
396353 elseif tree. degree == 1
397- return deg1_eval_constant (tree, vals [tree. op], operators)
354+ return deg1_eval_constant (tree, operators . unaops [tree. op], operators)
398355 else
399- return deg2_eval_constant (tree, vals [tree. op], operators)
356+ return deg2_eval_constant (tree, operators . binops [tree. op], operators)
400357 end
401358end
402359
@@ -405,19 +362,17 @@ end
405362end
406363
407364function deg1_eval_constant (
408- tree:: Node{T} , :: Val{op_idx} , operators:: OperatorEnum
409- ):: Tuple{T,Bool} where {T<: Real ,op_idx}
410- op = operators. unaops[op_idx]
365+ tree:: Node{T} , op:: F , operators:: OperatorEnum
366+ ):: Tuple{T,Bool} where {T<: Real ,F}
411367 (cumulator, complete) = _eval_constant_tree (tree. l, operators)
412368 ! complete && return zero (T), false
413369 output = op (cumulator):: T
414370 return output, isfinite (output)
415371end
416372
417373function deg2_eval_constant (
418- tree:: Node{T} , :: Val{op_idx} , operators:: OperatorEnum
419- ):: Tuple{T,Bool} where {T<: Real ,op_idx}
420- op = operators. binops[op_idx]
374+ tree:: Node{T} , op:: F , operators:: OperatorEnum
375+ ):: Tuple{T,Bool} where {T<: Real ,F}
421376 (cumulator, complete) = _eval_constant_tree (tree. l, operators)
422377 ! complete && return zero (T), false
423378 (cumulator2, complete2) = _eval_constant_tree (tree. r, operators)
@@ -442,31 +397,29 @@ function differentiable_eval_tree_array(
442397 return (cX[tree. feature, :], true )
443398 end
444399 elseif tree. degree == 1
445- return deg1_diff_eval (tree, cX, vals [tree. op], operators)
400+ return deg1_diff_eval (tree, cX, operators . unaops [tree. op], operators)
446401 else
447- return deg2_diff_eval (tree, cX, vals [tree. op], operators)
402+ return deg2_diff_eval (tree, cX, operators . binops [tree. op], operators)
448403 end
449404end
450405
451406function deg1_diff_eval (
452- tree:: Node{T1} , cX:: AbstractMatrix{T} , :: Val{op_idx} , operators:: OperatorEnum
453- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx ,T1}
407+ tree:: Node{T1} , cX:: AbstractMatrix{T} , op :: F , operators:: OperatorEnum
408+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F ,T1}
454409 (left, complete) = differentiable_eval_tree_array (tree. l, cX, operators)
455410 @return_on_false complete left
456- op = operators. unaops[op_idx]
457411 out = op .(left)
458412 no_nans = ! any (x -> (! isfinite (x)), out)
459413 return (out, no_nans)
460414end
461415
462416function deg2_diff_eval (
463- tree:: Node{T1} , cX:: AbstractMatrix{T} , :: Val{op_idx} , operators:: OperatorEnum
464- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx ,T1}
417+ tree:: Node{T1} , cX:: AbstractMatrix{T} , op :: F , operators:: OperatorEnum
418+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F ,T1}
465419 (left, complete) = differentiable_eval_tree_array (tree. l, cX, operators)
466420 @return_on_false complete left
467421 (right, complete2) = differentiable_eval_tree_array (tree. r, cX, operators)
468422 @return_on_false complete2 left
469- op = operators. binops[op_idx]
470423 out = op .(left, right)
471424 no_nans = ! any (x -> (! isfinite (x)), out)
472425 return (out, no_nans)
@@ -557,30 +510,32 @@ function _eval_tree_array_generic(
557510 end
558511 end
559512 elseif tree. degree == 1
560- return deg1_eval_generic (tree, cX, vals[tree. op], operators, Val (throw_errors))
513+ return deg1_eval_generic (
514+ tree, cX, operators. unaops[tree. op], operators, Val (throw_errors)
515+ )
561516 else
562- return deg2_eval_generic (tree, cX, vals[tree. op], operators, Val (throw_errors))
517+ return deg2_eval_generic (
518+ tree, cX, operators. binops[tree. op], operators, Val (throw_errors)
519+ )
563520 end
564521end
565522
566523function deg1_eval_generic (
567- tree, cX, :: Val{op_idx} , operators:: GenericOperatorEnum , :: Val{throw_errors}
568- ) where {op_idx ,throw_errors}
524+ tree, cX, op :: F , operators:: GenericOperatorEnum , :: Val{throw_errors}
525+ ) where {F ,throw_errors}
569526 left, complete = eval_tree_array (tree. l, cX, operators)
570527 ! throw_errors && ! complete && return nothing , false
571- op = operators. unaops[op_idx]
572528 ! throw_errors && ! hasmethod (op, Tuple{typeof (left)}) && return nothing , false
573529 return op (left), true
574530end
575531
576532function deg2_eval_generic (
577- tree, cX, :: Val{op_idx} , operators:: GenericOperatorEnum , :: Val{throw_errors}
578- ) where {op_idx ,throw_errors}
533+ tree, cX, op :: F , operators:: GenericOperatorEnum , :: Val{throw_errors}
534+ ) where {F ,throw_errors}
579535 left, complete = eval_tree_array (tree. l, cX, operators)
580536 ! throw_errors && ! complete && return nothing , false
581537 right, complete = eval_tree_array (tree. r, cX, operators)
582538 ! throw_errors && ! complete && return nothing , false
583- op = operators. binops[op_idx]
584539 ! throw_errors &&
585540 ! hasmethod (op, Tuple{typeof (left),typeof (right)}) &&
586541 return nothing , false
0 commit comments