Skip to content

Commit e072f1d

Browse files
committed
Switch to Int16 for Node fields
1 parent 5836ba2 commit e072f1d

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

src/Equation.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import ..OperatorEnumModule: AbstractOperatorEnum
44
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap
55

66
const DEFAULT_NODE_TYPE = Float32
7+
const FIELD_TYPE = Int16
78

9+
#! format: off
810
"""
911
Node{T}
1012
@@ -36,31 +38,27 @@ nodes, you can evaluate or print a given expression.
3638
argument to the binary operator.
3739
"""
3840
mutable struct Node{T}
39-
degree::Int # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
41+
degree::FIELD_TYPE # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
4042
constant::Bool # false if variable
4143
val::Union{T,Nothing} # If is a constant, this stores the actual value
4244
# ------------------- (possibly undefined below)
43-
feature::Int # If is a variable (e.g., x in cos(x)), this stores the feature index.
44-
op::Int # If operator, this is the index of the operator in operators.binary_operators, or operators.unary_operators
45+
feature::FIELD_TYPE # If is a variable (e.g., x in cos(x)), this stores the feature index.
46+
op::FIELD_TYPE # If operator, this is the index of the operator in operators.binary_operators, or operators.unary_operators
4547
l::Node{T} # Left child node. Only defined for degree=1 or degree=2.
4648
r::Node{T} # Right child node. Only defined for degree=2.
4749

4850
#################
4951
## Constructors:
5052
#################
51-
Node(d::Int, c::Bool, v::_T) where {_T} = new{_T}(d, c, v)
52-
Node(::Type{_T}, d::Int, c::Bool, v::_T) where {_T} = new{_T}(d, c, v)
53-
Node(::Type{_T}, d::Int, c::Bool, v::Nothing, f::Int) where {_T} = new{_T}(d, c, v, f)
54-
function Node(d::Int, c::Bool, v::Nothing, f::Int, o::Int, l::Node{_T}) where {_T}
55-
return new{_T}(d, c, v, f, o, l)
56-
end
57-
function Node(
58-
d::Int, c::Bool, v::Nothing, f::Int, o::Int, l::Node{_T}, r::Node{_T}
59-
) where {_T}
60-
return new{_T}(d, c, v, f, o, l, r)
61-
end
53+
Node(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(FIELD_TYPE(d), c, v)
54+
Node(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(FIELD_TYPE(d), c, v)
55+
Node(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(FIELD_TYPE(d), c, v, FIELD_TYPE(f))
56+
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}) where {_T} = new{_T}(FIELD_TYPE(d), c, v, FIELD_TYPE(f), FIELD_TYPE(o), l)
57+
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}, r::Node{_T}) where {_T} = new{_T}(FIELD_TYPE(d), c, v, FIELD_TYPE(f), FIELD_TYPE(o), l, r)
58+
6259
end
6360
################################################################################
61+
#! format: on
6462

6563
include("base.jl")
6664

@@ -119,14 +117,14 @@ end
119117
120118
Apply unary operator `op` (enumerating over the order given) to `Node` `l`
121119
"""
122-
Node(op::Int, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l)
120+
Node(op::Integer, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l)
123121

124122
"""
125123
Node(op::Int, l::Node, r::Node)
126124
127125
Apply binary operator `op` (enumerating over the order given) to `Node`s `l` and `r`
128126
"""
129-
function Node(op::Int, l::Node{T1}, r::Node{T2}) where {T1,T2}
127+
function Node(op::Integer, l::Node{T1}, r::Node{T2}) where {T1,T2}
130128
# Get highest type:
131129
if T1 != T2
132130
T = promote_type(T1, T2)
@@ -150,9 +148,11 @@ Create a variable node, using a user-passed format
150148
"""
151149
function Node(var_string::String, variable_names::Array{String,1})
152150
return Node(;
153-
feature=[
154-
i for (i, _variable) in enumerate(variable_names) if _variable == var_string
155-
][1]::Int,
151+
feature=FIELD_TYPE(
152+
[
153+
i for (i, _variable) in enumerate(variable_names) if _variable == var_string
154+
][1]::Int,
155+
),
156156
)
157157
end
158158

src/EquationUtils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ has_constants(tree::Node) = any(is_node_constant, tree)
4646
4747
Check if a tree has any operators.
4848
"""
49-
has_operators(tree::Node) = tree.degree !== 0
49+
has_operators(tree::Node) = tree.degree != 0
5050

5151
"""
5252
is_constant(tree::Node)::Bool
5353
5454
Check if an expression is a constant numerical value, or
5555
whether it depends on input features.
5656
"""
57-
is_constant(tree::Node) = all(t -> t.degree !== 0 || t.constant, tree)
57+
is_constant(tree::Node) = all(t -> t.degree != 0 || t.constant, tree)
5858

5959
"""
6060
get_constants(tree::Node{T})::Vector{T} where {T}

0 commit comments

Comments
 (0)