Skip to content

Commit 1a44e33

Browse files
authored
Merge pull request #26 from TensorBFS/jg/fix-tropicalgemm
Add tropicalgemm and precompile
2 parents f6298c5 + bfedcb3 commit 1a44e33

File tree

9 files changed

+28
-15
lines changed

9 files changed

+28
-15
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
12+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1213
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1314
TropicalGEMM = "a4ad3063-64a7-4bad-8738-34ed09bc0236"
1415
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
@@ -18,6 +19,7 @@ CUDA = "4"
1819
DocStringExtensions = "0.8.6, 0.9"
1920
OMEinsum = "0.7"
2021
Requires = "1"
22+
PrecompileTools = "1"
2123
TropicalGEMM = "0.1"
2224
TropicalNumbers = "0.5.4"
2325
julia = "1.3"

src/TensorInference.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module TensorInference
33
using OMEinsum, LinearAlgebra
44
using DocStringExtensions, TropicalNumbers
55
using Artifacts
6+
# The Tropical GEMM support
67
using TropicalGEMM
78

89
# reexport OMEinsum functions
@@ -34,4 +35,13 @@ function __init__()
3435
@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
3536
end
3637

38+
import PrecompileTools
39+
PrecompileTools.@setup_workload begin
40+
# Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
41+
# precompile file and potentially make loading faster.
42+
PrecompileTools.@compile_workload begin
43+
include("../example/asia/asia.jl")
44+
end
45+
end
46+
3747
end # module

src/maxprob.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,5 @@ Returns an output array containing largest log-probabilities.
6060
function maximum_logp(tn::TensorNetworkModel; usecuda = false)::AbstractArray{<:Real}
6161
# generate tropical tensors with its elements being log(p).
6262
tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
63-
return map(content, tn.code(tensors...))
63+
return broadcasted_content(tn.code(tensors...))
6464
end

src/mmap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ end
183183

184184
function maximum_logp(mmap::MMAPModel; usecuda = false)::AbstractArray{<:Real}
185185
tensors = map(t -> OMEinsum.asarray(Tropical.(log.(t)), t), adapt_tensors(mmap; usecuda, rescale = false))
186-
return map(content, mmap.code(tensors...))
186+
return broadcasted_content(mmap.code(tensors...))
187187
end
188188

189189
function log_probability(mmap::MMAPModel, config::Union{Dict, AbstractVector}; rescale = true, usecuda = false)::Real

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,6 @@ function uai_problem_from_file(uai_filepath::String; uai_evid_filepath="", uai_m
192192
reference_marginals = isempty(uai_mar_filepath) ? Vector{eltype}[] : read_uai_mar_file(uai_mar_filepath)
193193
return UAIInstance(nvars, ncliques, cards, factors, obsvars, obsvals, reference_marginals)
194194
end
195+
196+
# patch to get content by broadcasting into array, while keep array size unchanged.
197+
broadcasted_content(x) = asarray(content.(x), x)

test/cuda.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ CUDA.allowscalar(false)
99

1010
# does not optimize over open vertices
1111
tn = TensorNetworkModel(instance; optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40))
12-
@info contraction_complexity(tn)
12+
@debug contraction_complexity(tn)
1313
@time marginals2 = marginals(tn; usecuda = true)
1414
@test all(x -> x isa CuArray, marginals2)
1515
# for dangling vertices, the output size is 1.
@@ -26,10 +26,10 @@ end
2626

2727
# does not optimize over open vertices
2828
tn = TensorNetworkModel(instance; optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40))
29-
@info contraction_complexity(tn)
29+
@debug contraction_complexity(tn)
3030
most_probable_config(tn)
3131
@time logp, config = most_probable_config(tn; usecuda = true)
32-
@test log_probability(tn, config) logp.n
32+
@test log_probability(tn, config) logp
3333
culogp = maximum_logp(tn; usecuda = true)
3434
@test culogp isa CuArray
3535
@test Array(culogp)[] logp
@@ -54,10 +54,10 @@ end
5454
culogp = maximum_logp(tn2; usecuda = true)
5555
@test cup isa RescaledArray{T, N, <:CuArray} where {T, N}
5656
@test culogp isa CuArray
57-
@test Array(cup)[] exp(Array(culogp)[].n)
57+
@test Array(cup)[] exp(Array(culogp)[])
5858

5959
# does not optimize over open vertices
6060
tn3 = MMAPModel(instance; marginalized = [2, 4, 6], optimizer)
6161
logp, config = most_probable_config(tn3; usecuda = true)
62-
@test log_probability(tn3, config) logp.n
62+
@test log_probability(tn3, config) logp
6363
end

test/inference.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ end
6363
x -> map(first, x) # get the first capture of each element
6464

6565
for problem in problems
66-
@show problem
67-
66+
@info "Testing: $problem"
6867
@testset "$(problem)" begin
6968
problem = read_uai_problem(problem)
7069

@@ -74,8 +73,7 @@ end
7473
if sc > 28
7574
error("space complexity too large! got $(sc)")
7675
end
77-
# @info(tn)
78-
@info contraction_complexity(tn)
76+
@debug contraction_complexity(tn)
7977
marginals2 = marginals(tn)
8078
# for dangling vertices, the output size is 1.
8179
npass = 0

test/maxprob.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using TensorInference
88

99
# does not optimize over open vertices
1010
tn = TensorNetworkModel(instance; optimizer = TreeSA(ntrials = 3, niters = 2, βs = 1:0.1:80))
11-
@info contraction_complexity(tn)
11+
@debug contraction_complexity(tn)
1212
most_probable_config(tn)
1313
@time logp, config = most_probable_config(tn)
1414
@test log_probability(tn, config) logp

test/mmap.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@ end
1515
tn_ref = TensorNetworkModel(instance; optimizer)
1616
# does not marginalize any var
1717
mmap = MMAPModel(instance; marginalized = Int[], optimizer)
18-
@info(mmap)
18+
@debug(mmap)
1919
@test maximum_logp(tn_ref) maximum_logp(mmap)
2020

2121
# marginalize all vars
2222
mmap2 = MMAPModel(instance; marginalized = collect(1:(instance.nvars)), optimizer)
23-
@info(mmap2)
23+
@debug(mmap2)
2424
@test Array(probability(tn_ref))[] exp(maximum_logp(mmap2)[])
2525

2626
# does not optimize over open vertices
2727
mmap3 = MMAPModel(instance; marginalized = [2, 4, 6], optimizer)
28-
@info(mmap3)
28+
@debug(mmap3)
2929
logp, config = most_probable_config(mmap3)
3030
@test log_probability(mmap3, config) logp
3131
end

0 commit comments

Comments
 (0)