@@ -5,6 +5,12 @@ function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Real}} where {
55 return backward_sampling (OMEinsum. getixs (eins), xs, OMEinsum. getiy (eins), y, dy, size_dict)
66end
77
8+ struct Samples{L}
9+ samples:: Vector{Vector{Int}}
10+ labels:: Vector{L}
11+ setmask:: Vector{Bool}
12+ end
13+
814"""
915$(TYPEDSIGNATURES)
1016
@@ -15,18 +21,18 @@ The backward rule for tropical einsum.
1521* `ysamples` is the samples generated on the output tensor,
1622* `size_dict` is a key-value map from tensor label to dimension size.
1723"""
18- function backward_sampling (ixs, @nospecialize (xs:: Tuple ), iy, @nospecialize (y), @nospecialize (ysamples) , size_dict)
19- xsamples = []
20- for i in eachindex (ixs)
21- nixs = OMEinsum . _insertat (ixs, i, iy)
22- nxs = OMEinsum . _insertat (xs, i, y )
23- niy = ixs[i]
24- A = einsum ( EinCode (nixs, niy), nxs, size_dict )
25-
26- # compute the mask, one of its entry in `A^{-1}` that equal to the corresponding entry in `X` is masked to true.
27- j = argmax (xs[i] ./ inv .(A))
28- mask = onehot_like (A, j )
29- push! (xsamples, mask )
24+ function backward_sampling (ixs, @nospecialize (xs:: Tuple ), iy, @nospecialize (y), samples :: Samples , size_dict)
25+ eliminated_variables = setdiff ( vcat (ixs ... ), iy)
26+ newiy = eliminated_variables
27+ newixs = eliminated_variables
28+ code = DynamicEinCode (newixs, newiy )
29+ totalset = CartesianIndices ( map (x -> size_dict[x], eliminated_variables))
30+ for (i, sample) in enumerate (samples . samples )
31+ newxs = [ get_slice (x, ix, iy => sample) for (x, ix) in zip (xs, ixs)]
32+ newy = Array ( get_slice (y, iy, iy => sample))[]
33+ probabilities = einsum (code, newxs, size_dict) / newy
34+ config = StatsBase . sample (totalset, weights = StatsBase . Weights (probabilities) )
35+ update_sample! (samples, i, eliminated_variables => config )
3036 end
3137 return xsamples
3238end
0 commit comments