Skip to content

Commit 3414f6a

Browse files
committed
Document random sampling
1 parent 0e2a990 commit 3414f6a

File tree

3 files changed

+33
-20
lines changed

3 files changed

+33
-20
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Documenter
22
using DynamicExpressions
3+
using Random: AbstractRNG
34

45
makedocs(;
56
sitename="DynamicExpressions.jl",

docs/src/utils.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ convert(::Type{<:AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) whe
2020
hash(tree::AbstractExpressionNode{T}, h::UInt; break_sharing::Val=Val(false)) where {T}
2121
```
2222

23+
## Sampling
24+
25+
There are also methods for random sampling of nodes:
26+
27+
```@docs
28+
NodeSampler
29+
rand(rng::AbstractRNG, tree::AbstractNode; break_sharing::Val=Val(false))
30+
rand(rng::AbstractRNG, sampler::NodeSampler{N,F,Nothing}) where {N,F}
31+
```
32+
2333
## Internal utilities
2434

2535
Almost all node utilities are crafted using the `tree_mapreduce` function,

src/Random.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ module RandomModule
22

33
import Compat: Returns
44
import Random: AbstractRNG
5+
import Base: rand
56
import ..EquationModule: AbstractNode, tree_mapreduce, filter_map
67

78
"""
8-
NodeSampler(; tree, filter=Returns(true), weighting=nothing, break_sharing=Val(false))
9+
NodeSampler(; tree, filter::Function=Returns(true), weighting::Union{Nothing,Function}=nothing, break_sharing::Val=Val(false))
910
1011
Defines a sampler of nodes in a tree. `filter` can be used to pre-filter
1112
nodes on which to sample.
@@ -23,51 +24,52 @@ nodes on which to sample.
2324
returns a weight for the node, if it passes the filter, proportional
2425
to the probability of sampling the node. If `nothing`, all nodes are
2526
sampled uniformly.
26-
- `break_sharing::Union{Bool,Val}`: If `true` or `Val(true)`, the
27+
- `break_sharing::Val`: If `Val(true)`, the
2728
sampler will break sharing in the tree, and sample nodes uniformly
2829
from the tree.
2930
"""
3031
Base.@kwdef struct NodeSampler{
31-
N<:AbstractNode,F<:Function,W<:Union{Nothing,Function},B<:Union{Bool,Val}
32+
N<:AbstractNode,F<:Function,W<:Union{Nothing,Function},B<:Val
3233
}
3334
tree::N
3435
weighting::W = nothing
3536
filter::F = Returns(true)
3637
break_sharing::B = Val(false)
3738
end
3839

39-
Base.rand(rng::AbstractRNG, tree::AbstractNode) = rand(rng, NodeSampler(; tree))
40-
function Base.rand(rng::AbstractRNG, sampler::NodeSampler{N,F,Nothing}) where {N,F}
41-
break_sharing = if sampler.break_sharing isa Val
42-
sampler.break_sharing
43-
else
44-
sampler.break_sharing ? Val(true) : Val(false)
45-
end
46-
n = count(sampler.filter, sampler.tree; break_sharing)
40+
"""
41+
rand(rng::AbstractRNG, tree::AbstractNode; break_sharing::Val=Val(false))
42+
43+
Sample a node from a tree according to the default sampler `NodeSampler(; tree, break_sharing)`.
44+
"""
45+
rand(rng::AbstractRNG, tree::AbstractNode; break_sharing::Val=Val(false)) = rand(rng, NodeSampler(; tree, break_sharing))
46+
47+
"""
48+
rand(rng::AbstractRNG, sampler::NodeSampler)
49+
50+
Sample a node from a tree according to the sampler `sampler`.
51+
"""
52+
function rand(rng::AbstractRNG, sampler::NodeSampler{N,F,Nothing}) where {N,F}
53+
n = count(sampler.filter, sampler.tree; sampler.break_sharing)
4754
idx = rand(rng, 1:n)
4855
i = Ref(0)
4956
out = Ref(sampler.tree)
50-
foreach(sampler.tree; break_sharing) do node
57+
foreach(sampler.tree; sampler.break_sharing) do node
5158
if @inline(sampler.filter(node)) && (i[] += 1) == idx
5259
out[] = node
5360
end
5461
nothing
5562
end
5663
return out[]
5764
end
58-
function Base.rand(rng::AbstractRNG, sampler::NodeSampler{N,F,W}) where {N,F,W<:Function}
59-
break_sharing = if sampler.break_sharing isa Val
60-
sampler.break_sharing
61-
else
62-
sampler.break_sharing ? Val(true) : Val(false)
63-
end
65+
function rand(rng::AbstractRNG, sampler::NodeSampler{N,F,W}) where {N,F,W<:Function}
6466
weights = filter_map(
65-
sampler.filter, sampler.weighting, sampler.tree, Float64; break_sharing
67+
sampler.filter, sampler.weighting, sampler.tree, Float64; sampler.break_sharing
6668
)
6769
idx = sample_idx(rng, weights)
6870
i = Ref(0)
6971
out = Ref(sampler.tree)
70-
foreach(sampler.tree; break_sharing) do node
72+
foreach(sampler.tree; sampler.break_sharing) do node
7173
if @inline(sampler.filter(node)) && (i[] += 1) == idx
7274
out[] = node
7375
end

0 commit comments

Comments
 (0)