3434function Base. show (io:: IO , mmap:: MMAPModel )
3535 open = getiyv (mmap. code)
3636 variables = join ([string_var (var, open, mmap. fixedvertices) for var in mmap. vars], " , " )
37- tc, sc, rw = timespacereadwrite_complexity (mmap)
37+ tc, sc, rw = contraction_complexity (mmap)
3838 println (io, " $(typeof (mmap)) " )
3939 println (io, " variables: $variables " )
4040 println (io, " marginalized variables: $(map (x-> x. eliminated_vars, mmap. clusters)) " )
5858"""
5959$(TYPEDSIGNATURES)
6060"""
61- function MMAPModel (instance:: UAIInstance ; marginalizedvertices , openvertices = (), optimizer = GreedyMethod (), simplifier = nothing ):: MMAPModel
61+ function MMAPModel (instance:: UAIInstance ; marginalized , openvertices = (), optimizer = GreedyMethod (), simplifier = nothing ):: MMAPModel
6262 return MMAPModel (
63- 1 : (instance. nvars), instance. factors; marginalizedvertices , fixedvertices = Dict (zip (instance. obsvars, instance. obsvals .- 1 )), optimizer, simplifier, openvertices
63+ 1 : (instance. nvars), instance. factors; marginalized , fixedvertices = Dict (zip (instance. obsvars, instance. obsvals)), optimizer, simplifier, openvertices
6464 )
6565end
6666
6767"""
6868$(TYPEDSIGNATURES)
6969"""
70- function MMAPModel (vars:: AbstractVector{LT} , factors:: Vector{<:Factor{T}} ; marginalizedvertices , openvertices = (),
70+ function MMAPModel (vars:: AbstractVector{LT} , factors:: Vector{<:Factor{T}} ; marginalized , openvertices = (),
7171 fixedvertices = Dict {LT, Int} (),
7272 optimizer = GreedyMethod (), simplifier = nothing ,
7373 marginalize_optimizer = GreedyMethod (), marginalize_simplifier = nothing
@@ -81,7 +81,7 @@ function MMAPModel(vars::AbstractVector{LT}, factors::Vector{<:Factor{T}}; margi
8181 size_dict = OMEinsum. get_size_dict (all_ixs, all_tensors)
8282
8383 # detect clusters for marginalize variables
84- subsets = connected_clusters (all_ixs, marginalizedvertices )
84+ subsets = connected_clusters (all_ixs, marginalized )
8585 clusters = Cluster{LT}[]
8686 ixs = Vector{LT}[]
8787 for (contracted, cluster) in subsets
@@ -96,10 +96,10 @@ function MMAPModel(vars::AbstractVector{LT}, factors::Vector{<:Factor{T}}; margi
9696 rem_indices = setdiff (1 : length (all_ixs), vcat ([c. second for c in subsets]. .. ))
9797 remaining_tensors = all_tensors[rem_indices]
9898 code = optimize_code (EinCode ([all_ixs[rem_indices]. .. , ixs... ], iy), size_dict, optimizer, simplifier)
99- return MMAPModel (setdiff (vars, marginalizedvertices ), code, remaining_tensors, clusters, fixedvertices)
99+ return MMAPModel (setdiff (vars, marginalized ), code, remaining_tensors, clusters, fixedvertices)
100100end
101101
102- function OMEinsum. timespacereadwrite_complexity (mmap:: MMAPModel{LT} ) where {LT}
102+ function OMEinsum. contraction_complexity (mmap:: MMAPModel{LT} ) where {LT}
103103 # extract size
104104 size_dict = Dict (zip (get_vars (mmap), get_cards (mmap; fixedisone = true )))
105105 sc = - Inf
@@ -111,18 +111,17 @@ function OMEinsum.timespacereadwrite_complexity(mmap::MMAPModel{LT}) where {LT}
111111 # the head sector are for unity tensors.
112112 size_dict[cluster. eliminated_vars[k]] = length (cluster. tensors[k])
113113 end
114- tc, sci, rw = timespacereadwrite_complexity (cluster. code, size_dict)
114+ tc, sci, rw = contraction_complexity (cluster. code, size_dict)
115115 push! (tcs, tc)
116116 push! (rws, rw)
117117 sc = max (sc, sci)
118118 end
119119
120- tc, sci, rw = timespacereadwrite_complexity (mmap. code, size_dict)
120+ tc, sci, rw = contraction_complexity (mmap. code, size_dict)
121121 push! (tcs, tc)
122122 push! (rws, tc)
123123 OMEinsum. OMEinsumContractionOrders. log2sumexp2 (tcs), max (sc, sci), OMEinsum. OMEinsumContractionOrders. log2sumexp2 (rws)
124124end
125- OMEinsum. timespace_complexity (mmap:: MMAPModel ) = timespacereadwrite_complexity (mmap)[1 : 2 ]
126125
127126function adapt_tensors (mmap:: MMAPModel ; usecuda, rescale)
128127 return [adapt_tensors (mmap. code, mmap. tensors, mmap. fixedvertices; usecuda, rescale)... ,
@@ -174,35 +173,29 @@ function visit_var!(var, vars::AbstractVector{LT}, ixs, visited_ixs, visited_var
174173 end
175174end
176175
177- """
178- $(TYPEDSIGNATURES)
179- """
180- function most_probable_config (mmap:: MMAPModel ; usecuda = false ):: Tuple{Tropical, Vector}
176+ function most_probable_config (mmap:: MMAPModel ; usecuda = false ):: Tuple{Real, Vector}
181177 vars = get_vars (mmap)
182178 tensors = map (t -> OMEinsum. asarray (Tropical .(log .(t)), t), adapt_tensors (mmap; usecuda, rescale = false ))
183179 logp, grads = cost_and_gradient (mmap. code, tensors)
184180 # use Array to convert CuArray to CPU arrays
185- return Array (logp)[], map (k -> haskey (mmap. fixedvertices, vars[k]) ? mmap. fixedvertices[vars[k]] : argmax (grads[k]) - 1 , 1 : length (vars))
181+ return content ( Array (logp)[]) , map (k -> haskey (mmap. fixedvertices, vars[k]) ? mmap. fixedvertices[vars[k]] : argmax (grads[k]) - 1 , 1 : length (vars))
186182end
187183
188- """
189- $(TYPEDSIGNATURES)
190- """
191- function maximum_logp (mmap:: MMAPModel ; usecuda = false ):: AbstractArray{<:Tropical}
184+ function maximum_logp (mmap:: MMAPModel ; usecuda = false ):: AbstractArray{<:Real}
192185 tensors = map (t -> OMEinsum. asarray (Tropical .(log .(t)), t), adapt_tensors (mmap; usecuda, rescale = false ))
193- return mmap. code (tensors... )
186+ return map (content, mmap. code (tensors... ) )
194187end
195188
196- """
197- $(TYPEDSIGNATURES)
198- """
199189function log_probability (mmap:: MMAPModel , config:: Union{Dict, AbstractVector} ; rescale = true , usecuda = false ):: Real
200190 @assert length (get_vars (mmap)) == length (config)
201191 fixedvertices = config isa AbstractVector ? Dict (zip (get_vars (mmap), config)) : config
202192 assign = merge (mmap. fixedvertices, fixedvertices)
203193 # two contributions to the probability, not-clustered tensors and clusters.
204194 m1 = sum (x -> log (x[2 ][(getindex .(Ref (assign), x[1 ]) .+ 1 ). .. ]), zip (getixsv (mmap. code), mmap. tensors))
205- m2 = sum (cluster -> probability (cluster; fixedvertices, usecuda, rescale). log_factor, mmap. clusters)
195+ m2 = sum (mmap. clusters) do cluster
196+ p = probability (cluster; fixedvertices, usecuda, rescale)
197+ rescale ? p. log_factor : log (p[])
198+ end
206199 return m1 + m2
207200end
208201
0 commit comments