Skip to content

Commit 3145d8f

Browse files
committed
new asia example
1 parent 7c23c46 commit 3145d8f

File tree

15 files changed

+140
-79
lines changed

15 files changed

+140
-79
lines changed

benchmark/bench_mmap.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ problem = read_uai_problem("Promedus_14")
1010
optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)
1111

1212
# Does not marginalize any var
13-
mmap1 = MMAPModel(problem; marginalizedvertices = Int[], optimizer)
13+
mmap1 = MMAPModel(problem; marginalized = Int[], optimizer)
1414
SUITE["mmap-1"] = @benchmarkable maximum_logp(mmap1)
1515

1616
# Marginalizes all vars
17-
mmap2 = MMAPModel(problem; marginalizedvertices = collect(1:(problem.nvars)), optimizer)
17+
mmap2 = MMAPModel(problem; marginalized = collect(1:(problem.nvars)), optimizer)
1818
SUITE["mmap-2"] = @benchmarkable maximum_logp(mmap2)
1919

2020
# Does not optimize over open vertices
21-
mmap3 = MMAPModel(problem; marginalizedvertices = [2, 4, 6], optimizer)
21+
mmap3 = MMAPModel(problem; marginalized = [2, 4, 6], optimizer)
2222
SUITE["mmap-3"] = @benchmarkable most_probable_config(mmap3)
2323

2424
end # module

example/asia/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# The "Asia" Bayesian network
22

3-
The variables and factors in `.uai` files are labelled as below.
3+
Please check the Julia code [asia.jl](asia.jl).
4+
5+
The variables and factors for the asia model is described in the [asia.uai](asia.uai) file.
6+
The UAI file format is detailed in:
7+
https://personal.utdallas.edu/~vibhav.gogate/uai16-evaluation/uaiformat.html
8+
9+
The meanings of variables and factors as listed bellow.
410

511
## Variables
612
index 0 is mapped to yes, 1 is mapped to no.

example/asia/asia.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,33 @@
11
using TensorInference
22

3-
problem = uai_problem_from_file(joinpath(@__DIR__, "data/asia.uai"))
4-
tnet = TensorNetworkModel(problem)
5-
marginals(problem)
3+
# Load the model that detailed in the README and `asia.uai`.
4+
instance = uai_problem_from_file(joinpath(@__DIR__, "asia.uai"))
5+
tnet = TensorNetworkModel(instance)
6+
7+
# Get the probabilities (PR)
8+
probability(tnet)
9+
10+
# Get the marginal probabilities (MAR)
11+
marginals(tnet) .|> first
12+
13+
# Set the evidence variables "X-ray" (7) to be positive.
14+
set_evidence!(instance, 7=>0)
15+
16+
# Since the evidence variable may change the contraction order, we re-compute the tensor network.
17+
tnet = TensorNetworkModel(instance)
18+
19+
# Get the maximum log-probabilities (MAP)
20+
maximum_logp(tnet)
21+
22+
# Get not only the maximum log-probability, but also the most probable conifguration
23+
# In the most probable configuration, the most probable one is the patient smoke (3) and has lung cancer (4)
24+
logp, cfg = most_probable_config(tnet)
25+
26+
# Get the maximum log-probabilities (MMAP)
27+
# To get the probability of lung cancer, we need to marginalize out other variables.
28+
mmap = MMAPModel(instance; marginalized=[1,2,3,5,6,8])
29+
# We get the most probable configurations on [4, 7]
30+
most_probable_config(mmap)
31+
# The total probability of having lung cancer is roughly half.
32+
log_probability(mmap, [1, 0])
33+
log_probability(mmap, [0, 0])

example/asia/asia.uai

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
MARKOV
2+
8
3+
2 2 2 2 2 2 2 2
4+
8
5+
1 0
6+
2 1 0
7+
1 2
8+
2 3 2
9+
2 4 2
10+
3 5 3 1
11+
2 6 5
12+
3 7 5 4
13+
14+
2
15+
0.01
16+
0.99
17+
18+
4
19+
0.05 0.01
20+
0.95 0.99
21+
22+
2
23+
0.5
24+
0.5
25+
26+
4
27+
0.1 0.01
28+
0.9 0.99
29+
30+
4
31+
0.6 0.3
32+
0.4 0.7
33+
34+
8
35+
1 1 1 0
36+
0 0 0 1
37+
38+
4
39+
0.98 0.05
40+
0.02 0.95
41+
42+
8
43+
0.9 0.7 0.8 0.1
44+
0.1 0.3 0.2 0.9

example/asia/data/asia.uai

Lines changed: 0 additions & 36 deletions
This file was deleted.

example/asia/data/asia.uai.evid

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/Core.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,30 @@ struct UAIInstance{ET, FT <: Factor{ET}}
3737
reference_marginals::Vector{Vector{ET}}
3838
end
3939

40+
"""
41+
$TYPEDSIGNATURES
42+
43+
Set the evidence of an UAI instance.
44+
45+
### Examples
46+
```jldoctest; setup=:(using TensorInference)
47+
julia> problem = read_uai_problem("Promedus_14"); problem.obsvars, problem.obsvals
48+
([42, 48, 27, 30, 29, 15, 124, 5, 148], [1, 1, 1, 1, 1, 1, 1, 1, 1])
49+
50+
julia> set_evidence!(problem, 2=>0, 4=>1); problem.obsvars, problem.obsvals
51+
([2, 4], [0, 1])
52+
```
53+
"""
54+
function set_evidence!(uai::UAIInstance, pairs::Pair{Int}...)
55+
empty!(uai.obsvars)
56+
empty!(uai.obsvals)
57+
for (var, val) in pairs
58+
push!(uai.obsvars, var)
59+
push!(uai.obsvals, val)
60+
end
61+
return uai
62+
end
63+
4064
"""
4165
$(TYPEDEF)
4266
@@ -95,7 +119,7 @@ function TensorNetworkModel(
95119
instance.cards,
96120
instance.factors;
97121
openvertices,
98-
fixedvertices = Dict(zip(instance.obsvars, instance.obsvals .- 1)),
122+
fixedvertices = Dict(zip(instance.obsvars, instance.obsvals)),
99123
optimizer,
100124
simplifier
101125
)

src/RescaledArray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Base.show(io::IO, c::RescaledArray) = print(io, "exp($(c.log_factor)) * $(c.norm
1515
Base.show(io::IO, ::MIME"text/plain", c::RescaledArray) = Base.show(io, c)
1616
Base.Array(c::RescaledArray) = rmul!(Array(c.normalized_value), exp(c.log_factor))
1717
Base.copy(c::RescaledArray) = RescaledArray(c.log_factor, copy(c.normalized_value))
18+
Base.getindex(r::RescaledArray, indices...) = map(x->x * exp(r.log_factor), getindex(r.normalized_value, indices...))
1819

1920
"""
2021
$(TYPEDSIGNATURES)

src/TensorInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export contraction_complexity, TreeSA, GreedyMethod, KaHyParBipartite, SABiparti
1111

1212
# read and load uai files
1313
export read_uai_file, read_td_file, read_uai_evid_file, read_uai_mar_file, read_uai_problem, uai_problem_from_file
14+
export set_evidence!
1415

1516
# marginals
1617
export TensorNetworkModel, get_vars, get_cards, log_probability, probability, marginals

src/maxprob.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The backward rule for tropical einsum.
1919
function backward_tropical(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), @nospecialize(ymask), size_dict)
2020
y .= masked_inv.(y, ymask)
2121
masks = []
22-
for i in 1:length(ixs)
22+
for i in eachindex(ixs)
2323
nixs = OMEinsum._insertat(ixs, i, iy)
2424
nxs = OMEinsum._insertat(xs, i, y)
2525
niy = ixs[i]
@@ -44,21 +44,21 @@ $(TYPEDSIGNATURES)
4444
4545
Returns the largest log-probability and the most probable configuration.
4646
"""
47-
function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Tropical, Vector}
47+
function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector}
4848
vars = get_vars(tn)
4949
tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
5050
logp, grads = cost_and_gradient(tn.code, tensors)
5151
# use Array to convert CuArray to CPU arrays
52-
return Array(logp)[], map(k -> haskey(tn.fixedvertices, vars[k]) ? tn.fixedvertices[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
52+
return content(Array(logp)[]), map(k -> haskey(tn.fixedvertices, vars[k]) ? tn.fixedvertices[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
5353
end
5454

5555
"""
5656
$(TYPEDSIGNATURES)
5757
5858
Returns an output array containing largest log-probabilities.
5959
"""
60-
function maximum_logp(tn::TensorNetworkModel; usecuda = false)::AbstractArray{<:Tropical}
60+
function maximum_logp(tn::TensorNetworkModel; usecuda = false)::AbstractArray{<:Real}
6161
# generate tropical tensors with its elements being log(p).
6262
tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
63-
return tn.code(tensors...)
63+
return map(content, tn.code(tensors...))
6464
end

0 commit comments

Comments
 (0)