From 070bcbab5e26418cc8d68ec2bbf396fcb36d855c Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Tue, 2 Dec 2025 19:58:38 +0100 Subject: [PATCH 1/3] WIP: support multiple simultaneous events in a single VectorContinuousCallback --- src/callbacks.jl | 106 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 90 insertions(+), 16 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index a84d11e4c..6fdfaa60d 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -188,15 +188,13 @@ end end integrator.sol.stats.ncondition += 1 - ivec = integrator.vector_event_last_time + ivec = integrator.callback_cache.prev_simultaneous_events prev_sign = @view(integrator.callback_cache.prev_sign[1:(callback.len)]) next_sign = @view(integrator.callback_cache.next_sign[1:(callback.len)]) - if integrator.event_last_time == counter && - minimum(ODE_DEFAULT_NORM( - ArrayInterface.allowed_getindex(previous_condition, - ivec), integrator.t)) <= - 100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t) + if integrator.event_last_time == counter && minimum(minimum(idx -> ODE_DEFAULT_NORM(ArrayInterface.allowed_getindex(previous_condition, idx), integrator.t), + ivec, init=typemax(typeof(integrator.t)))) <= + 100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t) # If there was a previous event, utilize the derivative at the start to # chose the previous sign. If the derivative is positive at tprev, then @@ -215,7 +213,9 @@ end abst = integrator.tprev + integrator.dt * callback.repeat_nudge tmp_condition = get_condition(integrator, callback, abst) @. prev_sign = sign(previous_condition) - prev_sign[ivec] = sign(tmp_condition[ivec]) + for idx ∈ ivec + prev_sign[idx] = sign(tmp_condition[idx]) + end else @. prev_sign = sign(previous_condition) end @@ -263,7 +263,6 @@ end interp_index = callback.interp_points end end - event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx end @@ -466,6 +465,13 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun callback, counter) if event_occurred + (; simultaneous_events, prev_simultaneous_events) = integrator.callback_cache + empty!(prev_simultaneous_events) + for idx ∈ simultaneous_events + push!(prev_simultaneous_events, idx) + end + empty!(simultaneous_events) + if callback.condition === nothing new_t = zero(typeof(integrator.t)) min_event_idx = findfirst(isequal(1), event_idx) @@ -492,14 +498,13 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun Θ = top_t else if integrator.event_last_time == counter && - integrator.vector_event_last_time == idx && + idx ∈ prev_simultaneous_events && abs(zero_func(bottom_t)) <= 100abs(integrator.last_event_error) && prev_sign_index == 1 # Determined that there is an event by derivative # But floating point error may make the end point negative - bottom_t += integrator.dt * callback.repeat_nudge sign_top = sign(zero_func(top_t)) sign(zero_func(bottom_t)) * sign_top >= zero(sign_top) && @@ -515,8 +520,12 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun end end if integrator.tdir * Θ < integrator.tdir * min_t + empty!(simultaneous_events) + end + if integrator.tdir * Θ <= integrator.tdir * min_t min_event_idx = idx min_t = Θ + push!(simultaneous_events, idx) end end end @@ -532,9 +541,19 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun elseif interp_index != callback.interp_points && !isdiscrete(integrator.alg) new_t = ts[interp_index] - integrator.tprev min_event_idx = findfirst(isequal(1), event_idx) + for (i, idx) ∈ enumerate(event_idx) + if idx == 1 + push!(simultaneous_events, i) + end + end else # If no solve and no interpolants, just use endpoint new_t = integrator.dt + for (i, idx) ∈ enumerate(event_idx) + if idx == 1 + push!(simultaneous_events, i) + end + end min_event_idx = findfirst(isequal(1), event_idx) end end @@ -546,13 +565,12 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun if event_occurred && min_event_idx < 0 error("Callback handling failed. Please file an issue with code to reproduce.") end - - new_t, ArrayInterface.allowed_getindex(prev_sign, min_event_idx), - event_occurred::Bool, min_event_idx::Int + # We still pass around the min_event_idx for now because some stuff in OrdinaryDiffEqCore expects it to be an Int + new_t, prev_sign, event_occurred::Bool, min_event_idx::Int end function apply_callback!(integrator, - callback::Union{ContinuousCallback, VectorContinuousCallback}, + callback::ContinuousCallback, cb_time, prev_sign, event_idx) if isadaptive(integrator) set_proposed_dt!(integrator, @@ -610,6 +628,56 @@ function apply_callback!(integrator, false, saved_in_cb end +function apply_callback!(integrator, + callback::VectorContinuousCallback, + cb_time, prev_sign, min_event_idx) + if isadaptive(integrator) + set_proposed_dt!(integrator, + integrator.tdir * max(nextfloat(integrator.opts.dtmin), + integrator.tdir * callback.dtrelax * integrator.dt)) + end + + change_t_via_interpolation!( + integrator, integrator.tprev + cb_time, Val{:false}, callback.initializealg) + + # handle saveat + _, savedexactly = savevalues!(integrator) + saved_in_cb = true + + @inbounds if callback.save_positions[1] + # if already saved then skip saving + savedexactly || savevalues!(integrator, true) + end + + u_modified = false + for i ∈ integrator.callback_cache.simultaneous_events + if prev_sign[i] < 0 callback.affect! !== nothing + callback.affect!(integrator, i) + u_modified = true + elseif prev_sign[i] > 0 && callback.affect_neg! !== nothing + callback.affect_neg!(integrator, i) + u_modified = true + end + end + integrator.u_modified = u_modified + if u_modified + reeval_internals_due_to_modification!( + integrator, callback_initializealg = callback.initializealg) + + @inbounds if callback.save_positions[2] + savevalues!(integrator, true) + if !isdefined(integrator.opts, :save_discretes) || integrator.opts.save_discretes + for i ∈ integrator.callback_cache.simultaneous_events + SciMLBase.save_discretes!(integrator, callback, i) + end + end + saved_in_cb = true + end + return true, saved_in_cb + end + false, saved_in_cb +end + #Base Case: Just one @inline function apply_discrete_callback!(integrator, callback::DiscreteCallback) saved_in_cb = false @@ -698,6 +766,8 @@ mutable struct CallbackCache{conditionType, signType} previous_condition::conditionType next_sign::signType prev_sign::signType + simultaneous_events::Vector{Int} + prev_simultaneous_events::Vector{Int} end function CallbackCache(u, max_len, ::Type{conditionType}, @@ -706,7 +776,9 @@ function CallbackCache(u, max_len, ::Type{conditionType}, previous_condition = similar(u, conditionType, max_len) next_sign = similar(u, signType, max_len) prev_sign = similar(u, signType, max_len) - CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign) + simultaneous_events = sizehint!(Int[], max_len) + prev_simultaneous_events = sizehint!(Int[], max_len) + CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign, simultaneous_events, prev_simultaneous_events) end function CallbackCache(max_len, ::Type{conditionType}, @@ -715,5 +787,7 @@ function CallbackCache(max_len, ::Type{conditionType}, previous_condition = zeros(conditionType, max_len) next_sign = zeros(signType, max_len) prev_sign = zeros(signType, max_len) - CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign) + prev_simultaneous_events = sizehint!(Int[], max_len) + simultaneous_events = sizehint!(Int[], max_len) + CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign, simultaneous_events, prev_simultaneous_events) end From d2699fffa8e1cac1d34b5b1a119342e6c15dee85 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Tue, 2 Dec 2025 20:42:13 +0100 Subject: [PATCH 2/3] formatting --- src/callbacks.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 6fdfaa60d..5817de7c9 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -192,9 +192,10 @@ end prev_sign = @view(integrator.callback_cache.prev_sign[1:(callback.len)]) next_sign = @view(integrator.callback_cache.next_sign[1:(callback.len)]) - if integrator.event_last_time == counter && minimum(minimum(idx -> ODE_DEFAULT_NORM(ArrayInterface.allowed_getindex(previous_condition, idx), integrator.t), - ivec, init=typemax(typeof(integrator.t)))) <= - 100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t) + if integrator.event_last_time == counter && + minimum(minimum(idx -> ODE_DEFAULT_NORM(ArrayInterface.allowed_getindex(previous_condition, idx), integrator.t), + ivec, init=typemax(typeof(integrator.t)))) <= + 100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t) # If there was a previous event, utilize the derivative at the start to # chose the previous sign. If the derivative is positive at tprev, then From 8ea1cd6ccb0c07f709402cae46a5718cd8a07fbf Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Wed, 3 Dec 2025 14:58:25 +0100 Subject: [PATCH 3/3] Switch to `Vector{Bool}` for the simultaneous events --- src/callbacks.jl | 59 ++++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 5817de7c9..804f9fcf1 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -188,13 +188,13 @@ end end integrator.sol.stats.ncondition += 1 - ivec = integrator.callback_cache.prev_simultaneous_events + ivec = enumerate(integrator.callback_cache.prev_simultaneous_events) prev_sign = @view(integrator.callback_cache.prev_sign[1:(callback.len)]) next_sign = @view(integrator.callback_cache.next_sign[1:(callback.len)]) if integrator.event_last_time == counter && minimum(minimum(idx -> ODE_DEFAULT_NORM(ArrayInterface.allowed_getindex(previous_condition, idx), integrator.t), - ivec, init=typemax(typeof(integrator.t)))) <= + (idx for (idx, triggered) ∈ ivec if triggered), init=typemax(typeof(integrator.t)))) <= 100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t) # If there was a previous event, utilize the derivative at the start to @@ -214,8 +214,10 @@ end abst = integrator.tprev + integrator.dt * callback.repeat_nudge tmp_condition = get_condition(integrator, callback, abst) @. prev_sign = sign(previous_condition) - for idx ∈ ivec - prev_sign[idx] = sign(tmp_condition[idx]) + for (idx, triggered) ∈ ivec + if triggered + prev_sign[idx] = sign(tmp_condition[idx]) + end end else @. prev_sign = sign(previous_condition) @@ -467,11 +469,8 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun counter) if event_occurred (; simultaneous_events, prev_simultaneous_events) = integrator.callback_cache - empty!(prev_simultaneous_events) - for idx ∈ simultaneous_events - push!(prev_simultaneous_events, idx) - end - empty!(simultaneous_events) + prev_simultaneous_events .= simultaneous_events + simultaneous_events .= false if callback.condition === nothing new_t = zero(typeof(integrator.t)) @@ -499,7 +498,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun Θ = top_t else if integrator.event_last_time == counter && - idx ∈ prev_simultaneous_events && + prev_simultaneous_events[idx] && abs(zero_func(bottom_t)) <= 100abs(integrator.last_event_error) && prev_sign_index == 1 @@ -521,12 +520,12 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun end end if integrator.tdir * Θ < integrator.tdir * min_t - empty!(simultaneous_events) + simultaneous_events .= false end if integrator.tdir * Θ <= integrator.tdir * min_t min_event_idx = idx min_t = Θ - push!(simultaneous_events, idx) + simultaneous_events[idx] = true end end end @@ -544,7 +543,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun min_event_idx = findfirst(isequal(1), event_idx) for (i, idx) ∈ enumerate(event_idx) if idx == 1 - push!(simultaneous_events, i) + simultaneous_events[i] = true end end else @@ -552,7 +551,7 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun new_t = integrator.dt for (i, idx) ∈ enumerate(event_idx) if idx == 1 - push!(simultaneous_events, i) + simultaneous_events[i] = true end end min_event_idx = findfirst(isequal(1), event_idx) @@ -651,13 +650,15 @@ function apply_callback!(integrator, end u_modified = false - for i ∈ integrator.callback_cache.simultaneous_events - if prev_sign[i] < 0 callback.affect! !== nothing - callback.affect!(integrator, i) - u_modified = true - elseif prev_sign[i] > 0 && callback.affect_neg! !== nothing - callback.affect_neg!(integrator, i) - u_modified = true + for (i, triggered) ∈ enumerate(integrator.callback_cache.simultaneous_events) + if triggered + if prev_sign[i] < 0 callback.affect! !== nothing + callback.affect!(integrator, i) + u_modified = true + elseif prev_sign[i] > 0 && callback.affect_neg! !== nothing + callback.affect_neg!(integrator, i) + u_modified = true + end end end integrator.u_modified = u_modified @@ -767,28 +768,28 @@ mutable struct CallbackCache{conditionType, signType} previous_condition::conditionType next_sign::signType prev_sign::signType - simultaneous_events::Vector{Int} - prev_simultaneous_events::Vector{Int} + simultaneous_events::Vector{Bool} + prev_simultaneous_events::Vector{Bool} end function CallbackCache(u, max_len, ::Type{conditionType}, - ::Type{signType}) where {conditionType, signType} + ::Type{signType}) where {conditionType, signType} tmp_condition = similar(u, conditionType, max_len) previous_condition = similar(u, conditionType, max_len) next_sign = similar(u, signType, max_len) prev_sign = similar(u, signType, max_len) - simultaneous_events = sizehint!(Int[], max_len) - prev_simultaneous_events = sizehint!(Int[], max_len) + simultaneous_events = zeros(Bool, max_len) + prev_simultaneous_events = zeros(Bool, max_len) CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign, simultaneous_events, prev_simultaneous_events) end function CallbackCache(max_len, ::Type{conditionType}, - ::Type{signType}) where {conditionType, signType} + ::Type{signType}) where {conditionType, signType} tmp_condition = zeros(conditionType, max_len) previous_condition = zeros(conditionType, max_len) next_sign = zeros(signType, max_len) prev_sign = zeros(signType, max_len) - prev_simultaneous_events = sizehint!(Int[], max_len) - simultaneous_events = sizehint!(Int[], max_len) + simultaneous_events = zeros(Bool, max_len) + prev_simultaneous_events = zeros(Bool, max_len) CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign, simultaneous_events, prev_simultaneous_events) end