|
| 1 | +module RandomModule |
| 2 | + |
| 3 | +import Compat: Returns |
| 4 | +import Random: AbstractRNG |
| 5 | +import ..EquationModule: AbstractNode, tree_mapreduce, filter_map |
| 6 | + |
| 7 | +""" |
| 8 | + NodeSampler(; tree, filter=Returns(true), weighting=nothing, break_sharing=Val(false)) |
| 9 | +
|
| 10 | +Defines a sampler of nodes in a tree. `filter` can be used to pre-filter |
| 11 | +nodes on which to sample. |
| 12 | +
|
| 13 | +# Arguments |
| 14 | +
|
| 15 | +- `tree`: The tree to sample nodes from. For a regular `Node`, |
| 16 | + nodes are sampled uniformly. For a `GraphNode`, nodes are also |
| 17 | + sampled uniformly (e.g., in `sin(x) + {x}`, the `x` has equal |
| 18 | + probability of being sampled from the `sin` or the `+` node, because |
| 19 | + it is shared), unless `break_sharing` is set to `true` or `Val(true)`. |
| 20 | +- `filter::Function`: A function that takes a node and returns a boolean |
| 21 | + indicating whether the node should be sampled. Defaults to `Returns(true)`. |
| 22 | +- `weighting::Union{Nothing,Function}`: A function that takes a node and |
| 23 | + returns a weight for the node, if it passes the filter, proportional |
| 24 | + to the probability of sampling the node. If `nothing`, all nodes are |
| 25 | + sampled uniformly. |
| 26 | +- `break_sharing::Union{Bool,Val}`: If `true` or `Val(true)`, the |
| 27 | + sampler will break sharing in the tree, and sample nodes uniformly |
| 28 | + from the tree. |
| 29 | +""" |
| 30 | +Base.@kwdef struct NodeSampler{ |
| 31 | + N<:AbstractNode,F<:Function,W<:Union{Nothing,Function},B<:Union{Bool,Val} |
| 32 | +} |
| 33 | + tree::N |
| 34 | + weighting::W = nothing |
| 35 | + filter::F = Returns(true) |
| 36 | + break_sharing::B = Val(false) |
| 37 | +end |
| 38 | + |
| 39 | +Base.rand(rng::AbstractRNG, tree::AbstractNode) = rand(rng, NodeSampler(; tree)) |
| 40 | +function Base.rand(rng::AbstractRNG, sampler::NodeSampler{N,F,Nothing}) where {N,F} |
| 41 | + break_sharing = if sampler.break_sharing isa Val |
| 42 | + sampler.break_sharing |
| 43 | + else |
| 44 | + sampler.break_sharing ? Val(true) : Val(false) |
| 45 | + end |
| 46 | + n = count(sampler.filter, sampler.tree; break_sharing) |
| 47 | + idx = rand(rng, 1:n) |
| 48 | + i = Ref(0) |
| 49 | + out = Ref(sampler.tree) |
| 50 | + foreach(sampler.tree; break_sharing) do node |
| 51 | + if @inline(sampler.filter(node)) && (i[] += 1) == idx |
| 52 | + out[] = node |
| 53 | + end |
| 54 | + nothing |
| 55 | + end |
| 56 | + return out[] |
| 57 | +end |
| 58 | +function Base.rand(rng::AbstractRNG, sampler::NodeSampler{N,F,W}) where {N,F,W<:Function} |
| 59 | + break_sharing = if sampler.break_sharing isa Val |
| 60 | + sampler.break_sharing |
| 61 | + else |
| 62 | + sampler.break_sharing ? Val(true) : Val(false) |
| 63 | + end |
| 64 | + weights = filter_map( |
| 65 | + sampler.filter, sampler.weighting, sampler.tree, Float64; break_sharing |
| 66 | + ) |
| 67 | + idx = sample_idx(rng, weights) |
| 68 | + i = Ref(0) |
| 69 | + out = Ref(sampler.tree) |
| 70 | + foreach(sampler.tree; break_sharing) do node |
| 71 | + if @inline(sampler.filter(node)) && (i[] += 1) == idx |
| 72 | + out[] = node |
| 73 | + end |
| 74 | + nothing |
| 75 | + end |
| 76 | + return out[] |
| 77 | +end |
| 78 | +sample_idx(rng::AbstractRNG, weights) = findfirst(cumsum(weights) .> rand(rng))::Int |
| 79 | + |
| 80 | +end |
0 commit comments