Skip to content

Commit df03471

Browse files
authored
FactorState -> Recipestate + Recipehyper (#1894)
Co-authored-by: Johannes Terblanche <Affie@users.noreply.github.com>
1 parent 553be33 commit df03471

File tree

8 files changed

+31
-47
lines changed

8 files changed

+31
-47
lines changed

IncrementalInference/src/Deprecated.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,6 @@ end
195195
# _getZDim(fcd::DFG.GenericFunctionNodeData) = _getCCW(fcd) |> _getZDim
196196
# DFG.getDimension(fct::DFG.GenericFunctionNodeData) = _getZDim(fct)
197197

198-
# function resetData!(vdata::DFG.FunctionNodeData)
199-
# error("resetData!(vdata::FunctionNodeData) is deprecated, use resetData!(state::FactorState) instead")
200-
# end
201-
202198
function sampleTangent(x::ManifoldKernelDensity, p = mean(x))
203199
error("sampleTangent(x::ManifoldKernelDensity, p) should be replaced by sampleTangent(M<:AbstractManifold, x::ManifoldKernelDensity, p)")
204200
end

IncrementalInference/src/Serialization/services/DispatchPackedConversions.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,16 @@ function DFG.rebuildFactorCache!(
2121
# Set up the neighbor data
2222

2323
# Rebuilding the CCW
24-
state = DFG.getFactorState(factor)
25-
state, solvercache = getDefaultFactorData(
24+
_, _, solvercache = getDefaultFactorData(
2625
dfg,
2726
neighbors,
2827
DFG.getObservation(factor);
29-
multihypo = state.multihypo,
30-
nullhypo = state.nullhypo,
28+
multihypo = factor.hyper.multihypo,
29+
nullhypo = factor.hyper.nullhypo,
3130
# special inflation override
32-
inflation = state.inflation,
33-
eliminated = state.eliminated,
34-
potentialused = state.potentialused,
35-
# solveInProgress = state.solveInProgress,
31+
inflation = factor.hyper.inflation,
32+
eliminated = factor.state.eliminated,
33+
potentialused = factor.state.potentialused,
3634
_blockRecursion=_blockRecursionGradients
3735
)
3836
#

IncrementalInference/src/parametric/services/ParametricUtils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ function CalcFactorMahalanobis(fg, fct::FactorCompute)
146146

147147
cache = preambleCache(fg, getVariable.(fg, varOrder), getObservation(fct))
148148

149-
multihypo = DFG.getFactorState(fct).multihypo
150-
nullhypo = DFG.getFactorState(fct).nullhypo
149+
multihypo = fct.hyper.multihypo
150+
nullhypo = fct.hyper.nullhypo
151151

152152
# FIXME, type instability
153153
if length(multihypo) > 0

IncrementalInference/src/services/BayesNet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,15 +151,15 @@ function buildBayesNet!(dfg::AbstractDFG, elimorder::Vector{Symbol}; solvable::I
151151
vert = DFG.getVariable(dfg, v)
152152
for fctId in listNeighbors(dfg, vert; solvable = solvable)
153153
fct = DFG.getFactor(dfg, fctId)
154-
if (DFG.getFactorState(fct).eliminated != true)
154+
if (fct.state.eliminated != true)
155155
push!(fi, fctId)
156156
for sepNode in listNeighbors(dfg, fct; solvable = solvable)
157157
# TODO -- validate !(sepNode.index in Si) vs. older !(sepNode in Si)
158158
if sepNode != v && !(sepNode in Si) # Symbol comparison!
159159
push!(Si, sepNode)
160160
end
161161
end
162-
DFG.getFactorState(fct).eliminated = true
162+
fct.state.eliminated = true
163163
end
164164

165165
if typeof(_getCCW(fct)) == CommonConvWrapper{GenericMarginal}

IncrementalInference/src/services/FGOSUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ function getFactorsAmongVariablesOnly(
409409
# now check if those factors have already been added
410410
for fct in prefcts
411411
vert = DFG.getFactor(dfg, fct)
412-
if !DFG.getFactorState(vert).potentialused
412+
if !vert.state.potentialused
413413
push!(almostfcts, fct)
414414
end
415415
end

IncrementalInference/src/services/FactorGraph.jl

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,6 @@ function getDefaultFactorData(
716716
# threadmodel = SingleThreaded,
717717
eliminated::Bool = false,
718718
potentialused::Bool = false,
719-
edgeIDs = Int[],
720-
# solveInProgress = 0,
721719
inflation::Real = getSolverParams(dfg).inflation,
722720
_blockRecursion::Bool = false,
723721
keepCalcFactor::Bool = false,
@@ -742,17 +740,11 @@ function getDefaultFactorData(
742740
keepCalcFactor,
743741
)
744742

745-
state = DFG.FactorState(
746-
eliminated,
747-
potentialused,
748-
multihypo,
749-
ccwl.hyporecipe.certainhypo,
750-
nullhypo,
751-
# solveInProgress,
752-
inflation,
753-
)
743+
state = DFG.Recipestate(; eliminated, potentialused)
744+
745+
hyper = DFG.Recipehyper(; nullhypo, multihypo, inflation)
754746

755-
return state, ccwl
747+
return hyper, state, ccwl
756748

757749
end
758750

@@ -838,7 +830,7 @@ function DFG.addFactor!(
838830
_zonedtime(s::DateTime) = ZonedDateTime(s, localzone())
839831

840832
varOrderLabels = Symbol[v.label for v in Xi]
841-
state, solvercache = getDefaultFactorData(
833+
hyper, state, solvercache = getDefaultFactorData(
842834
dfg,
843835
Xi,
844836
deepcopy(usrfnc);
@@ -854,6 +846,7 @@ function DFG.addFactor!(
854846
Symbol(namestring),
855847
varOrderLabels,
856848
usrfnc,
849+
hyper,
857850
state,
858851
solvercache;
859852
tags = Set(union(tags, [:FACTOR])),

IncrementalInference/src/services/JunctionTreeUtils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ function resetData!(vdata::State)
869869
return nothing
870870
end
871871

872-
function resetData!(state::DFG.FactorState)
872+
function resetData!(state::DFG.Recipestate)
873873
state.eliminated = false
874874
state.potentialused = false
875875
return nothing
@@ -886,7 +886,7 @@ function resetFactorGraphNewTree!(dfg::AbstractDFG)
886886
resetData!(getState(v, :default))
887887
end
888888
for f in DFG.getFactors(dfg)
889-
resetData!(DFG.getFactorState(f))
889+
resetData!(DFG.getRecipestate(f))
890890
end
891891
return nothing
892892
end
@@ -998,13 +998,13 @@ function getCliqFactorsFromFrontals(
998998
# usefcts = Int[]
999999
for fctid in ls(fgl, frsym)
10001000
fct = getFactor(fgl, fctid)
1001-
if !unused || !DFG.getFactorState(fct).potentialused
1001+
if !unused || !fct.state.potentialused
10021002
loutn = ls(fgl, fctid; solvable = solvable)
10031003
# deal with unary factors
10041004
if length(loutn) == 1
10051005
union!(usefcts, Symbol[Symbol(fct.label);])
10061006
# appendUseFcts!(usefcts, loutn, fct) # , frsym)
1007-
DFG.getFactorState(fct).potentialused = true
1007+
fct.state.potentialused = true
10081008
end
10091009
# deal with n-ary factors
10101010
for sep in loutn
@@ -1014,7 +1014,7 @@ function getCliqFactorsFromFrontals(
10141014
insep = sep in allids
10151015
if !inseparator || insep
10161016
union!(usefcts, Symbol[Symbol(fct.label);])
1017-
DFG.getFactorState(fct).potentialused = true
1017+
fct.state.potentialused = true
10181018
if !insep
10191019
@debug "cliq=$(cliq.id) adding factor that is not in separator, $sep"
10201020
end
@@ -1061,7 +1061,7 @@ function setCliqPotentials!(
10611061
fcts = map(x -> getFactor(dfg, x), fctsyms)
10621062
getCliqueData(cliq).partialpotential = map(x -> isPartial(x), fcts)
10631063
for fct in fcts
1064-
DFG.getFactorState(fct).potentialused = true
1064+
fct.state.potentialused = true
10651065
end
10661066

10671067
@debug "finding all frontals for down WIP"

IncrementalInference/test/testSaveLoadDFG.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@ Base.rm(saveFolder*".tar.gz")
2121
@test symdiff(ls(fg), ls(retDFG)) == []
2222
@test symdiff(lsf(fg), lsf(retDFG)) == []
2323

24-
@show getFactorState(fg, :x2x3x4f1).multihypo
25-
@show getFactorState(retDFG, :x2x3x4f1).multihypo
24+
@show DFG.getRecipehyper(fg, :x2x3x4f1).multihypo
25+
@show DFG.getRecipehyper(retDFG, :x2x3x4f1).multihypo
2626

2727
# check for match
28-
@test getFactorState(fg, :x2x3x4f1).multihypo - getFactorState(retDFG, :x2x3x4f1).multihypo |> norm < 1e-10
29-
@test getFactorState(fg, :x2x3x4f1).certainhypo - getFactorState(retDFG, :x2x3x4f1).certainhypo |> norm < 1e-10
28+
@test DFG.getRecipehyper(fg, :x2x3x4f1).multihypo - DFG.getRecipehyper(retDFG, :x2x3x4f1).multihypo |> norm < 1e-10
3029

3130
##
3231
end
@@ -46,9 +45,8 @@ solveTree!(fg)
4645

4746
#manually change a few fields to test if they are preserved
4847
fa = getFactor(fg, :x2x3x4f1)
49-
getFactorState(fa).eliminated = true
50-
# getFactorState(fa).solveInProgress = 1
51-
getFactorState(fa).nullhypo = 0.5
48+
fa.state.eliminated = true
49+
fa.hyper.nullhypo = 0.5
5250

5351

5452
saveFolder = "/tmp/dfg_test"
@@ -62,12 +60,11 @@ Base.rm(saveFolder*".tar.gz")
6260
@test issetequal(ls(fg), ls(retDFG))
6361
@test issetequal(lsf(fg), lsf(retDFG))
6462

65-
@show getFactorState(fg, :x2x3x4f1).multihypo
66-
@show getFactorState(retDFG, :x2x3x4f1).multihypo
63+
@show DFG.getRecipehyper(fg, :x2x3x4f1).multihypo
64+
@show DFG.getRecipehyper(retDFG, :x2x3x4f1).multihypo
6765

6866
# check for match
69-
@test isapprox(getFactorState(fg, :x2x3x4f1).multihypo, getFactorState(retDFG, :x2x3x4f1).multihypo)
70-
@test isapprox(getFactorState(fg, :x2x3x4f1).certainhypo, getFactorState(retDFG, :x2x3x4f1).certainhypo)
67+
@test isapprox(DFG.getRecipehyper(fg, :x2x3x4f1).multihypo, DFG.getRecipehyper(retDFG, :x2x3x4f1).multihypo)
7168

7269

7370
fb = getFactor(retDFG, :x2x3x4f1)

0 commit comments

Comments
 (0)