Skip to content

Commit 1103d14

Browse files
committed
Refactor constructors
1 parent 542413d commit 1103d14

File tree

3 files changed

+87
-86
lines changed

3 files changed

+87
-86
lines changed

src/Equation.jl

Lines changed: 79 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module EquationModule
22

33
import ..OperatorEnumModule: AbstractOperatorEnum
4-
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap
4+
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined
55

66
const DEFAULT_NODE_TYPE = Float32
77

@@ -76,6 +76,42 @@ nodes, you can evaluate or print a given expression.
7676
- `r::Node{T}`: Right child of the node. Only defined if `degree == 2`.
7777
Same type as the parent node. This is to be passed as the right
7878
argument to the binary operator.
79+
80+
# Constructors
81+
82+
## Leafs
83+
84+
Node(; val=nothing, feature::Union{Integer,Nothing}=nothing)
85+
Node{T}(; val=nothing, feature::Union{Integer,Nothing}=nothing) where {T}
86+
87+
Create a leaf node: either a constant, or a variable.
88+
89+
- `::Type{T}`, optionally specify the type of the
90+
node, if not already given by the type of
91+
`val`.
92+
- `val`, if you are specifying a constant, pass
93+
the value of the constant here.
94+
- `feature::Integer`, if you are specifying a variable,
95+
pass the index of the variable here.
96+
97+
You can also create a leaf node from variable names:
98+
99+
Node(; var_string::String, variable_names::Array{String,1})
100+
Node{T}(; var_string::String, variable_names::Array{String,1}) where {T}
101+
102+
## Unary operator
103+
104+
Node(op::Integer, l::Node)
105+
106+
Apply unary operator `op` (enumerating over the order given in `OperatorEnum`)
107+
to `Node` `l`.
108+
109+
## Binary operator
110+
111+
Node(op::Integer, l::Node, r::Node)
112+
113+
Apply binary operator `op` (enumerating over the order given in `OperatorEnum`)
114+
to `Node`s `l` and `r`.
79115
"""
80116
mutable struct Node{T} <: AbstractExpressionNode{T}
81117
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
@@ -104,7 +140,8 @@ end
104140
Exactly the same as `Node{T}`, but with the assumption that some
105141
nodes will be shared. All copies of this graph-like structure will
106142
be performed with this assumption, to preserve structure of the graph.
107-
For example:
143+
144+
# Examples
108145
109146
```julia
110147
julia> operators = OperatorEnum(;
@@ -158,73 +195,29 @@ preserve_sharing(::Type{<:GraphNode}) = true
158195

159196
include("base.jl")
160197

161-
"""
162-
Node([::Type{T}]; val=nothing, feature::Union{Integer,Nothing}=nothing) where {T}
163-
164-
Create a leaf node: either a constant, or a variable.
165-
166-
# Arguments:
167-
168-
- `::Type{T}`, optionally specify the type of the
169-
node, if not already given by the type of
170-
`val`.
171-
- `val`, if you are specifying a constant, pass
172-
the value of the constant here.
173-
- `feature::Integer`, if you are specifying a variable,
174-
pass the index of the variable here.
175-
"""
176-
function (::Type{N})(;
177-
val::T1=nothing, feature::T2=nothing
178-
) where {T1,T2<:Union{Integer,Nothing},N<:AbstractExpressionNode}
179-
if T1 <: Nothing && T2 <: Nothing
180-
error("You must specify either `val` or `feature` when creating a leaf node.")
181-
elseif !(T1 <: Nothing || T2 <: Nothing)
182-
error(
183-
"You must specify either `val` or `feature` when creating a leaf node, not both.",
184-
)
185-
elseif T2 <: Nothing
186-
return constructorof(N)(0, true, val)
187-
else
188-
return constructorof(N)(DEFAULT_NODE_TYPE, 0, false, nothing, feature)
189-
end
190-
end
191198
function (::Type{N})(
192-
::Type{T}; val::T1=nothing, feature::T2=nothing
199+
::Type{T}=Undefined; val::T1=nothing, feature::T2=nothing
193200
) where {T,T1,T2<:Union{Integer,Nothing},N<:AbstractExpressionNode}
194-
if T1 <: Nothing && T2 <: Nothing
195-
error("You must specify either `val` or `feature` when creating a leaf node.")
196-
elseif !(T1 <: Nothing || T2 <: Nothing)
197-
error(
198-
"You must specify either `val` or `feature` when creating a leaf node, not both.",
199-
)
200-
elseif T2 <: Nothing
201+
((T1 <: Nothing) (T2 <: Nothing)) || error(
202+
"You must specify exactly one of `val` or `feature` when creating a leaf node."
203+
)
204+
Tout = compute_value_output_type(N, T, T1)
205+
if T2 <: Nothing
201206
if !(T1 <: T)
202207
# Only convert if not already in the type union.
203-
val = convert(T, val)
208+
val = convert(Tout, val)
204209
end
205-
return constructorof(N)(T, 0, true, val)
210+
return constructorof(N)(Tout, 0, true, val)
206211
else
207-
return constructorof(N)(T, 0, false, nothing, feature)
212+
return constructorof(N)(Tout, 0, false, nothing, feature)
208213
end
209214
end
210-
211-
"""
212-
Node(op::Integer, l::Node)
213-
214-
Apply unary operator `op` (enumerating over the order given) to `Node` `l`
215-
"""
216215
function (::Type{N})(
217216
op::Integer, l::AbstractExpressionNode{T}
218217
) where {T,N<:AbstractExpressionNode}
219218
@assert l isa N
220219
return constructorof(N)(1, false, nothing, 0, op, l)
221220
end
222-
223-
"""
224-
Node(op::Integer, l::Node, r::Node)
225-
226-
Apply binary operator `op` (enumerating over the order given) to `Node`s `l` and `r`
227-
"""
228221
function (::Type{N})(
229222
op::Integer, l::AbstractExpressionNode{T1}, r::AbstractExpressionNode{T2}
230223
) where {T1,T2,N<:AbstractExpressionNode}
@@ -238,31 +231,40 @@ function (::Type{N})(
238231
end
239232
return constructorof(N)(2, false, nothing, 0, op, l, r)
240233
end
241-
242-
"""
243-
Node(var_string::String)
244-
245-
Create a variable node, using the format `"x1"` to mean feature 1
246-
"""
247234
function (::Type{N})(var_string::String) where {N<:AbstractExpressionNode}
248-
return constructorof(N)(; feature=parse(UInt16, var_string[2:end]))
235+
Base.depwarn(
236+
"Creating a node using a string is deprecated and will be removed in a future version.",
237+
:string_tree,
238+
)
239+
return N(; feature=parse(UInt16, var_string[2:end]))
249240
end
250-
251-
# TODO: Include helpful check if in the wrong format!
252-
253-
"""
254-
Node(var_string::String, variable_names::Array{String, 1})
255-
256-
Create a variable node, using a user-passed format
257-
"""
258241
function (::Type{N})(
259242
var_string::String, variable_names::Array{String,1}
260243
) where {N<:AbstractExpressionNode}
261-
return constructorof(N)(;
262-
feature=[
263-
i for (i, _variable) in enumerate(variable_names) if _variable == var_string
264-
][1]::Int,
265-
)
244+
i = findfirst(==(var_string), variable_names)::Int
245+
return N(; feature=i)
246+
end
247+
248+
@inline function compute_value_output_type(
249+
::Type{N}, ::Type{T}, ::Type{T1}
250+
) where {N<:AbstractExpressionNode,T,T1}
251+
!(N isa UnionAll) &&
252+
T !== Undefined &&
253+
error(
254+
"Ambiguous type for node. Please either use `Node{T}(; val, feature)` or `Node(T; val, feature)`.",
255+
)
256+
257+
if T === Undefined && N isa UnionAll
258+
if T1 <: Nothing
259+
return DEFAULT_NODE_TYPE
260+
else
261+
return T1
262+
end
263+
elseif T === Undefined
264+
return eltype(N)
265+
else
266+
return T
267+
end
266268
end
267269

268270
function Base.promote_rule(::Type{Node{T1}}, ::Type{Node{T2}}) where {T1,T2}
@@ -277,10 +279,8 @@ end
277279
Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T
278280
Base.eltype(::AbstractExpressionNode{T}) where {T} = T
279281

280-
function create_dummy_node(::Type{N}) where {T,N<:AbstractExpressionNode{T}}
281-
# TODO: Verify using this helps with garbage collection
282-
return constructorof(N)(T; feature=zero(UInt16))
283-
end
282+
# TODO: Verify using this helps with garbage collection
283+
create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N(; feature=zero(UInt16))
284284

285285
"""
286286
set_node!(tree::AbstractExpressionNode{T}, new_tree::AbstractExpressionNode{T}) where {T}

src/Utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,11 @@ function deprecate_varmap(variable_names, varMap, func_name)
192192
return variable_names
193193
end
194194

195+
"""
196+
Undefined
197+
198+
Just a type like `Nothing` to differentiate from a literal `Nothing`.
199+
"""
200+
struct Undefined end
201+
195202
end

src/base.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,8 @@ import Base:
2323
reduce,
2424
sum
2525
import Compat: @inline, Returns
26-
import ..UtilsModule: @memoize_on, @with_memoize
26+
import ..UtilsModule: @memoize_on, @with_memoize, Undefined
2727

28-
"""
29-
Undefined
30-
31-
Just a type like `Nothing` to differentiate from a literal `Nothing`.
32-
"""
33-
struct Undefined end
3428

3529
"""
3630
tree_mapreduce(

0 commit comments

Comments
 (0)