Skip to content

Commit 08a1ce4

Browse files
committed
new sampling (on cpu)
1 parent 13af59a commit 08a1ce4

File tree

8 files changed

+138
-31
lines changed

8 files changed

+138
-31
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
1212
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1313
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
14+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1415
TropicalGEMM = "a4ad3063-64a7-4bad-8738-34ed09bc0236"
1516
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
1617

1718
[compat]
19+
Artifacts = "1"
1820
CUDA = "4"
1921
DocStringExtensions = "0.8.6, 0.9"
2022
OMEinsum = "0.7"
21-
Requires = "1"
2223
PrecompileTools = "1"
24+
Requires = "1"
25+
StatsBase = "0.34"
2326
TropicalGEMM = "0.1"
2427
TropicalNumbers = "0.5.4"
2528
julia = "1.3"

example/asia/asia.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ probability(tnet)
1010
# Get the marginal probabilities (MAR)
1111
marginals(tnet) .|> first
1212

13+
# The corresponding variables are
14+
get_vars(tnet)
15+
1316
# Set the evidence variables "X-ray" (7) to be positive.
1417
set_evidence!(instance, 7=>0)
1518

src/TensorInference.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using DocStringExtensions, TropicalNumbers
55
using Artifacts
66
# The Tropical GEMM support
77
using TropicalGEMM
8+
using StatsBase
89

910
# reexport OMEinsum functions
1011
export RescaledArray
@@ -20,6 +21,9 @@ export TensorNetworkModel, get_vars, get_cards, log_probability, probability, ma
2021
# MAP
2122
export most_probable_config, maximum_logp
2223

24+
# sampling
25+
export sample
26+
2327
# MMAP
2428
export MMAPModel
2529

@@ -29,6 +33,7 @@ include("utils.jl")
2933
include("inference.jl")
3034
include("maxprob.jl")
3135
include("mmap.jl")
36+
include("sampling.jl")
3237

3338
using Requires
3439
function __init__()

src/inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
4949
end
5050

5151
# computed gradient tree by back propagation
52-
function generate_gradient_tree(se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
52+
function generate_gradient_tree(se::SlicedEinsum, cache::CacheTree{T}, dy, size_dict::Dict) where {T}
5353
if length(se.slicing) != 0
5454
@warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`."
5555
end
@@ -58,7 +58,7 @@ end
5858

5959
# recursively compute the gradients and store it into a tree.
6060
# also known as the back-propagation algorithm.
61-
function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
61+
function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy, size_dict::Dict) where {T}
6262
if OMEinsum.isleaf(code)
6363
return CacheTree(dy, CacheTree{T}[])
6464
else

src/sampling.jl

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
############ Sampling ############
2-
3-
########### Backward propagating sampling process ##############
4-
function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Real}} where {M}, y, size_dict, dy::Samples)
5-
return backward_sampling(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), y, dy, size_dict)
6-
end
7-
82
struct Samples{L}
93
samples::Vector{Vector{Int}}
104
labels::Vector{L}
11-
setmask::Vector{Bool}
5+
setmask::BitVector
126
end
7+
function setmask!(samples::Samples, eliminated_variables)
8+
for var in eliminated_variables
9+
loc = findfirst(==(var), samples.labels)
10+
samples.setmask[loc] && error("varaible `$var` is already eliminated.")
11+
samples.setmask[loc] = true
12+
end
13+
return samples
14+
end
15+
16+
idx4labels(totalset, labels) = map(v->findfirst(==(v), totalset), labels)
1317

1418
"""
1519
$(TYPEDSIGNATURES)
@@ -21,48 +25,74 @@ The backward rule for tropical einsum.
2125
* `ysamples` is the samples generated on the output tensor,
2226
* `size_dict` is a key-value map from tensor label to dimension size.
2327
"""
24-
function backward_sampling(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), samples::Samples, size_dict)
25-
idx4label(totalset, labels) = map(v->findfirst(==(v), totalset), labels)
28+
function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), samples::Samples, size_dict)
2629
eliminated_variables = setdiff(vcat(ixs...), iy)
27-
eliminated_locs = idx4label(samples.labels, eliminated_variables)
28-
samples.setmask[eliminated_locs] .= true
30+
eliminated_locs = idx4labels(samples.labels, eliminated_variables)
31+
setmask!(samples, eliminated_variables)
2932

3033
# the contraction code to get probability
3134
newiy = eliminated_variables
3235
iy_in_sample = idx4labels(samples.labels, iy)
3336
slice_y_dim = collect(1:length(iy))
3437
newixs = map(ix->setdiff(ix, iy), ixs)
3538
ix_in_sample = map(ix->idx4labels(samples.labels, ix iy), ixs)
36-
slice_xs_dim = map(ix->idx4label(ix, ix iy), ixs)
39+
slice_xs_dim = map(ix->idx4labels(ix, ix iy), ixs)
3740
code = DynamicEinCode(newixs, newiy)
3841

39-
totalset = CartesianIndices(map(x->size_dict[x], eliminated_variables)...)
42+
totalset = CartesianIndices((map(x->size_dict[x], eliminated_variables)...,))
4043
for (i, sample) in enumerate(samples.samples)
4144
newxs = [get_slice(x, dimx, sample[ixloc]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
42-
newy = Array(get_slice(y, slice_y_dim, sample[iy_in_sample]))[]
43-
probabilities = einsum(code, newxs, size_dict) / newy
44-
config = StatsBase.sample(totalset, weights=StatsBase.Weights(probabilities))
45+
newy = get_element(y, slice_y_dim, sample[iy_in_sample])
46+
probabilities = einsum(code, (newxs...,), size_dict) / newy
47+
config = StatsBase.sample(totalset, Weights(vec(probabilities)))
4548
# update the samples
46-
samples.samples[i][eliminated_locs] .= config.I
49+
samples.samples[i][eliminated_locs] .= config.I .- 1
4750
end
48-
return xsamples
51+
return samples
4952
end
5053

5154
# type unstable
5255
function get_slice(x, dim, config)
53-
for (d, c) in zip(dim, config)
54-
x = selectdim(x, d, c)
55-
end
56-
return x
56+
asarray(x[[i dim ? config[findfirst(==(i), dim)]+1 : Colon() for i in 1:ndims(x)]...], x)
57+
end
58+
function get_element(x, dim, config)
59+
x[[config[findfirst(==(i), dim)]+1 for i in 1:ndims(x)]...]
5760
end
5861

5962
"""
6063
$(TYPEDSIGNATURES)
6164
6265
Sample a tensor network based probabilistic model.
6366
"""
64-
function sample(tn::TensorNetworkModel; usecuda = false)::AbstractArray{<:Real}
67+
function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Samples
6568
# generate tropical tensors with its elements being log(p).
66-
tensors = adapt_tensors(tn; usecuda, rescale = false)
67-
return tn.code(tensors...)
69+
xs = adapt_tensors(tn; usecuda, rescale = false)
70+
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
71+
size_dict = OMEinsum.get_size_dict!(getixsv(tn.code), xs, Dict{Int, Int}())
72+
# forward compute and cache intermediate results.
73+
cache = cached_einsum(tn.code, xs, size_dict)
74+
# initialize `y̅` as the initial batch of samples.
75+
labels = OMEinsum.uniquelabels(tn.code)
76+
iy = getiyv(tn.code)
77+
setmask = falses(length(labels))
78+
idx = map(l->findfirst(==(l), labels), iy)
79+
setmask[idx] .= true
80+
indices = StatsBase.sample(CartesianIndices(size(cache.content)), Weights(normalize!(vec(LinearAlgebra.normalize!(cache.content)))), n)
81+
configs = map(indices) do ind
82+
c=zeros(Int, length(labels))
83+
c[idx] .= ind.I .- 1
84+
c
85+
end
86+
samples = Samples(configs, labels, setmask)
87+
# back-propagate
88+
generate_samples(tn.code, cache, samples, size_dict)
89+
return samples
90+
end
91+
92+
function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}
93+
if !OMEinsum.isleaf(code)
94+
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
95+
backward_sampling!(OMEinsum.getixs(code.eins), xs, OMEinsum.getiy(code.eins), cache.content, samples, size_dict)
96+
generate_samples.(code.args, cache.siblings, Ref(samples), Ref(size_dict))
97+
end
6898
end

src/utils.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ The UAI file formats are defined in:
88
https://personal.utdallas.edu/~vibhav.gogate/uai16-evaluation/uaiformat.html
99
"""
1010
function read_uai_file(uai_filepath; factor_eltype = Float64)
11-
1211
# Read the uai file into an array of lines
13-
rawlines = open(uai_filepath) do file
14-
readlines(file)
12+
str = open(uai_filepath) do file
13+
read(file, String)
1514
end
15+
return read_uai_string(str; factor_eltype)
16+
end
1617

18+
function read_uai_string(str; factor_eltype = Float64)
19+
rawlines = split(str, "\n")
1720
# Filter out empty lines
1821
lines = filter(!isempty, rawlines)
1922

@@ -193,5 +196,10 @@ function uai_problem_from_file(uai_filepath::String; uai_evid_filepath="", uai_m
193196
return UAIInstance(nvars, ncliques, cards, factors, obsvars, obsvals, reference_marginals)
194197
end
195198

199+
function uai_problem_from_string(uai::String; eltype=Float64)::UAIInstance
200+
nvars, cards, ncliques, factors = read_uai_string(uai; factor_eltype = eltype)
201+
return UAIInstance(nvars, ncliques, cards, factors, Int[], Int[], Vector{eltype}[])
202+
end
203+
196204
# patch to get content by broadcasting into array, while keep array size unchanged.
197205
broadcasted_content(x) = asarray(content.(x), x)

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ end
1313
@testset "MMAP" begin
1414
include("mmap.jl")
1515
end
16+
@testset "MMAP" begin
17+
include("sampling.jl")
18+
end
1619

1720
using CUDA
1821
if CUDA.functional()

test/sampling.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using TensorInference, Test
2+
3+
@testset "sampling" begin
4+
instance = TensorInference.uai_problem_from_string("""MARKOV
5+
8
6+
2 2 2 2 2 2 2 2
7+
8
8+
1 0
9+
2 1 0
10+
1 2
11+
2 3 2
12+
2 4 2
13+
3 5 3 1
14+
2 6 5
15+
3 7 5 4
16+
17+
2
18+
0.01
19+
0.99
20+
21+
4
22+
0.05 0.01
23+
0.95 0.99
24+
25+
2
26+
0.5
27+
0.5
28+
29+
4
30+
0.1 0.01
31+
0.9 0.99
32+
33+
4
34+
0.6 0.3
35+
0.4 0.7
36+
37+
8
38+
1 1 1 0
39+
0 0 0 1
40+
41+
4
42+
0.98 0.05
43+
0.02 0.95
44+
45+
8
46+
0.9 0.7 0.8 0.1
47+
0.1 0.3 0.2 0.9
48+
""")
49+
n = 10000
50+
tnet = TensorNetworkModel(instance)
51+
samples = sample(tnet, n)
52+
mars = getindex.(marginals(tnet), 2)
53+
mars_sample = [count(s->s[k]==(1), samples.samples) for k=1:8] ./ n
54+
@test isapprox(mars, mars_sample, atol=0.05)
55+
end

0 commit comments

Comments
 (0)