From 034b2c4cf6b1fc69175c406440f6e5eb806d68ee Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 17:17:40 +0100 Subject: [PATCH 01/36] correct rk45 and add tsit5 --- bayesflow/utils/integrate.py | 253 ++++++++++++++++++++++++----- tests/test_utils/test_integrate.py | 28 ++++ 2 files changed, 238 insertions(+), 43 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b197ea975..a9bd6ea3f 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -29,6 +29,7 @@ def euler_step( k1 = fn(time, **filter_kwargs(state, fn)) if use_adaptive_step_size: + # Use Heun's method (RK2) as embedded pair for proper error estimation intermediate_state = state.copy() for key, delta in k1.items(): intermediate_state[key] = state[key] + step_size * delta @@ -39,18 +40,23 @@ def euler_step( if set(k1.keys()) != set(k2.keys()): raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") - # compute next step size - intermediate_error = keras.ops.stack([keras.ops.norm(k2[key] - k1[key], ord=2, axis=-1) for key in k1]) - new_step_size = step_size * tolerance / (intermediate_error + 1e-9) + # Heun's (RK2) solution + heun_state = state.copy() + for key in k1.keys(): + heun_state[key] = state[key] + 0.5 * step_size * (k1[key] + k2[key]) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + # Error estimate: difference between Euler and Heun + intermediate_error = keras.ops.stack( + [keras.ops.norm(heun_state[key] - intermediate_state[key], ord=2, axis=-1) for key in k1] + ) - # consolidate step size - new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) + max_error = keras.ops.max(intermediate_error) + new_step_size = step_size * keras.ops.sqrt(tolerance / (max_error + 1e-9)) + + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) else: new_step_size = step_size - # apply updates new_state = state.copy() for key in k1.keys(): new_state[key] = state[key] + step_size * k1[key] @@ -60,6 +66,16 @@ def euler_step( return new_state, new_time, new_step_size +def add_scaled(state, ks, coeffs, h): + out = {} + for key, y in state.items(): + acc = keras.ops.zeros_like(y) + for c, k in zip(coeffs, ks): + acc = acc + c * k[key] + out[key] = y + h * acc + return out + + def rk45_step( fn: Callable, state: dict[str, ArrayLike], @@ -70,57 +86,151 @@ def rk45_step( max_step_size: ArrayLike = float("inf"), use_adaptive_step_size: bool = False, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + """ + Dormand-Prince 5(4) method with embedded error estimation. + """ step_size = last_step_size + h = step_size k1 = fn(time, **filter_kwargs(state, fn)) + k2 = fn(time + h * (1 / 5), **add_scaled(state, [k1], [1 / 5], h)) + k3 = fn(time + h * (3 / 10), **add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h)) + k4 = fn(time + h * (4 / 5), **add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h)) + k5 = fn( + time + h * (8 / 9), + **add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h), + ) + k6 = fn( + time + h, + **add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h), + ) - intermediate_state = state.copy() - for key, delta in k1.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta + # check all keys are equal + if not all(set(k.keys()) == set(k1.keys()) for k in [k2, k3, k4, k5, k6]): + raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") - k2 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + # 5th order solution + new_state = {} + for key in k1.keys(): + new_state[key] = state[key] + h * ( + 35 / 384 * k1[key] + 500 / 1113 * k3[key] + 125 / 192 * k4[key] - 2187 / 6784 * k5[key] + 11 / 84 * k6[key] + ) - intermediate_state = state.copy() - for key, delta in k2.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta + if use_adaptive_step_size: + k7 = fn(time + h, **filter_kwargs(new_state, fn)) + + # 4th order embedded solution + err_state = {} + for key in k1.keys(): + y4 = state[key] + h * ( + 5179 / 57600 * k1[key] + + 7571 / 16695 * k3[key] + + 393 / 640 * k4[key] + - 92097 / 339200 * k5[key] + + 187 / 2100 * k6[key] + + 1 / 40 * k7[key] + ) + err_state[key] = new_state[key] - y4 - k3 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) - intermediate_state = state.copy() - for key, delta in k3.items(): - intermediate_state[key] = state[key] + step_size * delta + new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + else: + new_step_size = step_size - k4 = fn(time + step_size, **filter_kwargs(intermediate_state, fn)) + new_time = time + h + return new_state, new_time, new_step_size - if use_adaptive_step_size: - intermediate_state = state.copy() - for key, delta in k4.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta - k5 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) +def tsit5_step( + fn: Callable, + state: dict[str, ArrayLike], + time: ArrayLike, + last_step_size: ArrayLike, + tolerance: ArrayLike = 1e-6, + min_step_size: ArrayLike = -float("inf"), + max_step_size: ArrayLike = float("inf"), + use_adaptive_step_size: bool = False, +): + """ + Implements a single step of the Tsitouras 5/4 Runge-Kutta method. + """ + step_size = last_step_size + h = step_size - # check all keys are equal - if not all(set(k.keys()) == set(k1.keys()) for k in [k2, k3, k4, k5]): - raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") + # Butcher tableau coefficients + c2 = 0.161 + c3 = 0.327 + c4 = 0.9 + c5 = 0.9800255409045097 - # compute next step size - intermediate_error = keras.ops.stack([keras.ops.norm(k5[key] - k4[key], ord=2, axis=-1) for key in k5.keys()]) - new_step_size = step_size * tolerance / (intermediate_error + 1e-9) + k1 = fn(time, **filter_kwargs(state, fn)) + k2 = fn(time + h * c2, **add_scaled(state, [k1], [0.161], h)) + k3 = fn(time + h * c3, **add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h)) + k4 = fn( + time + h * c4, **add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h) + ) + k5 = fn( + time + h * c5, + **add_scaled( + state, [k1, k2, k3, k4], [4.325279681768730, -11.74888356406283, 7.495539342889836, -0.09249506636175525], h + ), + ) + k6 = fn( + time + h, + **add_scaled( + state, + [k1, k2, k3, k4, k5], + [5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838], + h, + ), + ) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + # 5th order solution: b coefficients + new_state = {} + for key in state.keys(): + new_state[key] = state[key] + h * ( + 0.09646076681806523 * k1[key] + + 0.01 * k2[key] + + 0.4798896504144996 * k3[key] + + 1.379008574103742 * k4[key] + - 3.290069515436081 * k5[key] + + 2.324710524099774 * k6[key] + ) - # consolidate step size - new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) - else: - new_step_size = step_size + if use_adaptive_step_size: + # 7th stage evaluation + k7 = fn(time + h, **filter_kwargs(new_state, fn)) + + # 4th order embedded solution: b_hat coefficients + y4 = {} + for key in state.keys(): + y4[key] = state[key] + h * ( + 0.001780011052226 * k1[key] + + 0.000816434459657 * k2[key] + - 0.007880878010262 * k3[key] + + 0.144711007173263 * k4[key] + - 0.582357165452555 * k5[key] + + 0.458082105929187 * k6[key] + + (1.0 / 66.0) * k7[key] + ) - # apply updates - new_state = state.copy() - for key in k1.keys(): - new_state[key] = state[key] + (step_size / 6.0) * (k1[key] + 2.0 * k2[key] + 2.0 * k3[key] + k4[key]) + # Error estimate + err_state = {} + for key in state.keys(): + err_state[key] = new_state[key] - y4[key] - new_time = time + step_size + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) + new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + else: + new_step_size = h + + new_time = time + h return new_state, new_time, new_step_size @@ -141,6 +251,8 @@ def integrate_fixed( step_fn = euler_step case "rk45": step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: @@ -180,6 +292,8 @@ def integrate_adaptive( step_fn = euler_step case "rk45": step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: @@ -249,6 +363,8 @@ def integrate_scheduled( step_fn = euler_step case "rk45": step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: @@ -401,11 +517,19 @@ def integrate_stochastic( steps: int, seed: keras.random.SeedGenerator, method: str = "euler_maruyama", + score_fn: Callable = None, + corrector_steps: int = 0, + noise_schedule=None, + step_size_factor: float = 0.1, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ Integrates a stochastic differential equation from start_time to stop_time. + When score_fn is provided, performs predictor-corrector sampling where: + - Predictor: reverse diffusion SDE solver + - Corrector: annealed Langevin dynamics with step size e = sqrt(dim) + Args: drift_fn: Function that computes the drift term. diffusion_fn: Function that computes the diffusion term. @@ -415,11 +539,15 @@ def integrate_stochastic( steps: Number of integration steps. seed: Random seed for noise generation. method: Integration method to use, e.g., 'euler_maruyama'. + score_fn: Optional score function for predictor-corrector sampling. + Should take (time, **state) and return score dict. + corrector_steps: Number of corrector steps to take after each predictor step. + noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. + step_size_factor: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. Returns: - If return_noise is False, returns the final state dictionary. - If return_noise is True, returns a tuple of (final_state, noise_history). + Final state dictionary after integration. """ if steps <= 0: raise ValueError("Number of steps must be positive.") @@ -438,17 +566,56 @@ def integrate_stochastic( step_size = (stop_time - start_time) / steps sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size)) - # Pre-generate noise history: shape = (steps, *state_shape) + # Pre-generate noise history for predictor: shape = (steps, *state_shape) noise_history = {} for key, val in state.items(): noise_history[key] = ( keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt ) + # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape) + corrector_noise_history = {} + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + + for key, val in state.items(): + corrector_noise_history[key] = keras.random.normal( + (steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed + ) + def body(_loop_var, _loop_state): _current_state, _current_time = _loop_state _noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()} + + # Predictor step new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) + + # Corrector steps: annealed Langevin dynamics if score_fn is provided + if corrector_steps > 0: + for corrector_step in range(corrector_steps): + score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) + _corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()} + + # Compute noise schedule components for corrector step size + log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) + alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector + # where e = 2*alpha_t * (r * ||z|| / ||score||)**2 + for k in new_state.keys(): + if k in score: + z_norm = keras.ops.norm(_corrector_noise[k], axis=-1, keepdims=True) + score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) + + # Prevent division by zero + score_norm = keras.ops.maximum(score_norm, 1e-8) + + e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 + sqrt_2e = keras.ops.sqrt(2.0 * e) + + new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] + return new_state, new_time final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time)) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index db5c448d7..75e65ca1d 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -1,4 +1,11 @@ import numpy as np +import keras +import pytest +from bayesflow.utils import integrate + + +TOLERANCE_ADAPTIVE = 1e-6 # Adaptive solvers should be very accurate. +TOLERANCE_EULER = 1e-3 # Euler with fixed steps requires a larger tolerance def test_scheduled_integration(): @@ -34,3 +41,24 @@ def fn(t, x): scipy_kwargs={"atol": 1e-6, "rtol": 1e-6}, )["x"] np.testing.assert_allclose(exact_result, result, atol=1e-6, rtol=1e-6) + + +@pytest.mark.parametrize( + "method, atol", [("euler", TOLERANCE_EULER), ("rk45", TOLERANCE_ADAPTIVE), ("tsit5", TOLERANCE_ADAPTIVE)] +) +def test_analytical_integration(method, atol): + def fn(t, x): + return {"x": keras.ops.convert_to_tensor([2.0 * t])} + + initial_state = {"x": keras.ops.convert_to_tensor([1.0])} + T_final = 2.0 + num_steps = 100 + analytical_result = 1.0 + T_final**2 + + result = integrate(fn, initial_state, start_time=0.0, stop_time=T_final, steps=num_steps, method=method)["x"] + result_adaptive = integrate( + fn, initial_state, start_time=0.0, stop_time=T_final, steps="adaptive", method=method, max_steps=1_000 + )["x"] + np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) + + np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.01) From 3cae88379fe4d05a5ef741ea1fd591d19378e231 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 17:19:58 +0100 Subject: [PATCH 02/36] add predictor corrector --- .../diffusion_model/diffusion_model.py | 76 ++++++++++++++++--- 1 file changed, 66 insertions(+), 10 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index ca8a634e9..bc77a884d 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -243,6 +243,55 @@ def _apply_subnet( else: return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training) + def score( + self, + xz: Tensor, + time: float | Tensor = None, + log_snr_t: Tensor = None, + conditions: Tensor = None, + training: bool = False, + ) -> Tensor: + """ + Computes the score of the target or latent variable `xz`. + + Parameters + ---------- + xz : Tensor + The current state of the latent variable `z`, typically of shape (..., D), + where D is the dimensionality of the latent space. + time : float or Tensor + Scalar or tensor representing the time (or noise level) at which the velocity + should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided. + log_snr_t : Tensor + The log signal-to-noise ratio at time `t`. If None, time must be provided. + conditions : Tensor, optional + Conditional inputs to the network, such as conditioning variables + or encoder outputs. Shape must be broadcastable with `xz`. Default is None. + training : bool, optional + Whether the model is in training mode. Affects behavior of dropout, batch norm, + or other stochastic layers. Default is False. + + Returns + ------- + Tensor + The velocity tensor of the same shape as `xz`, representing the right-hand + side of the SDE or ODE at the given `time`. + """ + if log_snr_t is None: + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + subnet_out = self._apply_subnet( + xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training + ) + pred = self.output_projector(subnet_out, training=training) + + x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) + + score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + return score + def velocity( self, xz: Tensor, @@ -279,19 +328,10 @@ def velocity( The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ - # calculate the current noise level and transform into correct shape log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - subnet_out = self._apply_subnet( - xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training - ) - pred = self.output_projector(subnet_out, training=training) - - x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) - score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -447,9 +487,25 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} + score_fn = None + if "corrector_steps" in integrate_kwargs: + if integrate_kwargs["corrector_steps"] > 0: + + def score_fn(time, xz): + return { + "xz": self.score( + xz, + time=time, + conditions=conditions, + training=training, + ) + } + state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, + score_fn=score_fn, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, From 39682b162a5562c7dd98e1c8813b39c8f570b4a6 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 18:09:47 +0100 Subject: [PATCH 03/36] add adaptive sampler SDE --- bayesflow/utils/integrate.py | 227 +++++++++++++++++++++++------ tests/test_utils/test_integrate.py | 108 +++++++++++++- 2 files changed, 288 insertions(+), 47 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a9bd6ea3f..86d94c1e1 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -1,4 +1,5 @@ from collections.abc import Callable, Sequence +from typing import Dict, Tuple, Optional from functools import partial import keras @@ -505,24 +506,115 @@ def euler_maruyama_step( base = base + diffusion[key] * noise[key] new_state[key] = base - return new_state, time + step_size + return new_state, time + step_size, step_size + + +def shark_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: Dict[str, ArrayLike], + time: ArrayLike, + step_size: ArrayLike, + noise: Dict[str, ArrayLike], + use_adaptive_step_size: bool = False, + tolerance: ArrayLike = 1e-3, + min_step_size: ArrayLike = 1e-6, + max_step_size: ArrayLike = float("inf"), + half_noises: Optional[Tuple[Dict[str, ArrayLike], Dict[str, ArrayLike]]] = None, + bridge_aux: Optional[Dict[str, ArrayLike]] = None, + validate_split: bool = True, +) -> Tuple[Dict[str, ArrayLike], ArrayLike] | Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: + """ + Shifted Additive-noise Runge-Kutta method for additive SDEs. + """ + h = step_size + t = time + + # full step: midpoint drift, diffusion at midpoint time + k1 = drift_fn(t, **filter_kwargs(state, drift_fn)) + mid_state = {k: state[k] + 0.5 * h * k1[k] for k in state} + k2 = drift_fn(t + 0.5 * h, **filter_kwargs(mid_state, drift_fn)) + g_mid = diffusion_fn(t + 0.5 * h, **filter_kwargs(state, diffusion_fn)) + + det_full = {k: state[k] + h * k2[k] for k in state} + sto_full = {k: g_mid[k] * noise[k] for k in g_mid} + y_full = {k: det_full[k] + sto_full.get(k, keras.ops.zeros_like(det_full[k])) for k in det_full} + + if not use_adaptive_step_size: + return y_full, t + h, h + + # prepare two half step noises without drawing randomness here + if half_noises is not None: + dW1, dW2 = half_noises + if set(dW1.keys()) != set(noise.keys()) or set(dW2.keys()) != set(noise.keys()): + raise ValueError("half_noises must have the same keys as noise") + if validate_split: + sum_diff = {k: dW1[k] + dW2[k] - noise[k] for k in noise} + parts = [] + for v in sum_diff.values(): + if not hasattr(v, "shape") or len(v.shape) == 0: + v = keras.ops.reshape(v, (1,)) + parts.append(keras.ops.norm(v, ord=2, axis=-1)) + if float(keras.ops.max(keras.ops.stack(parts))) > 1e-6: + raise ValueError("half_noises do not sum to provided noise") + else: + if bridge_aux is None: + raise ValueError("Provide either half_noises or bridge_aux when use_adaptive_step_size is True") + if set(bridge_aux.keys()) != set(noise.keys()): + raise ValueError("bridge_aux must have the same keys as noise") + sqrt_h = keras.ops.sqrt(h + 1e-12) + dW1 = {k: 0.5 * noise[k] + 0.5 * sqrt_h * bridge_aux[k] for k in noise} + dW2 = {k: noise[k] - dW1[k] for k in noise} + + half = 0.5 * h + + # first half step on [t, t + h 2] + k1h = drift_fn(t, **filter_kwargs(state, drift_fn)) + mid1 = {k: state[k] + 0.5 * half * k1h[k] for k in state} + k2h = drift_fn(t + 0.5 * half, **filter_kwargs(mid1, drift_fn)) + g_q1 = diffusion_fn(t + 0.5 * half, **filter_kwargs(state, diffusion_fn)) + y_half = {k: state[k] + half * k2h[k] + g_q1.get(k, 0) * dW1.get(k, 0) for k in state} + + # second half step on [t + h 2, t + h] + k1h2 = drift_fn(t + half, **filter_kwargs(y_half, drift_fn)) + mid2 = {k: y_half[k] + 0.5 * half * k1h2[k] for k in y_half} + k2h2 = drift_fn(t + 1.5 * half, **filter_kwargs(mid2, drift_fn)) + g_q2 = diffusion_fn(t + 1.5 * half, **filter_kwargs(state, diffusion_fn)) + y_twohalf = {k: y_half[k] + half * k2h2[k] + g_q2.get(k, 0) * dW2.get(k, 0) for k in y_half} + + # error estimate + parts = [] + for k in y_full: + v = y_full[k] - y_twohalf[k] + if not hasattr(v, "shape") or len(v.shape) == 0: + v = keras.ops.reshape(v, (1,)) + parts.append(keras.ops.norm(v, ord=2, axis=-1)) + err = keras.ops.max(keras.ops.stack(parts)) + + # controller for strong order one on additive noise local error ~ h^{3 2} + factor = 0.9 * (tolerance / (err + 1e-12)) ** (2.0 / 3.0) + h_new = keras.ops.clip(h * keras.ops.clip(factor, 0.2, 5.0), min_step_size, max_step_size) + + return y_full, t + h, h_new def integrate_stochastic( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: Dict[str, ArrayLike], start_time: ArrayLike, stop_time: ArrayLike, - steps: int, seed: keras.random.SeedGenerator, + min_steps: int = 10, + max_steps: int = 10_000, + steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, step_size_factor: float = 0.1, **kwargs, -) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: +) -> Union[Dict[str, ArrayLike], Tuple[Dict[str, ArrayLike], Dict[str, Sequence[ArrayLike]]]]: """ Integrates a stochastic differential equation from start_time to stop_time. @@ -535,88 +627,131 @@ def integrate_stochastic( diffusion_fn: Function that computes the diffusion term. state: Dictionary containing the initial state. start_time: Starting time for integration. - stop_time: Ending time for integration. - steps: Number of integration steps. + stop_time: Ending time for integration. steps: Number of integration steps. seed: Random seed for noise generation. - method: Integration method to use, e.g., 'euler_maruyama'. + min_steps: Minimum number of steps for adaptive integration. + max_steps: Maximum number of steps for adaptive integration. + steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. + method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. score_fn: Optional score function for predictor-corrector sampling. - Should take (time, **state) and return score dict. + Should take (time, **state) and return score dict. corrector_steps: Number of corrector steps to take after each predictor step. noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. step_size_factor: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. - Returns: - Final state dictionary after integration. + Returns: Final state dictionary after integration. """ - if steps <= 0: - raise ValueError("Number of steps must be positive.") + use_adaptive = False + if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: + if start_time is None or stop_time is None: + raise ValueError( + "Please provide start_time and stop_time for the integration, was " + f"'start_time={start_time}', 'stop_time={stop_time}'." + ) + if min_steps <= 0 or max_steps <= 0: + raise ValueError("min_steps and max_steps must be positive.") + if max_steps < min_steps: + raise ValueError("max_steps must be greater or equal to min_steps.") + use_adaptive = True + loop_steps = max_steps + initial_step = (stop_time - start_time) / float(min_steps) + elif isinstance(steps, int): + if steps <= 0: + raise ValueError("Number of steps must be positive.") + use_adaptive = False + loop_steps = steps + initial_step = (stop_time - start_time) / float(steps) + else: + raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") - # Select step function based on method match method: case "euler_maruyama": step_fn = euler_maruyama_step + if use_adaptive: + raise ValueError("Adaptive step size is not supported for Euler Maruyama method.") + case "shark": + step_fn = shark_step case other: raise TypeError(f"Invalid integration method: {other!r}") - # Prepare step function with partial application step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, **kwargs) - # Time step - step_size = (stop_time - start_time) / steps - sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size)) - - # Pre-generate noise history for predictor: shape = (steps, *state_shape) - noise_history = {} + # pre generate standard normals scale by sqrt(dt) inside the loop using the current dt + z_history = {} + bridge_history = {} for key, val in state.items(): - noise_history[key] = ( - keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt - ) + shape = keras.ops.shape(val) + z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + if method == "shark" and use_adaptive: + bridge_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape) + # pre generate corrector noise if requested corrector_noise_history = {} if corrector_steps > 0: if score_fn is None or noise_schedule is None: raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") - for key, val in state.items(): + shape = keras.ops.shape(val) corrector_noise_history[key] = keras.random.normal( - (steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed + (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed ) - def body(_loop_var, _loop_state): - _current_state, _current_time = _loop_state - _noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()} - - # Predictor step - new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) - - # Corrector steps: annealed Langevin dynamics if score_fn is provided + def body(_i, _loop_state): + _current_state, _current_time, _current_step = _loop_state + + # clamp last step to hit stop_time + remaining = stop_time - _current_time + dt = keras.ops.minimum(_current_step, remaining) + + sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) + _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} + + if method == "shark" and use_adaptive: + _bridge = {k: bridge_history[k][_i] for k in _current_state.keys()} + out = step_fn( + state=_current_state, + time=_current_time, + step_size=dt, + noise=_noise_i, + bridge_aux=_bridge, + use_adaptive_step_size=True, + ) + new_state, new_time, new_step = out + else: + out = step_fn(state=_current_state, time=_current_time, step_size=dt, noise=_noise_i) + if isinstance(out, tuple) and len(out) == 2: + new_state, new_time = out + new_step = _current_step + else: + new_state, new_time, new_step = out + + # corrector if corrector_steps > 0: - for corrector_step in range(corrector_steps): + for j in range(corrector_steps): score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) - _corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()} + _z_corr = {k: corrector_noise_history[k][_i, j] for k in new_state.keys()} - # Compute noise schedule components for corrector step size log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector - # where e = 2*alpha_t * (r * ||z|| / ||score||)**2 for k in new_state.keys(): if k in score: - z_norm = keras.ops.norm(_corrector_noise[k], axis=-1, keepdims=True) + z_norm = keras.ops.norm(_z_corr[k], axis=-1, keepdims=True) score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) - - # Prevent division by zero score_norm = keras.ops.maximum(score_norm, 1e-8) e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 - sqrt_2e = keras.ops.sqrt(2.0 * e) + new_state[k] = new_state[k] + e * score[k] + keras.ops.sqrt(2.0 * e) * _z_corr[k] - new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] + return new_state, new_time, new_step - return new_state, new_time + final_state, final_time, last_step = keras.ops.fori_loop(0, loop_steps, body, (state, start_time, initial_step)) + + if use_adaptive and float(final_time) < float(stop_time): + logging.warning( + f"Reached max_steps={max_steps} before stop_time. " + f"final_time={float(final_time)} stop_time={float(stop_time)}" + ) - final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time)) return final_state diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 75e65ca1d..142a0bbb7 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -1,12 +1,17 @@ import numpy as np import keras import pytest -from bayesflow.utils import integrate +from bayesflow.utils import integrate, integrate_stochastic TOLERANCE_ADAPTIVE = 1e-6 # Adaptive solvers should be very accurate. TOLERANCE_EULER = 1e-3 # Euler with fixed steps requires a larger tolerance +# tolerances for SDE tests +TOL_MEAN = 3e-2 +TOL_VAR = 5e-2 +TOL_DET = 1e-3 + def test_scheduled_integration(): import keras @@ -62,3 +67,104 @@ def fn(t, x): np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.01) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("shark", False), + ("shark", True), + ], +) +def test_additive_OU_weak_means_and_vars(method, use_adapt): + """ + Ornstein Uhlenbeck with additive noise + dX = a X dt + sigma dW + Exact at time T: + E[X_T] = x0 * exp(a T) + Var[X_T] = sigma^2 * (exp(2 a T) - 1) / (2 a) + We verify weak accuracy by matching empirical mean and variance. + """ + # SDE parameters + a = -1.0 + sigma = 0.5 + x0 = 1.2 + T = 1.0 + + # batch of trajectories + N = 20000 # large enough to control sampling error + seed = keras.random.SeedGenerator(42) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + # additive noise, independent of state + return {"x": keras.ops.convert_to_tensor([sigma])} + + initial_state = {"x": keras.ops.ones((N,)) * x0} + steps = 200 if not use_adapt else "adaptive" + + # expected mean and variance + exp_mean = x0 * np.exp(a * T) + exp_var = sigma**2 * (np.exp(2.0 * a * T) - 1.0) / (2.0 * a) + + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=0.0, + stop_time=T, + steps=steps, + seed=seed, + method=method, + ) + + xT = np.array(out["x"]) + emp_mean = float(xT.mean()) + emp_var = float(xT.var()) + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN, rtol=0.0) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR, rtol=0.0) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("shark", False), + ("shark", True), + ], +) +def test_zero_noise_reduces_to_deterministic(method, use_adapt): + """ + With zero diffusion the SDE reduces to the ODE + dX = a X dt + """ + a = 0.7 + x0 = 0.9 + T = 1.25 + steps = 200 if not use_adapt else "adaptive" + seed = keras.random.SeedGenerator(999) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + # identically zero diffusion + return {"x": keras.ops.convert_to_tensor([0.0])} + + initial_state = {"x": keras.ops.ones((256,)) * x0} + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=0.0, + stop_time=T, + steps=steps, + seed=seed, + method=method, + )["x"] + + exact = x0 * np.exp(a * T) + np.testing.assert_allclose(np.array(out).mean(), exact, atol=TOL_DET, rtol=0.1) From 770abc72ea00af98226581382b2a13d6c3f34dda Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 18:24:53 +0100 Subject: [PATCH 04/36] add shark --- bayesflow/networks/diffusion_model/diffusion_model.py | 6 +++--- bayesflow/utils/integrate.py | 7 +------ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index bc77a884d..659641f51 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -408,7 +408,7 @@ def _forward( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: raise ValueError("Stochastic methods are not supported for forward integration.") if density: @@ -458,7 +458,7 @@ def _inverse( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if density: - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): @@ -477,7 +477,7 @@ def deltas(time, xz): return x, log_density state = {"xz": z} - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: def deltas(time, xz): return { diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 86d94c1e1..d7935c1ee 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -717,14 +717,9 @@ def body(_i, _loop_state): bridge_aux=_bridge, use_adaptive_step_size=True, ) - new_state, new_time, new_step = out else: out = step_fn(state=_current_state, time=_current_time, step_size=dt, noise=_noise_i) - if isinstance(out, tuple) and len(out) == 2: - new_state, new_time = out - new_step = _current_step - else: - new_state, new_time, new_step = out + new_state, new_time, new_step = out # corrector if corrector_steps > 0: From eba68922c73bc699d1d1f1eecaf6215e3597e9a2 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 18:30:26 +0100 Subject: [PATCH 05/36] rm warn --- bayesflow/utils/integrate.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index d7935c1ee..ed35f9b71 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -742,11 +742,4 @@ def body(_i, _loop_state): return new_state, new_time, new_step final_state, final_time, last_step = keras.ops.fori_loop(0, loop_steps, body, (state, start_time, initial_step)) - - if use_adaptive and float(final_time) < float(stop_time): - logging.warning( - f"Reached max_steps={max_steps} before stop_time. " - f"final_time={float(final_time)} stop_time={float(stop_time)}" - ) - return final_state From e901b733c8610f53f59f4b69ad142b8ab9704eba Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 18:35:03 +0100 Subject: [PATCH 06/36] fix dt --- bayesflow/utils/integrate.py | 6 +++--- tests/test_utils/test_integrate.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index ed35f9b71..bd0a580ae 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -699,10 +699,10 @@ def integrate_stochastic( def body(_i, _loop_state): _current_state, _current_time, _current_step = _loop_state - - # clamp last step to hit stop_time remaining = stop_time - _current_time - dt = keras.ops.minimum(_current_step, remaining) + sign = keras.ops.sign(remaining) + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) + dt = sign * dt_mag sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 142a0bbb7..d328f9476 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -119,6 +119,7 @@ def diffusion_fn(t, x): steps=steps, seed=seed, method=method, + max_steps=1_000, ) xT = np.array(out["x"]) @@ -164,6 +165,7 @@ def diffusion_fn(t, x): steps=steps, seed=seed, method=method, + max_steps=1_000, )["x"] exact = x0 * np.exp(a * T) From e8be555ed14730786799e7a6d6a8f131d9cd6b36 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 19:06:16 +0100 Subject: [PATCH 07/36] fix adaptive step size --- bayesflow/utils/integrate.py | 44 ++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index bd0a580ae..9e61534c4 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -475,6 +475,8 @@ def euler_maruyama_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], + min_step_size: ArrayLike = None, + max_step_size: ArrayLike = None, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): """ Performs a single Euler-Maruyama step for stochastic differential equations. @@ -486,6 +488,8 @@ def euler_maruyama_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. + min_step_size: Minimum allowed step size (not used here). + max_step_size: Maximum allowed step size (not used here). Returns: new_state: Updated state after one Euler-Maruyama step. @@ -516,19 +520,22 @@ def shark_step( time: ArrayLike, step_size: ArrayLike, noise: Dict[str, ArrayLike], + min_step_size: ArrayLike, + max_step_size: ArrayLike, use_adaptive_step_size: bool = False, tolerance: ArrayLike = 1e-3, - min_step_size: ArrayLike = 1e-6, - max_step_size: ArrayLike = float("inf"), half_noises: Optional[Tuple[Dict[str, ArrayLike], Dict[str, ArrayLike]]] = None, bridge_aux: Optional[Dict[str, ArrayLike]] = None, validate_split: bool = True, -) -> Tuple[Dict[str, ArrayLike], ArrayLike] | Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: +) -> Union[Tuple[Dict[str, ArrayLike], ArrayLike], Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]]: """ - Shifted Additive-noise Runge-Kutta method for additive SDEs. + Shifted Additive noise Runge Kutta for additive SDEs. """ + # direction aware handling h = step_size t = time + h_sign = keras.ops.sign(h) + h_mag = keras.ops.abs(h) # full step: midpoint drift, diffusion at midpoint time k1 = drift_fn(t, **filter_kwargs(state, drift_fn)) @@ -562,20 +569,20 @@ def shark_step( raise ValueError("Provide either half_noises or bridge_aux when use_adaptive_step_size is True") if set(bridge_aux.keys()) != set(noise.keys()): raise ValueError("bridge_aux must have the same keys as noise") - sqrt_h = keras.ops.sqrt(h + 1e-12) + sqrt_h = keras.ops.sqrt(h_mag + 1e-12) # use magnitude dW1 = {k: 0.5 * noise[k] + 0.5 * sqrt_h * bridge_aux[k] for k in noise} dW2 = {k: noise[k] - dW1[k] for k in noise} half = 0.5 * h - # first half step on [t, t + h 2] + # first half step k1h = drift_fn(t, **filter_kwargs(state, drift_fn)) mid1 = {k: state[k] + 0.5 * half * k1h[k] for k in state} k2h = drift_fn(t + 0.5 * half, **filter_kwargs(mid1, drift_fn)) g_q1 = diffusion_fn(t + 0.5 * half, **filter_kwargs(state, diffusion_fn)) y_half = {k: state[k] + half * k2h[k] + g_q1.get(k, 0) * dW1.get(k, 0) for k in state} - # second half step on [t + h 2, t + h] + # second half step k1h2 = drift_fn(t + half, **filter_kwargs(y_half, drift_fn)) mid2 = {k: y_half[k] + 0.5 * half * k1h2[k] for k in y_half} k2h2 = drift_fn(t + 1.5 * half, **filter_kwargs(mid2, drift_fn)) @@ -591,9 +598,14 @@ def shark_step( parts.append(keras.ops.norm(v, ord=2, axis=-1)) err = keras.ops.max(keras.ops.stack(parts)) - # controller for strong order one on additive noise local error ~ h^{3 2} + # controller for strong order one on additive noise factor = 0.9 * (tolerance / (err + 1e-12)) ** (2.0 / 3.0) - h_new = keras.ops.clip(h * keras.ops.clip(factor, 0.2, 5.0), min_step_size, max_step_size) + h_prop = h * keras.ops.clip(factor, 0.2, 5.0) + + # clip by magnitude bounds then restore original sign + mag = keras.ops.abs(h_prop) + mag_new = keras.ops.clip(mag, min_step_size, max_step_size) + h_new = h_sign * mag_new return y_full, t + h, h_new @@ -656,12 +668,17 @@ def integrate_stochastic( use_adaptive = True loop_steps = max_steps initial_step = (stop_time - start_time) / float(min_steps) + + span_mag = keras.ops.abs(stop_time - start_time) + min_step_size = span_mag / keras.ops.cast(max_steps, span_mag.dtype) + max_step_size = span_mag / keras.ops.cast(min_steps, span_mag.dtype) elif isinstance(steps, int): if steps <= 0: raise ValueError("Number of steps must be positive.") use_adaptive = False loop_steps = steps initial_step = (stop_time - start_time) / float(steps) + min_step_size, max_step_size = initial_step, initial_step else: raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") @@ -675,7 +692,14 @@ def integrate_stochastic( case other: raise TypeError(f"Invalid integration method: {other!r}") - step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, **kwargs) + step_fn = partial( + step_fn, + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + min_step_size=min_step_size, + max_step_size=max_step_size, + **kwargs, + ) # pre generate standard normals scale by sqrt(dt) inside the loop using the current dt z_history = {} From 36a16b35c5d3ee04652252ea2d902585cfedc2d4 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 19:31:32 +0100 Subject: [PATCH 08/36] refactor stochastic integrator --- bayesflow/utils/integrate.py | 307 ++++++++++++++++++++++++++++++++++- 1 file changed, 302 insertions(+), 5 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 9e61534c4..428e4e3b5 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -15,6 +15,7 @@ from . import logging ArrayLike = int | float | Tensor +StateDict = Dict[str, ArrayLike] def euler_step( @@ -475,6 +476,7 @@ def euler_maruyama_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], + use_adaptive_step_size: bool = False, min_step_size: ArrayLike = None, max_step_size: ArrayLike = None, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): @@ -488,6 +490,7 @@ def euler_maruyama_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. + use_adaptive_step_size: Whether to use adaptive step sizing (not used here). min_step_size: Minimum allowed step size (not used here). max_step_size: Maximum allowed step size (not used here). @@ -610,21 +613,21 @@ def shark_step( return y_full, t + h, h_new -def integrate_stochastic( +def integrate_stochastic_old( drift_fn: Callable, diffusion_fn: Callable, state: Dict[str, ArrayLike], start_time: ArrayLike, stop_time: ArrayLike, seed: keras.random.SeedGenerator, - min_steps: int = 10, - max_steps: int = 10_000, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, step_size_factor: float = 0.1, + min_steps: int = 10, + max_steps: int = 1_000, **kwargs, ) -> Union[Dict[str, ArrayLike], Tuple[Dict[str, ArrayLike], Dict[str, Sequence[ArrayLike]]]]: """ @@ -641,8 +644,6 @@ def integrate_stochastic( start_time: Starting time for integration. stop_time: Ending time for integration. steps: Number of integration steps. seed: Random seed for noise generation. - min_steps: Minimum number of steps for adaptive integration. - max_steps: Maximum number of steps for adaptive integration. steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. score_fn: Optional score function for predictor-corrector sampling. @@ -650,6 +651,8 @@ def integrate_stochastic( corrector_steps: Number of corrector steps to take after each predictor step. noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. step_size_factor: Scaling factor for corrector step size. + min_steps: Minimum number of steps for adaptive integration. + max_steps: Maximum number of steps for adaptive integration. **kwargs: Additional arguments to pass to the step function. Returns: Final state dictionary after integration. @@ -767,3 +770,297 @@ def body(_i, _loop_state): final_state, final_time, last_step = keras.ops.fori_loop(0, loop_steps, body, (state, start_time, initial_step)) return final_state + + +def _apply_corrector( + new_state: StateDict, + new_time: ArrayLike, + i: ArrayLike, + corrector_steps: int, + score_fn: Optional[Callable], + step_size_factor: float, + corrector_noise_history: Dict[str, ArrayLike], + noise_schedule=None, +) -> StateDict: + """Helper function to apply corrector steps.""" + if corrector_steps <= 0: + return new_state + + # Ensures score_fn and noise_schedule are present if needed, though checked in integrate_stochastic + if score_fn is None or noise_schedule is None: + return new_state # Should not happen if checks are passed + + for j in range(corrector_steps): + score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) + _z_corr = {k: corrector_noise_history[k][i, j] for k in new_state.keys()} + + log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) + alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + for k in new_state.keys(): + if k in score: + # Calculate required norms for Langevin step + z_norm = keras.ops.norm(_z_corr[k], axis=-1, keepdims=True) + score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) + score_norm = keras.ops.maximum(score_norm, 1e-8) + + # Compute step size 'e' for the Langevin update + e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 + + # Annealed Langevin Dynamics update + new_state[k] = new_state[k] + e * score[k] + keras.ops.sqrt(2.0 * e) * _z_corr[k] + return new_state + + +def integrate_stochastic_fixed( + step_fn: Callable, + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + steps: int, + z_history: Dict[str, ArrayLike], + corrector_steps: int, + score_fn: Optional[Callable], + step_size_factor: float, + corrector_noise_history: Dict[str, ArrayLike], + noise_schedule=None, +) -> StateDict: + """ + Performs fixed-step SDE integration. + """ + initial_step = (stop_time - start_time) / float(steps) + + def body_fixed(_i, _loop_state): + _current_state, _current_time, _current_step = _loop_state + + # Determine step size: either the constant size or the remainder to reach stop_time + remaining = stop_time - _current_time + sign = keras.ops.sign(remaining) + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) + dt = sign * dt_mag + + # Generate noise increment scaled by sqrt(dt) + sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) + _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} + + new_state, new_time, new_step = step_fn( + state=_current_state, + time=_current_time, + step_size=dt, + noise=_noise_i, + use_adaptive_step_size=False, + ) + + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + return new_state, new_time, initial_step + + # Execute the fixed loop + final_state, final_time, _ = keras.ops.fori_loop(0, steps, body_fixed, (state, start_time, initial_step)) + return final_state + + +def integrate_stochastic_adaptive( + step_fn: Callable, + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + max_steps: int, + initial_step: ArrayLike, + z_history: Dict[str, ArrayLike], + bridge_history: Dict[str, ArrayLike], + corrector_steps: int, + score_fn: Optional[Callable], + step_size_factor: float, + corrector_noise_history: Dict[str, ArrayLike], + noise_schedule=None, +) -> StateDict: + """ + Performs adaptive-step SDE integration. + """ + initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step) + + def cond(i, current_state, current_time, current_step): + # We use a small epsilon check for floating point equality + time_reached = keras.ops.all(keras.ops.isclose(current_time, stop_time)) + return keras.ops.logical_and(keras.ops.less(i, max_steps), keras.ops.logical_not(time_reached)) + + def body_adaptive(_i, _current_state, _current_time, _current_step): + # Step Size Control + remaining = stop_time - _current_time + sign = keras.ops.sign(remaining) + # Ensure the next step does not overshoot the stop_time + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) + dt = sign * dt_mag + + sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) + _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} + _bridge = {k: bridge_history[k][_i] for k in _current_state.keys()} + + new_state, new_time, new_step = step_fn( + state=_current_state, + time=_current_time, + step_size=dt, + noise=_noise_i, + bridge_aux=_bridge, + use_adaptive_step_size=True, + ) + + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + + return _i + 1, new_state, new_time, new_step + + # Execute the adaptive loop + _, final_state, _, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + return final_state + + +def integrate_stochastic( + drift_fn: Callable, + diffusion_fn: Callable, + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + seed: keras.random.SeedGenerator, + steps: int | Literal["adaptive"] = 100, + method: str = "euler_maruyama", + score_fn: Callable = None, + corrector_steps: int = 0, + noise_schedule=None, + step_size_factor: float = 0.1, + min_steps: int = 10, + max_steps: int = 10_000, + **kwargs, +) -> StateDict: + """ + Integrates a stochastic differential equation from start_time to stop_time. + + Dispatches to fixed-step or adaptive-step integration logic. + + Args: + drift_fn: Function that computes the drift term. + diffusion_fn: Function that computes the diffusion term. + state: Dictionary containing the initial state. + start_time: Starting time for integration. + stop_time: Ending time for integration. steps: Number of integration steps. + seed: Random seed for noise generation. + steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. + method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. + score_fn: Optional score function for predictor-corrector sampling. + corrector_steps: Number of corrector steps to take after each predictor step. + noise_schedule: Noise schedule object for computing alpha_t in corrector. + step_size_factor: Scaling factor for corrector step size. + min_steps: Minimum number of steps for adaptive integration. + max_steps: Maximum number of steps for adaptive integration. + **kwargs: Additional arguments to pass to the step function. + + Returns: Final state dictionary after integration. + """ + is_adaptive = isinstance(steps, str) and steps in ["adaptive", "dynamic"] + if is_adaptive: + if start_time is None or stop_time is None: + raise ValueError("Please provide start_time and stop_time for adaptive integration.") + if min_steps <= 0 or max_steps <= 0 or max_steps < min_steps: + raise ValueError("min_steps and max_steps must be positive, and max_steps >= min_steps.") + if method != "shark": + raise ValueError("Adaptive step size is only supported for the 'shark' method.") + + loop_steps = max_steps + initial_step = (stop_time - start_time) / float(min_steps) + span_mag = keras.ops.abs(stop_time - start_time) + min_step_size = span_mag / keras.ops.cast(max_steps, span_mag.dtype) + max_step_size = span_mag / keras.ops.cast(min_steps, span_mag.dtype) + else: + if steps <= 0: + raise ValueError("Number of steps must be positive.") + loop_steps = int(steps) + initial_step = (stop_time - start_time) / float(loop_steps) + # For fixed step, min/max step size are just the fixed step size + min_step_size, max_step_size = initial_step, initial_step + + match method: + case "euler_maruyama": + step_fn_raw = euler_maruyama_step + case "shark": + step_fn_raw = shark_step + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + # Partial the step function with common arguments + step_fn = partial( + step_fn_raw, + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + min_step_size=min_step_size, + max_step_size=max_step_size, + **kwargs, + ) + + # Pre-generate standard normals for the predictor step (up to max_steps) + z_history = {} + bridge_history = {} + for key, val in state.items(): + shape = keras.ops.shape(val) + z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + if is_adaptive and method == "shark": + # Only required for SHARK adaptive step (Brownian Bridge aux noise) + bridge_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + + # Pre-generate corrector noise if requested + corrector_noise_history = {} + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + for key, val in state.items(): + shape = keras.ops.shape(val) + corrector_noise_history[key] = keras.random.normal( + (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed + ) + + if is_adaptive: + return integrate_stochastic_adaptive( + step_fn=step_fn, + state=state, + start_time=start_time, + stop_time=stop_time, + max_steps=max_steps, + initial_step=initial_step, + z_history=z_history, + bridge_history=bridge_history, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + else: + return integrate_stochastic_fixed( + step_fn=step_fn, + state=state, + start_time=start_time, + stop_time=stop_time, + steps=loop_steps, + z_history=z_history, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) From de57eaf100106ae525091e79ff5248e93fca3099 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 19:36:06 +0100 Subject: [PATCH 09/36] refactor stochastic integrator --- bayesflow/utils/integrate.py | 177 ++--------------------------------- 1 file changed, 10 insertions(+), 167 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 428e4e3b5..3f9cdee33 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -613,165 +613,6 @@ def shark_step( return y_full, t + h, h_new -def integrate_stochastic_old( - drift_fn: Callable, - diffusion_fn: Callable, - state: Dict[str, ArrayLike], - start_time: ArrayLike, - stop_time: ArrayLike, - seed: keras.random.SeedGenerator, - steps: int | Literal["adaptive"] = 100, - method: str = "euler_maruyama", - score_fn: Callable = None, - corrector_steps: int = 0, - noise_schedule=None, - step_size_factor: float = 0.1, - min_steps: int = 10, - max_steps: int = 1_000, - **kwargs, -) -> Union[Dict[str, ArrayLike], Tuple[Dict[str, ArrayLike], Dict[str, Sequence[ArrayLike]]]]: - """ - Integrates a stochastic differential equation from start_time to stop_time. - - When score_fn is provided, performs predictor-corrector sampling where: - - Predictor: reverse diffusion SDE solver - - Corrector: annealed Langevin dynamics with step size e = sqrt(dim) - - Args: - drift_fn: Function that computes the drift term. - diffusion_fn: Function that computes the diffusion term. - state: Dictionary containing the initial state. - start_time: Starting time for integration. - stop_time: Ending time for integration. steps: Number of integration steps. - seed: Random seed for noise generation. - steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. - method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. - score_fn: Optional score function for predictor-corrector sampling. - Should take (time, **state) and return score dict. - corrector_steps: Number of corrector steps to take after each predictor step. - noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. - step_size_factor: Scaling factor for corrector step size. - min_steps: Minimum number of steps for adaptive integration. - max_steps: Maximum number of steps for adaptive integration. - **kwargs: Additional arguments to pass to the step function. - - Returns: Final state dictionary after integration. - """ - use_adaptive = False - if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: - if start_time is None or stop_time is None: - raise ValueError( - "Please provide start_time and stop_time for the integration, was " - f"'start_time={start_time}', 'stop_time={stop_time}'." - ) - if min_steps <= 0 or max_steps <= 0: - raise ValueError("min_steps and max_steps must be positive.") - if max_steps < min_steps: - raise ValueError("max_steps must be greater or equal to min_steps.") - use_adaptive = True - loop_steps = max_steps - initial_step = (stop_time - start_time) / float(min_steps) - - span_mag = keras.ops.abs(stop_time - start_time) - min_step_size = span_mag / keras.ops.cast(max_steps, span_mag.dtype) - max_step_size = span_mag / keras.ops.cast(min_steps, span_mag.dtype) - elif isinstance(steps, int): - if steps <= 0: - raise ValueError("Number of steps must be positive.") - use_adaptive = False - loop_steps = steps - initial_step = (stop_time - start_time) / float(steps) - min_step_size, max_step_size = initial_step, initial_step - else: - raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") - - match method: - case "euler_maruyama": - step_fn = euler_maruyama_step - if use_adaptive: - raise ValueError("Adaptive step size is not supported for Euler Maruyama method.") - case "shark": - step_fn = shark_step - case other: - raise TypeError(f"Invalid integration method: {other!r}") - - step_fn = partial( - step_fn, - drift_fn=drift_fn, - diffusion_fn=diffusion_fn, - min_step_size=min_step_size, - max_step_size=max_step_size, - **kwargs, - ) - - # pre generate standard normals scale by sqrt(dt) inside the loop using the current dt - z_history = {} - bridge_history = {} - for key, val in state.items(): - shape = keras.ops.shape(val) - z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - if method == "shark" and use_adaptive: - bridge_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - - # pre generate corrector noise if requested - corrector_noise_history = {} - if corrector_steps > 0: - if score_fn is None or noise_schedule is None: - raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") - for key, val in state.items(): - shape = keras.ops.shape(val) - corrector_noise_history[key] = keras.random.normal( - (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed - ) - - def body(_i, _loop_state): - _current_state, _current_time, _current_step = _loop_state - remaining = stop_time - _current_time - sign = keras.ops.sign(remaining) - dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) - dt = sign * dt_mag - - sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) - _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} - - if method == "shark" and use_adaptive: - _bridge = {k: bridge_history[k][_i] for k in _current_state.keys()} - out = step_fn( - state=_current_state, - time=_current_time, - step_size=dt, - noise=_noise_i, - bridge_aux=_bridge, - use_adaptive_step_size=True, - ) - else: - out = step_fn(state=_current_state, time=_current_time, step_size=dt, noise=_noise_i) - new_state, new_time, new_step = out - - # corrector - if corrector_steps > 0: - for j in range(corrector_steps): - score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) - _z_corr = {k: corrector_noise_history[k][_i, j] for k in new_state.keys()} - - log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) - alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - for k in new_state.keys(): - if k in score: - z_norm = keras.ops.norm(_z_corr[k], axis=-1, keepdims=True) - score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) - score_norm = keras.ops.maximum(score_norm, 1e-8) - - e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 - new_state[k] = new_state[k] + e * score[k] + keras.ops.sqrt(2.0 * e) * _z_corr[k] - - return new_state, new_time, new_step - - final_state, final_time, last_step = keras.ops.fori_loop(0, loop_steps, body, (state, start_time, initial_step)) - return final_state - - def _apply_corrector( new_state: StateDict, new_time: ArrayLike, @@ -886,20 +727,21 @@ def integrate_stochastic_adaptive( """ Performs adaptive-step SDE integration. """ - initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step) + initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0) def cond(i, current_state, current_time, current_step): # We use a small epsilon check for floating point equality time_reached = keras.ops.all(keras.ops.isclose(current_time, stop_time)) return keras.ops.logical_and(keras.ops.less(i, max_steps), keras.ops.logical_not(time_reached)) - def body_adaptive(_i, _current_state, _current_time, _current_step): + def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): # Step Size Control remaining = stop_time - _current_time sign = keras.ops.sign(remaining) # Ensure the next step does not overshoot the stop_time dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) dt = sign * dt_mag + _counter += 1 sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} @@ -925,10 +767,11 @@ def body_adaptive(_i, _current_state, _current_time, _current_step): corrector_noise_history=corrector_noise_history, ) - return _i + 1, new_state, new_time, new_step + return _i + 1, new_state, new_time, new_step, _counter # Execute the adaptive loop - _, final_state, _, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + _, final_state, _, counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + logging.debug("Finished integration after {} steps.", counter) return final_state @@ -941,12 +784,12 @@ def integrate_stochastic( seed: keras.random.SeedGenerator, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", + min_steps: int = 10, + max_steps: int = 10_000, score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, step_size_factor: float = 0.1, - min_steps: int = 10, - max_steps: int = 10_000, **kwargs, ) -> StateDict: """ @@ -963,12 +806,12 @@ def integrate_stochastic( seed: Random seed for noise generation. steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. + min_steps: Minimum number of steps for adaptive integration. + max_steps: Maximum number of steps for adaptive integration. score_fn: Optional score function for predictor-corrector sampling. corrector_steps: Number of corrector steps to take after each predictor step. noise_schedule: Noise schedule object for computing alpha_t in corrector. step_size_factor: Scaling factor for corrector step size. - min_steps: Minimum number of steps for adaptive integration. - max_steps: Maximum number of steps for adaptive integration. **kwargs: Additional arguments to pass to the step function. Returns: Final state dictionary after integration. From dde5451e5026636a0657530b0264ac72dee6dfe1 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 19:38:25 +0100 Subject: [PATCH 10/36] refactor stochastic integrator --- bayesflow/utils/integrate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 3f9cdee33..0249bf131 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -729,7 +729,7 @@ def integrate_stochastic_adaptive( """ initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0) - def cond(i, current_state, current_time, current_step): + def cond(i, current_state, current_time, current_step, counter): # We use a small epsilon check for floating point equality time_reached = keras.ops.all(keras.ops.isclose(current_time, stop_time)) return keras.ops.logical_and(keras.ops.less(i, max_steps), keras.ops.logical_not(time_reached)) @@ -770,8 +770,8 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): return _i + 1, new_state, new_time, new_step, _counter # Execute the adaptive loop - _, final_state, _, counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) - logging.debug("Finished integration after {} steps.", counter) + _, final_state, _, final_counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + logging.debug("Finished integration after {} steps.", final_counter) return final_state From 3d2c80ea125ee84fc9090295448835861ba279f2 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 19:44:14 +0100 Subject: [PATCH 11/36] fix adaptive --- bayesflow/utils/integrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 0249bf131..9d4683dc7 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -770,7 +770,7 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): return _i + 1, new_state, new_time, new_step, _counter # Execute the adaptive loop - _, final_state, _, final_counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + _, final_state, _, _, final_counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) logging.debug("Finished integration after {} steps.", final_counter) return final_state From ed5e89fd1706eee0e3b5636d5a362d3707dc46e0 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 25 Nov 2025 17:16:40 +0100 Subject: [PATCH 12/36] fix Tsit5 --- bayesflow/utils/integrate.py | 58 +++++++++++------------------- tests/test_utils/test_integrate.py | 2 +- 2 files changed, 21 insertions(+), 39 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 9d4683dc7..6ae7ed0b8 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -31,30 +31,19 @@ def euler_step( k1 = fn(time, **filter_kwargs(state, fn)) if use_adaptive_step_size: - # Use Heun's method (RK2) as embedded pair for proper error estimation - intermediate_state = state.copy() - for key, delta in k1.items(): - intermediate_state[key] = state[key] + step_size * delta + # Euler step + y_euler = {k: state[k] + step_size * k1[k] for k in state} - k2 = fn(time + step_size, **filter_kwargs(intermediate_state, fn)) + # Heun slope + k2 = fn(time + step_size, **filter_kwargs(y_euler, fn)) - # check all keys are equal - if set(k1.keys()) != set(k2.keys()): - raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") + # error = (h/2) (k2 - k1) + err_state = {k: 0.5 * step_size * (k2[k] - k1[k]) for k in state} - # Heun's (RK2) solution - heun_state = state.copy() - for key in k1.keys(): - heun_state[key] = state[key] + 0.5 * step_size * (k1[key] + k2[key]) - - # Error estimate: difference between Euler and Heun - intermediate_error = keras.ops.stack( - [keras.ops.norm(heun_state[key] - intermediate_state[key], ord=2, axis=-1) for key in k1] - ) - - max_error = keras.ops.max(intermediate_error) - new_step_size = step_size * keras.ops.sqrt(tolerance / (max_error + 1e-9)) + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) + new_step_size = step_size * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.5, 0.2, 5.0) new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) else: new_step_size = step_size @@ -177,7 +166,7 @@ def tsit5_step( k5 = fn( time + h * c5, **add_scaled( - state, [k1, k2, k3, k4], [4.325279681768730, -11.74888356406283, 7.495539342889836, -0.09249506636175525], h + state, [k1, k2, k3, k4], [5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525], h ), ) k6 = fn( @@ -203,26 +192,19 @@ def tsit5_step( ) if use_adaptive_step_size: - # 7th stage evaluation k7 = fn(time + h, **filter_kwargs(new_state, fn)) - # 4th order embedded solution: b_hat coefficients - y4 = {} - for key in state.keys(): - y4[key] = state[key] + h * ( - 0.001780011052226 * k1[key] - + 0.000816434459657 * k2[key] - - 0.007880878010262 * k3[key] - + 0.144711007173263 * k4[key] - - 0.582357165452555 * k5[key] - + 0.458082105929187 * k6[key] - + (1.0 / 66.0) * k7[key] - ) - - # Error estimate err_state = {} for key in state.keys(): - err_state[key] = new_state[key] - y4[key] + err_state[key] = h * ( + -0.00178001105222577714 * k1[key] + - 0.0008164344596567469 * k2[key] + + 0.007880878010261995 * k3[key] + - 0.1447110071732629 * k4[key] + + 0.5823571654525552 * k5[key] + - 0.45808210592918697 * k6[key] + + 0.015151515151515152 * k7[key] + ) err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) err = keras.ops.max(err_norm) @@ -230,7 +212,7 @@ def tsit5_step( new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) else: - new_step_size = h + new_step_size = step_size new_time = time + h return new_state, new_time, new_step_size diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index d328f9476..765032c43 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -64,8 +64,8 @@ def fn(t, x): result_adaptive = integrate( fn, initial_state, start_time=0.0, stop_time=T_final, steps="adaptive", method=method, max_steps=1_000 )["x"] - np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) + np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.01) From c4b52a770c90d33c852911bb3387a1fd0b29e258 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 25 Nov 2025 20:18:34 +0100 Subject: [PATCH 13/36] fix sampler --- bayesflow/utils/integrate.py | 39 +++++++++--------------------- tests/test_utils/test_integrate.py | 11 ++++++--- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 6ae7ed0b8..a79c98b18 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -28,33 +28,17 @@ def euler_step( max_step_size: ArrayLike = float("inf"), use_adaptive_step_size: bool = False, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): - k1 = fn(time, **filter_kwargs(state, fn)) - if use_adaptive_step_size: - # Euler step - y_euler = {k: state[k] + step_size * k1[k] for k in state} - - # Heun slope - k2 = fn(time + step_size, **filter_kwargs(y_euler, fn)) - - # error = (h/2) (k2 - k1) - err_state = {k: 0.5 * step_size * (k2[k] - k1[k]) for k in state} - - err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) - err = keras.ops.max(err_norm) + raise ValueError("Adaptive step size not supported for Euler method.") - new_step_size = step_size * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.5, 0.2, 5.0) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) - else: - new_step_size = step_size + k1 = fn(time, **filter_kwargs(state, fn)) new_state = state.copy() for key in k1.keys(): new_state[key] = state[key] + step_size * k1[key] - new_time = time + step_size - return new_state, new_time, new_step_size + return new_state, new_time, step_size def add_scaled(state, ks, coeffs, h): @@ -224,7 +208,7 @@ def integrate_fixed( start_time: ArrayLike, stop_time: ArrayLike, steps: int, - method: str = "rk45", + method: str, **kwargs, ) -> dict[str, ArrayLike]: if steps <= 0: @@ -263,17 +247,15 @@ def integrate_adaptive( state: dict[str, ArrayLike], start_time: ArrayLike, stop_time: ArrayLike, - min_steps: int = 10, - max_steps: int = 1000, - method: str = "rk45", + min_steps: int, + max_steps: int, + method: str, **kwargs, ) -> dict[str, ArrayLike]: if max_steps <= min_steps: raise ValueError("Maximum number of steps must be greater than minimum number of steps.") match method: - case "euler": - step_fn = euler_step case "rk45": step_fn = rk45_step case "tsit5": @@ -339,7 +321,7 @@ def integrate_scheduled( fn: Callable, state: dict[str, ArrayLike], steps: Tensor | np.ndarray, - method: str = "rk45", + method: str, **kwargs, ) -> dict[str, ArrayLike]: match method: @@ -422,7 +404,7 @@ def integrate( min_steps: int = 10, max_steps: int = 10_000, steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, - method: str = "euler", + method: str = "rk45", **kwargs, ) -> dict[str, ArrayLike]: if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: @@ -480,6 +462,9 @@ def euler_maruyama_step( new_state: Updated state after one Euler-Maruyama step. new_time: time + dt. """ + if use_adaptive_step_size: + raise ValueError("Adaptive step size not supported for Euler method.") + # Compute drift and diffusion drift = drift_fn(time, **filter_kwargs(state, drift_fn)) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 765032c43..4f76cc5da 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -61,12 +61,15 @@ def fn(t, x): analytical_result = 1.0 + T_final**2 result = integrate(fn, initial_state, start_time=0.0, stop_time=T_final, steps=num_steps, method=method)["x"] - result_adaptive = integrate( - fn, initial_state, start_time=0.0, stop_time=T_final, steps="adaptive", method=method, max_steps=1_000 - )["x"] + if method == "euler": + result_adaptive = result + else: + result_adaptive = integrate( + fn, initial_state, start_time=0.0, stop_time=T_final, steps="adaptive", method=method, max_steps=1_000 + )["x"] np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) - np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.01) + np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.1) @pytest.mark.parametrize( From ac22af55e96b6c46ec1a398bd9e865bbb92b9163 Mon Sep 17 00:00:00 2001 From: arrjon Date: Thu, 27 Nov 2025 15:04:45 +0100 Subject: [PATCH 14/36] updated stochastic solvers --- bayesflow/utils/integrate.py | 328 +++++++++++++++++++++-------- tests/test_utils/test_integrate.py | 6 +- 2 files changed, 246 insertions(+), 88 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a79c98b18..bc7c62ae3 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -440,6 +440,7 @@ def euler_maruyama_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], + noise_aux: dict[str, ArrayLike] = None, use_adaptive_step_size: bool = False, min_step_size: ArrayLike = None, max_step_size: ArrayLike = None, @@ -454,6 +455,7 @@ def euler_maruyama_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise (not used here). use_adaptive_step_size: Whether to use adaptive step sizing (not used here). min_step_size: Minimum allowed step size (not used here). max_step_size: Maximum allowed step size (not used here). @@ -483,101 +485,244 @@ def euler_maruyama_step( return new_state, time + step_size, step_size +def sea_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: dict[str, ArrayLike], + time: ArrayLike, + step_size: ArrayLike, + noise: dict[str, ArrayLike], + noise_aux: dict[str, ArrayLike] = None, + use_adaptive_step_size: bool = False, + min_step_size: ArrayLike = None, + max_step_size: ArrayLike = None, +) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + """ + Performs a single shifted Euler step for SDEs with additive noise [1]. + + Compared to Euler-Maruyama, this evaluates the drift at a shifted state, + which improves the local error and the global error constant for additive noise. + + The scheme is + X_{n+1} = X_n + f(t_n, X_n + 0.5 * g(t_n) * ΔW_n) * h + g(t_n) * ΔW_n + + [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise (not used here). + use_adaptive_step_size: Whether to use adaptive step sizing (not used here). + min_step_size: Minimum allowed step size (not used here). + max_step_size: Maximum allowed step size (not used here). + + Returns: + new_state: Updated state after one SEA step. + new_time: time + dt. + """ + if use_adaptive_step_size: + raise ValueError("Adaptive step size not supported for Euler method.") + + # Compute diffusion (assumed additive or weakly state dependent) + diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) + + # Check noise keys + if set(diffusion.keys()) != set(noise.keys()): + raise ValueError("Keys of diffusion terms and noise do not match.") + + # Build shifted state: X_shift = X + 0.5 * g * ΔW + shifted_state = {} + for key, x in state.items(): + if key in diffusion: + shifted_state[key] = x + 0.5 * diffusion[key] * noise[key] + else: + shifted_state[key] = x + + # Drift evaluated at shifted state + drift_shifted = drift_fn(time, **filter_kwargs(shifted_state, drift_fn)) + + # Final update + new_state = {} + for key, d in drift_shifted.items(): + base = state[key] + step_size * d + if key in diffusion: + base = base + diffusion[key] * noise[key] + new_state[key] = base + + return new_state, time + step_size, step_size + + def shark_step( drift_fn: Callable, diffusion_fn: Callable, state: Dict[str, ArrayLike], time: ArrayLike, step_size: ArrayLike, - noise: Dict[str, ArrayLike], - min_step_size: ArrayLike, - max_step_size: ArrayLike, + noise: Dict[str, ArrayLike], # w_k = ΔW_k (already scaled by sqrt(|h|)) + noise_aux: Dict[str, ArrayLike], # Z_k ~ N(0,1), used to build H_k use_adaptive_step_size: bool = False, - tolerance: ArrayLike = 1e-3, - half_noises: Optional[Tuple[Dict[str, ArrayLike], Dict[str, ArrayLike]]] = None, - bridge_aux: Optional[Dict[str, ArrayLike]] = None, - validate_split: bool = True, + min_step_size: ArrayLike = -float("inf"), + max_step_size: ArrayLike = float("inf"), + tolerance: float = 1e-3, ) -> Union[Tuple[Dict[str, ArrayLike], ArrayLike], Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]]: """ - Shifted Additive noise Runge Kutta for additive SDEs. + Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion + per step and has a strong order 1.5. + + SHARK method as specified: + + 1) ỹ_k = y_k + g(y_k) H_k + 2) ỹ_{k+5/6} = ỹ_k + (5/6)[ f(ỹ_k) h + g(ỹ_k) W_k ] + 3) y_{k+1} = y_k + + (2/5) f(ỹ_k) h + + (3/5) f(ỹ_{k+5/6}) h + + g(ỹ_k) ( 2/5 W_k + 6/5 H_k ) + + g(ỹ_{k+5/6}) ( 3/5 W_k - 6/5 H_k ) + + with + H_k = 0.5 * |h| * W_k + (|h| ** 1.5) / (2 * sqrt(3)) * Z_k + + [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) + + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise. + use_adaptive_step_size: Whether to use adaptive step sizing (not used here). + min_step_size: Minimum allowed step size (not used here). + max_step_size: Maximum allowed step size (not used here). + tolerance: Tolerance for adaptive step sizing. + + Returns: + new_state: Updated state after one SHARK step. + new_time: time + dt. """ - # direction aware handling h = step_size t = time - h_sign = keras.ops.sign(h) - h_mag = keras.ops.abs(h) - # full step: midpoint drift, diffusion at midpoint time - k1 = drift_fn(t, **filter_kwargs(state, drift_fn)) - mid_state = {k: state[k] + 0.5 * h * k1[k] for k in state} - k2 = drift_fn(t + 0.5 * h, **filter_kwargs(mid_state, drift_fn)) - g_mid = diffusion_fn(t + 0.5 * h, **filter_kwargs(state, diffusion_fn)) + # Magnitude of the time step for stochastic scaling + h_mag = keras.ops.abs(h) + h_sign = keras.ops.sign(h) + sqrt_h_mag = keras.ops.sqrt(h_mag) + inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(h_mag)) + + # g(y_k) + g0 = diffusion_fn(t, **filter_kwargs(state, diffusion_fn)) + + # Build H_k from w_k and Z_k + H = {} + for k in state.keys(): + if k in g0: + w_k = noise[k] # already scaled by sqrt(|h|) + z_k = noise_aux[k] # standard normal + term1 = 0.5 * h_mag * w_k + term2 = 0.5 * h_mag * sqrt_h_mag * inv_sqrt3 * z_k + H[k] = term1 + term2 + else: + H[k] = keras.ops.zeros_like(state[k]) + + # === 1) shifted initial state === + y_tilde_k = {} + for k in state.keys(): + if k in g0: + y_tilde_k[k] = state[k] + g0[k] * H[k] + else: + y_tilde_k[k] = state[k] + + # === evaluate drift and diffusion at ỹ_k === + f_tilde_k = drift_fn(t, **filter_kwargs(y_tilde_k, drift_fn)) + g_tilde_k = diffusion_fn(t, **filter_kwargs(y_tilde_k, diffusion_fn)) + + # === 2) internal stage at 5/6 === + y_tilde_mid = {} + for k in state.keys(): + drift_part = (5.0 / 6.0) * f_tilde_k[k] * h + if k in g_tilde_k: + sto_part = (5.0 / 6.0) * g_tilde_k[k] * noise[k] + else: + sto_part = keras.ops.zeros_like(state[k]) + y_tilde_mid[k] = y_tilde_k[k] + drift_part + sto_part + + # === evaluate drift and diffusion at ỹ_(k+5/6) === + f_tilde_mid = drift_fn(t + 5.0 / 6.0 * h, **filter_kwargs(y_tilde_mid, drift_fn)) + g_tilde_mid = diffusion_fn(t + 5.0 / 6.0 * h, **filter_kwargs(y_tilde_mid, diffusion_fn)) + + # === 3) final update === + new_state = {} + for k in state.keys(): + # deterministic weights + det = state[k] + (2.0 / 5.0) * f_tilde_k[k] * h + (3.0 / 5.0) * f_tilde_mid[k] * h + + # stochastic parts + sto1 = ( + g_tilde_k[k] * ((2.0 / 5.0) * noise[k] + (6.0 / 5.0) * H[k]) + if k in g_tilde_k + else keras.ops.zeros_like(det) + ) + sto2 = ( + g_tilde_mid[k] * ((3.0 / 5.0) * noise[k] - (6.0 / 5.0) * H[k]) + if k in g_tilde_mid + else keras.ops.zeros_like(det) + ) - det_full = {k: state[k] + h * k2[k] for k in state} - sto_full = {k: g_mid[k] * noise[k] for k in g_mid} - y_full = {k: det_full[k] + sto_full.get(k, keras.ops.zeros_like(det_full[k])) for k in det_full} + new_state[k] = det + sto1 + sto2 if not use_adaptive_step_size: - return y_full, t + h, h - - # prepare two half step noises without drawing randomness here - if half_noises is not None: - dW1, dW2 = half_noises - if set(dW1.keys()) != set(noise.keys()) or set(dW2.keys()) != set(noise.keys()): - raise ValueError("half_noises must have the same keys as noise") - if validate_split: - sum_diff = {k: dW1[k] + dW2[k] - noise[k] for k in noise} - parts = [] - for v in sum_diff.values(): - if not hasattr(v, "shape") or len(v.shape) == 0: - v = keras.ops.reshape(v, (1,)) - parts.append(keras.ops.norm(v, ord=2, axis=-1)) - if float(keras.ops.max(keras.ops.stack(parts))) > 1e-6: - raise ValueError("half_noises do not sum to provided noise") + return new_state, t + h, h + + # embedded lower order solution y_low + # here: one stage strong order one method using y_tilde_k + y_low = {} + for k in state.keys(): + det_low = state[k] + f_tilde_k[k] * h + if k in g0: + sto_low = g0[k] * noise[k] + else: + sto_low = keras.ops.zeros_like(det_low) + y_low[k] = det_low + sto_low + + # error estimate as max over components of RMS norm + err_list = [] + for k in state.keys(): + diff = new_state[k] - y_low[k] + sq = keras.ops.square(diff) + mean_sq = keras.ops.mean(sq) + err_k = keras.ops.sqrt(mean_sq) + err_list.append(err_k) + + if len(err_list) == 0: + err = keras.ops.zeros_like(h_mag) else: - if bridge_aux is None: - raise ValueError("Provide either half_noises or bridge_aux when use_adaptive_step_size is True") - if set(bridge_aux.keys()) != set(noise.keys()): - raise ValueError("bridge_aux must have the same keys as noise") - sqrt_h = keras.ops.sqrt(h_mag + 1e-12) # use magnitude - dW1 = {k: 0.5 * noise[k] + 0.5 * sqrt_h * bridge_aux[k] for k in noise} - dW2 = {k: noise[k] - dW1[k] for k in noise} - - half = 0.5 * h - - # first half step - k1h = drift_fn(t, **filter_kwargs(state, drift_fn)) - mid1 = {k: state[k] + 0.5 * half * k1h[k] for k in state} - k2h = drift_fn(t + 0.5 * half, **filter_kwargs(mid1, drift_fn)) - g_q1 = diffusion_fn(t + 0.5 * half, **filter_kwargs(state, diffusion_fn)) - y_half = {k: state[k] + half * k2h[k] + g_q1.get(k, 0) * dW1.get(k, 0) for k in state} - - # second half step - k1h2 = drift_fn(t + half, **filter_kwargs(y_half, drift_fn)) - mid2 = {k: y_half[k] + 0.5 * half * k1h2[k] for k in y_half} - k2h2 = drift_fn(t + 1.5 * half, **filter_kwargs(mid2, drift_fn)) - g_q2 = diffusion_fn(t + 1.5 * half, **filter_kwargs(state, diffusion_fn)) - y_twohalf = {k: y_half[k] + half * k2h2[k] + g_q2.get(k, 0) * dW2.get(k, 0) for k in y_half} - - # error estimate - parts = [] - for k in y_full: - v = y_full[k] - y_twohalf[k] - if not hasattr(v, "shape") or len(v.shape) == 0: - v = keras.ops.reshape(v, (1,)) - parts.append(keras.ops.norm(v, ord=2, axis=-1)) - err = keras.ops.max(keras.ops.stack(parts)) - - # controller for strong order one on additive noise - factor = 0.9 * (tolerance / (err + 1e-12)) ** (2.0 / 3.0) - h_prop = h * keras.ops.clip(factor, 0.2, 5.0) - - # clip by magnitude bounds then restore original sign - mag = keras.ops.abs(h_prop) - mag_new = keras.ops.clip(mag, min_step_size, max_step_size) - h_new = h_sign * mag_new - - return y_full, t + h, h_new + err = err_list[0] + for e_k in err_list[1:]: + err = keras.ops.maximum(err, e_k) + + tiny = keras.ops.cast(1e12, dtype=keras.ops.dtype(h_mag)) + safety = keras.ops.cast(0.9, dtype=keras.ops.dtype(h_mag)) + # effective order between one and one point five + exponent = keras.ops.cast(0.5, dtype=keras.ops.dtype(h_mag)) + + factor = safety * keras.ops.power(tolerance / (err + tiny), exponent) + + # clamp factor + factor_min = keras.ops.cast(0.2, dtype=keras.ops.dtype(h_mag)) + factor_max = keras.ops.cast(5.0, dtype=keras.ops.dtype(h_mag)) + factor = keras.ops.minimum(keras.ops.maximum(factor, factor_min), factor_max) + + new_h_mag = h_mag * factor + new_h_mag = keras.ops.maximum(new_h_mag, min_step_size) + new_h_mag = keras.ops.minimum(new_h_mag, max_step_size) + + new_h = h_sign * new_h_mag + + return new_state, t + h, new_h def _apply_corrector( @@ -627,6 +772,7 @@ def integrate_stochastic_fixed( stop_time: ArrayLike, steps: int, z_history: Dict[str, ArrayLike], + z_extra_history: Dict[str, ArrayLike], corrector_steps: int, score_fn: Optional[Callable], step_size_factor: float, @@ -650,12 +796,17 @@ def body_fixed(_i, _loop_state): # Generate noise increment scaled by sqrt(dt) sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} + if len(z_extra_history) == 0: + _noise_extra_i = None + else: + _noise_extra_i = {k: z_extra_history[k][_i] for k in _current_state.keys()} new_state, new_time, new_step = step_fn( state=_current_state, time=_current_time, step_size=dt, noise=_noise_i, + noise_aux=_noise_extra_i, use_adaptive_step_size=False, ) @@ -684,7 +835,7 @@ def integrate_stochastic_adaptive( max_steps: int, initial_step: ArrayLike, z_history: Dict[str, ArrayLike], - bridge_history: Dict[str, ArrayLike], + z_extra_history: Dict[str, ArrayLike], corrector_steps: int, score_fn: Optional[Callable], step_size_factor: float, @@ -712,14 +863,17 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} - _bridge = {k: bridge_history[k][_i] for k in _current_state.keys()} + if len(z_extra_history) == 0: + _noise_extra_i = None + else: + _noise_extra_i = {k: z_extra_history[k][_i] for k in _current_state.keys()} new_state, new_time, new_step = step_fn( state=_current_state, time=_current_time, step_size=dt, noise=_noise_i, - bridge_aux=_bridge, + noise_aux=_noise_extra_i, use_adaptive_step_size=True, ) @@ -808,6 +962,8 @@ def integrate_stochastic( match method: case "euler_maruyama": step_fn_raw = euler_maruyama_step + case "sea": + step_fn_raw = sea_step case "shark": step_fn_raw = shark_step case other: @@ -825,13 +981,12 @@ def integrate_stochastic( # Pre-generate standard normals for the predictor step (up to max_steps) z_history = {} - bridge_history = {} + z_extra_history = {} for key, val in state.items(): shape = keras.ops.shape(val) z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - if is_adaptive and method == "shark": - # Only required for SHARK adaptive step (Brownian Bridge aux noise) - bridge_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + if method == "shark": + z_extra_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) # Pre-generate corrector noise if requested corrector_noise_history = {} @@ -853,7 +1008,7 @@ def integrate_stochastic( max_steps=max_steps, initial_step=initial_step, z_history=z_history, - bridge_history=bridge_history, + z_extra_history=z_extra_history, corrector_steps=corrector_steps, score_fn=score_fn, noise_schedule=noise_schedule, @@ -868,6 +1023,7 @@ def integrate_stochastic( stop_time=stop_time, steps=loop_steps, z_history=z_history, + z_extra_history=z_extra_history, corrector_steps=corrector_steps, score_fn=score_fn, noise_schedule=noise_schedule, diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 4f76cc5da..160e7f228 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -76,6 +76,7 @@ def fn(t, x): "method,use_adapt", [ ("euler_maruyama", False), + ("sea", False), ("shark", False), ("shark", True), ], @@ -96,7 +97,7 @@ def test_additive_OU_weak_means_and_vars(method, use_adapt): T = 1.0 # batch of trajectories - N = 20000 # large enough to control sampling error + N = 10000 # large enough to control sampling error seed = keras.random.SeedGenerator(42) def drift_fn(t, x): @@ -136,6 +137,7 @@ def diffusion_fn(t, x): "method,use_adapt", [ ("euler_maruyama", False), + ("sea", False), ("shark", False), ("shark", True), ], @@ -149,7 +151,7 @@ def test_zero_noise_reduces_to_deterministic(method, use_adapt): x0 = 0.9 T = 1.25 steps = 200 if not use_adapt else "adaptive" - seed = keras.random.SeedGenerator(999) + seed = keras.random.SeedGenerator(0) def drift_fn(t, x): return {"x": a * x} From 44570cfe4dd313776ab01eabc0e66388d151bd72 Mon Sep 17 00:00:00 2001 From: arrjon Date: Thu, 27 Nov 2025 16:11:47 +0100 Subject: [PATCH 15/36] add Langevin --- .../diffusion_model/diffusion_model.py | 7 +- bayesflow/utils/__init__.py | 2 +- bayesflow/utils/integrate.py | 146 ++++++++++++++++-- tests/test_utils/test_integrate.py | 73 +++++++++ 4 files changed, 207 insertions(+), 21 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 659641f51..dfd1c28e4 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -16,6 +16,7 @@ integrate_stochastic, logging, tensor_utils, + STOCHASTIC_METHODS, ) from bayesflow.utils.serialization import serialize, deserialize, serializable @@ -408,7 +409,7 @@ def _forward( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: + if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for forward integration.") if density: @@ -458,7 +459,7 @@ def _inverse( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if density: - if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: + if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): @@ -477,7 +478,7 @@ def deltas(time, xz): return x, log_density state = {"xz": z} - if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: + if integrate_kwargs["method"] in STOCHASTIC_METHODS: def deltas(time, xz): return { diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index a8d28a50a..25b7dd920 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -47,7 +47,7 @@ ) from .hparam_utils import find_batch_size, find_memory_budget -from .integrate import integrate, integrate_stochastic +from .integrate import integrate, integrate_stochastic, DETERMINISTIC_METHODS, STOCHASTIC_METHODS from .io import ( pickle_load, diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index bc7c62ae3..dddae5f9c 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -18,6 +18,10 @@ StateDict = Dict[str, ArrayLike] +DETERMINISTIC_METHODS = ["euler", "rk45", "tsit5"] +STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "langevin"] + + def euler_step( fn: Callable, state: dict[str, ArrayLike], @@ -731,11 +735,14 @@ def _apply_corrector( i: ArrayLike, corrector_steps: int, score_fn: Optional[Callable], - step_size_factor: float, corrector_noise_history: Dict[str, ArrayLike], + step_size_factor: float = 0.01, noise_schedule=None, ) -> StateDict: - """Helper function to apply corrector steps.""" + """Helper function to apply corrector steps [1]. + + [1] Song et al., "Score-Based Generative Modeling through Stochastic Differential Equations" (2020) + """ if corrector_steps <= 0: return new_state @@ -773,10 +780,10 @@ def integrate_stochastic_fixed( steps: int, z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], - corrector_steps: int, score_fn: Optional[Callable], step_size_factor: float, corrector_noise_history: Dict[str, ArrayLike], + corrector_steps: int = 0, noise_schedule=None, ) -> StateDict: """ @@ -793,7 +800,7 @@ def body_fixed(_i, _loop_state): dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) dt = sign * dt_mag - # Generate noise increment scaled by sqrt(dt) + # Generate noise increment sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} if len(z_extra_history) == 0: @@ -836,10 +843,10 @@ def integrate_stochastic_adaptive( initial_step: ArrayLike, z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], - corrector_steps: int, score_fn: Optional[Callable], step_size_factor: float, corrector_noise_history: Dict[str, ArrayLike], + corrector_steps: int = 0, noise_schedule=None, ) -> StateDict: """ @@ -896,6 +903,89 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): return final_state +def integrate_langevin( + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + steps: int, + z_history: Dict[str, ArrayLike], + score_fn: Callable, + noise_schedule, + corrector_noise_history: Dict[str, ArrayLike], + step_size_factor: float = 0.01, + corrector_steps: int = 0, +) -> StateDict: + """ + Annealed Langevin dynamics using the given score_fn and noise_schedule [1]. + + At each step i with time t_i, performs for every state component k: + state_k <- state_k + e * score_k + sqrt(2 * e) * z + + Times are stepped linearly from start_time to stop_time. + + [1] Song et al., "Generative Modeling by Estimating Gradients of the Data Distribution" (2020) + """ + + if steps <= 0: + raise ValueError("Number of Langevin steps must be positive.") + if score_fn is None or noise_schedule is None: + raise ValueError("score_fn and noise_schedule must be provided.") + # basic shape check + for k, v in state.items(): + if k not in z_history: + raise ValueError(f"Missing noise for key {k!r} in z_history.") + if keras.ops.shape(z_history[k])[0] < steps: + raise ValueError(f"z_history[{k!r}] has fewer than {steps} steps.") + + # Linear time grid + dt = (stop_time - start_time) / float(steps) + effective_factor = step_size_factor * 100 / np.sqrt(steps) + + def body(_i, loop_state): + current_state, current_time = loop_state + t = current_time + + # score at current time + score = score_fn(t, **filter_kwargs(current_state, score_fn)) + + # noise schedule + log_snr_t = noise_schedule.get_log_snr(t=t, training=False) + _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + new_state: StateDict = {} + for k in current_state.keys(): + s_k = score.get(k, None) + if s_k is None: + new_state[k] = current_state[k] + continue + + e = effective_factor * sigma_t**2 + new_state[k] = current_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history[k][_i] + + new_time = current_time + dt + + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + + return new_state, new_time + + final_state, _ = keras.ops.fori_loop( + 0, + steps, + body, + (state, start_time), + ) + return final_state + + def integrate_stochastic( drift_fn: Callable, diffusion_fn: Callable, @@ -910,7 +1000,7 @@ def integrate_stochastic( score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, - step_size_factor: float = 0.1, + step_size_factor: float = 0.01, **kwargs, ) -> StateDict: """ @@ -938,6 +1028,7 @@ def integrate_stochastic( Returns: Final state dictionary after integration. """ is_adaptive = isinstance(steps, str) and steps in ["adaptive", "dynamic"] + if is_adaptive: if start_time is None or stop_time is None: raise ValueError("Please provide start_time and stop_time for adaptive integration.") @@ -959,6 +1050,17 @@ def integrate_stochastic( # For fixed step, min/max step size are just the fixed step size min_step_size, max_step_size = initial_step, initial_step + # Pre-generate corrector noise if requested + corrector_noise_history = {} + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + for key, val in state.items(): + shape = keras.ops.shape(val) + corrector_noise_history[key] = keras.random.normal( + (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed + ) + match method: case "euler_maruyama": step_fn_raw = euler_maruyama_step @@ -966,6 +1068,27 @@ def integrate_stochastic( step_fn_raw = sea_step case "shark": step_fn_raw = shark_step + case "langevin": + if is_adaptive: + raise ValueError("Langevin sampling does not support adaptive steps.") + + z_history = {} + for key, val in state.items(): + shape = keras.ops.shape(val) + z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + + return integrate_langevin( + state=state, + start_time=start_time, + stop_time=stop_time, + steps=loop_steps, + z_history=z_history, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_steps=corrector_steps, + corrector_noise_history=corrector_noise_history, + ) case other: raise TypeError(f"Invalid integration method: {other!r}") @@ -988,17 +1111,6 @@ def integrate_stochastic( if method == "shark": z_extra_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - # Pre-generate corrector noise if requested - corrector_noise_history = {} - if corrector_steps > 0: - if score_fn is None or noise_schedule is None: - raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") - for key, val in state.items(): - shape = keras.ops.shape(val) - corrector_noise_history[key] = keras.random.normal( - (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed - ) - if is_adaptive: return integrate_stochastic_adaptive( step_fn=step_fn, diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 160e7f228..c846679cf 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -175,3 +175,76 @@ def diffusion_fn(t, x): exact = x0 * np.exp(a * T) np.testing.assert_allclose(np.array(out).mean(), exact, atol=TOL_DET, rtol=0.1) + + +@pytest.mark.parametrize("steps", [500]) +def test_langevin_gaussian_sampling(steps): + """ + Test annealed Langevin dynamics on a 1D Gaussian target. + + Target distribution: N(mu, sigma^2), with score + ∇_x log p(x) = -(x - mu) / sigma^2 + + We verify that the empirical mean and variance after Langevin sampling + match the target within a loose tolerance (to allow for Monte Carlo noise). + """ + # target parameters + mu = 0.3 + sigma = 0.7 + + # number of particles + N = 20000 + start_time = 0.0 + stop_time = 1.0 + + # tolerances for mean and variance + tol_mean = 5e-2 + tol_var = 5e-2 + + # initial state: broad Gaussian, independent of target + seed = keras.random.SeedGenerator(42) + x0 = keras.random.normal((N,), dtype="float32", seed=seed) + initial_state = {"x": x0} + + # simple dummy noise schedule: constant alpha + class DummyNoiseSchedule: + def get_log_snr(self, t, training=False): + return keras.ops.zeros_like(t) + + def get_alpha_sigma(self, log_snr_t): + alpha_t = keras.ops.ones_like(log_snr_t) + sigma_t = keras.ops.ones_like(log_snr_t) + return alpha_t, sigma_t + + noise_schedule = DummyNoiseSchedule() + + # score of the target Gaussian + def score_fn(t, x): + s = -(x - mu) / (sigma**2) + return {"x": s} + + # run Langevin + final_state = integrate_stochastic( + drift_fn=None, + diffusion_fn=None, + score_fn=score_fn, + noise_schedule=noise_schedule, + state=initial_state, + start_time=start_time, + stop_time=stop_time, + steps=steps, + seed=seed, + method="langevin", + max_steps=1_000, + corrector_steps=1, + ) + + xT = np.array(final_state["x"]) + emp_mean = float(xT.mean()) + emp_var = float(xT.var()) + + exp_mean = mu + exp_var = sigma**2 + + np.testing.assert_allclose(emp_mean, exp_mean, atol=tol_mean, rtol=0.0) + np.testing.assert_allclose(emp_var, exp_var, atol=tol_var, rtol=0.0) From 531c6109afc58b70f4be9559d03b2bd241b84a37 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 28 Nov 2025 08:55:32 +0100 Subject: [PATCH 16/36] add Langevin --- .../diffusion_model/diffusion_model.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index dfd1c28e4..0e38ea4f1 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -489,18 +489,17 @@ def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} score_fn = None - if "corrector_steps" in integrate_kwargs: - if integrate_kwargs["corrector_steps"] > 0: - - def score_fn(time, xz): - return { - "xz": self.score( - xz, - time=time, - conditions=conditions, - training=training, - ) - } + if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin": + + def score_fn(time, xz): + return { + "xz": self.score( + xz, + time=time, + conditions=conditions, + training=training, + ) + } state = integrate_stochastic( drift_fn=deltas, From fdeeb2f291c54c14b157bf4b20a7e99701c534a7 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 28 Nov 2025 09:32:21 +0100 Subject: [PATCH 17/36] add adaptive step size --- bayesflow/utils/integrate.py | 228 ++++++++++++++++++----------- tests/test_utils/test_integrate.py | 4 + 2 files changed, 146 insertions(+), 86 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index dddae5f9c..83fc6b74c 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -27,10 +27,8 @@ def euler_step( state: dict[str, ArrayLike], time: ArrayLike, step_size: ArrayLike, - tolerance: ArrayLike = 1e-6, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), use_adaptive_step_size: bool = False, + **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): if use_adaptive_step_size: raise ValueError("Adaptive step size not supported for Euler method.") @@ -437,6 +435,38 @@ def integrate( raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") +def adaptive_step_size_controller(state, drift, adaptive_factor, min_step_size, max_step_size): + """ + Adaptive step size controller based on [1]. + + Adaptive step sizing uses: + h = max(1, ||x||**2) / max(1, ||f(x)||**2) * adaptive_factor + + + [1] Fang & Giles, Adaptive Euler-Maruyama Method for SDEs with Non-Globally Lipschitz Drift Coefficients (2020) + + Returns + ------- + New step size. + """ + state_norms = [] + drift_norms = [] + for key in state.keys(): + state_norms.append(keras.ops.norm(state[key], ord=2, axis=-1)) + drift_norms.append(keras.ops.norm(drift[key], ord=2, axis=-1)) + state_norm = keras.ops.stack(state_norms) + drift_norm = keras.ops.stack(drift_norms) + max_state_norm = keras.ops.maximum( + keras.ops.cast(1.0, dtype=keras.ops.dtype(state_norm)), keras.ops.max(state_norm) ** 2 + ) + max_drift_norm = keras.ops.maximum( + keras.ops.cast(1.0, dtype=keras.ops.dtype(drift_norm)), keras.ops.max(drift_norm) ** 2 + ) + new_step_size = max_state_norm / max_drift_norm * adaptive_factor + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + return new_step_size + + def euler_maruyama_step( drift_fn: Callable, diffusion_fn: Callable, @@ -444,10 +474,11 @@ def euler_maruyama_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], - noise_aux: dict[str, ArrayLike] = None, use_adaptive_step_size: bool = False, - min_step_size: ArrayLike = None, - max_step_size: ArrayLike = None, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), + adaptive_factor: float = 1.0, + **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): """ Performs a single Euler-Maruyama step for stochastic differential equations. @@ -459,18 +490,15 @@ def euler_maruyama_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. - noise_aux: Mapping of variable names to auxiliary noise (not used here). - use_adaptive_step_size: Whether to use adaptive step sizing (not used here). - min_step_size: Minimum allowed step size (not used here). - max_step_size: Maximum allowed step size (not used here). + use_adaptive_step_size: Whether to use adaptive step sizing. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. + adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). Returns: new_state: Updated state after one Euler-Maruyama step. new_time: time + dt. """ - if use_adaptive_step_size: - raise ValueError("Adaptive step size not supported for Euler method.") - # Compute drift and diffusion drift = drift_fn(time, **filter_kwargs(state, drift_fn)) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) @@ -486,7 +514,17 @@ def euler_maruyama_step( base = base + diffusion[key] * noise[key] new_state[key] = base - return new_state, time + step_size, step_size + new_step_size = step_size + if use_adaptive_step_size: + new_step_size = adaptive_step_size_controller( + state=state, + drift=drift, + adaptive_factor=adaptive_factor, + min_step_size=min_step_size, + max_step_size=max_step_size, + ) + + return new_state, time + step_size, new_step_size def sea_step( @@ -496,10 +534,11 @@ def sea_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], - noise_aux: dict[str, ArrayLike] = None, use_adaptive_step_size: bool = False, - min_step_size: ArrayLike = None, - max_step_size: ArrayLike = None, + min_step_size: ArrayLike = -float("inf"), + max_step_size: ArrayLike = float("inf"), + adaptive_factor: ArrayLike = 1.0, + **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): """ Performs a single shifted Euler step for SDEs with additive noise [1]. @@ -518,18 +557,15 @@ def sea_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. - noise_aux: Mapping of variable names to auxiliary noise (not used here). - use_adaptive_step_size: Whether to use adaptive step sizing (not used here). - min_step_size: Minimum allowed step size (not used here). - max_step_size: Maximum allowed step size (not used here). + use_adaptive_step_size: Whether to use adaptive step sizing. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. + adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). Returns: new_state: Updated state after one SEA step. new_time: time + dt. """ - if use_adaptive_step_size: - raise ValueError("Adaptive step size not supported for Euler method.") - # Compute diffusion (assumed additive or weakly state dependent) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) @@ -556,7 +592,17 @@ def sea_step( base = base + diffusion[key] * noise[key] new_state[key] = base - return new_state, time + step_size, step_size + new_step_size = step_size + if use_adaptive_step_size: + new_step_size = adaptive_step_size_controller( + state=state, + drift=drift_shifted, + adaptive_factor=adaptive_factor, + min_step_size=min_step_size, + max_step_size=max_step_size, + ) + + return new_state, time + step_size, new_step_size def shark_step( @@ -570,7 +616,7 @@ def shark_step( use_adaptive_step_size: bool = False, min_step_size: ArrayLike = -float("inf"), max_step_size: ArrayLike = float("inf"), - tolerance: float = 1e-3, + adaptive_factor: ArrayLike = 1.0, ) -> Union[Tuple[Dict[str, ArrayLike], ArrayLike], Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]]: """ Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion @@ -599,10 +645,10 @@ def shark_step( step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. noise_aux: Mapping of variable names to auxiliary noise. - use_adaptive_step_size: Whether to use adaptive step sizing (not used here). - min_step_size: Minimum allowed step size (not used here). - max_step_size: Maximum allowed step size (not used here). - tolerance: Tolerance for adaptive step sizing. + use_adaptive_step_size: Whether to use adaptive step sizing. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. + adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). Returns: new_state: Updated state after one SHARK step. @@ -613,7 +659,7 @@ def shark_step( # Magnitude of the time step for stochastic scaling h_mag = keras.ops.abs(h) - h_sign = keras.ops.sign(h) + # h_sign = keras.ops.sign(h) sqrt_h_mag = keras.ops.sqrt(h_mag) inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(h_mag)) @@ -678,55 +724,67 @@ def shark_step( new_state[k] = det + sto1 + sto2 - if not use_adaptive_step_size: - return new_state, t + h, h + # if not use_adaptive_step_size: + # return new_state, t + h, h - # embedded lower order solution y_low - # here: one stage strong order one method using y_tilde_k - y_low = {} - for k in state.keys(): - det_low = state[k] + f_tilde_k[k] * h - if k in g0: - sto_low = g0[k] * noise[k] - else: - sto_low = keras.ops.zeros_like(det_low) - y_low[k] = det_low + sto_low - - # error estimate as max over components of RMS norm - err_list = [] - for k in state.keys(): - diff = new_state[k] - y_low[k] - sq = keras.ops.square(diff) - mean_sq = keras.ops.mean(sq) - err_k = keras.ops.sqrt(mean_sq) - err_list.append(err_k) - - if len(err_list) == 0: - err = keras.ops.zeros_like(h_mag) - else: - err = err_list[0] - for e_k in err_list[1:]: - err = keras.ops.maximum(err, e_k) - - tiny = keras.ops.cast(1e12, dtype=keras.ops.dtype(h_mag)) - safety = keras.ops.cast(0.9, dtype=keras.ops.dtype(h_mag)) - # effective order between one and one point five - exponent = keras.ops.cast(0.5, dtype=keras.ops.dtype(h_mag)) - - factor = safety * keras.ops.power(tolerance / (err + tiny), exponent) - - # clamp factor - factor_min = keras.ops.cast(0.2, dtype=keras.ops.dtype(h_mag)) - factor_max = keras.ops.cast(5.0, dtype=keras.ops.dtype(h_mag)) - factor = keras.ops.minimum(keras.ops.maximum(factor, factor_min), factor_max) - - new_h_mag = h_mag * factor - new_h_mag = keras.ops.maximum(new_h_mag, min_step_size) - new_h_mag = keras.ops.minimum(new_h_mag, max_step_size) - - new_h = h_sign * new_h_mag + new_step_size = h + if use_adaptive_step_size: + new_step_size = adaptive_step_size_controller( + state=state, + drift=f_tilde_k, + adaptive_factor=adaptive_factor, + min_step_size=min_step_size, + max_step_size=max_step_size, + ) - return new_state, t + h, new_h + return new_state, t + h, new_step_size + + # # embedded lower order solution y_low + # # here: one stage strong order one method using y_tilde_k + # y_low = {} + # for k in state.keys(): + # det_low = state[k] + f_tilde_k[k] * h + # if k in g0: + # sto_low = g0[k] * noise[k] + # else: + # sto_low = keras.ops.zeros_like(det_low) + # y_low[k] = det_low + sto_low + # + # # error estimate as max over components of RMS norm + # err_list = [] + # for k in state.keys(): + # diff = new_state[k] - y_low[k] + # sq = keras.ops.square(diff) + # mean_sq = keras.ops.mean(sq) + # err_k = keras.ops.sqrt(mean_sq) + # err_list.append(err_k) + # + # if len(err_list) == 0: + # err = keras.ops.zeros_like(h_mag) + # else: + # err = err_list[0] + # for e_k in err_list[1:]: + # err = keras.ops.maximum(err, e_k) + # + # tiny = keras.ops.cast(1e12, dtype=keras.ops.dtype(h_mag)) + # safety = keras.ops.cast(0.9, dtype=keras.ops.dtype(h_mag)) + # # effective order between one and one point five + # exponent = keras.ops.cast(0.5, dtype=keras.ops.dtype(h_mag)) + # + # factor = safety * keras.ops.power(tolerance / (err + tiny), exponent) + # + # # clamp factor + # factor_min = keras.ops.cast(0.2, dtype=keras.ops.dtype(h_mag)) + # factor_max = keras.ops.cast(5.0, dtype=keras.ops.dtype(h_mag)) + # factor = keras.ops.minimum(keras.ops.maximum(factor, factor_min), factor_max) + # + # new_h_mag = h_mag * factor + # new_h_mag = keras.ops.maximum(new_h_mag, min_step_size) + # new_h_mag = keras.ops.minimum(new_h_mag, max_step_size) + # + # new_h = h_sign * new_h_mag + # + # return new_state, t + h, new_h def _apply_corrector( @@ -736,7 +794,7 @@ def _apply_corrector( corrector_steps: int, score_fn: Optional[Callable], corrector_noise_history: Dict[str, ArrayLike], - step_size_factor: float = 0.01, + step_size_factor: ArrayLike = 0.01, noise_schedule=None, ) -> StateDict: """Helper function to apply corrector steps [1]. @@ -764,7 +822,7 @@ def _apply_corrector( score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) score_norm = keras.ops.maximum(score_norm, 1e-8) - # Compute step size 'e' for the Langevin update + # Compute step size for the Langevin update e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 # Annealed Langevin Dynamics update @@ -781,7 +839,7 @@ def integrate_stochastic_fixed( z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], score_fn: Optional[Callable], - step_size_factor: float, + step_size_factor: ArrayLike, corrector_noise_history: Dict[str, ArrayLike], corrector_steps: int = 0, noise_schedule=None, @@ -844,7 +902,7 @@ def integrate_stochastic_adaptive( z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], score_fn: Optional[Callable], - step_size_factor: float, + step_size_factor: ArrayLike, corrector_noise_history: Dict[str, ArrayLike], corrector_steps: int = 0, noise_schedule=None, @@ -912,7 +970,7 @@ def integrate_langevin( score_fn: Callable, noise_schedule, corrector_noise_history: Dict[str, ArrayLike], - step_size_factor: float = 0.01, + step_size_factor: ArrayLike = 0.01, corrector_steps: int = 0, ) -> StateDict: """ @@ -1000,7 +1058,7 @@ def integrate_stochastic( score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, - step_size_factor: float = 0.01, + step_size_factor: ArrayLike = 0.01, **kwargs, ) -> StateDict: """ @@ -1034,8 +1092,6 @@ def integrate_stochastic( raise ValueError("Please provide start_time and stop_time for adaptive integration.") if min_steps <= 0 or max_steps <= 0 or max_steps < min_steps: raise ValueError("min_steps and max_steps must be positive, and max_steps >= min_steps.") - if method != "shark": - raise ValueError("Adaptive step size is only supported for the 'shark' method.") loop_steps = max_steps initial_step = (stop_time - start_time) / float(min_steps) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index c846679cf..ba286d857 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -76,7 +76,9 @@ def fn(t, x): "method,use_adapt", [ ("euler_maruyama", False), + ("euler_maruyama", True), ("sea", False), + ("sea", True), ("shark", False), ("shark", True), ], @@ -137,7 +139,9 @@ def diffusion_fn(t, x): "method,use_adapt", [ ("euler_maruyama", False), + ("euler_maruyama", True), ("sea", False), + ("sea", True), ("shark", False), ("shark", True), ], From f45c2dcf69991b0e4c63cd33251e8df1e7f31e69 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 29 Nov 2025 11:15:06 +0100 Subject: [PATCH 18/36] tune adaptive step size --- bayesflow/utils/integrate.py | 70 +++++++----------------------------- 1 file changed, 13 insertions(+), 57 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 83fc6b74c..984fd3267 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -435,7 +435,13 @@ def integrate( raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") -def adaptive_step_size_controller(state, drift, adaptive_factor, min_step_size, max_step_size): +def adaptive_step_size_controller( + state, + drift, + adaptive_factor: ArrayLike = 1.0, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), +) -> ArrayLike: """ Adaptive step size controller based on [1]. @@ -477,7 +483,7 @@ def euler_maruyama_step( use_adaptive_step_size: bool = False, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), - adaptive_factor: float = 1.0, + adaptive_factor: float = 0.1, **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): """ @@ -493,7 +499,7 @@ def euler_maruyama_step( use_adaptive_step_size: Whether to use adaptive step sizing. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). + adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one Euler-Maruyama step. @@ -537,7 +543,7 @@ def sea_step( use_adaptive_step_size: bool = False, min_step_size: ArrayLike = -float("inf"), max_step_size: ArrayLike = float("inf"), - adaptive_factor: ArrayLike = 1.0, + adaptive_factor: ArrayLike = 0.1, **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): """ @@ -560,7 +566,7 @@ def sea_step( use_adaptive_step_size: Whether to use adaptive step sizing. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). + adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one SEA step. @@ -616,7 +622,7 @@ def shark_step( use_adaptive_step_size: bool = False, min_step_size: ArrayLike = -float("inf"), max_step_size: ArrayLike = float("inf"), - adaptive_factor: ArrayLike = 1.0, + adaptive_factor: ArrayLike = 0.1, ) -> Union[Tuple[Dict[str, ArrayLike], ArrayLike], Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]]: """ Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion @@ -648,7 +654,7 @@ def shark_step( use_adaptive_step_size: Whether to use adaptive step sizing. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). + adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one SHARK step. @@ -724,9 +730,6 @@ def shark_step( new_state[k] = det + sto1 + sto2 - # if not use_adaptive_step_size: - # return new_state, t + h, h - new_step_size = h if use_adaptive_step_size: new_step_size = adaptive_step_size_controller( @@ -739,53 +742,6 @@ def shark_step( return new_state, t + h, new_step_size - # # embedded lower order solution y_low - # # here: one stage strong order one method using y_tilde_k - # y_low = {} - # for k in state.keys(): - # det_low = state[k] + f_tilde_k[k] * h - # if k in g0: - # sto_low = g0[k] * noise[k] - # else: - # sto_low = keras.ops.zeros_like(det_low) - # y_low[k] = det_low + sto_low - # - # # error estimate as max over components of RMS norm - # err_list = [] - # for k in state.keys(): - # diff = new_state[k] - y_low[k] - # sq = keras.ops.square(diff) - # mean_sq = keras.ops.mean(sq) - # err_k = keras.ops.sqrt(mean_sq) - # err_list.append(err_k) - # - # if len(err_list) == 0: - # err = keras.ops.zeros_like(h_mag) - # else: - # err = err_list[0] - # for e_k in err_list[1:]: - # err = keras.ops.maximum(err, e_k) - # - # tiny = keras.ops.cast(1e12, dtype=keras.ops.dtype(h_mag)) - # safety = keras.ops.cast(0.9, dtype=keras.ops.dtype(h_mag)) - # # effective order between one and one point five - # exponent = keras.ops.cast(0.5, dtype=keras.ops.dtype(h_mag)) - # - # factor = safety * keras.ops.power(tolerance / (err + tiny), exponent) - # - # # clamp factor - # factor_min = keras.ops.cast(0.2, dtype=keras.ops.dtype(h_mag)) - # factor_max = keras.ops.cast(5.0, dtype=keras.ops.dtype(h_mag)) - # factor = keras.ops.minimum(keras.ops.maximum(factor, factor_min), factor_max) - # - # new_h_mag = h_mag * factor - # new_h_mag = keras.ops.maximum(new_h_mag, min_step_size) - # new_h_mag = keras.ops.minimum(new_h_mag, max_step_size) - # - # new_h = h_sign * new_h_mag - # - # return new_state, t + h, new_h - def _apply_corrector( new_state: StateDict, From 9fd77074325a9f50a74e990ed19b8a8aa76dc08b Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 30 Nov 2025 17:18:40 +0100 Subject: [PATCH 19/36] add Gotta Go Fast SDE sampler --- bayesflow/utils/integrate.py | 327 ++++++++++++++++++++--------- tests/test_utils/test_integrate.py | 109 ++++++++-- 2 files changed, 315 insertions(+), 121 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 984fd3267..c9f951982 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -19,7 +19,7 @@ DETERMINISTIC_METHODS = ["euler", "rk45", "tsit5"] -STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "langevin"] +STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "langevin", "fast_adaptive"] def euler_step( @@ -27,12 +27,8 @@ def euler_step( state: dict[str, ArrayLike], time: ArrayLike, step_size: ArrayLike, - use_adaptive_step_size: bool = False, **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): - if use_adaptive_step_size: - raise ValueError("Adaptive step size not supported for Euler method.") - k1 = fn(time, **filter_kwargs(state, fn)) new_state = state.copy() @@ -82,10 +78,6 @@ def rk45_step( **add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h), ) - # check all keys are equal - if not all(set(k.keys()) == set(k1.keys()) for k in [k2, k3, k4, k5, k6]): - raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") - # 5th order solution new_state = {} for key in k1.keys(): @@ -262,6 +254,8 @@ def integrate_adaptive( step_fn = rk45_step case "tsit5": step_fn = tsit5_step + case "euler": + raise ValueError("Adaptive step sizing is not supported for the 'euler' method.") case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: @@ -438,12 +432,12 @@ def integrate( def adaptive_step_size_controller( state, drift, - adaptive_factor: ArrayLike = 1.0, + adaptive_factor: ArrayLike, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), ) -> ArrayLike: """ - Adaptive step size controller based on [1]. + Adaptive step size controller based on [1]. Similar to a tamed explicit Euler method when used in Euler-Maruyama. Adaptive step sizing uses: h = max(1, ||x||**2) / max(1, ||f(x)||**2) * adaptive_factor @@ -483,9 +477,12 @@ def euler_maruyama_step( use_adaptive_step_size: bool = False, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), - adaptive_factor: float = 0.1, + adaptive_factor: float = 0.01, **kwargs, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): +) -> Union[ + Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike], + Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike, Dict[str, ArrayLike]], +]: """ Performs a single Euler-Maruyama step for stochastic differential equations. @@ -509,19 +506,9 @@ def euler_maruyama_step( drift = drift_fn(time, **filter_kwargs(state, drift_fn)) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) - # Check noise keys - if set(diffusion.keys()) != set(noise.keys()): - raise ValueError("Keys of diffusion terms and noise do not match.") - - new_state = {} - for key, d in drift.items(): - base = state[key] + step_size * d - if key in diffusion: # stochastic update - base = base + diffusion[key] * noise[key] - new_state[key] = base - new_step_size = step_size if use_adaptive_step_size: + sign_step = keras.ops.sign(step_size) new_step_size = adaptive_step_size_controller( state=state, drift=drift, @@ -529,8 +516,168 @@ def euler_maruyama_step( min_step_size=min_step_size, max_step_size=max_step_size, ) + new_step_size = sign_step * keras.ops.abs(new_step_size) + + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(new_step_size)) + + new_state = {} + for key, d in drift.items(): + base = state[key] + new_step_size * d + if key in diffusion: + base = base + diffusion[key] * sqrt_step_size * noise[key] + new_state[key] = base + + if use_adaptive_step_size: + return new_state, time + new_step_size, new_step_size, state + return new_state, time + new_step_size, new_step_size + + +def fast_adaptive_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: dict[str, ArrayLike], + time: ArrayLike, + step_size: ArrayLike, + noise: dict[str, ArrayLike], + last_state: dict[str, ArrayLike] = None, + use_adaptive_step_size: bool = True, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), + e_abs: float = 0.01, + e_rel: float = 0.01, + r: float = 0.9, + adapt_safety: float = 0.9, + **kwargs, +) -> Union[ + Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike], + Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike, Dict[str, ArrayLike]], +]: + """ + Performs a single adaptive step for stochastic differential equations based on [1]. + + Based on + + This method uses a predictor-corrector approach with error estimation: + 1. Take an Euler-Maruyama step (predictor) + 2. Take another Euler-Maruyama step from the predicted state + 3. Average the two predictions (corrector) + 4. Estimate error and adapt step size + + When step_size reaches min_step_size, steps are always accepted regardless of + error to ensure progress and termination within max_steps. + + [1] Jolicoeur-Martineau et al. (2021) "Gotta Go Fast When Generating Data with Score-Based Models" + + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors (pre-scaled by sqrt(dt)). + last_state: Previous state for error estimation. + use_adaptive_step_size: Whether to adapt step size. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. + e_abs: Absolute error tolerance. + e_rel: Relative error tolerance. + r: Order of the method for step size adaptation. + adapt_safety: Safety factor for step size adaptation. + **kwargs: Additional arguments passed to drift_fn and diffusion_fn. + + Returns: + new_state: Updated state after one adaptive step. + new_time: time + dt (or time if step rejected). + new_step_size: Adapted step size for next iteration. + """ + state_euler, time_mid, _ = euler_maruyama_step( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=state, + time=time, + step_size=step_size, + min_step_size=min_step_size, + max_step_size=max_step_size, + noise=noise, + use_adaptive_step_size=False, + ) + + # Compute drift and diffusion at new state, but update from old state + drift_mid = drift_fn(time_mid, **filter_kwargs(state_euler, drift_fn)) + diffusion_mid = diffusion_fn(time_mid, **filter_kwargs(state_euler, diffusion_fn)) + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) + + state_euler_mid = {} + for key, d in drift_mid.items(): + base = state[key] + step_size * d + if key in diffusion_mid: + base = base + diffusion_mid[key] * sqrt_step_size * noise[key] + state_euler_mid[key] = base + + # average the two predictions + state_heun = {} + for key in state.keys(): + state_heun[key] = 0.5 * (state_euler[key] + state_euler_mid[key]) + + # Error estimation + if use_adaptive_step_size: + # Check if we're at minimum step size - if so, force acceptance + at_min_step = keras.ops.less_equal(step_size, min_step_size) + + # Compute error tolerance for each component + e_abs_tensor = keras.ops.cast(e_abs, dtype=keras.ops.dtype(list(state.values())[0])) + e_rel_tensor = keras.ops.cast(e_rel, dtype=keras.ops.dtype(list(state.values())[0])) + + max_error = keras.ops.cast(0.0, dtype=keras.ops.dtype(list(state.values())[0])) + + for key in state.keys(): + # Local error estimate: difference between Heun and first Euler step + error_estimate = keras.ops.abs(state_heun[key] - state_euler[key]) - return new_state, time + step_size, new_step_size + # Tolerance threshold + delta = keras.ops.maximum( + e_abs_tensor, + e_rel_tensor * keras.ops.maximum(keras.ops.abs(state_euler[key]), keras.ops.abs(last_state[key])), + ) + + # Normalized error + normalized_error = error_estimate / (delta + 1e-10) + + # Maximum error across all components and batch dimensions + component_max_error = keras.ops.max(normalized_error) + max_error = keras.ops.maximum(max_error, component_max_error) + + error_scale = 1 # 1/sqrt(n_params) + E2 = error_scale * max_error + + # Accept step if error is acceptable OR if at minimum step size + error_acceptable = keras.ops.less_equal(E2, keras.ops.cast(1.0, dtype=keras.ops.dtype(E2))) + accepted = keras.ops.logical_or(error_acceptable, at_min_step) + + # Adapt step size for next iteration (only if not at minimum) + # Ensure E2 is not zero to avoid division issues + E2_safe = keras.ops.maximum(E2, 1e-10) + + # New step size based on error estimate + adapt_factor = adapt_safety * keras.ops.power(E2_safe, -r) + new_step_candidate = step_size * adapt_factor + + # Clamp to valid range + sign_step = keras.ops.sign(step_size) + new_step_size = keras.ops.minimum(keras.ops.maximum(new_step_candidate, min_step_size), max_step_size) + new_step_size = sign_step * keras.ops.abs(new_step_size) + + # Return appropriate state based on acceptance + new_state = keras.ops.cond(accepted, lambda: state_heun, lambda: state) + + new_time = keras.ops.cond(accepted, lambda: time_mid, lambda: time) + + prev_state = keras.ops.cond(accepted, lambda: state_euler, lambda: state) + + return new_state, new_time, new_step_size, prev_state + + else: + return state_heun, time_mid, step_size def sea_step( @@ -540,12 +687,8 @@ def sea_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], - use_adaptive_step_size: bool = False, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - adaptive_factor: ArrayLike = 0.1, **kwargs, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): +) -> Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: """ Performs a single shifted Euler step for SDEs with additive noise [1]. @@ -563,10 +706,6 @@ def sea_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. - use_adaptive_step_size: Whether to use adaptive step sizing. - min_step_size: Minimum allowed step size. - max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one SEA step. @@ -574,16 +713,13 @@ def sea_step( """ # Compute diffusion (assumed additive or weakly state dependent) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) - - # Check noise keys - if set(diffusion.keys()) != set(noise.keys()): - raise ValueError("Keys of diffusion terms and noise do not match.") + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) # Build shifted state: X_shift = X + 0.5 * g * ΔW shifted_state = {} for key, x in state.items(): if key in diffusion: - shifted_state[key] = x + 0.5 * diffusion[key] * noise[key] + shifted_state[key] = x + 0.5 * diffusion[key] * sqrt_step_size * noise[key] else: shifted_state[key] = x @@ -595,20 +731,10 @@ def sea_step( for key, d in drift_shifted.items(): base = state[key] + step_size * d if key in diffusion: - base = base + diffusion[key] * noise[key] + base = base + diffusion[key] * sqrt_step_size * noise[key] new_state[key] = base - new_step_size = step_size - if use_adaptive_step_size: - new_step_size = adaptive_step_size_controller( - state=state, - drift=drift_shifted, - adaptive_factor=adaptive_factor, - min_step_size=min_step_size, - max_step_size=max_step_size, - ) - - return new_state, time + step_size, new_step_size + return new_state, time + step_size, step_size def shark_step( @@ -617,13 +743,10 @@ def shark_step( state: Dict[str, ArrayLike], time: ArrayLike, step_size: ArrayLike, - noise: Dict[str, ArrayLike], # w_k = ΔW_k (already scaled by sqrt(|h|)) - noise_aux: Dict[str, ArrayLike], # Z_k ~ N(0,1), used to build H_k - use_adaptive_step_size: bool = False, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - adaptive_factor: ArrayLike = 0.1, -) -> Union[Tuple[Dict[str, ArrayLike], ArrayLike], Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]]: + noise: Dict[str, ArrayLike], + noise_aux: Dict[str, ArrayLike], + **kwargs, +) -> Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: """ Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion per step and has a strong order 1.5. @@ -651,10 +774,6 @@ def shark_step( step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. noise_aux: Mapping of variable names to auxiliary noise. - use_adaptive_step_size: Whether to use adaptive step sizing. - min_step_size: Minimum allowed step size. - max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one SHARK step. @@ -676,7 +795,7 @@ def shark_step( H = {} for k in state.keys(): if k in g0: - w_k = noise[k] # already scaled by sqrt(|h|) + w_k = sqrt_h_mag * noise[k] z_k = noise_aux[k] # standard normal term1 = 0.5 * h_mag * w_k term2 = 0.5 * h_mag * sqrt_h_mag * inv_sqrt3 * z_k @@ -701,7 +820,7 @@ def shark_step( for k in state.keys(): drift_part = (5.0 / 6.0) * f_tilde_k[k] * h if k in g_tilde_k: - sto_part = (5.0 / 6.0) * g_tilde_k[k] * noise[k] + sto_part = (5.0 / 6.0) * g_tilde_k[k] * sqrt_h_mag * noise[k] else: sto_part = keras.ops.zeros_like(state[k]) y_tilde_mid[k] = y_tilde_k[k] + drift_part + sto_part @@ -718,29 +837,19 @@ def shark_step( # stochastic parts sto1 = ( - g_tilde_k[k] * ((2.0 / 5.0) * noise[k] + (6.0 / 5.0) * H[k]) + g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * H[k]) if k in g_tilde_k else keras.ops.zeros_like(det) ) sto2 = ( - g_tilde_mid[k] * ((3.0 / 5.0) * noise[k] - (6.0 / 5.0) * H[k]) + g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * H[k]) if k in g_tilde_mid else keras.ops.zeros_like(det) ) new_state[k] = det + sto1 + sto2 - new_step_size = h - if use_adaptive_step_size: - new_step_size = adaptive_step_size_controller( - state=state, - drift=f_tilde_k, - adaptive_factor=adaptive_factor, - min_step_size=min_step_size, - max_step_size=max_step_size, - ) - - return new_state, t + h, new_step_size + return new_state, t + h, h def _apply_corrector( @@ -792,6 +901,8 @@ def integrate_stochastic_fixed( start_time: ArrayLike, stop_time: ArrayLike, steps: int, + min_step_size: ArrayLike, + max_step_size: ArrayLike, z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], score_fn: Optional[Callable], @@ -809,14 +920,13 @@ def body_fixed(_i, _loop_state): _current_state, _current_time, _current_step = _loop_state # Determine step size: either the constant size or the remainder to reach stop_time - remaining = stop_time - _current_time - sign = keras.ops.sign(remaining) - dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) + remaining = keras.ops.abs(stop_time - _current_time) + sign = keras.ops.sign(_current_step) + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) dt = sign * dt_mag # Generate noise increment - sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) - _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} + _noise_i = {k: z_history[k][_i] for k in _current_state.keys()} if len(z_extra_history) == 0: _noise_extra_i = None else: @@ -826,6 +936,8 @@ def body_fixed(_i, _loop_state): state=_current_state, time=_current_time, step_size=dt, + min_step_size=min_step_size, + max_step_size=keras.ops.minimum(max_step_size, remaining), noise=_noise_i, noise_aux=_noise_extra_i, use_adaptive_step_size=False, @@ -854,6 +966,8 @@ def integrate_stochastic_adaptive( start_time: ArrayLike, stop_time: ArrayLike, max_steps: int, + min_step_size: ArrayLike, + max_step_size: ArrayLike, initial_step: ArrayLike, z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], @@ -866,33 +980,34 @@ def integrate_stochastic_adaptive( """ Performs adaptive-step SDE integration. """ - initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0) + initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0, state) - def cond(i, current_state, current_time, current_step, counter): - # We use a small epsilon check for floating point equality - time_reached = keras.ops.all(keras.ops.isclose(current_time, stop_time)) - return keras.ops.logical_and(keras.ops.less(i, max_steps), keras.ops.logical_not(time_reached)) + def cond(i, current_state, current_time, current_step, counter, last_state): + # time remaining after the next step + time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step)) + return keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) - def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): + def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _last_state): # Step Size Control - remaining = stop_time - _current_time - sign = keras.ops.sign(remaining) + remaining = keras.ops.abs(stop_time - _current_time) + sign = keras.ops.sign(_current_step) # Ensure the next step does not overshoot the stop_time - dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) dt = sign * dt_mag _counter += 1 - sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) - _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} - if len(z_extra_history) == 0: - _noise_extra_i = None - else: + _noise_i = {k: z_history[k][_i] for k in _current_state.keys()} + _noise_extra_i = None + if len(z_extra_history) > 0: _noise_extra_i = {k: z_extra_history[k][_i] for k in _current_state.keys()} - new_state, new_time, new_step = step_fn( + new_state, new_time, new_step, _new_current_state = step_fn( state=_current_state, + last_state=_last_state, time=_current_time, step_size=dt, + min_step_size=min_step_size, + max_step_size=keras.ops.minimum(max_step_size, remaining), noise=_noise_i, noise_aux=_noise_extra_i, use_adaptive_step_size=True, @@ -909,10 +1024,10 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): corrector_noise_history=corrector_noise_history, ) - return _i + 1, new_state, new_time, new_step, _counter + return _i + 1, new_state, new_time, new_step, _counter, _new_current_state # Execute the adaptive loop - _, final_state, _, _, final_counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + _, final_state, _, _, final_counter, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) logging.debug("Finished integration after {} steps.", final_counter) return final_state @@ -1009,7 +1124,7 @@ def integrate_stochastic( seed: keras.random.SeedGenerator, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", - min_steps: int = 10, + min_steps: int = 20, max_steps: int = 10_000, score_fn: Callable = None, corrector_steps: int = 0, @@ -1078,8 +1193,14 @@ def integrate_stochastic( step_fn_raw = euler_maruyama_step case "sea": step_fn_raw = sea_step + if is_adaptive: + raise ValueError("SEA SDE solver does not support adaptive steps.") case "shark": step_fn_raw = shark_step + if is_adaptive: + raise ValueError("SHARK SDE solver does not support adaptive steps.") + case "fast_adaptive": + step_fn_raw = fast_adaptive_step case "langevin": if is_adaptive: raise ValueError("Langevin sampling does not support adaptive steps.") @@ -1109,8 +1230,6 @@ def integrate_stochastic( step_fn_raw, drift_fn=drift_fn, diffusion_fn=diffusion_fn, - min_step_size=min_step_size, - max_step_size=max_step_size, **kwargs, ) @@ -1130,6 +1249,8 @@ def integrate_stochastic( start_time=start_time, stop_time=stop_time, max_steps=max_steps, + min_step_size=min_step_size, + max_step_size=max_step_size, initial_step=initial_step, z_history=z_history, z_extra_history=z_extra_history, @@ -1145,6 +1266,8 @@ def integrate_stochastic( state=state, start_time=start_time, stop_time=stop_time, + min_step_size=min_step_size, + max_step_size=max_step_size, steps=loop_steps, z_history=z_history, z_extra_history=z_extra_history, diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index ba286d857..8a6de502c 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -78,42 +78,44 @@ def fn(t, x): ("euler_maruyama", False), ("euler_maruyama", True), ("sea", False), - ("sea", True), ("shark", False), - ("shark", True), + ("fast_adaptive", False), + ("fast_adaptive", True), ], ) -def test_additive_OU_weak_means_and_vars(method, use_adapt): +def test_forward_additive_ou_weak_means_and_vars(method, use_adapt): """ - Ornstein Uhlenbeck with additive noise + Ornstein-Uhlenbeck with additive noise, integrated FORWARD in time. + This serves as a sanity check that forward integration still works correctly. + + Forward SDE: dX = a X dt + sigma dW - Exact at time T: - E[X_T] = x0 * exp(a T) - Var[X_T] = sigma^2 * (exp(2 a T) - 1) / (2 a) - We verify weak accuracy by matching empirical mean and variance. + + Exact at time T starting from X(0) = x_0: + E[X(T)] = x_0 * exp(a T) + Var[X(T)] = sigma^2 * (exp(2 a T) - 1) / (2 a) """ # SDE parameters a = -1.0 sigma = 0.5 - x0 = 1.2 + x_0 = 1.2 # initial condition at time 0 T = 1.0 # batch of trajectories - N = 10000 # large enough to control sampling error + N = 10000 seed = keras.random.SeedGenerator(42) def drift_fn(t, x): return {"x": a * x} def diffusion_fn(t, x): - # additive noise, independent of state return {"x": keras.ops.convert_to_tensor([sigma])} - initial_state = {"x": keras.ops.ones((N,)) * x0} + initial_state = {"x": keras.ops.ones((N,)) * x_0} steps = 200 if not use_adapt else "adaptive" - # expected mean and variance - exp_mean = x0 * np.exp(a * T) + # Expected mean and variance at t=T + exp_mean = x_0 * np.exp(a * T) exp_var = sigma**2 * (np.exp(2.0 * a * T) - 1.0) / (2.0 * a) out = integrate_stochastic( @@ -128,9 +130,78 @@ def diffusion_fn(t, x): max_steps=1_000, ) - xT = np.array(out["x"]) - emp_mean = float(xT.mean()) - emp_var = float(xT.var()) + x_T = np.array(out["x"]) + emp_mean = float(x_T.mean()) + emp_var = float(x_T.var()) + + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN, rtol=0.0) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR, rtol=0.0) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("euler_maruyama", True), + ("sea", False), + ("shark", False), + ("fast_adaptive", False), + ("fast_adaptive", True), + ], +) +def test_backward_additive_ou_weak_means_and_vars(method, use_adapt): + """ + Ornstein-Uhlenbeck with additive noise, integrated BACKWARD in time. + + When integrating from t=T back to t=0 with initial condition X(T) = x_T, + we get X(0) which should satisfy: + E[X(0)] = x_T * exp(-a T) (-a because we go backward) + Var[X(0)] = sigma^2 * (exp(-2 a T) - 1) / (-2 a) + + We verify weak accuracy by matching empirical mean and variance. + """ + # SDE parameters + a = -1.0 + sigma = 0.5 + x_T = 1.2 # initial condition at time T + T = 1.0 + + # batch of trajectories + N = 10000 # large enough to control sampling error + seed = keras.random.SeedGenerator(42) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + # additive noise, independent of state + return {"x": keras.ops.convert_to_tensor([sigma])} + + # Start at time T with value x_T + initial_state = {"x": keras.ops.ones((N,)) * x_T} + steps = 200 if not use_adapt else "adaptive" + + # Expected mean and variance at t=0 after integrating backward from t=T + # For backward integration, the effective drift coefficient changes sign + exp_mean = x_T * np.exp(-a * T) + exp_var = sigma**2 * (np.exp(-2.0 * a * T) - 1.0) / (-2.0 * a) + + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=T, + stop_time=0.0, + steps=steps, + seed=seed, + method=method, + max_steps=1_000, + ) + + x_0 = np.array(out["x"]) + emp_mean = float(x_0.mean()) + emp_var = float(x_0.var()) + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN, rtol=0.0) np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR, rtol=0.0) @@ -141,9 +212,9 @@ def diffusion_fn(t, x): ("euler_maruyama", False), ("euler_maruyama", True), ("sea", False), - ("sea", True), ("shark", False), - ("shark", True), + ("fast_adaptive", False), + ("fast_adaptive", True), ], ) def test_zero_noise_reduces_to_deterministic(method, use_adapt): From 5c5abd361efc34267f05805838b8c1522ffaa45d Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 30 Nov 2025 18:33:45 +0100 Subject: [PATCH 20/36] improve adaptive ODE samplers --- bayesflow/utils/integrate.py | 358 ++++++++++++++++------------- tests/test_utils/test_integrate.py | 27 ++- 2 files changed, 220 insertions(+), 165 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index c9f951982..f0e049094 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -24,11 +24,11 @@ def euler_step( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, **kwargs, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): +) -> Tuple[StateDict, ArrayLike, None, ArrayLike]: k1 = fn(time, **filter_kwargs(state, fn)) new_state = state.copy() @@ -36,7 +36,7 @@ def euler_step( new_state[key] = state[key] + step_size * k1[key] new_time = time + step_size - return new_state, new_time, step_size + return new_state, new_time, None, 0.0 def add_scaled(state, ks, coeffs, h): @@ -51,21 +51,19 @@ def add_scaled(state, ks, coeffs, h): def rk45_step( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, - last_step_size: ArrayLike, - tolerance: ArrayLike = 1e-6, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - use_adaptive_step_size: bool = False, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + step_size: ArrayLike, + k1: StateDict = None, + use_adaptive_step_size: bool = True, +) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: """ Dormand-Prince 5(4) method with embedded error estimation. """ - step_size = last_step_size h = step_size - k1 = fn(time, **filter_kwargs(state, fn)) + if k1 is None: # reuse k1 if available + k1 = fn(time, **filter_kwargs(state, fn)) k2 = fn(time + h * (1 / 5), **add_scaled(state, [k1], [1 / 5], h)) k3 = fn(time + h * (3 / 10), **add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h)) k4 = fn(time + h * (4 / 5), **add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h)) @@ -85,48 +83,42 @@ def rk45_step( 35 / 384 * k1[key] + 500 / 1113 * k3[key] + 125 / 192 * k4[key] - 2187 / 6784 * k5[key] + 11 / 84 * k6[key] ) - if use_adaptive_step_size: - k7 = fn(time + h, **filter_kwargs(new_state, fn)) - - # 4th order embedded solution - err_state = {} - for key in k1.keys(): - y4 = state[key] + h * ( - 5179 / 57600 * k1[key] - + 7571 / 16695 * k3[key] - + 393 / 640 * k4[key] - - 92097 / 339200 * k5[key] - + 187 / 2100 * k6[key] - + 1 / 40 * k7[key] - ) - err_state[key] = new_state[key] - y4 + new_time = time + h + if not use_adaptive_step_size: + return new_state, new_time, None, 0.0 - err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) - err = keras.ops.max(err_norm) + k7 = fn(time + h, **filter_kwargs(new_state, fn)) - new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) - else: - new_step_size = step_size + # 4th order embedded solution + err_state = {} + for key in k1.keys(): + y4 = state[key] + h * ( + 5179 / 57600 * k1[key] + + 7571 / 16695 * k3[key] + + 393 / 640 * k4[key] + - 92097 / 339200 * k5[key] + + 187 / 2100 * k6[key] + + 1 / 40 * k7[key] + ) + err_state[key] = new_state[key] - y4 - new_time = time + h - return new_state, new_time, new_step_size + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) + + return new_state, new_time, k7, err def tsit5_step( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, - last_step_size: ArrayLike, - tolerance: ArrayLike = 1e-6, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - use_adaptive_step_size: bool = False, -): + step_size: ArrayLike, + k1: StateDict = None, + use_adaptive_step_size: bool = True, +) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: """ Implements a single step of the Tsitouras 5/4 Runge-Kutta method. """ - step_size = last_step_size h = step_size # Butcher tableau coefficients @@ -135,7 +127,8 @@ def tsit5_step( c4 = 0.9 c5 = 0.9800255409045097 - k1 = fn(time, **filter_kwargs(state, fn)) + if k1 is None: # reuse k1 if available + k1 = fn(time, **filter_kwargs(state, fn)) k2 = fn(time + h * c2, **add_scaled(state, [k1], [0.161], h)) k3 = fn(time + h * c3, **add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h)) k4 = fn( @@ -169,42 +162,39 @@ def tsit5_step( + 2.324710524099774 * k6[key] ) - if use_adaptive_step_size: - k7 = fn(time + h, **filter_kwargs(new_state, fn)) + new_time = time + h + if not use_adaptive_step_size: + return new_state, new_time, None, 0.0 - err_state = {} - for key in state.keys(): - err_state[key] = h * ( - -0.00178001105222577714 * k1[key] - - 0.0008164344596567469 * k2[key] - + 0.007880878010261995 * k3[key] - - 0.1447110071732629 * k4[key] - + 0.5823571654525552 * k5[key] - - 0.45808210592918697 * k6[key] - + 0.015151515151515152 * k7[key] - ) + k7 = fn(time + h, **filter_kwargs(new_state, fn)) - err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) - err = keras.ops.max(err_norm) + err_state = {} + for key in state.keys(): + err_state[key] = h * ( + -0.00178001105222577714 * k1[key] + - 0.0008164344596567469 * k2[key] + + 0.007880878010261995 * k3[key] + - 0.1447110071732629 * k4[key] + + 0.5823571654525552 * k5[key] + - 0.45808210592918697 * k6[key] + + 0.015151515151515152 * k7[key] + ) - new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) - else: - new_step_size = step_size + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) - new_time = time + h - return new_state, new_time, new_step_size + return new_state, new_time, k7, err def integrate_fixed( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, steps: int, method: str, **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: if steps <= 0: raise ValueError("Number of steps must be positive.") @@ -227,7 +217,7 @@ def integrate_fixed( def body(_loop_var, _loop_state): _state, _time = _loop_state - _state, _time, _ = step_fn(_state, _time, step_size) + _state, _time, _, _ = step_fn(_state, _time, step_size) return _state, _time @@ -236,6 +226,37 @@ def body(_loop_var, _loop_state): return state +def integrate_scheduled( + fn: Callable, + state: StateDict, + steps: Tensor | np.ndarray, + method: str, + **kwargs, +) -> StateDict: + match method: + case "euler": + step_fn = euler_step + case "rk45": + step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step + case str() as name: + raise ValueError(f"Unknown integration method name: {name!r}") + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + + def body(_loop_var, _loop_state): + _time = steps[_loop_var] + step_size = steps[_loop_var + 1] - steps[_loop_var] + _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size) + return _loop_state + + state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + return state + + def integrate_adaptive( fn: Callable, state: dict[str, ArrayLike], @@ -261,98 +282,106 @@ def integrate_adaptive( case other: raise TypeError(f"Invalid integration method: {other!r}") + tolerance = keras.ops.convert_to_tensor(kwargs.get("tolerance", 1e-6), dtype="float32") step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=True) - def cond(_state, _time, _step_size, _step): - # while step < min_steps or time_remaining > 0 and step < max_steps + # Initial (conservative) step size guess + total_time = stop_time - start_time + step_size0 = keras.ops.convert_to_tensor(total_time / max_steps, dtype="float32") - # time remaining after the next step - time_remaining = keras.ops.abs(stop_time - (_time + _step_size)) + # Track step count as scalar tensor + step0 = keras.ops.convert_to_tensor(0.0, dtype="float32") + count_not_accepted = 0 + + # "First Same As Last" (FSAL) property + k1_0 = fn(start_time, **filter_kwargs(state, fn)) + + def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted): + time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (_time + _step_size)) + step_lt_min = keras.ops.less(_step, float(min_steps)) + step_lt_max = keras.ops.less(_step, float(max_steps)) return keras.ops.logical_or( - keras.ops.all(_step < min_steps), - keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.all(_step < max_steps)), + step_lt_min, + keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max), ) - def body(_state, _time, _step_size, _step): - _step = _step + 1 - - # time remaining after the next step - time_remaining = stop_time - (_time + _step_size) + def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): + # Time remaining from current point + time_remaining = stop_time - _time + # Per-step min/max step sizes (like original code) min_step_size = time_remaining / (max_steps - _step) max_step_size = time_remaining / keras.ops.maximum(min_steps - _step, 1.0) - # reorder - min_step_size, max_step_size = ( - keras.ops.minimum(min_step_size, max_step_size), - keras.ops.maximum(min_step_size, max_step_size), - ) - - _state, _time, _step_size = step_fn( - _state, _time, _step_size, min_step_size=min_step_size, max_step_size=max_step_size + # Ensure ordering: min_step_size <= max_step_size + lower = keras.ops.minimum(min_step_size, max_step_size) + upper = keras.ops.maximum(min_step_size, max_step_size) + min_step_size = lower + max_step_size = upper + h = keras.ops.clip(_step_size, min_step_size, max_step_size) + + # Take one trial step + new_state, new_time, new_k1, err = step_fn( + state=_state, + time=_time, + step_size=h, + k1=_k1, ) - return _state, _time, _step_size, _step - - # select initial step size conservatively - step_size = (stop_time - start_time) / max_steps - - step = 0 - time = start_time - - state, time, step_size, step = keras.ops.while_loop(cond, body, [state, time, step_size, step]) - - # do the last step - step_size = stop_time - time - state, _, _ = step_fn(state, time, step_size) - step = step + 1 - - logging.debug("Finished integration after {} steps.", step) + new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) - return state + # Error control: reject if err > tolerance + too_big = keras.ops.greater(err, tolerance) + at_min = keras.ops.less_equal( + keras.ops.abs(h), + keras.ops.abs(min_step_size), + ) + accepted = keras.ops.logical_or(keras.ops.logical_not(too_big), at_min) + updated_state = keras.ops.cond(accepted, lambda: new_state, lambda: _state) + updated_time = keras.ops.cond(accepted, lambda: new_time, lambda: _time) + updated_k1 = keras.ops.cond(accepted, lambda: new_k1, lambda: _k1) -def integrate_scheduled( - fn: Callable, - state: dict[str, ArrayLike], - steps: Tensor | np.ndarray, - method: str, - **kwargs, -) -> dict[str, ArrayLike]: - match method: - case "euler": - step_fn = euler_step - case "rk45": - step_fn = rk45_step - case "tsit5": - step_fn = tsit5_step - case str() as name: - raise ValueError(f"Unknown integration method name: {name!r}") - case other: - raise TypeError(f"Invalid integration method: {other!r}") + # Step counter: increment only on accepted steps + updated_step = _step + keras.ops.where(accepted, 1.0, 0.0) + _count_not_accepted = _count_not_accepted + 1 if not accepted else _count_not_accepted - step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + # For the next iteration, always use the new suggested step size + return updated_state, updated_time, new_step_size, updated_step, updated_k1, _count_not_accepted - def body(_loop_var, _loop_state): - _time = steps[_loop_var] - step_size = steps[_loop_var + 1] - steps[_loop_var] + # Run the adaptive loop + state, time, step_size, step, k1, count_not_accepted = keras.ops.while_loop( + cond, + body, + [state, start_time, step_size0, step0, k1_0, count_not_accepted], + ) - _loop_state, _, _ = step_fn(_loop_state, _time, step_size) - return _loop_state + # Final step to hit stop_time exactly + time_diff = stop_time - time + time_remaining = keras.ops.sign(stop_time - start_time) * time_diff + if keras.ops.all(time_remaining > 0): + state, time, _, _ = step_fn( + state=state, + time=time, + step_size=time_diff, + k1=k1, + ) + step = step + 1.0 - state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + logging.debug(f"Finished integration after {step} steps with {count_not_accepted} rejected steps.") return state def integrate_scipy( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, scipy_kwargs: dict | None = None, **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: import scipy.integrate scipy_kwargs = scipy_kwargs or {} @@ -394,7 +423,7 @@ def scipy_wrapper_fn(time, x): def integrate( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike | None = None, stop_time: ArrayLike | None = None, min_steps: int = 10, @@ -402,7 +431,7 @@ def integrate( steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, method: str = "rk45", **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: if start_time is None or stop_time is None: raise ValueError( @@ -429,7 +458,10 @@ def integrate( raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") -def adaptive_step_size_controller( +############ SDE Solvers ############# + + +def stochastic_adaptive_step_size_controller( state, drift, adaptive_factor: ArrayLike, @@ -470,19 +502,16 @@ def adaptive_step_size_controller( def euler_maruyama_step( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: dict[str, ArrayLike], + noise: StateDict, use_adaptive_step_size: bool = False, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), adaptive_factor: float = 0.01, **kwargs, -) -> Union[ - Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike], - Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike, Dict[str, ArrayLike]], -]: +) -> Union[Tuple[StateDict, ArrayLike, ArrayLike], Tuple[StateDict, ArrayLike, ArrayLike, StateDict]]: """ Performs a single Euler-Maruyama step for stochastic differential equations. @@ -509,7 +538,7 @@ def euler_maruyama_step( new_step_size = step_size if use_adaptive_step_size: sign_step = keras.ops.sign(step_size) - new_step_size = adaptive_step_size_controller( + new_step_size = stochastic_adaptive_step_size_controller( state=state, drift=drift, adaptive_factor=adaptive_factor, @@ -535,11 +564,11 @@ def euler_maruyama_step( def fast_adaptive_step( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: dict[str, ArrayLike], - last_state: dict[str, ArrayLike] = None, + noise: StateDict, + last_state: StateDict = None, use_adaptive_step_size: bool = True, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), @@ -549,8 +578,8 @@ def fast_adaptive_step( adapt_safety: float = 0.9, **kwargs, ) -> Union[ - Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike], - Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike, Dict[str, ArrayLike]], + Tuple[StateDict, ArrayLike, ArrayLike], + Tuple[StateDict, ArrayLike, ArrayLike, StateDict], ]: """ Performs a single adaptive step for stochastic differential equations based on [1]. @@ -683,12 +712,12 @@ def fast_adaptive_step( def sea_step( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: dict[str, ArrayLike], + noise: StateDict, **kwargs, -) -> Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: +) -> Tuple[StateDict, ArrayLike, ArrayLike]: """ Performs a single shifted Euler step for SDEs with additive noise [1]. @@ -740,13 +769,13 @@ def sea_step( def shark_step( drift_fn: Callable, diffusion_fn: Callable, - state: Dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: Dict[str, ArrayLike], - noise_aux: Dict[str, ArrayLike], + noise: StateDict, + noise_aux: StateDict, **kwargs, -) -> Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: +) -> Tuple[StateDict, ArrayLike, ArrayLike]: """ Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion per step and has a strong order 1.5. @@ -858,7 +887,7 @@ def _apply_corrector( i: ArrayLike, corrector_steps: int, score_fn: Optional[Callable], - corrector_noise_history: Dict[str, ArrayLike], + corrector_noise_history: StateDict, step_size_factor: ArrayLike = 0.01, noise_schedule=None, ) -> StateDict: @@ -903,11 +932,11 @@ def integrate_stochastic_fixed( steps: int, min_step_size: ArrayLike, max_step_size: ArrayLike, - z_history: Dict[str, ArrayLike], - z_extra_history: Dict[str, ArrayLike], + z_history: StateDict, + z_extra_history: StateDict, score_fn: Optional[Callable], step_size_factor: ArrayLike, - corrector_noise_history: Dict[str, ArrayLike], + corrector_noise_history: StateDict, corrector_steps: int = 0, noise_schedule=None, ) -> StateDict: @@ -969,11 +998,11 @@ def integrate_stochastic_adaptive( min_step_size: ArrayLike, max_step_size: ArrayLike, initial_step: ArrayLike, - z_history: Dict[str, ArrayLike], - z_extra_history: Dict[str, ArrayLike], + z_history: StateDict, + z_extra_history: StateDict, score_fn: Optional[Callable], step_size_factor: ArrayLike, - corrector_noise_history: Dict[str, ArrayLike], + corrector_noise_history: StateDict, corrector_steps: int = 0, noise_schedule=None, ) -> StateDict: @@ -1027,8 +1056,9 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _l return _i + 1, new_state, new_time, new_step, _counter, _new_current_state # Execute the adaptive loop - _, final_state, _, _, final_counter, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) - logging.debug("Finished integration after {} steps.", final_counter) + _, final_state, final_time, _, final_counter, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + + logging.debug(f"Finished integration after {final_counter} steps at {final_time}.") return final_state @@ -1037,10 +1067,10 @@ def integrate_langevin( start_time: ArrayLike, stop_time: ArrayLike, steps: int, - z_history: Dict[str, ArrayLike], + z_history: StateDict, score_fn: Callable, noise_schedule, - corrector_noise_history: Dict[str, ArrayLike], + corrector_noise_history: StateDict, step_size_factor: ArrayLike = 0.01, corrector_steps: int = 0, ) -> StateDict: diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 8a6de502c..3f83a3a2d 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -56,7 +56,7 @@ def fn(t, x): return {"x": keras.ops.convert_to_tensor([2.0 * t])} initial_state = {"x": keras.ops.convert_to_tensor([1.0])} - T_final = 2.0 + T_final = 1.0 num_steps = 100 analytical_result = 1.0 + T_final**2 @@ -72,6 +72,31 @@ def fn(t, x): np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.1) +@pytest.mark.parametrize( + "method, atol", [("euler", TOLERANCE_EULER), ("rk45", TOLERANCE_ADAPTIVE), ("tsit5", TOLERANCE_ADAPTIVE)] +) +def test_analytical_backward_integration(method, atol): + T_final = 1.0 + + def fn(t, x): + return {"x": keras.ops.convert_to_tensor([2.0 * t])} + + num_steps = 100 + analytical_result = 1.0 + initial_state = {"x": keras.ops.convert_to_tensor([1.0 + T_final**2])} + + result = integrate(fn, initial_state, start_time=T_final, stop_time=0.0, steps=num_steps, method=method)["x"] + if method == "euler": + result_adaptive = result + else: + result_adaptive = integrate( + fn, initial_state, start_time=T_final, stop_time=0.0, steps="adaptive", method=method, max_steps=1_000 + )["x"] + + np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) + np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.1) + + @pytest.mark.parametrize( "method,use_adapt", [ From dd021bb782cf97ae359bd91de08d244495fec07e Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 30 Nov 2025 18:57:53 +0100 Subject: [PATCH 21/36] fix schedule test --- tests/test_utils/test_integrate.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 3f83a3a2d..8d4a06d7f 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -13,23 +13,20 @@ TOL_DET = 1e-3 -def test_scheduled_integration(): - import keras - from bayesflow.utils import integrate - +@pytest.mark.parametrize("method", ["euler", "rk45", "tsit5"]) +def test_scheduled_integration(method): def fn(t, x): return {"x": t**2} - steps = keras.ops.convert_to_tensor([0.0, 0.5, 1.0]) - approximate_result = 0.0 + 0.5**2 * 0.5 - result = integrate(fn, {"x": 0.0}, steps=steps)["x"] - assert result == approximate_result + def analytical_result(t): + return (t**3) / 3.0 + steps = keras.ops.arange(0.0, 1.0 + 1e-6, 0.01) + result = integrate(fn, {"x": 0.0}, steps=steps, method=method)["x"] + np.testing.assert_allclose(result, analytical_result(steps[-1]), atol=1e-1, rtol=1e-1) -def test_scipy_integration(): - import keras - from bayesflow.utils import integrate +def test_scipy_integration(): def fn(t, x): return {"x": keras.ops.exp(t)} From 1fe2c60212b90686cc0997ce1b87c4bab17b6fe3 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 1 Dec 2025 10:39:08 +0100 Subject: [PATCH 22/36] improved defaults --- bayesflow/utils/integrate.py | 84 ++++++++++++++++-------------- tests/test_utils/test_integrate.py | 13 ++--- 2 files changed, 53 insertions(+), 44 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index f0e049094..e299fef7b 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -19,7 +19,7 @@ DETERMINISTIC_METHODS = ["euler", "rk45", "tsit5"] -STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "langevin", "fast_adaptive"] +STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"] def euler_step( @@ -509,7 +509,6 @@ def euler_maruyama_step( use_adaptive_step_size: bool = False, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), - adaptive_factor: float = 0.01, **kwargs, ) -> Union[Tuple[StateDict, ArrayLike, ArrayLike], Tuple[StateDict, ArrayLike, ArrayLike, StateDict]]: """ @@ -525,7 +524,6 @@ def euler_maruyama_step( use_adaptive_step_size: Whether to use adaptive step sizing. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one Euler-Maruyama step. @@ -541,7 +539,7 @@ def euler_maruyama_step( new_step_size = stochastic_adaptive_step_size_controller( state=state, drift=drift, - adaptive_factor=adaptive_factor, + adaptive_factor=max_step_size, min_step_size=min_step_size, max_step_size=max_step_size, ) @@ -561,7 +559,7 @@ def euler_maruyama_step( return new_state, time + new_step_size, new_step_size -def fast_adaptive_step( +def two_step_adaptive_step( drift_fn: Callable, diffusion_fn: Callable, state: StateDict, @@ -572,8 +570,8 @@ def fast_adaptive_step( use_adaptive_step_size: bool = True, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), - e_abs: float = 0.01, - e_rel: float = 0.01, + e_rel: float = 0.1, + e_abs: float = None, r: float = 0.9, adapt_safety: float = 0.9, **kwargs, @@ -608,8 +606,8 @@ def fast_adaptive_step( use_adaptive_step_size: Whether to adapt step size. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. - e_abs: Absolute error tolerance. e_rel: Relative error tolerance. + e_abs: Absolute error tolerance. Default assumes standardized targets. r: Order of the method for step size adaptation. adapt_safety: Safety factor for step size adaptation. **kwargs: Additional arguments passed to drift_fn and diffusion_fn. @@ -650,6 +648,8 @@ def fast_adaptive_step( # Error estimation if use_adaptive_step_size: + if e_abs is None: + e_abs = 0.02576 # 1% of 99% CI of standardized unit variance # Check if we're at minimum step size - if so, force acceptance at_min_step = keras.ops.less_equal(step_size, min_step_size) @@ -709,13 +709,33 @@ def fast_adaptive_step( return state_heun, time_mid, step_size +def compute_levy_area( + state: StateDict, diffusion: StateDict, noise: StateDict, noise_aux: StateDict, step_size: ArrayLike +) -> StateDict: + step_size_abs = keras.ops.abs(step_size) + sqrt_step_size = keras.ops.sqrt(step_size_abs) + inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(step_size_abs)) + + # Build Lévy area H_k from w_k and Z_k + H = {} + for k in state.keys(): + if k in diffusion: + term1 = 0.5 * step_size_abs * noise[k] + term2 = 0.5 * step_size_abs * sqrt_step_size * inv_sqrt3 * noise_aux[k] + H[k] = term1 + term2 + else: + H[k] = keras.ops.zeros_like(state[k]) + return H + + def sea_step( drift_fn: Callable, diffusion_fn: Callable, state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: StateDict, + noise: StateDict, # standard normals + noise_aux: StateDict, # standard normals **kwargs, ) -> Tuple[StateDict, ArrayLike, ArrayLike]: """ @@ -725,7 +745,7 @@ def sea_step( which improves the local error and the global error constant for additive noise. The scheme is - X_{n+1} = X_n + f(t_n, X_n + 0.5 * g(t_n) * ΔW_n) * h + g(t_n) * ΔW_n + X_{n+1} = X_n + f(t_n, X_n + g(t_n) * (0.5 * ΔW_n + ΔH_n) * h + g(t_n) * ΔW_n [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) Args: @@ -735,20 +755,23 @@ def sea_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise. Returns: new_state: Updated state after one SEA step. new_time: time + dt. """ - # Compute diffusion (assumed additive or weakly state dependent) + # Compute diffusion diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) - # Build shifted state: X_shift = X + 0.5 * g * ΔW + la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size) + + # Build shifted state: X_shift = X + g * (0.5 * ΔW + ΔH) shifted_state = {} for key, x in state.items(): if key in diffusion: - shifted_state[key] = x + 0.5 * diffusion[key] * sqrt_step_size * noise[key] + shifted_state[key] = x + diffusion[key] * (0.5 * sqrt_step_size * noise[key] + la[key]) else: shifted_state[key] = x @@ -810,33 +833,18 @@ def shark_step( """ h = step_size t = time - - # Magnitude of the time step for stochastic scaling h_mag = keras.ops.abs(h) - # h_sign = keras.ops.sign(h) sqrt_h_mag = keras.ops.sqrt(h_mag) - inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(h_mag)) - # g(y_k) - g0 = diffusion_fn(t, **filter_kwargs(state, diffusion_fn)) + diffusion = diffusion_fn(t, **filter_kwargs(state, diffusion_fn)) - # Build H_k from w_k and Z_k - H = {} - for k in state.keys(): - if k in g0: - w_k = sqrt_h_mag * noise[k] - z_k = noise_aux[k] # standard normal - term1 = 0.5 * h_mag * w_k - term2 = 0.5 * h_mag * sqrt_h_mag * inv_sqrt3 * z_k - H[k] = term1 + term2 - else: - H[k] = keras.ops.zeros_like(state[k]) + la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size) # === 1) shifted initial state === y_tilde_k = {} for k in state.keys(): - if k in g0: - y_tilde_k[k] = state[k] + g0[k] * H[k] + if k in diffusion: + y_tilde_k[k] = state[k] + diffusion[k] * la[k] else: y_tilde_k[k] = state[k] @@ -866,12 +874,12 @@ def shark_step( # stochastic parts sto1 = ( - g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * H[k]) + g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * la[k]) if k in g_tilde_k else keras.ops.zeros_like(det) ) sto2 = ( - g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * H[k]) + g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * la[k]) if k in g_tilde_mid else keras.ops.zeros_like(det) ) @@ -1154,7 +1162,7 @@ def integrate_stochastic( seed: keras.random.SeedGenerator, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", - min_steps: int = 20, + min_steps: int = 10, max_steps: int = 10_000, score_fn: Callable = None, corrector_steps: int = 0, @@ -1229,8 +1237,8 @@ def integrate_stochastic( step_fn_raw = shark_step if is_adaptive: raise ValueError("SHARK SDE solver does not support adaptive steps.") - case "fast_adaptive": - step_fn_raw = fast_adaptive_step + case "two_step_adaptive": + step_fn_raw = two_step_adaptive_step case "langevin": if is_adaptive: raise ValueError("Langevin sampling does not support adaptive steps.") @@ -1269,7 +1277,7 @@ def integrate_stochastic( for key, val in state.items(): shape = keras.ops.shape(val) z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - if method == "shark": + if method in ["sea", "shark"]: z_extra_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) if is_adaptive: diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 8d4a06d7f..ceaa5851c 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -101,8 +101,8 @@ def fn(t, x): ("euler_maruyama", True), ("sea", False), ("shark", False), - ("fast_adaptive", False), - ("fast_adaptive", True), + ("two_step_adaptive", False), + ("two_step_adaptive", True), ], ) def test_forward_additive_ou_weak_means_and_vars(method, use_adapt): @@ -167,8 +167,8 @@ def diffusion_fn(t, x): ("euler_maruyama", True), ("sea", False), ("shark", False), - ("fast_adaptive", False), - ("fast_adaptive", True), + ("two_step_adaptive", False), + ("two_step_adaptive", True), ], ) def test_backward_additive_ou_weak_means_and_vars(method, use_adapt): @@ -218,6 +218,7 @@ def diffusion_fn(t, x): seed=seed, method=method, max_steps=1_000, + min_steps=100, ) x_0 = np.array(out["x"]) @@ -235,8 +236,8 @@ def diffusion_fn(t, x): ("euler_maruyama", True), ("sea", False), ("shark", False), - ("fast_adaptive", False), - ("fast_adaptive", True), + ("two_step_adaptive", False), + ("two_step_adaptive", True), ], ) def test_zero_noise_reduces_to_deterministic(method, use_adapt): From a771e3230b76d079ebc3d257b752cc16c94c0b92 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 1 Dec 2025 10:43:43 +0100 Subject: [PATCH 23/36] improved defaults --- bayesflow/utils/integrate.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index e299fef7b..34217908f 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -906,10 +906,6 @@ def _apply_corrector( if corrector_steps <= 0: return new_state - # Ensures score_fn and noise_schedule are present if needed, though checked in integrate_stochastic - if score_fn is None or noise_schedule is None: - return new_state # Should not happen if checks are passed - for j in range(corrector_steps): score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) _z_corr = {k: corrector_noise_history[k][i, j] for k in new_state.keys()} From a7adea2bcb9cb086139153630ff4acb3aff9e7e6 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 1 Dec 2025 21:26:24 +0100 Subject: [PATCH 24/36] improved initial step size --- bayesflow/utils/integrate.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 34217908f..b8ed29afb 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -284,12 +284,7 @@ def integrate_adaptive( tolerance = keras.ops.convert_to_tensor(kwargs.get("tolerance", 1e-6), dtype="float32") step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=True) - - # Initial (conservative) step size guess - total_time = stop_time - start_time - step_size0 = keras.ops.convert_to_tensor(total_time / max_steps, dtype="float32") - - # Track step count as scalar tensor + initial_step = (stop_time - start_time) / float(min_steps) step0 = keras.ops.convert_to_tensor(0.0, dtype="float32") count_not_accepted = 0 @@ -308,18 +303,10 @@ def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted): def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): # Time remaining from current point - time_remaining = stop_time - _time - - # Per-step min/max step sizes (like original code) + time_remaining = keras.ops.abs(stop_time - _time) min_step_size = time_remaining / (max_steps - _step) max_step_size = time_remaining / keras.ops.maximum(min_steps - _step, 1.0) - - # Ensure ordering: min_step_size <= max_step_size - lower = keras.ops.minimum(min_step_size, max_step_size) - upper = keras.ops.maximum(min_step_size, max_step_size) - min_step_size = lower - max_step_size = upper - h = keras.ops.clip(_step_size, min_step_size, max_step_size) + h = keras.ops.sign(_step_size) * keras.ops.clip(keras.ops.abs(_step_size), min_step_size, max_step_size) # Take one trial step new_state, new_time, new_k1, err = step_fn( @@ -330,7 +317,9 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): ) new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + new_step_size = keras.ops.sign(new_step_size) * keras.ops.clip( + keras.ops.abs(new_step_size), min_step_size, max_step_size + ) # Error control: reject if err > tolerance too_big = keras.ops.greater(err, tolerance) @@ -355,7 +344,7 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): state, time, step_size, step, k1, count_not_accepted = keras.ops.while_loop( cond, body, - [state, start_time, step_size0, step0, k1_0, count_not_accepted], + [state, start_time, initial_step, step0, k1_0, count_not_accepted], ) # Final step to hit stop_time exactly From 08853fbe73fbd55c39b118d69f92a76706ded506 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 1 Dec 2025 21:59:58 +0100 Subject: [PATCH 25/36] improved initial step size --- bayesflow/utils/integrate.py | 9 ++++----- tests/test_utils/test_integrate.py | 1 - 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b8ed29afb..d947b54c0 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -640,7 +640,7 @@ def two_step_adaptive_step( if e_abs is None: e_abs = 0.02576 # 1% of 99% CI of standardized unit variance # Check if we're at minimum step size - if so, force acceptance - at_min_step = keras.ops.less_equal(step_size, min_step_size) + at_min_step = keras.ops.less_equal(keras.ops.abs(step_size), min_step_size) # Compute error tolerance for each component e_abs_tensor = keras.ops.cast(e_abs, dtype=keras.ops.dtype(list(state.values())[0])) @@ -681,9 +681,8 @@ def two_step_adaptive_step( new_step_candidate = step_size * adapt_factor # Clamp to valid range - sign_step = keras.ops.sign(step_size) - new_step_size = keras.ops.minimum(keras.ops.maximum(new_step_candidate, min_step_size), max_step_size) - new_step_size = sign_step * keras.ops.abs(new_step_size) + new_step_size = keras.ops.clip(keras.ops.abs(new_step_candidate), min_step_size, max_step_size) + new_step_size = keras.ops.sign(step_size) * new_step_size # Return appropriate state based on acceptance new_state = keras.ops.cond(accepted, lambda: state_heun, lambda: state) @@ -1147,7 +1146,7 @@ def integrate_stochastic( seed: keras.random.SeedGenerator, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", - min_steps: int = 10, + min_steps: int = 50, max_steps: int = 10_000, score_fn: Callable = None, corrector_steps: int = 0, diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index ceaa5851c..78adae35f 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -218,7 +218,6 @@ def diffusion_fn(t, x): seed=seed, method=method, max_steps=1_000, - min_steps=100, ) x_0 = np.array(out["x"]) From 23a69eac0be53ed05bcb2e78849056af88ef8849 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 2 Dec 2025 12:57:48 +0100 Subject: [PATCH 26/36] check nan in integrate --- bayesflow/utils/integrate.py | 70 +++++++++++++++++++++++++----- tests/test_utils/test_integrate.py | 1 - 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index d947b54c0..2c027b3cc 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -22,6 +22,13 @@ STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"] +def _check_all_nans(state: StateDict): + all_nans_flags = [] + for v in state.values(): + all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) + return keras.ops.all(keras.ops.stack(all_nans_flags)) + + def euler_step( fn: Callable, state: StateDict, @@ -218,7 +225,8 @@ def integrate_fixed( def body(_loop_var, _loop_state): _state, _time = _loop_state _state, _time, _, _ = step_fn(_state, _time, step_size) - + if _check_all_nans(_state): + raise RuntimeError(f"All values are NaNs in state during integration at {_time}.") return _state, _time state, time = keras.ops.fori_loop(0, steps, body, (state, time)) @@ -251,6 +259,9 @@ def body(_loop_var, _loop_state): _time = steps[_loop_var] step_size = steps[_loop_var + 1] - steps[_loop_var] _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size) + + if _check_all_nans(_loop_state): + raise RuntimeError(f"All values are NaNs in state during integration at {_time}.") return _loop_state state = keras.ops.fori_loop(0, len(steps) - 1, body, state) @@ -296,10 +307,12 @@ def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted): step_lt_min = keras.ops.less(_step, float(min_steps)) step_lt_max = keras.ops.less(_step, float(max_steps)) - return keras.ops.logical_or( - step_lt_min, - keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max), + all_nans = _check_all_nans(_state) + + end_now = keras.ops.logical_or( + step_lt_min, keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max) ) + return keras.ops.logical_and(~all_nans, end_now) def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): # Time remaining from current point @@ -347,6 +360,9 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): [state, start_time, initial_step, step0, k1_0, count_not_accepted], ) + if _check_all_nans(state): + raise RuntimeError(f"All values are NaNs in state during integration at {time}.") + # Final step to hit stop_time exactly time_diff = stop_time - time time_remaining = keras.ops.sign(stop_time - start_time) * time_diff @@ -974,6 +990,9 @@ def body_fixed(_i, _loop_state): step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, ) + all_nans = _check_all_nans(new_state) + if all_nans: + raise RuntimeError(f"All values are NaNs in state during integration at {_current_time}.") return new_state, new_time, initial_step # Execute the fixed loop @@ -1004,9 +1023,10 @@ def integrate_stochastic_adaptive( initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0, state) def cond(i, current_state, current_time, current_step, counter, last_state): - # time remaining after the next step time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step)) - return keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) + all_nans = _check_all_nans(current_state) + end_now = keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) + return keras.ops.logical_and(~all_nans, end_now) def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _last_state): # Step Size Control @@ -1048,9 +1068,36 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _l return _i + 1, new_state, new_time, new_step, _counter, _new_current_state # Execute the adaptive loop - _, final_state, final_time, _, final_counter, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + _, final_state, final_time, _, final_counter, final_k1 = keras.ops.while_loop( + cond, body_adaptive, initial_loop_state + ) + + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") + + # Final step to hit stop_time exactly + time_diff = stop_time - final_time + time_remaining = keras.ops.sign(stop_time - start_time) * time_diff + if keras.ops.all(time_remaining > 0): + noise_final = {k: z_history[k][-1] for k in final_state.keys()} + noise_extra_final = None + if len(z_extra_history) > 0: + noise_extra_final = {k: z_extra_history[k][-1] for k in final_state.keys()} + + final_state, _, _ = step_fn( + state=final_state, + time=final_time, + step_size=time_diff, + last_state=final_k1, + min_step_size=min_step_size, + max_step_size=time_remaining, + noise=noise_final, + noise_aux=noise_extra_final, + use_adaptive_step_size=False, + ) + final_counter = final_counter + 1 - logging.debug(f"Finished integration after {final_counter} steps at {final_time}.") + logging.debug(f"Finished integration after {final_counter}.") return final_state @@ -1094,13 +1141,12 @@ def integrate_langevin( def body(_i, loop_state): current_state, current_time = loop_state - t = current_time # score at current time - score = score_fn(t, **filter_kwargs(current_state, score_fn)) + score = score_fn(current_time, **filter_kwargs(current_state, score_fn)) # noise schedule - log_snr_t = noise_schedule.get_log_snr(t=t, training=False) + log_snr_t = noise_schedule.get_log_snr(t=current_time, training=False) _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) new_state: StateDict = {} @@ -1125,6 +1171,8 @@ def body(_i, loop_state): step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, ) + if _check_all_nans(new_state): + raise RuntimeError(f"All values are NaNs in state during integration at {current_time}.") return new_state, new_time diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 78adae35f..44a6fc60f 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -202,7 +202,6 @@ def diffusion_fn(t, x): # Start at time T with value x_T initial_state = {"x": keras.ops.ones((N,)) * x_T} steps = 200 if not use_adapt else "adaptive" - # Expected mean and variance at t=0 after integrating backward from t=T # For backward integration, the effective drift coefficient changes sign exp_mean = x_T * np.exp(-a * T) From b9e8c964cd31bd1b821bb039c3658ce63c75fbd7 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 2 Dec 2025 16:10:44 +0100 Subject: [PATCH 27/36] set default --- bayesflow/utils/integrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 2c027b3cc..f0a36b06c 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -631,7 +631,7 @@ def two_step_adaptive_step( min_step_size=min_step_size, max_step_size=max_step_size, noise=noise, - use_adaptive_step_size=False, + use_adaptive_step_size=True, ) # Compute drift and diffusion at new state, but update from old state From be78470602ad64d91c43feb9110becd115136897 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 16:43:13 +0100 Subject: [PATCH 28/36] update model defaults --- bayesflow/networks/diffusion_model/diffusion_model.py | 6 +++--- bayesflow/networks/flow_matching/flow_matching.py | 4 ++-- bayesflow/utils/integrate.py | 8 ++++++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 0e38ea4f1..8cbce1e87 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -40,13 +40,13 @@ class DiffusionModel(InferenceNetwork): "activation": "mish", "kernel_initializer": "he_normal", "residual": True, - "dropout": 0.0, + "dropout": 0.05, "spectral_normalization": False, } INTEGRATE_DEFAULT_CONFIG = { - "method": "rk45", - "steps": 100, + "method": "two_step_adaptive", + "steps": "adaptive", } def __init__( diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index fa74089a4..485cbbd9e 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -53,8 +53,8 @@ class FlowMatching(InferenceNetwork): } INTEGRATE_DEFAULT_CONFIG = { - "method": "rk45", - "steps": 100, + "method": "tsit5", + "steps": "adaptive", } def __init__( diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index f0a36b06c..16bca18a2 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -65,7 +65,9 @@ def rk45_step( use_adaptive_step_size: bool = True, ) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: """ - Dormand-Prince 5(4) method with embedded error estimation. + Dormand-Prince 5(4) method with embedded error estimation [1]. + + Dormand (1996), Numerical Methods for Differential Equations: A Computational Approach """ h = step_size @@ -124,7 +126,9 @@ def tsit5_step( use_adaptive_step_size: bool = True, ) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: """ - Implements a single step of the Tsitouras 5/4 Runge-Kutta method. + Implements a single step of the Tsitouras 5/4 Runge-Kutta method [1]. + + [1] Tsitouras (2011), Runge--Kutta pairs of order 5(4) satisfying only the first column simplifying assumption """ h = step_size From ad276063e5bcddf0c6bd03c85355b708e4b478e5 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 17:46:29 +0100 Subject: [PATCH 29/36] make loop jax compatible --- bayesflow/utils/integrate.py | 109 ++++++++++++++++++++--------------- 1 file changed, 64 insertions(+), 45 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 16bca18a2..db0ebc813 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -224,17 +224,22 @@ def integrate_fixed( step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) step_size = (stop_time - start_time) / steps - time = start_time - - def body(_loop_var, _loop_state): - _state, _time = _loop_state - _state, _time, _, _ = step_fn(_state, _time, step_size) - if _check_all_nans(_state): - raise RuntimeError(f"All values are NaNs in state during integration at {_time}.") - return _state, _time + def cond(_loop_var, _loop_state, _loop_time): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, steps) + return keras.ops.logical_and(~all_nans, end_now) - state, time = keras.ops.fori_loop(0, steps, body, (state, time)) + def body(_loop_var, _loop_state, _loop_time): + _loop_state, _loop_time, _, _ = step_fn(_loop_state, _loop_time, step_size) + return _loop_var + 1, _loop_state, _loop_time + _, state, _ = keras.ops.while_loop( + cond, + body, + [0, state, start_time], + ) + if _check_all_nans(state): + raise RuntimeError("All values are NaNs in state during integration.") return state @@ -259,16 +264,25 @@ def integrate_scheduled( step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + def cond(_loop_var, _loop_state): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, len(steps) - 1) + return keras.ops.logical_and(~all_nans, end_now) + def body(_loop_var, _loop_state): _time = steps[_loop_var] step_size = steps[_loop_var + 1] - steps[_loop_var] _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size) + return _loop_var + 1, _loop_state - if _check_all_nans(_loop_state): - raise RuntimeError(f"All values are NaNs in state during integration at {_time}.") - return _loop_state + _, state = keras.ops.while_loop( + cond, + body, + [0, state], + ) - state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + if _check_all_nans(state): + raise RuntimeError("All values are NaNs in state during integration.") return state @@ -635,7 +649,7 @@ def two_step_adaptive_step( min_step_size=min_step_size, max_step_size=max_step_size, noise=noise, - use_adaptive_step_size=True, + use_adaptive_step_size=False, ) # Compute drift and diffusion at new state, but update from old state @@ -957,9 +971,12 @@ def integrate_stochastic_fixed( """ initial_step = (stop_time - start_time) / float(steps) - def body_fixed(_i, _loop_state): - _current_state, _current_time, _current_step = _loop_state + def cond(_loop_var, _loop_state, _loop_time, _loop_step): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, steps) + return keras.ops.logical_and(~all_nans, end_now) + def body(_i, _current_state, _current_time, _current_step): # Determine step size: either the constant size or the remainder to reach stop_time remaining = keras.ops.abs(stop_time - _current_time) sign = keras.ops.sign(_current_step) @@ -994,13 +1011,16 @@ def body_fixed(_i, _loop_state): step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, ) - all_nans = _check_all_nans(new_state) - if all_nans: - raise RuntimeError(f"All values are NaNs in state during integration at {_current_time}.") - return new_state, new_time, initial_step + return _i + 1, new_state, new_time, initial_step + + _, final_state, final_time, _ = keras.ops.while_loop( + cond, + body, + [0, state, start_time, initial_step], + ) + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") - # Execute the fixed loop - final_state, final_time, _ = keras.ops.fori_loop(0, steps, body_fixed, (state, start_time, initial_step)) return final_state @@ -1024,22 +1044,21 @@ def integrate_stochastic_adaptive( """ Performs adaptive-step SDE integration. """ - initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0, state) + initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, state) - def cond(i, current_state, current_time, current_step, counter, last_state): + def cond(i, current_state, current_time, current_step, last_state): time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step)) all_nans = _check_all_nans(current_state) end_now = keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) return keras.ops.logical_and(~all_nans, end_now) - def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _last_state): + def body_adaptive(_i, _current_state, _current_time, _current_step, _last_state): # Step Size Control remaining = keras.ops.abs(stop_time - _current_time) sign = keras.ops.sign(_current_step) # Ensure the next step does not overshoot the stop_time dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) dt = sign * dt_mag - _counter += 1 _noise_i = {k: z_history[k][_i] for k in _current_state.keys()} _noise_extra_i = None @@ -1069,12 +1088,10 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _l corrector_noise_history=corrector_noise_history, ) - return _i + 1, new_state, new_time, new_step, _counter, _new_current_state + return _i + 1, new_state, new_time, new_step, _new_current_state # Execute the adaptive loop - _, final_state, final_time, _, final_counter, final_k1 = keras.ops.while_loop( - cond, body_adaptive, initial_loop_state - ) + final_counter, final_state, final_time, _, final_k1 = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) if _check_all_nans(final_state): raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") @@ -1143,27 +1160,30 @@ def integrate_langevin( dt = (stop_time - start_time) / float(steps) effective_factor = step_size_factor * 100 / np.sqrt(steps) - def body(_i, loop_state): - current_state, current_time = loop_state + def cond(_loop_var, _loop_state, _loop_time): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, steps) + return keras.ops.logical_and(~all_nans, end_now) + def body(_i, _loop_state, _loop_time): # score at current time - score = score_fn(current_time, **filter_kwargs(current_state, score_fn)) + score = score_fn(_loop_time, **filter_kwargs(_loop_state, score_fn)) # noise schedule - log_snr_t = noise_schedule.get_log_snr(t=current_time, training=False) + log_snr_t = noise_schedule.get_log_snr(t=_loop_time, training=False) _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) new_state: StateDict = {} - for k in current_state.keys(): + for k in _loop_state.keys(): s_k = score.get(k, None) if s_k is None: - new_state[k] = current_state[k] + new_state[k] = _loop_state[k] continue e = effective_factor * sigma_t**2 - new_state[k] = current_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history[k][_i] + new_state[k] = _loop_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history[k][_i] - new_time = current_time + dt + new_time = _loop_time + dt new_state = _apply_corrector( new_state=new_state, @@ -1175,17 +1195,16 @@ def body(_i, loop_state): step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, ) - if _check_all_nans(new_state): - raise RuntimeError(f"All values are NaNs in state during integration at {current_time}.") - return new_state, new_time + return _i + 1, new_state, new_time - final_state, _ = keras.ops.fori_loop( - 0, - steps, + _, final_state, final_time = keras.ops.while_loop( + cond, body, - (state, start_time), + (0, state, start_time), ) + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") return final_state From 5a1a3fa4a90395e21657b93cf69b138270c43ea1 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 18:05:48 +0100 Subject: [PATCH 30/36] filter kwargs --- bayesflow/utils/integrate.py | 49 +++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index db0ebc813..41af3ecc1 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -73,16 +73,21 @@ def rk45_step( if k1 is None: # reuse k1 if available k1 = fn(time, **filter_kwargs(state, fn)) - k2 = fn(time + h * (1 / 5), **add_scaled(state, [k1], [1 / 5], h)) - k3 = fn(time + h * (3 / 10), **add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h)) - k4 = fn(time + h * (4 / 5), **add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h)) + k2 = fn(time + h * (1 / 5), **filter_kwargs(add_scaled(state, [k1], [1 / 5], h), fn)) + k3 = fn(time + h * (3 / 10), **filter_kwargs(add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h), fn)) + k4 = fn(time + h * (4 / 5), **filter_kwargs(add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h), fn)) k5 = fn( time + h * (8 / 9), - **add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h), + **filter_kwargs( + add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h), fn + ), ) k6 = fn( time + h, - **add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h), + **filter_kwargs( + add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h), + fn, + ), ) # 5th order solution @@ -140,24 +145,38 @@ def tsit5_step( if k1 is None: # reuse k1 if available k1 = fn(time, **filter_kwargs(state, fn)) - k2 = fn(time + h * c2, **add_scaled(state, [k1], [0.161], h)) - k3 = fn(time + h * c3, **add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h)) + k2 = fn(time + h * c2, **filter_kwargs(add_scaled(state, [k1], [0.161], h), fn)) + k3 = fn( + time + h * c3, **filter_kwargs(add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h), fn) + ) k4 = fn( - time + h * c4, **add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h) + time + h * c4, + **filter_kwargs( + add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h), fn + ), ) k5 = fn( time + h * c5, - **add_scaled( - state, [k1, k2, k3, k4], [5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525], h + **filter_kwargs( + add_scaled( + state, + [k1, k2, k3, k4], + [5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525], + h, + ), + fn, ), ) k6 = fn( time + h, - **add_scaled( - state, - [k1, k2, k3, k4, k5], - [5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838], - h, + **filter_kwargs( + add_scaled( + state, + [k1, k2, k3, k4, k5], + [5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838], + h, + ), + fn, ), ) From e5857083a3470313ccb9d2e8edcaa6ab80693e88 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 18:35:13 +0100 Subject: [PATCH 31/36] fix density computation --- .../diffusion_model/diffusion_model.py | 12 +++++ .../networks/flow_matching/flow_matching.py | 23 +++++++-- bayesflow/utils/integrate.py | 50 +++++++------------ 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 8cbce1e87..c2f5b5fde 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -413,6 +413,12 @@ def _forward( raise ValueError("Stochastic methods are not supported for forward integration.") if density: + if integrate_kwargs["steps"] == "adaptive": + logging.warning( + "Using adaptive integration for density estimation can lead to " + "problems with autodiff. Switching to 200 fixed steps instead." + ) + integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) @@ -461,6 +467,12 @@ def _inverse( if density: if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for density computation.") + if integrate_kwargs["steps"] == "adaptive": + logging.warning( + "Using adaptive integration for density estimation can lead to " + "problems with autodiff. Switching to 200 fixed steps instead." + ) + integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 485cbbd9e..ea581a7c5 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Sequence import keras @@ -236,14 +237,21 @@ def f(x): def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: + integrate_kwargs = self.integrate_kwargs | kwargs if density: + if integrate_kwargs["steps"] == "adaptive": + logging.warning( + "Using adaptive integration for density estimation can lead to " + "problems with autodiff. Switching to 200 fixed steps instead." + ) + integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) return {"xz": v, "trace": trace} state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))} - state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs) z = state["xz"] log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1) @@ -254,7 +262,7 @@ def deltas(time, xz): return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)} state = {"xz": x} - state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs) z = state["xz"] @@ -263,14 +271,21 @@ def deltas(time, xz): def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: + integrate_kwargs = self.integrate_kwargs | kwargs if density: + if integrate_kwargs["steps"] == "adaptive": + logging.warning( + "Using adaptive integration for density estimation can lead to " + "problems with autodiff. Switching to 200 fixed steps instead." + ) + integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) return {"xz": v, "trace": trace} state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))} - state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs) x = state["xz"] log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1) @@ -281,7 +296,7 @@ def deltas(time, xz): return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)} state = {"xz": z} - state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs) x = state["xz"] diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 41af3ecc1..b00a45325 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -22,13 +22,6 @@ STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"] -def _check_all_nans(state: StateDict): - all_nans_flags = [] - for v in state.values(): - all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) - return keras.ops.all(keras.ops.stack(all_nans_flags)) - - def euler_step( fn: Callable, state: StateDict, @@ -243,22 +236,17 @@ def integrate_fixed( step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) step_size = (stop_time - start_time) / steps - def cond(_loop_var, _loop_state, _loop_time): - all_nans = _check_all_nans(_loop_state) - end_now = keras.ops.less(_loop_var, steps) - return keras.ops.logical_and(~all_nans, end_now) - - def body(_loop_var, _loop_state, _loop_time): - _loop_state, _loop_time, _, _ = step_fn(_loop_state, _loop_time, step_size) - return _loop_var + 1, _loop_state, _loop_time + def body(_loop_var, _loop_state): + _state, _time = _loop_state + _state, _time, _, _ = step_fn(_state, _time, step_size) + return _state, _time - _, state, _ = keras.ops.while_loop( - cond, + state, _ = keras.ops.fori_loop( + 0, + steps, body, - [0, state, start_time], + (state, start_time), ) - if _check_all_nans(state): - raise RuntimeError("All values are NaNs in state during integration.") return state @@ -283,25 +271,18 @@ def integrate_scheduled( step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) - def cond(_loop_var, _loop_state): - all_nans = _check_all_nans(_loop_state) - end_now = keras.ops.less(_loop_var, len(steps) - 1) - return keras.ops.logical_and(~all_nans, end_now) - def body(_loop_var, _loop_state): _time = steps[_loop_var] step_size = steps[_loop_var + 1] - steps[_loop_var] _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size) - return _loop_var + 1, _loop_state + return _loop_state - _, state = keras.ops.while_loop( - cond, + state = keras.ops.fori_loop( + 0, + keras.ops.shape(steps)[0] - 1, body, - [0, state], + state, ) - - if _check_all_nans(state): - raise RuntimeError("All values are NaNs in state during integration.") return state @@ -501,6 +482,11 @@ def integrate( ############ SDE Solvers ############# +def _check_all_nans(state: StateDict): + all_nans_flags = [] + for v in state.values(): + all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) + return keras.ops.all(keras.ops.stack(all_nans_flags)) def stochastic_adaptive_step_size_controller( From ac07af288a4162cd7a79340167c173bcf8f1a875 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 19:00:27 +0100 Subject: [PATCH 32/36] fix jax all nans --- bayesflow/utils/integrate.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b00a45325..a208cb8e1 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -11,6 +11,7 @@ from bayesflow.types import Tensor from bayesflow.utils import filter_kwargs from bayesflow.utils.logging import warning +from keras import backend as K from . import logging @@ -22,6 +23,15 @@ STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"] +def _check_all_nans(state: StateDict): + if K.backend() == "jax": + return False # JAX backend does not support checks of the state variables + all_nans_flags = [] + for v in state.values(): + all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) + return keras.ops.all(keras.ops.stack(all_nans_flags)) + + def euler_step( fn: Callable, state: StateDict, @@ -482,11 +492,6 @@ def integrate( ############ SDE Solvers ############# -def _check_all_nans(state: StateDict): - all_nans_flags = [] - for v in state.values(): - all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) - return keras.ops.all(keras.ops.stack(all_nans_flags)) def stochastic_adaptive_step_size_controller( From 4c9d44b6540ffc373046989d5cc716e7790a1cfb Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 21:16:42 +0100 Subject: [PATCH 33/36] fix jax all nans --- bayesflow/utils/integrate.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a208cb8e1..3f5581894 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -11,7 +11,6 @@ from bayesflow.types import Tensor from bayesflow.utils import filter_kwargs from bayesflow.utils.logging import warning -from keras import backend as K from . import logging @@ -24,8 +23,6 @@ def _check_all_nans(state: StateDict): - if K.backend() == "jax": - return False # JAX backend does not support checks of the state variables all_nans_flags = [] for v in state.values(): all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) @@ -376,7 +373,7 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): # Step counter: increment only on accepted steps updated_step = _step + keras.ops.where(accepted, 1.0, 0.0) - _count_not_accepted = _count_not_accepted + 1 if not accepted else _count_not_accepted + _count_not_accepted = _count_not_accepted + keras.ops.where(accepted, 1.0, 0.0) # For the next iteration, always use the new suggested step size return updated_state, updated_time, new_step_size, updated_step, updated_k1, _count_not_accepted From 87297455b8be90dbe9bad92cfd710e9dde549600 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 21:17:58 +0100 Subject: [PATCH 34/36] fix jax all nans --- bayesflow/utils/integrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 3f5581894..03c3aff8a 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -12,7 +12,7 @@ from bayesflow.utils import filter_kwargs from bayesflow.utils.logging import warning -from . import logging +import logging ArrayLike = int | float | Tensor StateDict = Dict[str, ArrayLike] From adefa7b5dcece15c61cdd714f87c67b515fcf212 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 22:34:35 +0100 Subject: [PATCH 35/36] relax tols in tests --- tests/test_utils/test_integrate.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 44a6fc60f..b3214c229 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -8,9 +8,8 @@ TOLERANCE_EULER = 1e-3 # Euler with fixed steps requires a larger tolerance # tolerances for SDE tests -TOL_MEAN = 3e-2 +TOL_MEAN = 5e-2 TOL_VAR = 5e-2 -TOL_DET = 1e-3 @pytest.mark.parametrize("method", ["euler", "rk45", "tsit5"]) @@ -123,7 +122,6 @@ def test_forward_additive_ou_weak_means_and_vars(method, use_adapt): x_0 = 1.2 # initial condition at time 0 T = 1.0 - # batch of trajectories N = 10000 seed = keras.random.SeedGenerator(42) @@ -149,15 +147,14 @@ def diffusion_fn(t, x): steps=steps, seed=seed, method=method, - max_steps=1_000, ) x_T = np.array(out["x"]) emp_mean = float(x_T.mean()) emp_var = float(x_T.var()) - np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN, rtol=0.0) - np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR, rtol=0.0) + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR) @pytest.mark.parametrize( @@ -188,8 +185,7 @@ def test_backward_additive_ou_weak_means_and_vars(method, use_adapt): x_T = 1.2 # initial condition at time T T = 1.0 - # batch of trajectories - N = 10000 # large enough to control sampling error + N = 10000 seed = keras.random.SeedGenerator(42) def drift_fn(t, x): @@ -216,15 +212,14 @@ def diffusion_fn(t, x): steps=steps, seed=seed, method=method, - max_steps=1_000, ) x_0 = np.array(out["x"]) emp_mean = float(x_0.mean()) emp_var = float(x_0.var()) - np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN, rtol=0.0) - np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR, rtol=0.0) + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR) @pytest.mark.parametrize( @@ -270,7 +265,7 @@ def diffusion_fn(t, x): )["x"] exact = x0 * np.exp(a * T) - np.testing.assert_allclose(np.array(out).mean(), exact, atol=TOL_DET, rtol=0.1) + np.testing.assert_allclose(np.array(out).mean(), exact, atol=1e-3, rtol=0.1) @pytest.mark.parametrize("steps", [500]) From f9823f8c71307cb3d69d09a6e8772b16bcff3d59 Mon Sep 17 00:00:00 2001 From: arrjon Date: Thu, 4 Dec 2025 16:09:23 +0100 Subject: [PATCH 36/36] enable density computation with adaptive step size solvers --- .../networks/diffusion_model/diffusion_model.py | 12 ------------ bayesflow/networks/flow_matching/flow_matching.py | 13 ------------- 2 files changed, 25 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index c2f5b5fde..8cbce1e87 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -413,12 +413,6 @@ def _forward( raise ValueError("Stochastic methods are not supported for forward integration.") if density: - if integrate_kwargs["steps"] == "adaptive": - logging.warning( - "Using adaptive integration for density estimation can lead to " - "problems with autodiff. Switching to 200 fixed steps instead." - ) - integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) @@ -467,12 +461,6 @@ def _inverse( if density: if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for density computation.") - if integrate_kwargs["steps"] == "adaptive": - logging.warning( - "Using adaptive integration for density estimation can lead to " - "problems with autodiff. Switching to 200 fixed steps instead." - ) - integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index ea581a7c5..808cee681 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -1,4 +1,3 @@ -import logging from collections.abc import Sequence import keras @@ -239,12 +238,6 @@ def _forward( ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = self.integrate_kwargs | kwargs if density: - if integrate_kwargs["steps"] == "adaptive": - logging.warning( - "Using adaptive integration for density estimation can lead to " - "problems with autodiff. Switching to 200 fixed steps instead." - ) - integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) @@ -273,12 +266,6 @@ def _inverse( ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = self.integrate_kwargs | kwargs if density: - if integrate_kwargs["steps"] == "adaptive": - logging.warning( - "Using adaptive integration for density estimation can lead to " - "problems with autodiff. Switching to 200 fixed steps instead." - ) - integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)