Skip to content

Commit 4f3079c

Browse files
committed
Create random sampling procedure
1 parent 8109f9c commit 4f3079c

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

src/DynamicExpressions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ include("EvaluationHelpers.jl")
1010
include("SimplifyEquation.jl")
1111
include("OperatorEnumConstruction.jl")
1212
include("ExtensionInterface.jl")
13+
include("Random.jl")
1314

1415
import PackageExtensionCompat: @require_extensions
1516
import Reexport: @reexport
@@ -44,6 +45,7 @@ import .EquationModule: constructorof, preserve_sharing
4445
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree!
4546
@reexport import .EvaluationHelpersModule
4647
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
48+
@reexport import .RandomModule: NodeSampler
4749

4850
function __init__()
4951
@require_extensions

src/Random.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
module RandomModule
2+
3+
import Compat: Returns
4+
import Random: AbstractRNG
5+
import ..EquationModule: AbstractNode, tree_mapreduce, filter_map
6+
7+
"""
8+
NodeSampler(; tree, filter=Returns(true), weighting=nothing, break_sharing=Val(false))
9+
10+
Defines a sampler of nodes in a tree. `filter` can be used to pre-filter
11+
nodes on which to sample.
12+
13+
# Arguments
14+
15+
- `tree`: The tree to sample nodes from. For a regular `Node`,
16+
nodes are sampled uniformly. For a `GraphNode`, nodes are also
17+
sampled uniformly (e.g., in `sin(x) + {x}`, the `x` has equal
18+
probability of being sampled from the `sin` or the `+` node, because
19+
it is shared), unless `break_sharing` is set to `true` or `Val(true)`.
20+
- `filter::Function`: A function that takes a node and returns a boolean
21+
indicating whether the node should be sampled. Defaults to `Returns(true)`.
22+
- `weighting::Union{Nothing,Function}`: A function that takes a node and
23+
returns a weight for the node, if it passes the filter, proportional
24+
to the probability of sampling the node. If `nothing`, all nodes are
25+
sampled uniformly.
26+
- `break_sharing::Union{Bool,Val}`: If `true` or `Val(true)`, the
27+
sampler will break sharing in the tree, and sample nodes uniformly
28+
from the tree.
29+
"""
30+
Base.@kwdef struct NodeSampler{
31+
N<:AbstractNode,F<:Function,W<:Union{Nothing,Function},B<:Union{Bool,Val}
32+
}
33+
tree::N
34+
weighting::W = nothing
35+
filter::F = Returns(true)
36+
break_sharing::B = Val(false)
37+
end
38+
39+
Base.rand(rng::AbstractRNG, tree::AbstractNode) = rand(rng, NodeSampler(; tree))
40+
function Base.rand(rng::AbstractRNG, sampler::NodeSampler{N,F,Nothing}) where {N,F}
41+
break_sharing = if sampler.break_sharing isa Val
42+
sampler.break_sharing
43+
else
44+
sampler.break_sharing ? Val(true) : Val(false)
45+
end
46+
n = count(sampler.filter, sampler.tree; break_sharing)
47+
idx = rand(rng, 1:n)
48+
i = Ref(0)
49+
out = Ref(sampler.tree)
50+
foreach(sampler.tree; break_sharing) do node
51+
if @inline(sampler.filter(node)) && (i[] += 1) == idx
52+
out[] = node
53+
end
54+
nothing
55+
end
56+
return out[]
57+
end
58+
function Base.rand(rng::AbstractRNG, sampler::NodeSampler{N,F,W}) where {N,F,W<:Function}
59+
break_sharing = if sampler.break_sharing isa Val
60+
sampler.break_sharing
61+
else
62+
sampler.break_sharing ? Val(true) : Val(false)
63+
end
64+
weights = filter_map(
65+
sampler.filter, sampler.weighting, sampler.tree, Float64; break_sharing
66+
)
67+
idx = sample_idx(rng, weights)
68+
i = Ref(0)
69+
out = Ref(sampler.tree)
70+
foreach(sampler.tree; break_sharing) do node
71+
if @inline(sampler.filter(node)) && (i[] += 1) == idx
72+
out[] = node
73+
end
74+
nothing
75+
end
76+
return out[]
77+
end
78+
sample_idx(rng::AbstractRNG, weights) = findfirst(cumsum(weights) .> rand(rng))::Int
79+
80+
end

0 commit comments

Comments
 (0)