Skip to content

Commit 7c23c46

Browse files
committed
update
1 parent 066c4d3 commit 7c23c46

File tree

9 files changed

+87
-19
lines changed

9 files changed

+87
-19
lines changed

example/asia/README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# The "Asia" Bayesian network
2+
3+
The variables and factors in `.uai` files are labelled as below.
4+
5+
## Variables
6+
index 0 is mapped to yes, 1 is mapped to no.
7+
8+
1. visit to Asia (a)
9+
2. tuberculosis (t)
10+
3. smoking (s)
11+
4. lung cancer (l)
12+
5. bronchitis (b)
13+
6. either tub. or lung cancer (e)
14+
7. positive X-ray (x)
15+
8. dyspnoea (d)
16+
17+
## Factors
18+
1. p(a)
19+
2. p(t|a)
20+
3. p(s)
21+
4. p(l|s)
22+
5. p(b|s)
23+
6. p(e|l,t)
24+
7. p(x|e)
25+
8. p(d|e,b)

example/asia/asia.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
using TensorInference
2+
3+
problem = uai_problem_from_file(joinpath(@__DIR__, "data/asia.uai"))
4+
tnet = TensorNetworkModel(problem)
5+
marginals(problem)

example/asia/data/asia.uai

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
MARKOV
2+
8
3+
2 2 2 2 2 2 2 2
4+
8
5+
1 0
6+
2 1 0
7+
1 2
8+
2 3 2
9+
2 4 2
10+
3 5 3 1
11+
2 6 5
12+
3 7 5 4
13+
14+
2
15+
0.01 0.99
16+
17+
4
18+
0.05 0.01 0.95 0.99
19+
20+
2
21+
0.5 0.5
22+
23+
4
24+
0.1 0.01 0.9 0.99
25+
26+
4
27+
0.6 0.3 0.4 0.7
28+
29+
8
30+
1 1 1 0 0 0 0 1
31+
32+
4
33+
0.98 0.05 0.02 0.95
34+
35+
8
36+
0.9 0.7 0.8 0.1 0.1 0.3 0.2 0.9

example/asia/data/asia.uai.evid

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
2 0 0 6 0

src/Core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ end
5858
function Base.show(io::IO, tn::TensorNetworkModel)
5959
open = getiyv(tn.code)
6060
variables = join([string_var(var, open, tn.fixedvertices) for var in tn.vars], ", ")
61-
tc, sc, rw = timespacereadwrite_complexity(tn)
61+
tc, sc, rw = contraction_complexity(tn)
6262
println(io, "$(typeof(tn))")
6363
println(io, "variables: $variables")
6464
print_tcscrw(io, tc, sc, rw)

src/TensorInference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ using TropicalGEMM
77

88
# reexport OMEinsum functions
99
export RescaledArray
10-
export timespace_complexity, timespacereadwrite_complexity, TreeSA, GreedyMethod, KaHyParBipartite, SABipartite, MergeGreedy, MergeVectors
10+
export contraction_complexity, TreeSA, GreedyMethod, KaHyParBipartite, SABipartite, MergeGreedy, MergeVectors
1111

1212
# read and load uai files
13-
export read_uai_file, read_td_file, read_uai_evid_file, read_uai_mar_file, read_uai_problem
13+
export read_uai_file, read_td_file, read_uai_evid_file, read_uai_mar_file, read_uai_problem, uai_problem_from_file
1414

1515
# marginals
1616
export TensorNetworkModel, get_vars, get_cards, log_probability, probability, marginals

src/mmap.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ end
3434
function Base.show(io::IO, mmap::MMAPModel)
3535
open = getiyv(mmap.code)
3636
variables = join([string_var(var, open, mmap.fixedvertices) for var in mmap.vars], ", ")
37-
tc, sc, rw = timespacereadwrite_complexity(mmap)
37+
tc, sc, rw = contraction_complexity(mmap)
3838
println(io, "$(typeof(mmap))")
3939
println(io, "variables: $variables")
4040
println(io, "marginalized variables: $(map(x->x.eliminated_vars, mmap.clusters))")
@@ -99,7 +99,7 @@ function MMAPModel(vars::AbstractVector{LT}, factors::Vector{<:Factor{T}}; margi
9999
return MMAPModel(setdiff(vars, marginalizedvertices), code, remaining_tensors, clusters, fixedvertices)
100100
end
101101

102-
function OMEinsum.timespacereadwrite_complexity(mmap::MMAPModel{LT}) where {LT}
102+
function OMEinsum.contraction_complexity(mmap::MMAPModel{LT}) where {LT}
103103
# extract size
104104
size_dict = Dict(zip(get_vars(mmap), get_cards(mmap; fixedisone = true)))
105105
sc = -Inf
@@ -111,18 +111,17 @@ function OMEinsum.timespacereadwrite_complexity(mmap::MMAPModel{LT}) where {LT}
111111
# the head sector are for unity tensors.
112112
size_dict[cluster.eliminated_vars[k]] = length(cluster.tensors[k])
113113
end
114-
tc, sci, rw = timespacereadwrite_complexity(cluster.code, size_dict)
114+
tc, sci, rw = contraction_complexity(cluster.code, size_dict)
115115
push!(tcs, tc)
116116
push!(rws, rw)
117117
sc = max(sc, sci)
118118
end
119119

120-
tc, sci, rw = timespacereadwrite_complexity(mmap.code, size_dict)
120+
tc, sci, rw = contraction_complexity(mmap.code, size_dict)
121121
push!(tcs, tc)
122122
push!(rws, tc)
123123
OMEinsum.OMEinsumContractionOrders.log2sumexp2(tcs), max(sc, sci), OMEinsum.OMEinsumContractionOrders.log2sumexp2(rws)
124124
end
125-
OMEinsum.timespace_complexity(mmap::MMAPModel) = timespacereadwrite_complexity(mmap)[1:2]
126125

127126
function adapt_tensors(mmap::MMAPModel; usecuda, rescale)
128127
return [adapt_tensors(mmap.code, mmap.tensors, mmap.fixedvertices; usecuda, rescale)...,

src/utils.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ https://personal.utdallas.edu/~vibhav.gogate/uai16-evaluation/uaiformat.html
7171
function read_uai_evid_file(uai_evid_filepath::AbstractString)
7272
if isempty(uai_evid_filepath)
7373
# No evidence
74-
obsvars = Int64[]
75-
obsvals = Int64[]
74+
return Int64[], Int64[]
7675
else
7776
# Read the last line of the uai evid file
7877
line = open(uai_evid_filepath) do file
@@ -90,10 +89,6 @@ function read_uai_evid_file(uai_evid_filepath::AbstractString)
9089
@assert nobsvars == length(obsvars)
9190
end
9291

93-
# # DEBUG:
94-
# print(" "); @show obsvars
95-
# print(" "); @show obsvals
96-
9792
return obsvars, obsvals
9893
end
9994

@@ -179,14 +174,21 @@ $(TYPEDSIGNATURES)
179174
180175
Read a UAI problem from an artifact.
181176
"""
182-
function read_uai_problem(problem::AbstractString)::UAIInstance
177+
function read_uai_problem(problem::AbstractString; eltype=Float64)::UAIInstance
183178
uai_filepath = joinpath(artifact"MAR_prob", problem * ".uai")
184179
uai_evid_filepath = joinpath(artifact"MAR_prob", problem * ".uai.evid")
185180
uai_mar_filepath = joinpath(artifact"MAR_sol", problem * ".uai.MAR")
181+
return uai_problem_from_file(uai_filepath; uai_evid_filepath, uai_mar_filepath, eltype)
182+
end
186183

187-
nvars, cards, ncliques, factors = read_uai_file(uai_filepath; factor_eltype = Float64)
188-
obsvars, obsvals = read_uai_evid_file(uai_evid_filepath)
189-
reference_marginals = read_uai_mar_file(uai_mar_filepath)
184+
"""
185+
$(TYPEDSIGNATURES)
190186
187+
Read a UAI problem from a file.
188+
"""
189+
function uai_problem_from_file(uai_filepath::String; uai_evid_filepath="", uai_mar_filepath="", eltype=Float64)::UAIInstance
190+
nvars, cards, ncliques, factors = read_uai_file(uai_filepath; factor_eltype = eltype)
191+
obsvars, obsvals = read_uai_evid_file(uai_evid_filepath)
192+
reference_marginals = isempty(uai_mar_filepath) ? Vector{eltype}[] : read_uai_mar_file(uai_mar_filepath)
191193
return UAIInstance(nvars, ncliques, cards, factors, obsvars, obsvals, reference_marginals)
192194
end

test/inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575
error("space complexity too large! got $(sc)")
7676
end
7777
# @info(tn)
78-
# @info timespace_complexity(tn)
78+
@info contraction_complexity(tn)
7979
marginals2 = marginals(tn)
8080
# for dangling vertices, the output size is 1.
8181
npass = 0

0 commit comments

Comments
 (0)