Skip to content

Commit 13af59a

Browse files
committed
update
1 parent 91fb22b commit 13af59a

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

src/sampling.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,40 @@ The backward rule for tropical einsum.
2222
* `size_dict` is a key-value map from tensor label to dimension size.
2323
"""
2424
function backward_sampling(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), samples::Samples, size_dict)
25+
idx4label(totalset, labels) = map(v->findfirst(==(v), totalset), labels)
2526
eliminated_variables = setdiff(vcat(ixs...), iy)
27+
eliminated_locs = idx4label(samples.labels, eliminated_variables)
28+
samples.setmask[eliminated_locs] .= true
29+
30+
# the contraction code to get probability
2631
newiy = eliminated_variables
27-
newixs = eliminated_variables
32+
iy_in_sample = idx4labels(samples.labels, iy)
33+
slice_y_dim = collect(1:length(iy))
34+
newixs = map(ix->setdiff(ix, iy), ixs)
35+
ix_in_sample = map(ix->idx4labels(samples.labels, ix iy), ixs)
36+
slice_xs_dim = map(ix->idx4label(ix, ix iy), ixs)
2837
code = DynamicEinCode(newixs, newiy)
29-
totalset = CartesianIndices(map(x->size_dict[x], eliminated_variables))
38+
39+
totalset = CartesianIndices(map(x->size_dict[x], eliminated_variables)...)
3040
for (i, sample) in enumerate(samples.samples)
31-
newxs = [get_slice(x, ix, iy=>sample) for (x, ix) in zip(xs, ixs)]
32-
newy = Array(get_slice(y, iy, iy=>sample))[]
41+
newxs = [get_slice(x, dimx, sample[ixloc]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
42+
newy = Array(get_slice(y, slice_y_dim, sample[iy_in_sample]))[]
3343
probabilities = einsum(code, newxs, size_dict) / newy
3444
config = StatsBase.sample(totalset, weights=StatsBase.Weights(probabilities))
35-
update_sample!(samples, i, eliminated_variables=>config)
45+
# update the samples
46+
samples.samples[i][eliminated_locs] .= config.I
3647
end
3748
return xsamples
3849
end
3950

51+
# type unstable
52+
function get_slice(x, dim, config)
53+
for (d, c) in zip(dim, config)
54+
x = selectdim(x, d, c)
55+
end
56+
return x
57+
end
58+
4059
"""
4160
$(TYPEDSIGNATURES)
4261

0 commit comments

Comments
 (0)