11# ########### Sampling ############
2-
3- # ########## Backward propagating sampling process ##############
4- function einsum_backward_rule (eins, xs:: NTuple{M, AbstractArray{<:Real}} where {M}, y, size_dict, dy:: Samples )
5- return backward_sampling (OMEinsum. getixs (eins), xs, OMEinsum. getiy (eins), y, dy, size_dict)
6- end
7-
82struct Samples{L}
93 samples:: Vector{Vector{Int}}
104 labels:: Vector{L}
11- setmask:: Vector{Bool}
5+ setmask:: BitVector
126end
7+ function setmask! (samples:: Samples , eliminated_variables)
8+ for var in eliminated_variables
9+ loc = findfirst (== (var), samples. labels)
10+ samples. setmask[loc] && error (" varaible `$var ` is already eliminated." )
11+ samples. setmask[loc] = true
12+ end
13+ return samples
14+ end
15+
16+ idx4labels (totalset, labels) = map (v-> findfirst (== (v), totalset), labels)
1317
1418"""
1519$(TYPEDSIGNATURES)
@@ -21,48 +25,74 @@ The backward rule for tropical einsum.
2125* `ysamples` is the samples generated on the output tensor,
2226* `size_dict` is a key-value map from tensor label to dimension size.
2327"""
24- function backward_sampling (ixs, @nospecialize (xs:: Tuple ), iy, @nospecialize (y), samples:: Samples , size_dict)
25- idx4label (totalset, labels) = map (v-> findfirst (== (v), totalset), labels)
28+ function backward_sampling! (ixs, @nospecialize (xs:: Tuple ), iy, @nospecialize (y), samples:: Samples , size_dict)
2629 eliminated_variables = setdiff (vcat (ixs... ), iy)
27- eliminated_locs = idx4label (samples. labels, eliminated_variables)
28- samples . setmask[eliminated_locs] . = true
30+ eliminated_locs = idx4labels (samples. labels, eliminated_variables)
31+ setmask! (samples, eliminated_variables)
2932
3033 # the contraction code to get probability
3134 newiy = eliminated_variables
3235 iy_in_sample = idx4labels (samples. labels, iy)
3336 slice_y_dim = collect (1 : length (iy))
3437 newixs = map (ix-> setdiff (ix, iy), ixs)
3538 ix_in_sample = map (ix-> idx4labels (samples. labels, ix ∩ iy), ixs)
36- slice_xs_dim = map (ix-> idx4label (ix, ix ∩ iy), ixs)
39+ slice_xs_dim = map (ix-> idx4labels (ix, ix ∩ iy), ixs)
3740 code = DynamicEinCode (newixs, newiy)
3841
39- totalset = CartesianIndices (map (x-> size_dict[x], eliminated_variables)... )
42+ totalset = CartesianIndices (( map (x-> size_dict[x], eliminated_variables)... ,) )
4043 for (i, sample) in enumerate (samples. samples)
4144 newxs = [get_slice (x, dimx, sample[ixloc]) for (x, dimx, ixloc) in zip (xs, slice_xs_dim, ix_in_sample)]
42- newy = Array ( get_slice ( y, slice_y_dim, sample[iy_in_sample]))[]
43- probabilities = einsum (code, newxs, size_dict) / newy
44- config = StatsBase. sample (totalset, weights = StatsBase . Weights (probabilities))
45+ newy = get_element ( y, slice_y_dim, sample[iy_in_sample])
46+ probabilities = einsum (code, ( newxs... ,) , size_dict) / newy
47+ config = StatsBase. sample (totalset, Weights (vec ( probabilities) ))
4548 # update the samples
46- samples. samples[i][eliminated_locs] .= config. I
49+ samples. samples[i][eliminated_locs] .= config. I .- 1
4750 end
48- return xsamples
51+ return samples
4952end
5053
5154# type unstable
5255function get_slice (x, dim, config)
53- for (d, c) in zip (dim, config )
54- x = selectdim (x, d, c)
55- end
56- return x
56+ asarray (x[[i ∈ dim ? config[ findfirst ( == (i), dim)] + 1 : Colon () for i in 1 : ndims (x)] . .. ], x )
57+ end
58+ function get_element (x, dim, config)
59+ x[[config[ findfirst ( == (i), dim)] + 1 for i in 1 : ndims (x)] . .. ]
5760end
5861
5962"""
6063$(TYPEDSIGNATURES)
6164
6265Sample a tensor network based probabilistic model.
6366"""
64- function sample (tn:: TensorNetworkModel ; usecuda = false ):: AbstractArray{<:Real}
67+ function sample (tn:: TensorNetworkModel , n :: Int ; usecuda = false ):: Samples
6568 # generate tropical tensors with its elements being log(p).
66- tensors = adapt_tensors (tn; usecuda, rescale = false )
67- return tn. code (tensors... )
69+ xs = adapt_tensors (tn; usecuda, rescale = false )
70+ # infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
71+ size_dict = OMEinsum. get_size_dict! (getixsv (tn. code), xs, Dict {Int, Int} ())
72+ # forward compute and cache intermediate results.
73+ cache = cached_einsum (tn. code, xs, size_dict)
74+ # initialize `y̅` as the initial batch of samples.
75+ labels = OMEinsum. uniquelabels (tn. code)
76+ iy = getiyv (tn. code)
77+ setmask = falses (length (labels))
78+ idx = map (l-> findfirst (== (l), labels), iy)
79+ setmask[idx] .= true
80+ indices = StatsBase. sample (CartesianIndices (size (cache. content)), Weights (normalize! (vec (LinearAlgebra. normalize! (cache. content)))), n)
81+ configs = map (indices) do ind
82+ c= zeros (Int, length (labels))
83+ c[idx] .= ind. I .- 1
84+ c
85+ end
86+ samples = Samples (configs, labels, setmask)
87+ # back-propagate
88+ generate_samples (tn. code, cache, samples, size_dict)
89+ return samples
90+ end
91+
92+ function generate_samples (code:: NestedEinsum , cache:: CacheTree{T} , samples, size_dict:: Dict ) where {T}
93+ if ! OMEinsum. isleaf (code)
94+ xs = ntuple (i -> cache. siblings[i]. content, length (cache. siblings))
95+ backward_sampling! (OMEinsum. getixs (code. eins), xs, OMEinsum. getiy (code. eins), cache. content, samples, size_dict)
96+ generate_samples .(code. args, cache. siblings, Ref (samples), Ref (size_dict))
97+ end
6898end
0 commit comments