Skip to content

Commit 365b689

Browse files
committed
let marginals return dict
1 parent a066560 commit 365b689

File tree

7 files changed

+78
-24
lines changed

7 files changed

+78
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorInference"
22
uuid = "c2297e78-99bd-40ad-871d-f50e56b81012"
33
authors = ["Jin-Guo Liu", "Martin Roa Villescas"]
4-
version = "0.3.0"
4+
version = "0.4.0"
55

66
[deps]
77
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

examples/hard-core-lattice-gas/main.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ partition_func[]
5555

5656
# The marginal probabilities can be computed with the [`marginals`](@ref) function, which measures how likely a site is occupied.
5757
mars = marginals(pmodel)
58-
show_graph(graph; locs=sites, vertex_colors=[(1-b, 1-b, 1-b) for b in getindex.(mars, 2)], texts=fill("", nv(graph)))
58+
show_graph(graph; locs=sites, vertex_colors=[(b = mars[[i]][2]; (1-b, 1-b, 1-b)) for i in vertices(graph)], texts=fill("", nv(graph)))
5959
# The can see the sites at the corner is more likely to be occupied.
6060
# To obtain two-site correlations, one can set the variables to query marginal probabilities manually.
6161
pmodel2 = TensorNetworkModel(problem, β; mars=[[e.src, e.dst] for e in edges(graph)])
6262
mars = marginals(pmodel2);
6363

6464
# We show the probability that both sites on an edge are not occupied
65-
show_graph(graph; locs=sites, edge_colors=[(b=mar[1, 1]; (1-b, 1-b, 1-b)) for mar in mars], texts=fill("", nv(graph)), edge_line_width=5)
65+
show_graph(graph; locs=sites, edge_colors=[(b = mars[[e.src, e.dst]][1, 1]; (1-b, 1-b, 1-b)) for e in edges(graph)], texts=fill("", nv(graph)), edge_line_width=5)
6666

6767
# ## The most likely configuration
6868
# The MAP and MMAP can be used to get the most likely configuration given an evidence.
@@ -91,4 +91,4 @@ sum(config2)
9191
# The return value is a matrix, with the columns correspond to different samples.
9292
configs = sample(pmodel3, 1000)
9393
sizes = sum(configs; dims=1)
94-
[count(==(i), sizes) for i=0:34] # counting sizes
94+
[count(==(i), sizes) for i=0:34] # counting sizes

src/mar.jl

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,67 @@ end
124124
"""
125125
$(TYPEDSIGNATURES)
126126
127-
Returns the marginal probability distribution of variables.
128-
One can use `get_vars(tn)` to get the full list of variables in this tensor network.
127+
Query the marginals of the variables in a [`TensorNetworkModel`](@ref).
128+
The returned value is a dictionary of variables and their marginals, where a marginal is a joint probability distribution over the associated variables.
129+
By default, the marginals of all individual variables are returned.
130+
The marginal variables to query can be specified when constructing [`TensorNetworkModel`](@ref) as its field `mars`.
131+
It will affect the contraction order of the tensor network.
132+
133+
### Arguments
134+
- `tn`: the [`TensorNetworkModel`](@ref) to query.
135+
- `usecuda`: whether to use CUDA for tensor contraction.
136+
- `rescale`: whether to rescale the tensors during contraction.
137+
138+
### Example
139+
The following example is from [`examples/asia/main.jl`](@ref).
140+
141+
```jldoctest; setup = :(using TensorInference, Random; Random.seed!(0))
142+
julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia", "asia.uai"));
143+
144+
julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0))
145+
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
146+
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
147+
contraction time = 2^6.022, space = 2^2.0, read-write = 2^7.077
148+
149+
julia> marginals(tn)
150+
Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
151+
[8] => [0.450138, 0.549863]
152+
[3] => [0.5, 0.5]
153+
[1] => [1.0]
154+
[5] => [0.45, 0.55]
155+
[4] => [0.055, 0.945]
156+
[6] => [0.10225, 0.89775]
157+
[7] => [0.145092, 0.854908]
158+
[2] => [0.05, 0.95]
159+
160+
julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]])
161+
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
162+
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
163+
contraction time = 2^7.781, space = 2^5.0, read-write = 2^8.443
164+
165+
julia> marginals(tn2)
166+
Dict{Vector{Int64}, Matrix{Float64}} with 2 entries:
167+
[2, 3] => [0.025 0.025; 0.475 0.475]
168+
[3, 4] => [0.05 0.45; 0.005 0.495]
169+
```
170+
171+
In this example, we first set the evidence of variable 1 to 0, then we query the marginals of all individual variables.
172+
The returned values is a dictionary, the key are query variables, and the value are the corresponding marginals.
173+
The marginals are vectors, with its entries corresponding to the probability of the variable taking the value 0 and 1, respectively.
174+
For evidence variable 1, the marginal is always `[1.0]`, since it is fixed to 0.
175+
176+
Then we set the marginal variables to query to be variable 2 and 3, and variable 3 and 4, respectively.
177+
The joint marginals may or may not increase the contraction time and space.
178+
Here, the contraction space complexity is increased from 2^2.0 to 2^5.0, and the contraction time complexity is increased from 2^5.977 to 2^7.781.
179+
The output marginals are joint probabilities of the query variables represented by tensors.
129180
"""
130-
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Vector
181+
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}}
131182
# sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
132183
cost, grads = cost_and_gradient(tn.code, adapt_tensors(tn; usecuda, rescale))
133184
@debug "cost = $cost"
134185
if rescale
135-
return LinearAlgebra.normalize!.(getfield.(grads[1:length(tn.mars)], :normalized_value), 1)
186+
return Dict(zip(tn.mars, LinearAlgebra.normalize!.(getfield.(grads[1:length(tn.mars)], :normalized_value), 1)))
136187
else
137-
return LinearAlgebra.normalize!.(grads[1:length(tn.mars)], 1)
188+
return Dict(zip(tn.mars, LinearAlgebra.normalize!.(grads[1:length(tn.mars)], 1)))
138189
end
139190
end

test/cuda.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ CUDA.allowscalar(false)
1111
tn = TensorNetworkModel(model; optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40), evidence)
1212
@debug contraction_complexity(tn)
1313
@time marginals2 = marginals(tn; usecuda = true)
14-
@test all(x -> x isa CuArray, marginals2)
14+
@test all(x -> x.second isa CuArray, marginals2)
1515
# for dangling vertices, the output size is 1.
1616
npass = 0
1717
for i in 1:(model.nvars)
18-
npass += (length(marginals2[i]) == 1 && reference_solution[i] == [0.0, 1]) || isapprox(Array(marginals2[i]), reference_solution[i]; atol = 1e-6)
18+
npass += (length(marginals2[[i]]) == 1 && reference_solution[i] == [0.0, 1]) || isapprox(Array(marginals2[[i]]), reference_solution[i]; atol = 1e-6)
1919
end
2020
@test npass == model.nvars
2121
end

test/generictensornetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using GenericTensorNetworks, TensorInference
77
g = GenericTensorNetworks.Graphs.smallgraph(:petersen)
88
problem = IndependentSet(g)
99
model = TensorNetworkModel(problem, β; mars=[[2, 3]])
10-
mars = marginals(model)[1]
10+
mars = marginals(model)[[2, 3]]
1111
problem2 = IndependentSet(g; openvertices=[2,3])
1212
mars2 = TensorInference.normalize!(GenericTensorNetworks.solve(problem2, PartitionFunction(β)), 1)
1313
@test mars mars2

test/mar.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131
# compute marginals
3232
ti_sol = marginals(tn)
3333
ref_sol[collect(keys(evidence))] .= fill([1.0], length(evidence)) # imitate dummy vars
34-
@test isapprox(ti_sol, ref_sol; atol = 1e-5)
34+
@test isapprox([ti_sol[[i]] for i=1:length(ref_sol)], ref_sol; atol = 1e-5)
3535
end
3636

3737
@testset "UAI Reference Solution Comparison" begin
@@ -63,7 +63,7 @@ end
6363
@debug contraction_complexity(tn)
6464
ti_sol = marginals(tn)
6565
ref_sol[collect(keys(evidence))] .= fill([1.0], length(evidence)) # imitate dummy vars
66-
@test isapprox(ti_sol, ref_sol; atol = 1e-4)
66+
@test isapprox([ti_sol[[i]] for i=1:length(ref_sol)], ref_sol; atol = 1e-4)
6767
end
6868
end
6969
end
@@ -120,15 +120,18 @@ end
120120
mars = marginals(tnet)
121121
tnet23 = TensorNetworkModel(model; openvars=[2,3])
122122
tnet34 = TensorNetworkModel(model; openvars=[3,4])
123-
@test mars[1] probability(tnet23)
124-
@test mars[2] probability(tnet34)
123+
@test mars[[2 ,3]] probability(tnet23)
124+
@test mars[[3, 4]] probability(tnet34)
125125

126-
tnet1 = TensorNetworkModel(model; mars=[[2, 3], [3, 4]], evidence=Dict(3=>1))
127-
tnet2 = TensorNetworkModel(model; mars=[[2, 3], [3, 4]], evidence=Dict(3=>0))
126+
vars = [[2, 4], [3, 5]]
127+
tnet1 = TensorNetworkModel(model; mars=vars, evidence=Dict(3=>1))
128+
tnet2 = TensorNetworkModel(model; mars=vars, evidence=Dict(3=>0))
128129
mars1 = marginals(tnet1)
129130
mars2 = marginals(tnet2)
130131
update_evidence!(tnet1, Dict(3=>0))
131132
mars1b = marginals(tnet1)
132-
@test !(mars1 mars2)
133-
@test mars1b mars2
133+
for k in vars
134+
@test !(mars1[k] mars2[k])
135+
@test mars1b[k] mars2[k]
136+
end
134137
end

test/sampling.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ using TensorInference, Test
4949
n = 10000
5050
tnet = TensorNetworkModel(model)
5151
samples = sample(tnet, n)
52-
mars = getindex.(marginals(tnet), 2)
52+
mars = marginals(tnet)
5353
mars_sample = [count(s->s[k]==(1), samples) for k=1:8] ./ n
54-
@test isapprox(mars, mars_sample, atol=0.05)
54+
@test isapprox([mars[[i]][2] for i=1:8], mars_sample, atol=0.05)
5555

5656
# fix the evidence
5757
tnet = TensorNetworkModel(model, optimizer=TreeSA(), evidence=Dict(7=>1))
5858
samples = sample(tnet, n)
59-
mars = getindex.(marginals(tnet), 1)
59+
mars = marginals(tnet)
6060
mars_sample = [count(s->s[k]==(0), samples) for k=1:8] ./ n
61-
@test isapprox([mars[1:6]..., mars[8]], [mars_sample[1:6]..., mars_sample[8]], atol=0.05)
61+
@test isapprox([[mars[[i]][1] for i=1:6]..., mars[[8]][1]], [mars_sample[1:6]..., mars_sample[8]], atol=0.05)
6262
end

0 commit comments

Comments
 (0)