Skip to content

Commit 147f9dd

Browse files
authored
Merge pull request #22 from TensorBFS/jg/asia-example
Add Asia example and polish the interfaces
2 parents 066c4d3 + 4ce52ba commit 147f9dd

File tree

15 files changed

+189
-57
lines changed

15 files changed

+189
-57
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ pkg> add TensorInference
2828
```
2929

3030
To update, just type `up` in the package mode.
31+
32+
## Example
33+
Please check the [example](example) folder.

benchmark/bench_mmap.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ problem = read_uai_problem("Promedus_14")
1010
optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)
1111

1212
# Does not marginalize any var
13-
mmap1 = MMAPModel(problem; marginalizedvertices = Int[], optimizer)
13+
mmap1 = MMAPModel(problem; marginalized = Int[], optimizer)
1414
SUITE["mmap-1"] = @benchmarkable maximum_logp(mmap1)
1515

1616
# Marginalizes all vars
17-
mmap2 = MMAPModel(problem; marginalizedvertices = collect(1:(problem.nvars)), optimizer)
17+
mmap2 = MMAPModel(problem; marginalized = collect(1:(problem.nvars)), optimizer)
1818
SUITE["mmap-2"] = @benchmarkable maximum_logp(mmap2)
1919

2020
# Does not optimize over open vertices
21-
mmap3 = MMAPModel(problem; marginalizedvertices = [2, 4, 6], optimizer)
21+
mmap3 = MMAPModel(problem; marginalized = [2, 4, 6], optimizer)
2222
SUITE["mmap-3"] = @benchmarkable most_probable_config(mmap3)
2323

2424
end # module

example/asia/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# The "Asia" Bayesian network
2+
3+
Please check the Julia code [asia.jl](asia.jl).
4+
5+
The variables and factors for the asia model is described in the [asia.uai](asia.uai) file.
6+
The UAI file format is detailed in:
7+
https://personal.utdallas.edu/~vibhav.gogate/uai16-evaluation/uaiformat.html
8+
9+
The meanings of variables and factors as listed bellow.
10+
11+
## Variables
12+
index 0 is mapped to yes, 1 is mapped to no.
13+
14+
1. visit to Asia (a)
15+
2. tuberculosis (t)
16+
3. smoking (s)
17+
4. lung cancer (l)
18+
5. bronchitis (b)
19+
6. either tub. or lung cancer (e)
20+
7. positive X-ray (x)
21+
8. dyspnoea (d)
22+
23+
## Factors
24+
1. p(a)
25+
2. p(t|a)
26+
3. p(s)
27+
4. p(l|s)
28+
5. p(b|s)
29+
6. p(e|l,t)
30+
7. p(x|e)
31+
8. p(d|e,b)

example/asia/asia.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using TensorInference
2+
3+
# Load the model that detailed in the README and `asia.uai`.
4+
instance = uai_problem_from_file(joinpath(@__DIR__, "asia.uai"))
5+
tnet = TensorNetworkModel(instance)
6+
7+
# Get the probabilities (PR)
8+
probability(tnet)
9+
10+
# Get the marginal probabilities (MAR)
11+
marginals(tnet) .|> first
12+
13+
# Set the evidence variables "X-ray" (7) to be positive.
14+
set_evidence!(instance, 7=>0)
15+
16+
# Since the evidence variable may change the contraction order, we re-compute the tensor network.
17+
tnet = TensorNetworkModel(instance)
18+
19+
# Get the maximum log-probabilities (MAP)
20+
maximum_logp(tnet)
21+
22+
# Get not only the maximum log-probability, but also the most probable conifguration
23+
# In the most probable configuration, the most probable one is the patient smoke (3) and has lung cancer (4)
24+
logp, cfg = most_probable_config(tnet)
25+
26+
# Get the maximum log-probabilities (MMAP)
27+
# To get the probability of lung cancer, we need to marginalize out other variables.
28+
mmap = MMAPModel(instance; marginalized=[1,2,3,5,6,8])
29+
# We get the most probable configurations on [4, 7]
30+
most_probable_config(mmap)
31+
# The total probability of having lung cancer is roughly half.
32+
log_probability(mmap, [1, 0])
33+
log_probability(mmap, [0, 0])

example/asia/asia.uai

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
MARKOV
2+
8
3+
2 2 2 2 2 2 2 2
4+
8
5+
1 0
6+
2 1 0
7+
1 2
8+
2 3 2
9+
2 4 2
10+
3 5 3 1
11+
2 6 5
12+
3 7 5 4
13+
14+
2
15+
0.01
16+
0.99
17+
18+
4
19+
0.05 0.01
20+
0.95 0.99
21+
22+
2
23+
0.5
24+
0.5
25+
26+
4
27+
0.1 0.01
28+
0.9 0.99
29+
30+
4
31+
0.6 0.3
32+
0.4 0.7
33+
34+
8
35+
1 1 1 0
36+
0 0 0 1
37+
38+
4
39+
0.98 0.05
40+
0.02 0.95
41+
42+
8
43+
0.9 0.7 0.8 0.1
44+
0.1 0.3 0.2 0.9

src/Core.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,30 @@ struct UAIInstance{ET, FT <: Factor{ET}}
3737
reference_marginals::Vector{Vector{ET}}
3838
end
3939

40+
"""
41+
$TYPEDSIGNATURES
42+
43+
Set the evidence of an UAI instance.
44+
45+
### Examples
46+
```jldoctest; setup=:(using TensorInference)
47+
julia> problem = read_uai_problem("Promedus_14"); problem.obsvars, problem.obsvals
48+
([42, 48, 27, 30, 29, 15, 124, 5, 148], [1, 1, 1, 1, 1, 1, 1, 1, 1])
49+
50+
julia> set_evidence!(problem, 2=>0, 4=>1); problem.obsvars, problem.obsvals
51+
([2, 4], [0, 1])
52+
```
53+
"""
54+
function set_evidence!(uai::UAIInstance, pairs::Pair{Int}...)
55+
empty!(uai.obsvars)
56+
empty!(uai.obsvals)
57+
for (var, val) in pairs
58+
push!(uai.obsvars, var)
59+
push!(uai.obsvals, val)
60+
end
61+
return uai
62+
end
63+
4064
"""
4165
$(TYPEDEF)
4266
@@ -58,7 +82,7 @@ end
5882
function Base.show(io::IO, tn::TensorNetworkModel)
5983
open = getiyv(tn.code)
6084
variables = join([string_var(var, open, tn.fixedvertices) for var in tn.vars], ", ")
61-
tc, sc, rw = timespacereadwrite_complexity(tn)
85+
tc, sc, rw = contraction_complexity(tn)
6286
println(io, "$(typeof(tn))")
6387
println(io, "variables: $variables")
6488
print_tcscrw(io, tc, sc, rw)
@@ -95,7 +119,7 @@ function TensorNetworkModel(
95119
instance.cards,
96120
instance.factors;
97121
openvertices,
98-
fixedvertices = Dict(zip(instance.obsvars, instance.obsvals .- 1)),
122+
fixedvertices = Dict(zip(instance.obsvars, instance.obsvals)),
99123
optimizer,
100124
simplifier
101125
)

src/RescaledArray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Base.show(io::IO, c::RescaledArray) = print(io, "exp($(c.log_factor)) * $(c.norm
1515
Base.show(io::IO, ::MIME"text/plain", c::RescaledArray) = Base.show(io, c)
1616
Base.Array(c::RescaledArray) = rmul!(Array(c.normalized_value), exp(c.log_factor))
1717
Base.copy(c::RescaledArray) = RescaledArray(c.log_factor, copy(c.normalized_value))
18+
Base.getindex(r::RescaledArray, indices...) = map(x->x * exp(r.log_factor), getindex(r.normalized_value, indices...))
1819

1920
"""
2021
$(TYPEDSIGNATURES)

src/TensorInference.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ using TropicalGEMM
77

88
# reexport OMEinsum functions
99
export RescaledArray
10-
export timespace_complexity, timespacereadwrite_complexity, TreeSA, GreedyMethod, KaHyParBipartite, SABipartite, MergeGreedy, MergeVectors
10+
export contraction_complexity, TreeSA, GreedyMethod, KaHyParBipartite, SABipartite, MergeGreedy, MergeVectors
1111

1212
# read and load uai files
13-
export read_uai_file, read_td_file, read_uai_evid_file, read_uai_mar_file, read_uai_problem
13+
export read_uai_file, read_td_file, read_uai_evid_file, read_uai_mar_file, read_uai_problem, uai_problem_from_file
14+
export set_evidence!
1415

1516
# marginals
1617
export TensorNetworkModel, get_vars, get_cards, log_probability, probability, marginals

src/maxprob.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The backward rule for tropical einsum.
1919
function backward_tropical(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), @nospecialize(ymask), size_dict)
2020
y .= masked_inv.(y, ymask)
2121
masks = []
22-
for i in 1:length(ixs)
22+
for i in eachindex(ixs)
2323
nixs = OMEinsum._insertat(ixs, i, iy)
2424
nxs = OMEinsum._insertat(xs, i, y)
2525
niy = ixs[i]
@@ -44,21 +44,21 @@ $(TYPEDSIGNATURES)
4444
4545
Returns the largest log-probability and the most probable configuration.
4646
"""
47-
function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Tropical, Vector}
47+
function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector}
4848
vars = get_vars(tn)
4949
tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
5050
logp, grads = cost_and_gradient(tn.code, tensors)
5151
# use Array to convert CuArray to CPU arrays
52-
return Array(logp)[], map(k -> haskey(tn.fixedvertices, vars[k]) ? tn.fixedvertices[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
52+
return content(Array(logp)[]), map(k -> haskey(tn.fixedvertices, vars[k]) ? tn.fixedvertices[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
5353
end
5454

5555
"""
5656
$(TYPEDSIGNATURES)
5757
5858
Returns an output array containing largest log-probabilities.
5959
"""
60-
function maximum_logp(tn::TensorNetworkModel; usecuda = false)::AbstractArray{<:Tropical}
60+
function maximum_logp(tn::TensorNetworkModel; usecuda = false)::AbstractArray{<:Real}
6161
# generate tropical tensors with its elements being log(p).
6262
tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
63-
return tn.code(tensors...)
63+
return map(content, tn.code(tensors...))
6464
end

src/mmap.jl

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ end
3434
function Base.show(io::IO, mmap::MMAPModel)
3535
open = getiyv(mmap.code)
3636
variables = join([string_var(var, open, mmap.fixedvertices) for var in mmap.vars], ", ")
37-
tc, sc, rw = timespacereadwrite_complexity(mmap)
37+
tc, sc, rw = contraction_complexity(mmap)
3838
println(io, "$(typeof(mmap))")
3939
println(io, "variables: $variables")
4040
println(io, "marginalized variables: $(map(x->x.eliminated_vars, mmap.clusters))")
@@ -58,16 +58,16 @@ end
5858
"""
5959
$(TYPEDSIGNATURES)
6060
"""
61-
function MMAPModel(instance::UAIInstance; marginalizedvertices, openvertices = (), optimizer = GreedyMethod(), simplifier = nothing)::MMAPModel
61+
function MMAPModel(instance::UAIInstance; marginalized, openvertices = (), optimizer = GreedyMethod(), simplifier = nothing)::MMAPModel
6262
return MMAPModel(
63-
1:(instance.nvars), instance.factors; marginalizedvertices, fixedvertices = Dict(zip(instance.obsvars, instance.obsvals .- 1)), optimizer, simplifier, openvertices
63+
1:(instance.nvars), instance.factors; marginalized, fixedvertices = Dict(zip(instance.obsvars, instance.obsvals)), optimizer, simplifier, openvertices
6464
)
6565
end
6666

6767
"""
6868
$(TYPEDSIGNATURES)
6969
"""
70-
function MMAPModel(vars::AbstractVector{LT}, factors::Vector{<:Factor{T}}; marginalizedvertices, openvertices = (),
70+
function MMAPModel(vars::AbstractVector{LT}, factors::Vector{<:Factor{T}}; marginalized, openvertices = (),
7171
fixedvertices = Dict{LT, Int}(),
7272
optimizer = GreedyMethod(), simplifier = nothing,
7373
marginalize_optimizer = GreedyMethod(), marginalize_simplifier = nothing
@@ -81,7 +81,7 @@ function MMAPModel(vars::AbstractVector{LT}, factors::Vector{<:Factor{T}}; margi
8181
size_dict = OMEinsum.get_size_dict(all_ixs, all_tensors)
8282

8383
# detect clusters for marginalize variables
84-
subsets = connected_clusters(all_ixs, marginalizedvertices)
84+
subsets = connected_clusters(all_ixs, marginalized)
8585
clusters = Cluster{LT}[]
8686
ixs = Vector{LT}[]
8787
for (contracted, cluster) in subsets
@@ -96,10 +96,10 @@ function MMAPModel(vars::AbstractVector{LT}, factors::Vector{<:Factor{T}}; margi
9696
rem_indices = setdiff(1:length(all_ixs), vcat([c.second for c in subsets]...))
9797
remaining_tensors = all_tensors[rem_indices]
9898
code = optimize_code(EinCode([all_ixs[rem_indices]..., ixs...], iy), size_dict, optimizer, simplifier)
99-
return MMAPModel(setdiff(vars, marginalizedvertices), code, remaining_tensors, clusters, fixedvertices)
99+
return MMAPModel(setdiff(vars, marginalized), code, remaining_tensors, clusters, fixedvertices)
100100
end
101101

102-
function OMEinsum.timespacereadwrite_complexity(mmap::MMAPModel{LT}) where {LT}
102+
function OMEinsum.contraction_complexity(mmap::MMAPModel{LT}) where {LT}
103103
# extract size
104104
size_dict = Dict(zip(get_vars(mmap), get_cards(mmap; fixedisone = true)))
105105
sc = -Inf
@@ -111,18 +111,17 @@ function OMEinsum.timespacereadwrite_complexity(mmap::MMAPModel{LT}) where {LT}
111111
# the head sector are for unity tensors.
112112
size_dict[cluster.eliminated_vars[k]] = length(cluster.tensors[k])
113113
end
114-
tc, sci, rw = timespacereadwrite_complexity(cluster.code, size_dict)
114+
tc, sci, rw = contraction_complexity(cluster.code, size_dict)
115115
push!(tcs, tc)
116116
push!(rws, rw)
117117
sc = max(sc, sci)
118118
end
119119

120-
tc, sci, rw = timespacereadwrite_complexity(mmap.code, size_dict)
120+
tc, sci, rw = contraction_complexity(mmap.code, size_dict)
121121
push!(tcs, tc)
122122
push!(rws, tc)
123123
OMEinsum.OMEinsumContractionOrders.log2sumexp2(tcs), max(sc, sci), OMEinsum.OMEinsumContractionOrders.log2sumexp2(rws)
124124
end
125-
OMEinsum.timespace_complexity(mmap::MMAPModel) = timespacereadwrite_complexity(mmap)[1:2]
126125

127126
function adapt_tensors(mmap::MMAPModel; usecuda, rescale)
128127
return [adapt_tensors(mmap.code, mmap.tensors, mmap.fixedvertices; usecuda, rescale)...,
@@ -174,35 +173,29 @@ function visit_var!(var, vars::AbstractVector{LT}, ixs, visited_ixs, visited_var
174173
end
175174
end
176175

177-
"""
178-
$(TYPEDSIGNATURES)
179-
"""
180-
function most_probable_config(mmap::MMAPModel; usecuda = false)::Tuple{Tropical, Vector}
176+
function most_probable_config(mmap::MMAPModel; usecuda = false)::Tuple{Real, Vector}
181177
vars = get_vars(mmap)
182178
tensors = map(t -> OMEinsum.asarray(Tropical.(log.(t)), t), adapt_tensors(mmap; usecuda, rescale = false))
183179
logp, grads = cost_and_gradient(mmap.code, tensors)
184180
# use Array to convert CuArray to CPU arrays
185-
return Array(logp)[], map(k -> haskey(mmap.fixedvertices, vars[k]) ? mmap.fixedvertices[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
181+
return content(Array(logp)[]), map(k -> haskey(mmap.fixedvertices, vars[k]) ? mmap.fixedvertices[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
186182
end
187183

188-
"""
189-
$(TYPEDSIGNATURES)
190-
"""
191-
function maximum_logp(mmap::MMAPModel; usecuda = false)::AbstractArray{<:Tropical}
184+
function maximum_logp(mmap::MMAPModel; usecuda = false)::AbstractArray{<:Real}
192185
tensors = map(t -> OMEinsum.asarray(Tropical.(log.(t)), t), adapt_tensors(mmap; usecuda, rescale = false))
193-
return mmap.code(tensors...)
186+
return map(content, mmap.code(tensors...))
194187
end
195188

196-
"""
197-
$(TYPEDSIGNATURES)
198-
"""
199189
function log_probability(mmap::MMAPModel, config::Union{Dict, AbstractVector}; rescale = true, usecuda = false)::Real
200190
@assert length(get_vars(mmap)) == length(config)
201191
fixedvertices = config isa AbstractVector ? Dict(zip(get_vars(mmap), config)) : config
202192
assign = merge(mmap.fixedvertices, fixedvertices)
203193
# two contributions to the probability, not-clustered tensors and clusters.
204194
m1 = sum(x -> log(x[2][(getindex.(Ref(assign), x[1]) .+ 1)...]), zip(getixsv(mmap.code), mmap.tensors))
205-
m2 = sum(cluster -> probability(cluster; fixedvertices, usecuda, rescale).log_factor, mmap.clusters)
195+
m2 = sum(mmap.clusters) do cluster
196+
p = probability(cluster; fixedvertices, usecuda, rescale)
197+
rescale ? p.log_factor : log(p[])
198+
end
206199
return m1 + m2
207200
end
208201

0 commit comments

Comments
 (0)