Skip to content

Commit 64c4e45

Browse files
committed
fix tests
1 parent 4a14c7c commit 64c4e45

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

src/sampling.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
############ Sampling ############
2+
3+
########### Backward propagating sampling process ##############
4+
function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Real}} where {M}, y, size_dict, dy::Samples)
5+
return backward_sampling(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), y, dy, size_dict)
6+
end
7+
8+
"""
9+
$(TYPEDSIGNATURES)
10+
11+
The backward rule for tropical einsum.
12+
13+
* `ixs` and `xs` are labels and tensor data for input tensors,
14+
* `iy` and `y` are labels and tensor data for the output tensor,
15+
* `ysamples` is the samples generated on the output tensor,
16+
* `size_dict` is a key-value map from tensor label to dimension size.
17+
"""
18+
function backward_sampling(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), @nospecialize(ysamples), size_dict)
19+
xsamples = []
20+
for i in eachindex(ixs)
21+
nixs = OMEinsum._insertat(ixs, i, iy)
22+
nxs = OMEinsum._insertat(xs, i, y)
23+
niy = ixs[i]
24+
A = einsum(EinCode(nixs, niy), nxs, size_dict)
25+
26+
# compute the mask, one of its entry in `A^{-1}` that equal to the corresponding entry in `X` is masked to true.
27+
j = argmax(xs[i] ./ inv.(A))
28+
mask = onehot_like(A, j)
29+
push!(xsamples, mask)
30+
end
31+
return xsamples
32+
end
33+
34+
"""
35+
$(TYPEDSIGNATURES)
36+
37+
Sample a tensor network based probabilistic model.
38+
"""
39+
function sample(tn::TensorNetworkModel; usecuda = false)::AbstractArray{<:Real}
40+
# generate tropical tensors with its elements being log(p).
41+
tensors = adapt_tensors(tn; usecuda, rescale = false)
42+
return tn.code(tensors...)
43+
end

test/mmap.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
using Test
22
using OMEinsum
33
using TensorInference
4-
using Random
54

65
@testset "clustering" begin
76
ixs = [[1, 2, 3], [2, 3, 4], [4, 5, 6]]
87
@test TensorInference.connected_clusters(ixs, [2, 3, 6]) == [[2, 3] => [1, 2], [6] => [3]]
98
end
109

1110
@testset "mmap" begin
12-
Random.seed!(5)
1311
################# Load problem ####################
1412
instance = read_uai_problem("Promedus_14")
1513

0 commit comments

Comments
 (0)