@@ -22,21 +22,40 @@ The backward rule for tropical einsum.
2222* `size_dict` is a key-value map from tensor label to dimension size.
2323"""
2424function backward_sampling (ixs, @nospecialize (xs:: Tuple ), iy, @nospecialize (y), samples:: Samples , size_dict)
25+ idx4label (totalset, labels) = map (v-> findfirst (== (v), totalset), labels)
2526 eliminated_variables = setdiff (vcat (ixs... ), iy)
27+ eliminated_locs = idx4label (samples. labels, eliminated_variables)
28+ samples. setmask[eliminated_locs] .= true
29+
30+ # the contraction code to get probability
2631 newiy = eliminated_variables
27- newixs = eliminated_variables
32+ iy_in_sample = idx4labels (samples. labels, iy)
33+ slice_y_dim = collect (1 : length (iy))
34+ newixs = map (ix-> setdiff (ix, iy), ixs)
35+ ix_in_sample = map (ix-> idx4labels (samples. labels, ix ∩ iy), ixs)
36+ slice_xs_dim = map (ix-> idx4label (ix, ix ∩ iy), ixs)
2837 code = DynamicEinCode (newixs, newiy)
29- totalset = CartesianIndices (map (x-> size_dict[x], eliminated_variables))
38+
39+ totalset = CartesianIndices (map (x-> size_dict[x], eliminated_variables)... )
3040 for (i, sample) in enumerate (samples. samples)
31- newxs = [get_slice (x, ix, iy => sample) for (x, ix ) in zip (xs, ixs )]
32- newy = Array (get_slice (y, iy, iy => sample))[]
41+ newxs = [get_slice (x, dimx, sample[ixloc] ) for (x, dimx, ixloc ) in zip (xs, slice_xs_dim, ix_in_sample )]
42+ newy = Array (get_slice (y, slice_y_dim, sample[iy_in_sample] ))[]
3343 probabilities = einsum (code, newxs, size_dict) / newy
3444 config = StatsBase. sample (totalset, weights= StatsBase. Weights (probabilities))
35- update_sample! (samples, i, eliminated_variables=> config)
45+ # update the samples
46+ samples. samples[i][eliminated_locs] .= config. I
3647 end
3748 return xsamples
3849end
3950
51+ # type unstable
52+ function get_slice (x, dim, config)
53+ for (d, c) in zip (dim, config)
54+ x = selectdim (x, d, c)
55+ end
56+ return x
57+ end
58+
4059"""
4160$(TYPEDSIGNATURES)
4261
0 commit comments