Skip to content

Commit 91fb22b

Browse files
committed
update
1 parent 3a791a8 commit 91fb22b

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

src/sampling.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Real}} where {
55
return backward_sampling(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), y, dy, size_dict)
66
end
77

8+
struct Samples{L}
9+
samples::Vector{Vector{Int}}
10+
labels::Vector{L}
11+
setmask::Vector{Bool}
12+
end
13+
814
"""
915
$(TYPEDSIGNATURES)
1016
@@ -15,18 +21,18 @@ The backward rule for tropical einsum.
1521
* `ysamples` is the samples generated on the output tensor,
1622
* `size_dict` is a key-value map from tensor label to dimension size.
1723
"""
18-
function backward_sampling(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), @nospecialize(ysamples), size_dict)
19-
xsamples = []
20-
for i in eachindex(ixs)
21-
nixs = OMEinsum._insertat(ixs, i, iy)
22-
nxs = OMEinsum._insertat(xs, i, y)
23-
niy = ixs[i]
24-
A = einsum(EinCode(nixs, niy), nxs, size_dict)
25-
26-
# compute the mask, one of its entry in `A^{-1}` that equal to the corresponding entry in `X` is masked to true.
27-
j = argmax(xs[i] ./ inv.(A))
28-
mask = onehot_like(A, j)
29-
push!(xsamples, mask)
24+
function backward_sampling(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), samples::Samples, size_dict)
25+
eliminated_variables = setdiff(vcat(ixs...), iy)
26+
newiy = eliminated_variables
27+
newixs = eliminated_variables
28+
code = DynamicEinCode(newixs, newiy)
29+
totalset = CartesianIndices(map(x->size_dict[x], eliminated_variables))
30+
for (i, sample) in enumerate(samples.samples)
31+
newxs = [get_slice(x, ix, iy=>sample) for (x, ix) in zip(xs, ixs)]
32+
newy = Array(get_slice(y, iy, iy=>sample))[]
33+
probabilities = einsum(code, newxs, size_dict) / newy
34+
config = StatsBase.sample(totalset, weights=StatsBase.Weights(probabilities))
35+
update_sample!(samples, i, eliminated_variables=>config)
3036
end
3137
return xsamples
3238
end

0 commit comments

Comments
 (0)