@@ -2,10 +2,11 @@ module RandomModule
22
33import Compat: Returns
44import Random: AbstractRNG
5+ import Base: rand
56import .. EquationModule: AbstractNode, tree_mapreduce, filter_map
67
78"""
8- NodeSampler(; tree, filter=Returns(true), weighting=nothing, break_sharing=Val(false))
9+ NodeSampler(; tree, filter::Function =Returns(true), weighting::Union{Nothing,Function} =nothing, break_sharing::Val =Val(false))
910
1011Defines a sampler of nodes in a tree. `filter` can be used to pre-filter
1112nodes on which to sample.
@@ -23,51 +24,52 @@ nodes on which to sample.
2324 returns a weight for the node, if it passes the filter, proportional
2425 to the probability of sampling the node. If `nothing`, all nodes are
2526 sampled uniformly.
26- - `break_sharing::Union{Bool, Val} `: If `true` or `Val(true)`, the
27+ - `break_sharing::Val`: If `Val(true)`, the
2728 sampler will break sharing in the tree, and sample nodes uniformly
2829 from the tree.
2930"""
3031Base. @kwdef struct NodeSampler{
31- N<: AbstractNode ,F<: Function ,W<: Union{Nothing,Function} ,B<: Union{Bool, Val}
32+ N<: AbstractNode ,F<: Function ,W<: Union{Nothing,Function} ,B<: Val
3233}
3334 tree:: N
3435 weighting:: W = nothing
3536 filter:: F = Returns (true )
3637 break_sharing:: B = Val (false )
3738end
3839
39- Base. rand (rng:: AbstractRNG , tree:: AbstractNode ) = rand (rng, NodeSampler (; tree))
40- function Base. rand (rng:: AbstractRNG , sampler:: NodeSampler{N,F,Nothing} ) where {N,F}
41- break_sharing = if sampler. break_sharing isa Val
42- sampler. break_sharing
43- else
44- sampler. break_sharing ? Val (true ) : Val (false )
45- end
46- n = count (sampler. filter, sampler. tree; break_sharing)
40+ """
41+ rand(rng::AbstractRNG, tree::AbstractNode; break_sharing::Val=Val(false))
42+
43+ Sample a node from a tree according to the default sampler `NodeSampler(; tree, break_sharing)`.
44+ """
45+ rand (rng:: AbstractRNG , tree:: AbstractNode ; break_sharing:: Val = Val (false )) = rand (rng, NodeSampler (; tree, break_sharing))
46+
47+ """
48+ rand(rng::AbstractRNG, sampler::NodeSampler)
49+
50+ Sample a node from a tree according to the sampler `sampler`.
51+ """
52+ function rand (rng:: AbstractRNG , sampler:: NodeSampler{N,F,Nothing} ) where {N,F}
53+ n = count (sampler. filter, sampler. tree; sampler. break_sharing)
4754 idx = rand (rng, 1 : n)
4855 i = Ref (0 )
4956 out = Ref (sampler. tree)
50- foreach (sampler. tree; break_sharing) do node
57+ foreach (sampler. tree; sampler . break_sharing) do node
5158 if @inline (sampler. filter (node)) && (i[] += 1 ) == idx
5259 out[] = node
5360 end
5461 nothing
5562 end
5663 return out[]
5764end
58- function Base. rand (rng:: AbstractRNG , sampler:: NodeSampler{N,F,W} ) where {N,F,W<: Function }
59- break_sharing = if sampler. break_sharing isa Val
60- sampler. break_sharing
61- else
62- sampler. break_sharing ? Val (true ) : Val (false )
63- end
65+ function rand (rng:: AbstractRNG , sampler:: NodeSampler{N,F,W} ) where {N,F,W<: Function }
6466 weights = filter_map (
65- sampler. filter, sampler. weighting, sampler. tree, Float64; break_sharing
67+ sampler. filter, sampler. weighting, sampler. tree, Float64; sampler . break_sharing
6668 )
6769 idx = sample_idx (rng, weights)
6870 i = Ref (0 )
6971 out = Ref (sampler. tree)
70- foreach (sampler. tree; break_sharing) do node
72+ foreach (sampler. tree; sampler . break_sharing) do node
7173 if @inline (sampler. filter (node)) && (i[] += 1 ) == idx
7274 out[] = node
7375 end
0 commit comments