11module SpecialOperatorsModule
22
33using .. OperatorEnumModule: OperatorEnum
4- using .. EvaluateModule: _eval_tree_array, @return_on_nonfinite_array , deg2_eval
4+ using .. EvaluateModule:
5+ _eval_tree_array, @return_on_nonfinite_array , deg2_eval, ResultOk, get_filled_array
56using .. ExpressionModule: AbstractExpression
67using .. ExpressionAlgebraModule: @declare_expression_operator
78
2425@inline special_operator (:: Type{AssignOperator} ) = true
2526get_op_name (o:: AssignOperator ) = " [{FEATURE_" * string (o. target_register) * " } =]"
2627
27- # Base.@kwdef struct WhileOperator <: Function
28- # max_iters::Int = 100
29- # end
30- # @inline special_operator(::Type{WhileOperator}) = true
31- # function deg2_eval_special(tree, cX, op::WhileOperator, eval_options)
32- # cond = tree.l
33- # body = tree.r
34- # for _ in 1:(op.max_iters)
35- # let cond_result = _eval_tree_array(cond, cX, operators, eval_options)
36- # !cond_result.ok && return cond_result
37- # @return_on_nonfinite_array(eval_options, cond_result.x)
38- # end
39- # let body_result = _eval_tree_array(body, cX, operators, eval_options)
40- # !body_result.ok && return body_result
41- # @return_on_nonfinite_array(eval_options, body_result.x)
42- # # TODO : Need to somehow mask instances
43- # end
44- # end
45-
46- # return get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2))
47- # end
48- # TODO : Need to void any instance of buffer when using while loop.
49-
5028function deg1_eval_special (tree, cX, op:: AssignOperator , eval_options, operators)
5129 result = _eval_tree_array (tree. l, cX, operators, eval_options)
5230 ! result. ok && return result
@@ -58,4 +36,47 @@ function deg1_eval_special(tree, cX, op::AssignOperator, eval_options, operators
5836 return result
5937end
6038
39+ Base. @kwdef struct WhileOperator <: Function
40+ max_iters:: Int = 100
41+ end
42+
43+ @declare_expression_operator ((op:: WhileOperator ), 2 )
44+ @inline special_operator (:: Type{WhileOperator} ) = true
45+ get_op_name (o:: WhileOperator ) = " while"
46+
47+ # TODO : Need to void any instance of buffer when using while loop.
48+ function deg2_eval_special (tree, cX, op:: WhileOperator , eval_options, operators)
49+ cond = tree. l
50+ body = tree. r
51+ mask = trues (size (cX, 2 ))
52+ X = @view cX[:, mask]
53+ # Initialize the result array for all columns
54+ result_array = get_filled_array (eval_options. buffer, zero (eltype (cX)), cX, axes (cX, 2 ))
55+ body_result = ResultOk (result_array, true )
56+
57+ for _ in 1 : (op. max_iters)
58+ cond_result = _eval_tree_array (cond, X, operators, eval_options)
59+ ! cond_result. ok && return cond_result
60+ @return_on_nonfinite_array (eval_options, cond_result. x)
61+
62+ new_mask = cond_result. x .> 0.0
63+ any (new_mask) || return body_result
64+
65+ # Track which columns are still active
66+ mask[mask] .= new_mask
67+ X = @view cX[:, mask]
68+
69+ # Evaluate just for active columns
70+ iter_result = _eval_tree_array (body, X, operators, eval_options)
71+ ! iter_result. ok && return iter_result
72+
73+ # Update the corresponding elements in the result array
74+ body_result. x[mask] .= iter_result. x
75+ @return_on_nonfinite_array (eval_options, body_result. x)
76+ end
77+
78+ # We passed max_iters, so this result is invalid
79+ return ResultOk (body_result. x, false )
80+ end
81+
6182end
0 commit comments