@@ -3,7 +3,7 @@ module EvaluateEquationModule
33import LoopVectorization: @turbo , indices
44import .. EquationModule: Node, string_tree
55import .. OperatorEnumModule: OperatorEnum, GenericOperatorEnum
6- import .. UtilsModule: @return_on_false , @maybe_turbo , is_bad_array, vals
6+ import .. UtilsModule: @return_on_false , @maybe_turbo , is_bad_array
77import .. EquationUtilsModule: is_constant
88
99macro return_on_check (val, T, n)
@@ -98,53 +98,48 @@ function _eval_tree_array(
9898 ! flag && return Array {T,1} (undef, size (cX, 2 )), false
9999 return fill (result, size (cX, 2 )), true
100100 elseif tree. degree == 1
101+ op = operators. unaops[tree. op]
101102 if tree. l. degree == 2 && tree. l. l. degree == 0 && tree. l. r. degree == 0
102103 # op(op2(x, y)), where x, y, z are constants or variables.
103- return deg1_l2_ll0_lr0_eval (
104- tree, cX, vals[tree. op], vals[tree. l. op], operators, Val (turbo)
105- )
104+ op_l = operators. binops[tree. l. op]
105+ return deg1_l2_ll0_lr0_eval (tree, cX, op, op_l, Val (turbo))
106106 elseif tree. l. degree == 1 && tree. l. l. degree == 0
107107 # op(op2(x)), where x is a constant or variable.
108- return deg1_l1_ll0_eval (
109- tree, cX, vals[tree. op], vals[tree. l. op], operators, Val (turbo)
110- )
108+ op_l = operators. unaops[tree. l. op]
109+ return deg1_l1_ll0_eval (tree, cX, op, op_l, Val (turbo))
111110 else
112111 # op(x), for any x.
113- return deg1_eval (tree, cX, vals[tree . op] , operators, Val (turbo))
112+ return deg1_eval (tree, cX, op , operators, Val (turbo))
114113 end
115114 elseif tree. degree == 2
116115 # TODO - add op(op2(x, y), z) and op(x, op2(y, z))
116+ op = operators. binops[tree. op]
117117 if tree. l. degree == 0 && tree. r. degree == 0
118118 # op(x, y), where x, y are constants or variables.
119- return deg2_l0_r0_eval (tree, cX, vals[tree . op], operators , Val (turbo))
119+ return deg2_l0_r0_eval (tree, cX, op , Val (turbo))
120120 elseif tree. l. degree == 0
121121 # op(x, y), where x is a constant or variable but y is not.
122- return deg2_l0_eval (tree, cX, vals[tree . op] , operators, Val (turbo))
122+ return deg2_l0_eval (tree, cX, op , operators, Val (turbo))
123123 elseif tree. r. degree == 0
124124 # op(x, y), where y is a constant or variable but x is not.
125- return deg2_r0_eval (tree, cX, vals[tree . op] , operators, Val (turbo))
125+ return deg2_r0_eval (tree, cX, op , operators, Val (turbo))
126126 else
127127 # op(x, y), for any x or y
128- return deg2_eval (tree, cX, vals[tree . op] , operators, Val (turbo))
128+ return deg2_eval (tree, cX, op , operators, Val (turbo))
129129 end
130130 end
131131end
132132
133133function deg2_eval (
134- tree:: Node{T} ,
135- cX:: AbstractMatrix{T} ,
136- :: Val{op_idx} ,
137- operators:: OperatorEnum ,
138- :: Val{turbo} ,
139- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,turbo}
134+ tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , operators:: OperatorEnum , :: Val{turbo}
135+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,turbo}
140136 n = size (cX, 2 )
141137 (cumulator, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
142138 @return_on_false complete cumulator
143139 @return_on_nonfinite_array cumulator T n
144140 (array2, complete2) = _eval_tree_array (tree. r, cX, operators, Val (turbo))
145141 @return_on_false complete2 cumulator
146142 @return_on_nonfinite_array array2 T n
147- op = operators. binops[op_idx]
148143
149144 # We check inputs (and intermediates), not outputs.
150145 @maybe_turbo turbo for j in indices (cumulator)
@@ -156,22 +151,17 @@ function deg2_eval(
156151end
157152
158153function deg1_eval (
159- tree:: Node{T} ,
160- cX:: AbstractMatrix{T} ,
161- :: Val{op_idx} ,
162- operators:: OperatorEnum ,
163- :: Val{turbo} ,
164- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,turbo}
154+ tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , operators:: OperatorEnum , :: Val{turbo}
155+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,turbo}
165156 n = size (cX, 2 )
166157 (cumulator, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
167158 @return_on_false complete cumulator
168159 @return_on_nonfinite_array cumulator T n
169- op = operators. unaops[op_idx]
170160 @maybe_turbo turbo for j in indices (cumulator)
171161 x = op (cumulator[j]):: T
172162 cumulator[j] = x
173163 end
174- return (cumulator, true ) #
164+ return (cumulator, true )
175165end
176166
177167function deg0_eval (
@@ -188,14 +178,11 @@ end
188178function deg1_l2_ll0_lr0_eval (
189179 tree:: Node{T} ,
190180 cX:: AbstractMatrix{T} ,
191- :: Val{op_idx} ,
192- :: Val{op_l_idx} ,
193- operators:: OperatorEnum ,
181+ op:: F ,
182+ op_l:: F2 ,
194183 :: Val{turbo} ,
195- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,op_l_idx ,turbo}
184+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,F2 ,turbo}
196185 n = size (cX, 2 )
197- op = operators. unaops[op_idx]
198- op_l = operators. binops[op_l_idx]
199186 if tree. l. l. constant && tree. l. r. constant
200187 val_ll = tree. l. l. val:: T
201188 val_lr = tree. l. r. val:: T
@@ -245,14 +232,11 @@ end
245232function deg1_l1_ll0_eval (
246233 tree:: Node{T} ,
247234 cX:: AbstractMatrix{T} ,
248- :: Val{op_idx} ,
249- :: Val{op_l_idx} ,
250- operators:: OperatorEnum ,
235+ op:: F ,
236+ op_l:: F2 ,
251237 :: Val{turbo} ,
252- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,op_l_idx ,turbo}
238+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,F2 ,turbo}
253239 n = size (cX, 2 )
254- op = operators. unaops[op_idx]
255- op_l = operators. unaops[op_l_idx]
256240 if tree. l. l. constant
257241 val_ll = tree. l. l. val:: T
258242 @return_on_check val_ll T n
275259
276260# op(x, y) for x and y variable/constant
277261function deg2_l0_r0_eval (
278- tree:: Node{T} ,
279- cX:: AbstractMatrix{T} ,
280- :: Val{op_idx} ,
281- operators:: OperatorEnum ,
282- :: Val{turbo} ,
283- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,turbo}
262+ tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , :: Val{turbo}
263+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,turbo}
284264 n = size (cX, 2 )
285- op = operators. binops[op_idx]
286265 if tree. l. constant && tree. r. constant
287266 val_l = tree. l. val:: T
288267 @return_on_check val_l T n
@@ -323,17 +302,12 @@ end
323302
324303# op(x, y) for x variable/constant, y arbitrary
325304function deg2_l0_eval (
326- tree:: Node{T} ,
327- cX:: AbstractMatrix{T} ,
328- :: Val{op_idx} ,
329- operators:: OperatorEnum ,
330- :: Val{turbo} ,
331- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,turbo}
305+ tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , operators:: OperatorEnum , :: Val{turbo}
306+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,turbo}
332307 n = size (cX, 2 )
333308 (cumulator, complete) = _eval_tree_array (tree. r, cX, operators, Val (turbo))
334309 @return_on_false complete cumulator
335310 @return_on_nonfinite_array cumulator T n
336- op = operators. binops[op_idx]
337311 if tree. l. constant
338312 val = tree. l. val:: T
339313 @return_on_check val T n
@@ -353,17 +327,12 @@ end
353327
354328# op(x, y) for x arbitrary, y variable/constant
355329function deg2_r0_eval (
356- tree:: Node{T} ,
357- cX:: AbstractMatrix{T} ,
358- :: Val{op_idx} ,
359- operators:: OperatorEnum ,
360- :: Val{turbo} ,
361- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx,turbo}
330+ tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , operators:: OperatorEnum , :: Val{turbo}
331+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F,turbo}
362332 n = size (cX, 2 )
363333 (cumulator, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
364334 @return_on_false complete cumulator
365335 @return_on_nonfinite_array cumulator T n
366- op = operators. binops[op_idx]
367336 if tree. r. constant
368337 val = tree. r. val:: T
369338 @return_on_check val T n
@@ -394,9 +363,9 @@ function _eval_constant_tree(
394363 if tree. degree == 0
395364 return deg0_eval_constant (tree)
396365 elseif tree. degree == 1
397- return deg1_eval_constant (tree, vals [tree. op], operators)
366+ return deg1_eval_constant (tree, operators . unaops [tree. op], operators)
398367 else
399- return deg2_eval_constant (tree, vals [tree. op], operators)
368+ return deg2_eval_constant (tree, operators . binops [tree. op], operators)
400369 end
401370end
402371
@@ -405,19 +374,17 @@ end
405374end
406375
407376function deg1_eval_constant (
408- tree:: Node{T} , :: Val{op_idx} , operators:: OperatorEnum
409- ):: Tuple{T,Bool} where {T<: Real ,op_idx}
410- op = operators. unaops[op_idx]
377+ tree:: Node{T} , op:: F , operators:: OperatorEnum
378+ ):: Tuple{T,Bool} where {T<: Real ,F}
411379 (cumulator, complete) = _eval_constant_tree (tree. l, operators)
412380 ! complete && return zero (T), false
413381 output = op (cumulator):: T
414382 return output, isfinite (output)
415383end
416384
417385function deg2_eval_constant (
418- tree:: Node{T} , :: Val{op_idx} , operators:: OperatorEnum
419- ):: Tuple{T,Bool} where {T<: Real ,op_idx}
420- op = operators. binops[op_idx]
386+ tree:: Node{T} , op:: F , operators:: OperatorEnum
387+ ):: Tuple{T,Bool} where {T<: Real ,F}
421388 (cumulator, complete) = _eval_constant_tree (tree. l, operators)
422389 ! complete && return zero (T), false
423390 (cumulator2, complete2) = _eval_constant_tree (tree. r, operators)
@@ -442,31 +409,29 @@ function differentiable_eval_tree_array(
442409 return (cX[tree. feature, :], true )
443410 end
444411 elseif tree. degree == 1
445- return deg1_diff_eval (tree, cX, vals [tree. op], operators)
412+ return deg1_diff_eval (tree, cX, operators . unaops [tree. op], operators)
446413 else
447- return deg2_diff_eval (tree, cX, vals [tree. op], operators)
414+ return deg2_diff_eval (tree, cX, operators . binops [tree. op], operators)
448415 end
449416end
450417
451418function deg1_diff_eval (
452- tree:: Node{T1} , cX:: AbstractMatrix{T} , :: Val{op_idx} , operators:: OperatorEnum
453- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx ,T1}
419+ tree:: Node{T1} , cX:: AbstractMatrix{T} , op :: F , operators:: OperatorEnum
420+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F ,T1}
454421 (left, complete) = differentiable_eval_tree_array (tree. l, cX, operators)
455422 @return_on_false complete left
456- op = operators. unaops[op_idx]
457423 out = op .(left)
458424 no_nans = ! any (x -> (! isfinite (x)), out)
459425 return (out, no_nans)
460426end
461427
462428function deg2_diff_eval (
463- tree:: Node{T1} , cX:: AbstractMatrix{T} , :: Val{op_idx} , operators:: OperatorEnum
464- ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,op_idx ,T1}
429+ tree:: Node{T1} , cX:: AbstractMatrix{T} , op :: F , operators:: OperatorEnum
430+ ):: Tuple{AbstractVector{T},Bool} where {T<: Real ,F ,T1}
465431 (left, complete) = differentiable_eval_tree_array (tree. l, cX, operators)
466432 @return_on_false complete left
467433 (right, complete2) = differentiable_eval_tree_array (tree. r, cX, operators)
468434 @return_on_false complete2 left
469- op = operators. binops[op_idx]
470435 out = op .(left, right)
471436 no_nans = ! any (x -> (! isfinite (x)), out)
472437 return (out, no_nans)
@@ -557,30 +522,28 @@ function _eval_tree_array_generic(
557522 end
558523 end
559524 elseif tree. degree == 1
560- return deg1_eval_generic (tree, cX, vals [tree. op], operators, Val (throw_errors))
525+ return deg1_eval_generic (tree, cX, operators . unaops [tree. op], operators, Val (throw_errors))
561526 else
562- return deg2_eval_generic (tree, cX, vals [tree. op], operators, Val (throw_errors))
527+ return deg2_eval_generic (tree, cX, operators . binops [tree. op], operators, Val (throw_errors))
563528 end
564529end
565530
566531function deg1_eval_generic (
567- tree, cX, :: Val{op_idx} , operators:: GenericOperatorEnum , :: Val{throw_errors}
568- ) where {op_idx ,throw_errors}
532+ tree, cX, op :: F , operators:: GenericOperatorEnum , :: Val{throw_errors}
533+ ) where {F ,throw_errors}
569534 left, complete = eval_tree_array (tree. l, cX, operators)
570535 ! throw_errors && ! complete && return nothing , false
571- op = operators. unaops[op_idx]
572536 ! throw_errors && ! hasmethod (op, Tuple{typeof (left)}) && return nothing , false
573537 return op (left), true
574538end
575539
576540function deg2_eval_generic (
577- tree, cX, :: Val{op_idx} , operators:: GenericOperatorEnum , :: Val{throw_errors}
578- ) where {op_idx ,throw_errors}
541+ tree, cX, op :: F , operators:: GenericOperatorEnum , :: Val{throw_errors}
542+ ) where {F ,throw_errors}
579543 left, complete = eval_tree_array (tree. l, cX, operators)
580544 ! throw_errors && ! complete && return nothing , false
581545 right, complete = eval_tree_array (tree. r, cX, operators)
582546 ! throw_errors && ! complete && return nothing , false
583- op = operators. binops[op_idx]
584547 ! throw_errors &&
585548 ! hasmethod (op, Tuple{typeof (left),typeof (right)}) &&
586549 return nothing , false
0 commit comments