Skip to content

Commit 59aff22

Browse files
committed
sampling save
1 parent f6298c5 commit 59aff22

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

src/sampling.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
############ Sampling ############
2+
3+
########### Backward propagating sampling process ##############
4+
function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Real}} where {M}, y, size_dict, dy::Samples)
5+
return backward_sampling(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), y, dy, size_dict)
6+
end
7+
8+
"""
9+
$(TYPEDSIGNATURES)
10+
11+
The backward rule for tropical einsum.
12+
13+
* `ixs` and `xs` are labels and tensor data for input tensors,
14+
* `iy` and `y` are labels and tensor data for the output tensor,
15+
* `ysamples` is the samples generated on the output tensor,
16+
* `size_dict` is a key-value map from tensor label to dimension size.
17+
"""
18+
function backward_sampling(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), @nospecialize(ysamples), size_dict)
19+
xsamples = []
20+
for i in eachindex(ixs)
21+
nixs = OMEinsum._insertat(ixs, i, iy)
22+
nxs = OMEinsum._insertat(xs, i, y)
23+
niy = ixs[i]
24+
A = einsum(EinCode(nixs, niy), nxs, size_dict)
25+
26+
# compute the mask, one of its entry in `A^{-1}` that equal to the corresponding entry in `X` is masked to true.
27+
j = argmax(xs[i] ./ inv.(A))
28+
mask = onehot_like(A, j)
29+
push!(xsamples, mask)
30+
end
31+
return xsamples
32+
end
33+
34+
"""
35+
$(TYPEDSIGNATURES)
36+
37+
Sample a tensor network based probabilistic model.
38+
"""
39+
function sample(tn::TensorNetworkModel; usecuda = false)::AbstractArray{<:Real}
40+
# generate tropical tensors with its elements being log(p).
41+
tensors = adapt_tensors(tn; usecuda, rescale = false)
42+
return tn.code(tensors...)
43+
end

0 commit comments

Comments
 (0)