|
| 1 | +############ Sampling ############ |
| 2 | + |
| 3 | +########### Backward propagating sampling process ############## |
| 4 | +function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Real}} where {M}, y, size_dict, dy::Samples) |
| 5 | + return backward_sampling(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), y, dy, size_dict) |
| 6 | +end |
| 7 | + |
| 8 | +""" |
| 9 | +$(TYPEDSIGNATURES) |
| 10 | +
|
| 11 | +The backward rule for tropical einsum. |
| 12 | +
|
| 13 | +* `ixs` and `xs` are labels and tensor data for input tensors, |
| 14 | +* `iy` and `y` are labels and tensor data for the output tensor, |
| 15 | +* `ysamples` is the samples generated on the output tensor, |
| 16 | +* `size_dict` is a key-value map from tensor label to dimension size. |
| 17 | +""" |
| 18 | +function backward_sampling(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), @nospecialize(ysamples), size_dict) |
| 19 | + xsamples = [] |
| 20 | + for i in eachindex(ixs) |
| 21 | + nixs = OMEinsum._insertat(ixs, i, iy) |
| 22 | + nxs = OMEinsum._insertat(xs, i, y) |
| 23 | + niy = ixs[i] |
| 24 | + A = einsum(EinCode(nixs, niy), nxs, size_dict) |
| 25 | + |
| 26 | + # compute the mask, one of its entry in `A^{-1}` that equal to the corresponding entry in `X` is masked to true. |
| 27 | + j = argmax(xs[i] ./ inv.(A)) |
| 28 | + mask = onehot_like(A, j) |
| 29 | + push!(xsamples, mask) |
| 30 | + end |
| 31 | + return xsamples |
| 32 | +end |
| 33 | + |
| 34 | +""" |
| 35 | +$(TYPEDSIGNATURES) |
| 36 | +
|
| 37 | +Sample a tensor network based probabilistic model. |
| 38 | +""" |
| 39 | +function sample(tn::TensorNetworkModel; usecuda = false)::AbstractArray{<:Real} |
| 40 | + # generate tropical tensors with its elements being log(p). |
| 41 | + tensors = adapt_tensors(tn; usecuda, rescale = false) |
| 42 | + return tn.code(tensors...) |
| 43 | +end |
0 commit comments