Skip to content

Commit 2da0155

Browse files
committed
Add unittests for random sampling
1 parent a8a553b commit 2da0155

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

test/test_random.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
using DynamicExpressions
2+
using Random
3+
using Test
4+
5+
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
6+
x = GraphNode(Float64; feature=1)
7+
tree = cos(x) + x
8+
9+
num_samples = 10_000
10+
atol = 200
11+
12+
@testset "Basic random sampling" begin
13+
uniform_samples = let rng = Random.MersenneTwister(0)
14+
[rand(rng, tree) for _ in 1:num_samples]
15+
end
16+
17+
num_plus = count(Base.Fix1(===, tree), uniform_samples)
18+
num_x = count(Base.Fix1(===, x), uniform_samples)
19+
num_cos = count(Base.Fix1(===, tree.l), uniform_samples)
20+
21+
@test isapprox(num_plus, num_samples ÷ 3; atol)
22+
@test isapprox(num_x, num_samples ÷ 3; atol)
23+
@test isapprox(num_cos, num_samples ÷ 3; atol)
24+
25+
# Now, we sample without sharing
26+
broken_sharing_samples = let rng = Random.MersenneTwister(0)
27+
[rand(rng, NodeSampler(; tree, break_sharing=Val(true))) for _ in 1:num_samples]
28+
end
29+
num_plus = count(Base.Fix1(===, tree), broken_sharing_samples)
30+
num_x = count(Base.Fix1(===, x), broken_sharing_samples)
31+
num_cos = count(Base.Fix1(===, tree.l), broken_sharing_samples)
32+
33+
@test isapprox(num_plus, num_samples ÷ 4; atol)
34+
@test isapprox(num_x, num_samples ÷ 2; atol)
35+
@test isapprox(num_cos, num_samples ÷ 4; atol)
36+
end
37+
38+
@testset "Weighted sampling" begin
39+
function weighting_1(t)
40+
if t == cos(x)
41+
75.0
42+
elseif t == x
43+
10.0
44+
else
45+
15.0
46+
end
47+
end
48+
specific_weighted_samples = let rng = Random.MersenneTwister(0)
49+
[rand(rng, NodeSampler(; tree, weighting=weighting_1)) for _ in 1:num_samples]
50+
end
51+
num_plus = count(Base.Fix1(===, tree), specific_weighted_samples)
52+
num_x = count(Base.Fix1(===, x), specific_weighted_samples)
53+
num_cos = count(Base.Fix1(===, tree.l), specific_weighted_samples)
54+
55+
@test isapprox(num_plus, num_samples * 15 ÷ 100; atol)
56+
@test isapprox(num_x, num_samples * 10 ÷ 100; atol)
57+
@test isapprox(num_cos, num_samples * 75 ÷ 100; atol)
58+
59+
# Now, without sharing
60+
function weighting_2(t)
61+
if t == cos(x)
62+
75.0
63+
elseif t == x
64+
10.0 # Will be doubled if we break sharing
65+
else
66+
5.0
67+
end
68+
end
69+
broken_sharing_weighted_samples = let rng = Random.MersenneTwister(0)
70+
[
71+
rand(rng, NodeSampler(; tree, weighting=weighting_2, break_sharing=Val(true))) for _ in 1:num_samples
72+
]
73+
end
74+
num_plus = count(Base.Fix1(===, tree), broken_sharing_weighted_samples)
75+
num_x = count(Base.Fix1(===, x), broken_sharing_weighted_samples)
76+
num_cos = count(Base.Fix1(===, tree.l), broken_sharing_weighted_samples)
77+
78+
@test isapprox(num_plus, num_samples * 5 ÷ 100; atol)
79+
@test isapprox(num_x, num_samples * 20 ÷ 100; atol)
80+
@test isapprox(num_cos, num_samples * 75 ÷ 100; atol)
81+
end

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,7 @@ end
9595
@safetestset "Test custom node type" begin
9696
include("test_custom_node_type.jl")
9797
end
98+
99+
@safetestset "Test random sampling" begin
100+
include("test_random.jl")
101+
end

0 commit comments

Comments
 (0)