11module EquationModule
22
33import .. OperatorEnumModule: AbstractOperatorEnum
4- import .. UtilsModule: @memoize_on , @with_memoize , deprecate_varmap
4+ import .. UtilsModule: @memoize_on , @with_memoize , deprecate_varmap, Undefined
55
66const DEFAULT_NODE_TYPE = Float32
77
@@ -76,6 +76,42 @@ nodes, you can evaluate or print a given expression.
7676- `r::Node{T}`: Right child of the node. Only defined if `degree == 2`.
7777 Same type as the parent node. This is to be passed as the right
7878 argument to the binary operator.
79+
80+ # Constructors
81+
82+ ## Leafs
83+
84+ Node(; val=nothing, feature::Union{Integer,Nothing}=nothing)
85+ Node{T}(; val=nothing, feature::Union{Integer,Nothing}=nothing) where {T}
86+
87+ Create a leaf node: either a constant, or a variable.
88+
89+ - `::Type{T}`, optionally specify the type of the
90+ node, if not already given by the type of
91+ `val`.
92+ - `val`, if you are specifying a constant, pass
93+ the value of the constant here.
94+ - `feature::Integer`, if you are specifying a variable,
95+ pass the index of the variable here.
96+
97+ You can also create a leaf node from variable names:
98+
99+ Node(; var_string::String, variable_names::Array{String,1})
100+ Node{T}(; var_string::String, variable_names::Array{String,1}) where {T}
101+
102+ ## Unary operator
103+
104+ Node(op::Integer, l::Node)
105+
106+ Apply unary operator `op` (enumerating over the order given in `OperatorEnum`)
107+ to `Node` `l`.
108+
109+ ## Binary operator
110+
111+ Node(op::Integer, l::Node, r::Node)
112+
113+ Apply binary operator `op` (enumerating over the order given in `OperatorEnum`)
114+ to `Node`s `l` and `r`.
79115"""
80116mutable struct Node{T} <: AbstractExpressionNode{T}
81117 degree:: UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
104140Exactly the same as `Node{T}`, but with the assumption that some
105141nodes will be shared. All copies of this graph-like structure will
106142be performed with this assumption, to preserve structure of the graph.
107- For example:
143+
144+ # Examples
108145
109146```julia
110147julia> operators = OperatorEnum(;
@@ -158,73 +195,29 @@ preserve_sharing(::Type{<:GraphNode}) = true
158195
159196include (" base.jl" )
160197
161- """
162- Node([::Type{T}]; val=nothing, feature::Union{Integer,Nothing}=nothing) where {T}
163-
164- Create a leaf node: either a constant, or a variable.
165-
166- # Arguments:
167-
168- - `::Type{T}`, optionally specify the type of the
169- node, if not already given by the type of
170- `val`.
171- - `val`, if you are specifying a constant, pass
172- the value of the constant here.
173- - `feature::Integer`, if you are specifying a variable,
174- pass the index of the variable here.
175- """
176- function (:: Type{N} )(;
177- val:: T1 = nothing , feature:: T2 = nothing
178- ) where {T1,T2<: Union{Integer,Nothing} ,N<: AbstractExpressionNode }
179- if T1 <: Nothing && T2 <: Nothing
180- error (" You must specify either `val` or `feature` when creating a leaf node." )
181- elseif ! (T1 <: Nothing || T2 <: Nothing )
182- error (
183- " You must specify either `val` or `feature` when creating a leaf node, not both." ,
184- )
185- elseif T2 <: Nothing
186- return constructorof (N)(0 , true , val)
187- else
188- return constructorof (N)(DEFAULT_NODE_TYPE, 0 , false , nothing , feature)
189- end
190- end
191198function (:: Type{N} )(
192- :: Type{T} ; val:: T1 = nothing , feature:: T2 = nothing
199+ :: Type{T} = Undefined ; val:: T1 = nothing , feature:: T2 = nothing
193200) where {T,T1,T2<: Union{Integer,Nothing} ,N<: AbstractExpressionNode }
194- if T1 <: Nothing && T2 <: Nothing
195- error (" You must specify either `val` or `feature` when creating a leaf node." )
196- elseif ! (T1 <: Nothing || T2 <: Nothing )
197- error (
198- " You must specify either `val` or `feature` when creating a leaf node, not both." ,
199- )
200- elseif T2 <: Nothing
201+ ((T1 <: Nothing ) ⊻ (T2 <: Nothing )) || error (
202+ " You must specify exactly one of `val` or `feature` when creating a leaf node."
203+ )
204+ Tout = compute_value_output_type (N, T, T1)
205+ if T2 <: Nothing
201206 if ! (T1 <: T )
202207 # Only convert if not already in the type union.
203- val = convert (T , val)
208+ val = convert (Tout , val)
204209 end
205- return constructorof (N)(T , 0 , true , val)
210+ return constructorof (N)(Tout , 0 , true , val)
206211 else
207- return constructorof (N)(T , 0 , false , nothing , feature)
212+ return constructorof (N)(Tout , 0 , false , nothing , feature)
208213 end
209214end
210-
211- """
212- Node(op::Integer, l::Node)
213-
214- Apply unary operator `op` (enumerating over the order given) to `Node` `l`
215- """
216215function (:: Type{N} )(
217216 op:: Integer , l:: AbstractExpressionNode{T}
218217) where {T,N<: AbstractExpressionNode }
219218 @assert l isa N
220219 return constructorof (N)(1 , false , nothing , 0 , op, l)
221220end
222-
223- """
224- Node(op::Integer, l::Node, r::Node)
225-
226- Apply binary operator `op` (enumerating over the order given) to `Node`s `l` and `r`
227- """
228221function (:: Type{N} )(
229222 op:: Integer , l:: AbstractExpressionNode{T1} , r:: AbstractExpressionNode{T2}
230223) where {T1,T2,N<: AbstractExpressionNode }
@@ -238,31 +231,40 @@ function (::Type{N})(
238231 end
239232 return constructorof (N)(2 , false , nothing , 0 , op, l, r)
240233end
241-
242- """
243- Node(var_string::String)
244-
245- Create a variable node, using the format `"x1"` to mean feature 1
246- """
247234function (:: Type{N} )(var_string:: String ) where {N<: AbstractExpressionNode }
248- return constructorof (N)(; feature= parse (UInt16, var_string[2 : end ]))
235+ Base. depwarn (
236+ " Creating a node using a string is deprecated and will be removed in a future version." ,
237+ :string_tree ,
238+ )
239+ return N (; feature= parse (UInt16, var_string[2 : end ]))
249240end
250-
251- # TODO : Include helpful check if in the wrong format!
252-
253- """
254- Node(var_string::String, variable_names::Array{String, 1})
255-
256- Create a variable node, using a user-passed format
257- """
258241function (:: Type{N} )(
259242 var_string:: String , variable_names:: Array{String,1}
260243) where {N<: AbstractExpressionNode }
261- return constructorof (N)(;
262- feature= [
263- i for (i, _variable) in enumerate (variable_names) if _variable == var_string
264- ][1 ]:: Int ,
265- )
244+ i = findfirst (== (var_string), variable_names):: Int
245+ return N (; feature= i)
246+ end
247+
248+ @inline function compute_value_output_type (
249+ :: Type{N} , :: Type{T} , :: Type{T1}
250+ ) where {N<: AbstractExpressionNode ,T,T1}
251+ ! (N isa UnionAll) &&
252+ T != = Undefined &&
253+ error (
254+ " Ambiguous type for node. Please either use `Node{T}(; val, feature)` or `Node(T; val, feature)`." ,
255+ )
256+
257+ if T === Undefined && N isa UnionAll
258+ if T1 <: Nothing
259+ return DEFAULT_NODE_TYPE
260+ else
261+ return T1
262+ end
263+ elseif T === Undefined
264+ return eltype (N)
265+ else
266+ return T
267+ end
266268end
267269
268270function Base. promote_rule (:: Type{Node{T1}} , :: Type{Node{T2}} ) where {T1,T2}
277279Base. eltype (:: Type{<:AbstractExpressionNode{T}} ) where {T} = T
278280Base. eltype (:: AbstractExpressionNode{T} ) where {T} = T
279281
280- function create_dummy_node (:: Type{N} ) where {T,N<: AbstractExpressionNode{T} }
281- # TODO : Verify using this helps with garbage collection
282- return constructorof (N)(T; feature= zero (UInt16))
283- end
282+ # TODO : Verify using this helps with garbage collection
283+ create_dummy_node (:: Type{N} ) where {N<: AbstractExpressionNode } = N (; feature= zero (UInt16))
284284
285285"""
286286 set_node!(tree::AbstractExpressionNode{T}, new_tree::AbstractExpressionNode{T}) where {T}
0 commit comments