From 051012c12dcb2454d42c85400c49a95e167e13c4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 18 Aug 2025 18:57:28 +0100 Subject: [PATCH 1/2] Add a different struct that can pass kwargs on to Libtask --- Project.toml | 2 +- ext/AdvancedPSLibtaskExt.jl | 31 +++++++++++++++++++++++++++---- src/AdvancedPS.jl | 4 ++++ 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 491eebe2..28a5155f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AdvancedPS" uuid = "576499cb-2369-40b2-a588-c64705576edc" authors = ["TuringLang"] -version = "0.7.1" +version = "0.7.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/ext/AdvancedPSLibtaskExt.jl b/ext/AdvancedPSLibtaskExt.jl index d3ba9e02..ed31d225 100644 --- a/ext/AdvancedPSLibtaskExt.jl +++ b/ext/AdvancedPSLibtaskExt.jl @@ -35,7 +35,32 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`. function AdvancedPS.LibtaskModel( f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... ) # Changed the API, need to take care of the RNG properly - return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...)) + return AdvancedPS.LibtaskModel( + f, Libtask.TapedTask(TapedGlobals(rng), f, args...) + ) +end +# TODO: Upstream this to Turing +function AdvancedPS.LibtaskModel( + f::AdvancedPS.AbstractTuringLibtaskModel, rng::Random.AbstractRNG +) + return AdvancedPS.LibtaskModel( + f, Libtask.TapedTask(TapedGlobals(rng), f.fargs...; f.kwargs...) + ) +end + +const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R} + +function to_tapedtask( + newf::AdvancedPS.AbstractGenericModel, trace::LibtaskTrace, rng::Random.AbstractRNG +) + return Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf) +end +function to_tapedtask( + newf::AdvancedPS.AbstractTuringLibtaskModel, trace::LibtaskTrace, rng::Random.AbstractRNG +) + return Libtask.TapedTask( + TapedGlobals(rng, get_other_global(trace)), newf.fargs...; newf.kwargs... + ) end """ @@ -47,8 +72,6 @@ function Base.copy(model::AdvancedPS.LibtaskModel) return AdvancedPS.LibtaskModel(deepcopy(model.f), copy(model.ctask)) end -const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R} - function Base.copy(trace::LibtaskTrace) newtrace = AdvancedPS.Trace(copy(trace.model), deepcopy(trace.rng)) set_other_global!(newtrace, newtrace) @@ -114,7 +137,7 @@ function AdvancedPS.forkr(trace::LibtaskTrace) newf = AdvancedPS.reset_model(trace.model.f) Random123.set_counter!(rng, 1) - ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf) + ctask = to_tapedtask(newf, trace, rng) new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask) # add backward reference diff --git a/src/AdvancedPS.jl b/src/AdvancedPS.jl index faa673e8..93254f93 100644 --- a/src/AdvancedPS.jl +++ b/src/AdvancedPS.jl @@ -16,6 +16,10 @@ abstract type AbstractParticleSampler <: AbstractMCMC.AbstractSampler end abstract type AbstractStateSpaceModel <: AbstractParticleModel end abstract type AbstractGenericModel <: AbstractParticleModel end +# TODO(penelopeysm): This should be upstreamed to Turing together with anything that is +# Turing-specific in LibtaskExt. +abstract type AbstractTuringLibtaskModel <: AbstractGenericModel end + include("resampling.jl") include("rng.jl") include("model.jl") From 904bc38e22e116775f9701e5f3ee65418d5a52b3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 12 Dec 2025 17:36:58 +0000 Subject: [PATCH 2/2] Format --- ext/AdvancedPSLibtaskExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/AdvancedPSLibtaskExt.jl b/ext/AdvancedPSLibtaskExt.jl index ed31d225..e85f5835 100644 --- a/ext/AdvancedPSLibtaskExt.jl +++ b/ext/AdvancedPSLibtaskExt.jl @@ -35,9 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`. function AdvancedPS.LibtaskModel( f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... ) # Changed the API, need to take care of the RNG properly - return AdvancedPS.LibtaskModel( - f, Libtask.TapedTask(TapedGlobals(rng), f, args...) - ) + return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...)) end # TODO: Upstream this to Turing function AdvancedPS.LibtaskModel( @@ -56,7 +54,9 @@ function to_tapedtask( return Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf) end function to_tapedtask( - newf::AdvancedPS.AbstractTuringLibtaskModel, trace::LibtaskTrace, rng::Random.AbstractRNG + newf::AdvancedPS.AbstractTuringLibtaskModel, + trace::LibtaskTrace, + rng::Random.AbstractRNG, ) return Libtask.TapedTask( TapedGlobals(rng, get_other_global(trace)), newf.fargs...; newf.kwargs...