|
| 1 | +using DynamicExpressions |
| 2 | +using Random |
| 3 | +using Test |
| 4 | + |
| 5 | +operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin]) |
| 6 | +x = GraphNode(Float64; feature=1) |
| 7 | +tree = cos(x) + x |
| 8 | + |
| 9 | +num_samples = 10_000 |
| 10 | +atol = 200 |
| 11 | + |
| 12 | +@testset "Basic random sampling" begin |
| 13 | + uniform_samples = let rng = Random.MersenneTwister(0) |
| 14 | + [rand(rng, tree) for _ in 1:num_samples] |
| 15 | + end |
| 16 | + |
| 17 | + num_plus = count(Base.Fix1(===, tree), uniform_samples) |
| 18 | + num_x = count(Base.Fix1(===, x), uniform_samples) |
| 19 | + num_cos = count(Base.Fix1(===, tree.l), uniform_samples) |
| 20 | + |
| 21 | + @test isapprox(num_plus, num_samples ÷ 3; atol) |
| 22 | + @test isapprox(num_x, num_samples ÷ 3; atol) |
| 23 | + @test isapprox(num_cos, num_samples ÷ 3; atol) |
| 24 | + |
| 25 | + # Now, we sample without sharing |
| 26 | + broken_sharing_samples = let rng = Random.MersenneTwister(0) |
| 27 | + [rand(rng, NodeSampler(; tree, break_sharing=Val(true))) for _ in 1:num_samples] |
| 28 | + end |
| 29 | + num_plus = count(Base.Fix1(===, tree), broken_sharing_samples) |
| 30 | + num_x = count(Base.Fix1(===, x), broken_sharing_samples) |
| 31 | + num_cos = count(Base.Fix1(===, tree.l), broken_sharing_samples) |
| 32 | + |
| 33 | + @test isapprox(num_plus, num_samples ÷ 4; atol) |
| 34 | + @test isapprox(num_x, num_samples ÷ 2; atol) |
| 35 | + @test isapprox(num_cos, num_samples ÷ 4; atol) |
| 36 | +end |
| 37 | + |
| 38 | +@testset "Weighted sampling" begin |
| 39 | + function weighting_1(t) |
| 40 | + if t == cos(x) |
| 41 | + 75.0 |
| 42 | + elseif t == x |
| 43 | + 10.0 |
| 44 | + else |
| 45 | + 15.0 |
| 46 | + end |
| 47 | + end |
| 48 | + specific_weighted_samples = let rng = Random.MersenneTwister(0) |
| 49 | + [rand(rng, NodeSampler(; tree, weighting=weighting_1)) for _ in 1:num_samples] |
| 50 | + end |
| 51 | + num_plus = count(Base.Fix1(===, tree), specific_weighted_samples) |
| 52 | + num_x = count(Base.Fix1(===, x), specific_weighted_samples) |
| 53 | + num_cos = count(Base.Fix1(===, tree.l), specific_weighted_samples) |
| 54 | + |
| 55 | + @test isapprox(num_plus, num_samples * 15 ÷ 100; atol) |
| 56 | + @test isapprox(num_x, num_samples * 10 ÷ 100; atol) |
| 57 | + @test isapprox(num_cos, num_samples * 75 ÷ 100; atol) |
| 58 | + |
| 59 | + # Now, without sharing |
| 60 | + function weighting_2(t) |
| 61 | + if t == cos(x) |
| 62 | + 75.0 |
| 63 | + elseif t == x |
| 64 | + 10.0 # Will be doubled if we break sharing |
| 65 | + else |
| 66 | + 5.0 |
| 67 | + end |
| 68 | + end |
| 69 | + broken_sharing_weighted_samples = let rng = Random.MersenneTwister(0) |
| 70 | + [ |
| 71 | + rand(rng, NodeSampler(; tree, weighting=weighting_2, break_sharing=Val(true))) for _ in 1:num_samples |
| 72 | + ] |
| 73 | + end |
| 74 | + num_plus = count(Base.Fix1(===, tree), broken_sharing_weighted_samples) |
| 75 | + num_x = count(Base.Fix1(===, x), broken_sharing_weighted_samples) |
| 76 | + num_cos = count(Base.Fix1(===, tree.l), broken_sharing_weighted_samples) |
| 77 | + |
| 78 | + @test isapprox(num_plus, num_samples * 5 ÷ 100; atol) |
| 79 | + @test isapprox(num_x, num_samples * 20 ÷ 100; atol) |
| 80 | + @test isapprox(num_cos, num_samples * 75 ÷ 100; atol) |
| 81 | +end |
0 commit comments