diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index ca8a634e9..8cbce1e87 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -16,6 +16,7 @@ integrate_stochastic, logging, tensor_utils, + STOCHASTIC_METHODS, ) from bayesflow.utils.serialization import serialize, deserialize, serializable @@ -39,13 +40,13 @@ class DiffusionModel(InferenceNetwork): "activation": "mish", "kernel_initializer": "he_normal", "residual": True, - "dropout": 0.0, + "dropout": 0.05, "spectral_normalization": False, } INTEGRATE_DEFAULT_CONFIG = { - "method": "rk45", - "steps": 100, + "method": "two_step_adaptive", + "steps": "adaptive", } def __init__( @@ -243,6 +244,55 @@ def _apply_subnet( else: return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training) + def score( + self, + xz: Tensor, + time: float | Tensor = None, + log_snr_t: Tensor = None, + conditions: Tensor = None, + training: bool = False, + ) -> Tensor: + """ + Computes the score of the target or latent variable `xz`. + + Parameters + ---------- + xz : Tensor + The current state of the latent variable `z`, typically of shape (..., D), + where D is the dimensionality of the latent space. + time : float or Tensor + Scalar or tensor representing the time (or noise level) at which the velocity + should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided. + log_snr_t : Tensor + The log signal-to-noise ratio at time `t`. If None, time must be provided. + conditions : Tensor, optional + Conditional inputs to the network, such as conditioning variables + or encoder outputs. Shape must be broadcastable with `xz`. Default is None. + training : bool, optional + Whether the model is in training mode. Affects behavior of dropout, batch norm, + or other stochastic layers. Default is False. + + Returns + ------- + Tensor + The velocity tensor of the same shape as `xz`, representing the right-hand + side of the SDE or ODE at the given `time`. + """ + if log_snr_t is None: + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + subnet_out = self._apply_subnet( + xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training + ) + pred = self.output_projector(subnet_out, training=training) + + x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) + + score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + return score + def velocity( self, xz: Tensor, @@ -279,19 +329,10 @@ def velocity( The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ - # calculate the current noise level and transform into correct shape log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - subnet_out = self._apply_subnet( - xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training - ) - pred = self.output_projector(subnet_out, training=training) - - x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) - score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -368,7 +409,7 @@ def _forward( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for forward integration.") if density: @@ -418,7 +459,7 @@ def _inverse( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if density: - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): @@ -437,7 +478,7 @@ def deltas(time, xz): return x, log_density state = {"xz": z} - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in STOCHASTIC_METHODS: def deltas(time, xz): return { @@ -447,9 +488,24 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} + score_fn = None + if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin": + + def score_fn(time, xz): + return { + "xz": self.score( + xz, + time=time, + conditions=conditions, + training=training, + ) + } + state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, + score_fn=score_fn, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index fa74089a4..808cee681 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -53,8 +53,8 @@ class FlowMatching(InferenceNetwork): } INTEGRATE_DEFAULT_CONFIG = { - "method": "rk45", - "steps": 100, + "method": "tsit5", + "steps": "adaptive", } def __init__( @@ -236,6 +236,7 @@ def f(x): def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: + integrate_kwargs = self.integrate_kwargs | kwargs if density: def deltas(time, xz): @@ -243,7 +244,7 @@ def deltas(time, xz): return {"xz": v, "trace": trace} state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))} - state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs) z = state["xz"] log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1) @@ -254,7 +255,7 @@ def deltas(time, xz): return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)} state = {"xz": x} - state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs) z = state["xz"] @@ -263,6 +264,7 @@ def deltas(time, xz): def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: + integrate_kwargs = self.integrate_kwargs | kwargs if density: def deltas(time, xz): @@ -270,7 +272,7 @@ def deltas(time, xz): return {"xz": v, "trace": trace} state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))} - state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs) x = state["xz"] log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1) @@ -281,7 +283,7 @@ def deltas(time, xz): return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)} state = {"xz": z} - state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs) x = state["xz"] diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index a8d28a50a..25b7dd920 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -47,7 +47,7 @@ ) from .hparam_utils import find_batch_size, find_memory_budget -from .integrate import integrate, integrate_stochastic +from .integrate import integrate, integrate_stochastic, DETERMINISTIC_METHODS, STOCHASTIC_METHODS from .io import ( pickle_load, diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b197ea975..03c3aff8a 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -1,4 +1,5 @@ from collections.abc import Callable, Sequence +from typing import Dict, Tuple, Optional from functools import partial import keras @@ -11,128 +12,219 @@ from bayesflow.utils import filter_kwargs from bayesflow.utils.logging import warning -from . import logging +import logging ArrayLike = int | float | Tensor +StateDict = Dict[str, ArrayLike] + + +DETERMINISTIC_METHODS = ["euler", "rk45", "tsit5"] +STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"] + + +def _check_all_nans(state: StateDict): + all_nans_flags = [] + for v in state.values(): + all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) + return keras.ops.all(keras.ops.stack(all_nans_flags)) def euler_step( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - tolerance: ArrayLike = 1e-6, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - use_adaptive_step_size: bool = False, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + **kwargs, +) -> Tuple[StateDict, ArrayLike, None, ArrayLike]: k1 = fn(time, **filter_kwargs(state, fn)) - if use_adaptive_step_size: - intermediate_state = state.copy() - for key, delta in k1.items(): - intermediate_state[key] = state[key] + step_size * delta - - k2 = fn(time + step_size, **filter_kwargs(intermediate_state, fn)) - - # check all keys are equal - if set(k1.keys()) != set(k2.keys()): - raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") - - # compute next step size - intermediate_error = keras.ops.stack([keras.ops.norm(k2[key] - k1[key], ord=2, axis=-1) for key in k1]) - new_step_size = step_size * tolerance / (intermediate_error + 1e-9) - - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) - - # consolidate step size - new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) - else: - new_step_size = step_size - - # apply updates new_state = state.copy() for key in k1.keys(): new_state[key] = state[key] + step_size * k1[key] - new_time = time + step_size - return new_state, new_time, new_step_size + return new_state, new_time, None, 0.0 + + +def add_scaled(state, ks, coeffs, h): + out = {} + for key, y in state.items(): + acc = keras.ops.zeros_like(y) + for c, k in zip(coeffs, ks): + acc = acc + c * k[key] + out[key] = y + h * acc + return out def rk45_step( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, - last_step_size: ArrayLike, - tolerance: ArrayLike = 1e-6, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - use_adaptive_step_size: bool = False, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): - step_size = last_step_size - - k1 = fn(time, **filter_kwargs(state, fn)) - - intermediate_state = state.copy() - for key, delta in k1.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta + step_size: ArrayLike, + k1: StateDict = None, + use_adaptive_step_size: bool = True, +) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: + """ + Dormand-Prince 5(4) method with embedded error estimation [1]. - k2 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + Dormand (1996), Numerical Methods for Differential Equations: A Computational Approach + """ + h = step_size + + if k1 is None: # reuse k1 if available + k1 = fn(time, **filter_kwargs(state, fn)) + k2 = fn(time + h * (1 / 5), **filter_kwargs(add_scaled(state, [k1], [1 / 5], h), fn)) + k3 = fn(time + h * (3 / 10), **filter_kwargs(add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h), fn)) + k4 = fn(time + h * (4 / 5), **filter_kwargs(add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h), fn)) + k5 = fn( + time + h * (8 / 9), + **filter_kwargs( + add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h), fn + ), + ) + k6 = fn( + time + h, + **filter_kwargs( + add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h), + fn, + ), + ) - intermediate_state = state.copy() - for key, delta in k2.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta + # 5th order solution + new_state = {} + for key in k1.keys(): + new_state[key] = state[key] + h * ( + 35 / 384 * k1[key] + 500 / 1113 * k3[key] + 125 / 192 * k4[key] - 2187 / 6784 * k5[key] + 11 / 84 * k6[key] + ) - k3 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + new_time = time + h + if not use_adaptive_step_size: + return new_state, new_time, None, 0.0 - intermediate_state = state.copy() - for key, delta in k3.items(): - intermediate_state[key] = state[key] + step_size * delta + k7 = fn(time + h, **filter_kwargs(new_state, fn)) - k4 = fn(time + step_size, **filter_kwargs(intermediate_state, fn)) + # 4th order embedded solution + err_state = {} + for key in k1.keys(): + y4 = state[key] + h * ( + 5179 / 57600 * k1[key] + + 7571 / 16695 * k3[key] + + 393 / 640 * k4[key] + - 92097 / 339200 * k5[key] + + 187 / 2100 * k6[key] + + 1 / 40 * k7[key] + ) + err_state[key] = new_state[key] - y4 - if use_adaptive_step_size: - intermediate_state = state.copy() - for key, delta in k4.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) - k5 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + return new_state, new_time, k7, err - # check all keys are equal - if not all(set(k.keys()) == set(k1.keys()) for k in [k2, k3, k4, k5]): - raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") - # compute next step size - intermediate_error = keras.ops.stack([keras.ops.norm(k5[key] - k4[key], ord=2, axis=-1) for key in k5.keys()]) - new_step_size = step_size * tolerance / (intermediate_error + 1e-9) +def tsit5_step( + fn: Callable, + state: StateDict, + time: ArrayLike, + step_size: ArrayLike, + k1: StateDict = None, + use_adaptive_step_size: bool = True, +) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: + """ + Implements a single step of the Tsitouras 5/4 Runge-Kutta method [1]. - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + [1] Tsitouras (2011), Runge--Kutta pairs of order 5(4) satisfying only the first column simplifying assumption + """ + h = step_size + + # Butcher tableau coefficients + c2 = 0.161 + c3 = 0.327 + c4 = 0.9 + c5 = 0.9800255409045097 + + if k1 is None: # reuse k1 if available + k1 = fn(time, **filter_kwargs(state, fn)) + k2 = fn(time + h * c2, **filter_kwargs(add_scaled(state, [k1], [0.161], h), fn)) + k3 = fn( + time + h * c3, **filter_kwargs(add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h), fn) + ) + k4 = fn( + time + h * c4, + **filter_kwargs( + add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h), fn + ), + ) + k5 = fn( + time + h * c5, + **filter_kwargs( + add_scaled( + state, + [k1, k2, k3, k4], + [5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525], + h, + ), + fn, + ), + ) + k6 = fn( + time + h, + **filter_kwargs( + add_scaled( + state, + [k1, k2, k3, k4, k5], + [5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838], + h, + ), + fn, + ), + ) - # consolidate step size - new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) - else: - new_step_size = step_size + # 5th order solution: b coefficients + new_state = {} + for key in state.keys(): + new_state[key] = state[key] + h * ( + 0.09646076681806523 * k1[key] + + 0.01 * k2[key] + + 0.4798896504144996 * k3[key] + + 1.379008574103742 * k4[key] + - 3.290069515436081 * k5[key] + + 2.324710524099774 * k6[key] + ) - # apply updates - new_state = state.copy() - for key in k1.keys(): - new_state[key] = state[key] + (step_size / 6.0) * (k1[key] + 2.0 * k2[key] + 2.0 * k3[key] + k4[key]) + new_time = time + h + if not use_adaptive_step_size: + return new_state, new_time, None, 0.0 + + k7 = fn(time + h, **filter_kwargs(new_state, fn)) + + err_state = {} + for key in state.keys(): + err_state[key] = h * ( + -0.00178001105222577714 * k1[key] + - 0.0008164344596567469 * k2[key] + + 0.007880878010261995 * k3[key] + - 0.1447110071732629 * k4[key] + + 0.5823571654525552 * k5[key] + - 0.45808210592918697 * k6[key] + + 0.015151515151515152 * k7[key] + ) - new_time = time + step_size + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) - return new_state, new_time, new_step_size + return new_state, new_time, k7, err def integrate_fixed( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, steps: int, - method: str = "rk45", + method: str, **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: if steps <= 0: raise ValueError("Number of steps must be positive.") @@ -141,6 +233,8 @@ def integrate_fixed( step_fn = euler_step case "rk45": step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: @@ -149,16 +243,53 @@ def integrate_fixed( step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) step_size = (stop_time - start_time) / steps - time = start_time - def body(_loop_var, _loop_state): _state, _time = _loop_state - _state, _time, _ = step_fn(_state, _time, step_size) - + _state, _time, _, _ = step_fn(_state, _time, step_size) return _state, _time - state, time = keras.ops.fori_loop(0, steps, body, (state, time)) + state, _ = keras.ops.fori_loop( + 0, + steps, + body, + (state, start_time), + ) + return state + +def integrate_scheduled( + fn: Callable, + state: StateDict, + steps: Tensor | np.ndarray, + method: str, + **kwargs, +) -> StateDict: + match method: + case "euler": + step_fn = euler_step + case "rk45": + step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step + case str() as name: + raise ValueError(f"Unknown integration method name: {name!r}") + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + + def body(_loop_var, _loop_state): + _time = steps[_loop_var] + step_size = steps[_loop_var + 1] - steps[_loop_var] + _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size) + return _loop_state + + state = keras.ops.fori_loop( + 0, + keras.ops.shape(steps)[0] - 1, + body, + state, + ) return state @@ -167,114 +298,120 @@ def integrate_adaptive( state: dict[str, ArrayLike], start_time: ArrayLike, stop_time: ArrayLike, - min_steps: int = 10, - max_steps: int = 1000, - method: str = "rk45", + min_steps: int, + max_steps: int, + method: str, **kwargs, ) -> dict[str, ArrayLike]: if max_steps <= min_steps: raise ValueError("Maximum number of steps must be greater than minimum number of steps.") match method: - case "euler": - step_fn = euler_step case "rk45": step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step + case "euler": + raise ValueError("Adaptive step sizing is not supported for the 'euler' method.") case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: raise TypeError(f"Invalid integration method: {other!r}") + tolerance = keras.ops.convert_to_tensor(kwargs.get("tolerance", 1e-6), dtype="float32") step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=True) + initial_step = (stop_time - start_time) / float(min_steps) + step0 = keras.ops.convert_to_tensor(0.0, dtype="float32") + count_not_accepted = 0 - def cond(_state, _time, _step_size, _step): - # while step < min_steps or time_remaining > 0 and step < max_steps + # "First Same As Last" (FSAL) property + k1_0 = fn(start_time, **filter_kwargs(state, fn)) - # time remaining after the next step - time_remaining = keras.ops.abs(stop_time - (_time + _step_size)) + def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted): + time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (_time + _step_size)) + step_lt_min = keras.ops.less(_step, float(min_steps)) + step_lt_max = keras.ops.less(_step, float(max_steps)) - return keras.ops.logical_or( - keras.ops.all(_step < min_steps), - keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.all(_step < max_steps)), - ) - - def body(_state, _time, _step_size, _step): - _step = _step + 1 + all_nans = _check_all_nans(_state) - # time remaining after the next step - time_remaining = stop_time - (_time + _step_size) + end_now = keras.ops.logical_or( + step_lt_min, keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max) + ) + return keras.ops.logical_and(~all_nans, end_now) + def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): + # Time remaining from current point + time_remaining = keras.ops.abs(stop_time - _time) min_step_size = time_remaining / (max_steps - _step) max_step_size = time_remaining / keras.ops.maximum(min_steps - _step, 1.0) - - # reorder - min_step_size, max_step_size = ( - keras.ops.minimum(min_step_size, max_step_size), - keras.ops.maximum(min_step_size, max_step_size), + h = keras.ops.sign(_step_size) * keras.ops.clip(keras.ops.abs(_step_size), min_step_size, max_step_size) + + # Take one trial step + new_state, new_time, new_k1, err = step_fn( + state=_state, + time=_time, + step_size=h, + k1=_k1, ) - _state, _time, _step_size = step_fn( - _state, _time, _step_size, min_step_size=min_step_size, max_step_size=max_step_size + new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) + new_step_size = keras.ops.sign(new_step_size) * keras.ops.clip( + keras.ops.abs(new_step_size), min_step_size, max_step_size ) - return _state, _time, _step_size, _step - - # select initial step size conservatively - step_size = (stop_time - start_time) / max_steps - - step = 0 - time = start_time - - state, time, step_size, step = keras.ops.while_loop(cond, body, [state, time, step_size, step]) - - # do the last step - step_size = stop_time - time - state, _, _ = step_fn(state, time, step_size) - step = step + 1 + # Error control: reject if err > tolerance + too_big = keras.ops.greater(err, tolerance) + at_min = keras.ops.less_equal( + keras.ops.abs(h), + keras.ops.abs(min_step_size), + ) + accepted = keras.ops.logical_or(keras.ops.logical_not(too_big), at_min) - logging.debug("Finished integration after {} steps.", step) + updated_state = keras.ops.cond(accepted, lambda: new_state, lambda: _state) + updated_time = keras.ops.cond(accepted, lambda: new_time, lambda: _time) + updated_k1 = keras.ops.cond(accepted, lambda: new_k1, lambda: _k1) - return state + # Step counter: increment only on accepted steps + updated_step = _step + keras.ops.where(accepted, 1.0, 0.0) + _count_not_accepted = _count_not_accepted + keras.ops.where(accepted, 1.0, 0.0) + # For the next iteration, always use the new suggested step size + return updated_state, updated_time, new_step_size, updated_step, updated_k1, _count_not_accepted -def integrate_scheduled( - fn: Callable, - state: dict[str, ArrayLike], - steps: Tensor | np.ndarray, - method: str = "rk45", - **kwargs, -) -> dict[str, ArrayLike]: - match method: - case "euler": - step_fn = euler_step - case "rk45": - step_fn = rk45_step - case str() as name: - raise ValueError(f"Unknown integration method name: {name!r}") - case other: - raise TypeError(f"Invalid integration method: {other!r}") - - step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) - - def body(_loop_var, _loop_state): - _time = steps[_loop_var] - step_size = steps[_loop_var + 1] - steps[_loop_var] + # Run the adaptive loop + state, time, step_size, step, k1, count_not_accepted = keras.ops.while_loop( + cond, + body, + [state, start_time, initial_step, step0, k1_0, count_not_accepted], + ) - _loop_state, _, _ = step_fn(_loop_state, _time, step_size) - return _loop_state + if _check_all_nans(state): + raise RuntimeError(f"All values are NaNs in state during integration at {time}.") + + # Final step to hit stop_time exactly + time_diff = stop_time - time + time_remaining = keras.ops.sign(stop_time - start_time) * time_diff + if keras.ops.all(time_remaining > 0): + state, time, _, _ = step_fn( + state=state, + time=time, + step_size=time_diff, + k1=k1, + ) + step = step + 1.0 - state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + logging.debug(f"Finished integration after {step} steps with {count_not_accepted} rejected steps.") return state def integrate_scipy( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, scipy_kwargs: dict | None = None, **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: import scipy.integrate scipy_kwargs = scipy_kwargs or {} @@ -316,15 +453,15 @@ def scipy_wrapper_fn(time, x): def integrate( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike | None = None, stop_time: ArrayLike | None = None, min_steps: int = 10, max_steps: int = 10_000, steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, - method: str = "euler", + method: str = "rk45", **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: if start_time is None or stop_time is None: raise ValueError( @@ -351,14 +488,59 @@ def integrate( raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") +############ SDE Solvers ############# + + +def stochastic_adaptive_step_size_controller( + state, + drift, + adaptive_factor: ArrayLike, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), +) -> ArrayLike: + """ + Adaptive step size controller based on [1]. Similar to a tamed explicit Euler method when used in Euler-Maruyama. + + Adaptive step sizing uses: + h = max(1, ||x||**2) / max(1, ||f(x)||**2) * adaptive_factor + + + [1] Fang & Giles, Adaptive Euler-Maruyama Method for SDEs with Non-Globally Lipschitz Drift Coefficients (2020) + + Returns + ------- + New step size. + """ + state_norms = [] + drift_norms = [] + for key in state.keys(): + state_norms.append(keras.ops.norm(state[key], ord=2, axis=-1)) + drift_norms.append(keras.ops.norm(drift[key], ord=2, axis=-1)) + state_norm = keras.ops.stack(state_norms) + drift_norm = keras.ops.stack(drift_norms) + max_state_norm = keras.ops.maximum( + keras.ops.cast(1.0, dtype=keras.ops.dtype(state_norm)), keras.ops.max(state_norm) ** 2 + ) + max_drift_norm = keras.ops.maximum( + keras.ops.cast(1.0, dtype=keras.ops.dtype(drift_norm)), keras.ops.max(drift_norm) ** 2 + ) + new_step_size = max_state_norm / max_drift_norm * adaptive_factor + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + return new_step_size + + def euler_maruyama_step( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: dict[str, ArrayLike], -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + noise: StateDict, + use_adaptive_step_size: bool = False, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), + **kwargs, +) -> Union[Tuple[StateDict, ArrayLike, ArrayLike], Tuple[StateDict, ArrayLike, ArrayLike, StateDict]]: """ Performs a single Euler-Maruyama step for stochastic differential equations. @@ -369,6 +551,9 @@ def euler_maruyama_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. + use_adaptive_step_size: Whether to use adaptive step sizing. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. Returns: new_state: Updated state after one Euler-Maruyama step. @@ -378,78 +563,817 @@ def euler_maruyama_step( drift = drift_fn(time, **filter_kwargs(state, drift_fn)) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) - # Check noise keys - if set(diffusion.keys()) != set(noise.keys()): - raise ValueError("Keys of diffusion terms and noise do not match.") + new_step_size = step_size + if use_adaptive_step_size: + sign_step = keras.ops.sign(step_size) + new_step_size = stochastic_adaptive_step_size_controller( + state=state, + drift=drift, + adaptive_factor=max_step_size, + min_step_size=min_step_size, + max_step_size=max_step_size, + ) + new_step_size = sign_step * keras.ops.abs(new_step_size) + + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(new_step_size)) new_state = {} for key, d in drift.items(): + base = state[key] + new_step_size * d + if key in diffusion: + base = base + diffusion[key] * sqrt_step_size * noise[key] + new_state[key] = base + + if use_adaptive_step_size: + return new_state, time + new_step_size, new_step_size, state + return new_state, time + new_step_size, new_step_size + + +def two_step_adaptive_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: StateDict, + time: ArrayLike, + step_size: ArrayLike, + noise: StateDict, + last_state: StateDict = None, + use_adaptive_step_size: bool = True, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), + e_rel: float = 0.1, + e_abs: float = None, + r: float = 0.9, + adapt_safety: float = 0.9, + **kwargs, +) -> Union[ + Tuple[StateDict, ArrayLike, ArrayLike], + Tuple[StateDict, ArrayLike, ArrayLike, StateDict], +]: + """ + Performs a single adaptive step for stochastic differential equations based on [1]. + + Based on + + This method uses a predictor-corrector approach with error estimation: + 1. Take an Euler-Maruyama step (predictor) + 2. Take another Euler-Maruyama step from the predicted state + 3. Average the two predictions (corrector) + 4. Estimate error and adapt step size + + When step_size reaches min_step_size, steps are always accepted regardless of + error to ensure progress and termination within max_steps. + + [1] Jolicoeur-Martineau et al. (2021) "Gotta Go Fast When Generating Data with Score-Based Models" + + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors (pre-scaled by sqrt(dt)). + last_state: Previous state for error estimation. + use_adaptive_step_size: Whether to adapt step size. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. + e_rel: Relative error tolerance. + e_abs: Absolute error tolerance. Default assumes standardized targets. + r: Order of the method for step size adaptation. + adapt_safety: Safety factor for step size adaptation. + **kwargs: Additional arguments passed to drift_fn and diffusion_fn. + + Returns: + new_state: Updated state after one adaptive step. + new_time: time + dt (or time if step rejected). + new_step_size: Adapted step size for next iteration. + """ + state_euler, time_mid, _ = euler_maruyama_step( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=state, + time=time, + step_size=step_size, + min_step_size=min_step_size, + max_step_size=max_step_size, + noise=noise, + use_adaptive_step_size=False, + ) + + # Compute drift and diffusion at new state, but update from old state + drift_mid = drift_fn(time_mid, **filter_kwargs(state_euler, drift_fn)) + diffusion_mid = diffusion_fn(time_mid, **filter_kwargs(state_euler, diffusion_fn)) + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) + + state_euler_mid = {} + for key, d in drift_mid.items(): + base = state[key] + step_size * d + if key in diffusion_mid: + base = base + diffusion_mid[key] * sqrt_step_size * noise[key] + state_euler_mid[key] = base + + # average the two predictions + state_heun = {} + for key in state.keys(): + state_heun[key] = 0.5 * (state_euler[key] + state_euler_mid[key]) + + # Error estimation + if use_adaptive_step_size: + if e_abs is None: + e_abs = 0.02576 # 1% of 99% CI of standardized unit variance + # Check if we're at minimum step size - if so, force acceptance + at_min_step = keras.ops.less_equal(keras.ops.abs(step_size), min_step_size) + + # Compute error tolerance for each component + e_abs_tensor = keras.ops.cast(e_abs, dtype=keras.ops.dtype(list(state.values())[0])) + e_rel_tensor = keras.ops.cast(e_rel, dtype=keras.ops.dtype(list(state.values())[0])) + + max_error = keras.ops.cast(0.0, dtype=keras.ops.dtype(list(state.values())[0])) + + for key in state.keys(): + # Local error estimate: difference between Heun and first Euler step + error_estimate = keras.ops.abs(state_heun[key] - state_euler[key]) + + # Tolerance threshold + delta = keras.ops.maximum( + e_abs_tensor, + e_rel_tensor * keras.ops.maximum(keras.ops.abs(state_euler[key]), keras.ops.abs(last_state[key])), + ) + + # Normalized error + normalized_error = error_estimate / (delta + 1e-10) + + # Maximum error across all components and batch dimensions + component_max_error = keras.ops.max(normalized_error) + max_error = keras.ops.maximum(max_error, component_max_error) + + error_scale = 1 # 1/sqrt(n_params) + E2 = error_scale * max_error + + # Accept step if error is acceptable OR if at minimum step size + error_acceptable = keras.ops.less_equal(E2, keras.ops.cast(1.0, dtype=keras.ops.dtype(E2))) + accepted = keras.ops.logical_or(error_acceptable, at_min_step) + + # Adapt step size for next iteration (only if not at minimum) + # Ensure E2 is not zero to avoid division issues + E2_safe = keras.ops.maximum(E2, 1e-10) + + # New step size based on error estimate + adapt_factor = adapt_safety * keras.ops.power(E2_safe, -r) + new_step_candidate = step_size * adapt_factor + + # Clamp to valid range + new_step_size = keras.ops.clip(keras.ops.abs(new_step_candidate), min_step_size, max_step_size) + new_step_size = keras.ops.sign(step_size) * new_step_size + + # Return appropriate state based on acceptance + new_state = keras.ops.cond(accepted, lambda: state_heun, lambda: state) + + new_time = keras.ops.cond(accepted, lambda: time_mid, lambda: time) + + prev_state = keras.ops.cond(accepted, lambda: state_euler, lambda: state) + + return new_state, new_time, new_step_size, prev_state + + else: + return state_heun, time_mid, step_size + + +def compute_levy_area( + state: StateDict, diffusion: StateDict, noise: StateDict, noise_aux: StateDict, step_size: ArrayLike +) -> StateDict: + step_size_abs = keras.ops.abs(step_size) + sqrt_step_size = keras.ops.sqrt(step_size_abs) + inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(step_size_abs)) + + # Build Lévy area H_k from w_k and Z_k + H = {} + for k in state.keys(): + if k in diffusion: + term1 = 0.5 * step_size_abs * noise[k] + term2 = 0.5 * step_size_abs * sqrt_step_size * inv_sqrt3 * noise_aux[k] + H[k] = term1 + term2 + else: + H[k] = keras.ops.zeros_like(state[k]) + return H + + +def sea_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: StateDict, + time: ArrayLike, + step_size: ArrayLike, + noise: StateDict, # standard normals + noise_aux: StateDict, # standard normals + **kwargs, +) -> Tuple[StateDict, ArrayLike, ArrayLike]: + """ + Performs a single shifted Euler step for SDEs with additive noise [1]. + + Compared to Euler-Maruyama, this evaluates the drift at a shifted state, + which improves the local error and the global error constant for additive noise. + + The scheme is + X_{n+1} = X_n + f(t_n, X_n + g(t_n) * (0.5 * ΔW_n + ΔH_n) * h + g(t_n) * ΔW_n + + [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise. + + Returns: + new_state: Updated state after one SEA step. + new_time: time + dt. + """ + # Compute diffusion + diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) + + la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size) + + # Build shifted state: X_shift = X + g * (0.5 * ΔW + ΔH) + shifted_state = {} + for key, x in state.items(): + if key in diffusion: + shifted_state[key] = x + diffusion[key] * (0.5 * sqrt_step_size * noise[key] + la[key]) + else: + shifted_state[key] = x + + # Drift evaluated at shifted state + drift_shifted = drift_fn(time, **filter_kwargs(shifted_state, drift_fn)) + + # Final update + new_state = {} + for key, d in drift_shifted.items(): base = state[key] + step_size * d - if key in diffusion: # stochastic update - base = base + diffusion[key] * noise[key] + if key in diffusion: + base = base + diffusion[key] * sqrt_step_size * noise[key] new_state[key] = base - return new_state, time + step_size + return new_state, time + step_size, step_size -def integrate_stochastic( +def shark_step( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, + time: ArrayLike, + step_size: ArrayLike, + noise: StateDict, + noise_aux: StateDict, + **kwargs, +) -> Tuple[StateDict, ArrayLike, ArrayLike]: + """ + Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion + per step and has a strong order 1.5. + + SHARK method as specified: + + 1) ỹ_k = y_k + g(y_k) H_k + 2) ỹ_{k+5/6} = ỹ_k + (5/6)[ f(ỹ_k) h + g(ỹ_k) W_k ] + 3) y_{k+1} = y_k + + (2/5) f(ỹ_k) h + + (3/5) f(ỹ_{k+5/6}) h + + g(ỹ_k) ( 2/5 W_k + 6/5 H_k ) + + g(ỹ_{k+5/6}) ( 3/5 W_k - 6/5 H_k ) + + with + H_k = 0.5 * |h| * W_k + (|h| ** 1.5) / (2 * sqrt(3)) * Z_k + + [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) + + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise. + + Returns: + new_state: Updated state after one SHARK step. + new_time: time + dt. + """ + h = step_size + t = time + h_mag = keras.ops.abs(h) + sqrt_h_mag = keras.ops.sqrt(h_mag) + + diffusion = diffusion_fn(t, **filter_kwargs(state, diffusion_fn)) + + la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size) + + # === 1) shifted initial state === + y_tilde_k = {} + for k in state.keys(): + if k in diffusion: + y_tilde_k[k] = state[k] + diffusion[k] * la[k] + else: + y_tilde_k[k] = state[k] + + # === evaluate drift and diffusion at ỹ_k === + f_tilde_k = drift_fn(t, **filter_kwargs(y_tilde_k, drift_fn)) + g_tilde_k = diffusion_fn(t, **filter_kwargs(y_tilde_k, diffusion_fn)) + + # === 2) internal stage at 5/6 === + y_tilde_mid = {} + for k in state.keys(): + drift_part = (5.0 / 6.0) * f_tilde_k[k] * h + if k in g_tilde_k: + sto_part = (5.0 / 6.0) * g_tilde_k[k] * sqrt_h_mag * noise[k] + else: + sto_part = keras.ops.zeros_like(state[k]) + y_tilde_mid[k] = y_tilde_k[k] + drift_part + sto_part + + # === evaluate drift and diffusion at ỹ_(k+5/6) === + f_tilde_mid = drift_fn(t + 5.0 / 6.0 * h, **filter_kwargs(y_tilde_mid, drift_fn)) + g_tilde_mid = diffusion_fn(t + 5.0 / 6.0 * h, **filter_kwargs(y_tilde_mid, diffusion_fn)) + + # === 3) final update === + new_state = {} + for k in state.keys(): + # deterministic weights + det = state[k] + (2.0 / 5.0) * f_tilde_k[k] * h + (3.0 / 5.0) * f_tilde_mid[k] * h + + # stochastic parts + sto1 = ( + g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * la[k]) + if k in g_tilde_k + else keras.ops.zeros_like(det) + ) + sto2 = ( + g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * la[k]) + if k in g_tilde_mid + else keras.ops.zeros_like(det) + ) + + new_state[k] = det + sto1 + sto2 + + return new_state, t + h, h + + +def _apply_corrector( + new_state: StateDict, + new_time: ArrayLike, + i: ArrayLike, + corrector_steps: int, + score_fn: Optional[Callable], + corrector_noise_history: StateDict, + step_size_factor: ArrayLike = 0.01, + noise_schedule=None, +) -> StateDict: + """Helper function to apply corrector steps [1]. + + [1] Song et al., "Score-Based Generative Modeling through Stochastic Differential Equations" (2020) + """ + if corrector_steps <= 0: + return new_state + + for j in range(corrector_steps): + score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) + _z_corr = {k: corrector_noise_history[k][i, j] for k in new_state.keys()} + + log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) + alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + for k in new_state.keys(): + if k in score: + # Calculate required norms for Langevin step + z_norm = keras.ops.norm(_z_corr[k], axis=-1, keepdims=True) + score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) + score_norm = keras.ops.maximum(score_norm, 1e-8) + + # Compute step size for the Langevin update + e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 + + # Annealed Langevin Dynamics update + new_state[k] = new_state[k] + e * score[k] + keras.ops.sqrt(2.0 * e) * _z_corr[k] + return new_state + + +def integrate_stochastic_fixed( + step_fn: Callable, + state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, steps: int, + min_step_size: ArrayLike, + max_step_size: ArrayLike, + z_history: StateDict, + z_extra_history: StateDict, + score_fn: Optional[Callable], + step_size_factor: ArrayLike, + corrector_noise_history: StateDict, + corrector_steps: int = 0, + noise_schedule=None, +) -> StateDict: + """ + Performs fixed-step SDE integration. + """ + initial_step = (stop_time - start_time) / float(steps) + + def cond(_loop_var, _loop_state, _loop_time, _loop_step): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, steps) + return keras.ops.logical_and(~all_nans, end_now) + + def body(_i, _current_state, _current_time, _current_step): + # Determine step size: either the constant size or the remainder to reach stop_time + remaining = keras.ops.abs(stop_time - _current_time) + sign = keras.ops.sign(_current_step) + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) + dt = sign * dt_mag + + # Generate noise increment + _noise_i = {k: z_history[k][_i] for k in _current_state.keys()} + if len(z_extra_history) == 0: + _noise_extra_i = None + else: + _noise_extra_i = {k: z_extra_history[k][_i] for k in _current_state.keys()} + + new_state, new_time, new_step = step_fn( + state=_current_state, + time=_current_time, + step_size=dt, + min_step_size=min_step_size, + max_step_size=keras.ops.minimum(max_step_size, remaining), + noise=_noise_i, + noise_aux=_noise_extra_i, + use_adaptive_step_size=False, + ) + + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + return _i + 1, new_state, new_time, initial_step + + _, final_state, final_time, _ = keras.ops.while_loop( + cond, + body, + [0, state, start_time, initial_step], + ) + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") + + return final_state + + +def integrate_stochastic_adaptive( + step_fn: Callable, + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + max_steps: int, + min_step_size: ArrayLike, + max_step_size: ArrayLike, + initial_step: ArrayLike, + z_history: StateDict, + z_extra_history: StateDict, + score_fn: Optional[Callable], + step_size_factor: ArrayLike, + corrector_noise_history: StateDict, + corrector_steps: int = 0, + noise_schedule=None, +) -> StateDict: + """ + Performs adaptive-step SDE integration. + """ + initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, state) + + def cond(i, current_state, current_time, current_step, last_state): + time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step)) + all_nans = _check_all_nans(current_state) + end_now = keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) + return keras.ops.logical_and(~all_nans, end_now) + + def body_adaptive(_i, _current_state, _current_time, _current_step, _last_state): + # Step Size Control + remaining = keras.ops.abs(stop_time - _current_time) + sign = keras.ops.sign(_current_step) + # Ensure the next step does not overshoot the stop_time + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) + dt = sign * dt_mag + + _noise_i = {k: z_history[k][_i] for k in _current_state.keys()} + _noise_extra_i = None + if len(z_extra_history) > 0: + _noise_extra_i = {k: z_extra_history[k][_i] for k in _current_state.keys()} + + new_state, new_time, new_step, _new_current_state = step_fn( + state=_current_state, + last_state=_last_state, + time=_current_time, + step_size=dt, + min_step_size=min_step_size, + max_step_size=keras.ops.minimum(max_step_size, remaining), + noise=_noise_i, + noise_aux=_noise_extra_i, + use_adaptive_step_size=True, + ) + + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + + return _i + 1, new_state, new_time, new_step, _new_current_state + + # Execute the adaptive loop + final_counter, final_state, final_time, _, final_k1 = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") + + # Final step to hit stop_time exactly + time_diff = stop_time - final_time + time_remaining = keras.ops.sign(stop_time - start_time) * time_diff + if keras.ops.all(time_remaining > 0): + noise_final = {k: z_history[k][-1] for k in final_state.keys()} + noise_extra_final = None + if len(z_extra_history) > 0: + noise_extra_final = {k: z_extra_history[k][-1] for k in final_state.keys()} + + final_state, _, _ = step_fn( + state=final_state, + time=final_time, + step_size=time_diff, + last_state=final_k1, + min_step_size=min_step_size, + max_step_size=time_remaining, + noise=noise_final, + noise_aux=noise_extra_final, + use_adaptive_step_size=False, + ) + final_counter = final_counter + 1 + + logging.debug(f"Finished integration after {final_counter}.") + return final_state + + +def integrate_langevin( + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + steps: int, + z_history: StateDict, + score_fn: Callable, + noise_schedule, + corrector_noise_history: StateDict, + step_size_factor: ArrayLike = 0.01, + corrector_steps: int = 0, +) -> StateDict: + """ + Annealed Langevin dynamics using the given score_fn and noise_schedule [1]. + + At each step i with time t_i, performs for every state component k: + state_k <- state_k + e * score_k + sqrt(2 * e) * z + + Times are stepped linearly from start_time to stop_time. + + [1] Song et al., "Generative Modeling by Estimating Gradients of the Data Distribution" (2020) + """ + + if steps <= 0: + raise ValueError("Number of Langevin steps must be positive.") + if score_fn is None or noise_schedule is None: + raise ValueError("score_fn and noise_schedule must be provided.") + # basic shape check + for k, v in state.items(): + if k not in z_history: + raise ValueError(f"Missing noise for key {k!r} in z_history.") + if keras.ops.shape(z_history[k])[0] < steps: + raise ValueError(f"z_history[{k!r}] has fewer than {steps} steps.") + + # Linear time grid + dt = (stop_time - start_time) / float(steps) + effective_factor = step_size_factor * 100 / np.sqrt(steps) + + def cond(_loop_var, _loop_state, _loop_time): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, steps) + return keras.ops.logical_and(~all_nans, end_now) + + def body(_i, _loop_state, _loop_time): + # score at current time + score = score_fn(_loop_time, **filter_kwargs(_loop_state, score_fn)) + + # noise schedule + log_snr_t = noise_schedule.get_log_snr(t=_loop_time, training=False) + _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + new_state: StateDict = {} + for k in _loop_state.keys(): + s_k = score.get(k, None) + if s_k is None: + new_state[k] = _loop_state[k] + continue + + e = effective_factor * sigma_t**2 + new_state[k] = _loop_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history[k][_i] + + new_time = _loop_time + dt + + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + + return _i + 1, new_state, new_time + + _, final_state, final_time = keras.ops.while_loop( + cond, + body, + (0, state, start_time), + ) + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") + return final_state + + +def integrate_stochastic( + drift_fn: Callable, + diffusion_fn: Callable, + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, seed: keras.random.SeedGenerator, + steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", + min_steps: int = 50, + max_steps: int = 10_000, + score_fn: Callable = None, + corrector_steps: int = 0, + noise_schedule=None, + step_size_factor: ArrayLike = 0.01, **kwargs, -) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: +) -> StateDict: """ Integrates a stochastic differential equation from start_time to stop_time. + Dispatches to fixed-step or adaptive-step integration logic. + Args: drift_fn: Function that computes the drift term. diffusion_fn: Function that computes the diffusion term. state: Dictionary containing the initial state. start_time: Starting time for integration. - stop_time: Ending time for integration. - steps: Number of integration steps. + stop_time: Ending time for integration. steps: Number of integration steps. seed: Random seed for noise generation. - method: Integration method to use, e.g., 'euler_maruyama'. + steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. + method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. + min_steps: Minimum number of steps for adaptive integration. + max_steps: Maximum number of steps for adaptive integration. + score_fn: Optional score function for predictor-corrector sampling. + corrector_steps: Number of corrector steps to take after each predictor step. + noise_schedule: Noise schedule object for computing alpha_t in corrector. + step_size_factor: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. - Returns: - If return_noise is False, returns the final state dictionary. - If return_noise is True, returns a tuple of (final_state, noise_history). + Returns: Final state dictionary after integration. """ - if steps <= 0: - raise ValueError("Number of steps must be positive.") + is_adaptive = isinstance(steps, str) and steps in ["adaptive", "dynamic"] + + if is_adaptive: + if start_time is None or stop_time is None: + raise ValueError("Please provide start_time and stop_time for adaptive integration.") + if min_steps <= 0 or max_steps <= 0 or max_steps < min_steps: + raise ValueError("min_steps and max_steps must be positive, and max_steps >= min_steps.") + + loop_steps = max_steps + initial_step = (stop_time - start_time) / float(min_steps) + span_mag = keras.ops.abs(stop_time - start_time) + min_step_size = span_mag / keras.ops.cast(max_steps, span_mag.dtype) + max_step_size = span_mag / keras.ops.cast(min_steps, span_mag.dtype) + else: + if steps <= 0: + raise ValueError("Number of steps must be positive.") + loop_steps = int(steps) + initial_step = (stop_time - start_time) / float(loop_steps) + # For fixed step, min/max step size are just the fixed step size + min_step_size, max_step_size = initial_step, initial_step + + # Pre-generate corrector noise if requested + corrector_noise_history = {} + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + for key, val in state.items(): + shape = keras.ops.shape(val) + corrector_noise_history[key] = keras.random.normal( + (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed + ) - # Select step function based on method match method: case "euler_maruyama": - step_fn = euler_maruyama_step + step_fn_raw = euler_maruyama_step + case "sea": + step_fn_raw = sea_step + if is_adaptive: + raise ValueError("SEA SDE solver does not support adaptive steps.") + case "shark": + step_fn_raw = shark_step + if is_adaptive: + raise ValueError("SHARK SDE solver does not support adaptive steps.") + case "two_step_adaptive": + step_fn_raw = two_step_adaptive_step + case "langevin": + if is_adaptive: + raise ValueError("Langevin sampling does not support adaptive steps.") + + z_history = {} + for key, val in state.items(): + shape = keras.ops.shape(val) + z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + + return integrate_langevin( + state=state, + start_time=start_time, + stop_time=stop_time, + steps=loop_steps, + z_history=z_history, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_steps=corrector_steps, + corrector_noise_history=corrector_noise_history, + ) case other: raise TypeError(f"Invalid integration method: {other!r}") - # Prepare step function with partial application - step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, **kwargs) - - # Time step - step_size = (stop_time - start_time) / steps - sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size)) + # Partial the step function with common arguments + step_fn = partial( + step_fn_raw, + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + **kwargs, + ) - # Pre-generate noise history: shape = (steps, *state_shape) - noise_history = {} + # Pre-generate standard normals for the predictor step (up to max_steps) + z_history = {} + z_extra_history = {} for key, val in state.items(): - noise_history[key] = ( - keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt + shape = keras.ops.shape(val) + z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + if method in ["sea", "shark"]: + z_extra_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + + if is_adaptive: + return integrate_stochastic_adaptive( + step_fn=step_fn, + state=state, + start_time=start_time, + stop_time=stop_time, + max_steps=max_steps, + min_step_size=min_step_size, + max_step_size=max_step_size, + initial_step=initial_step, + z_history=z_history, + z_extra_history=z_extra_history, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + else: + return integrate_stochastic_fixed( + step_fn=step_fn, + state=state, + start_time=start_time, + stop_time=stop_time, + min_step_size=min_step_size, + max_step_size=max_step_size, + steps=loop_steps, + z_history=z_history, + z_extra_history=z_extra_history, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, ) - - def body(_loop_var, _loop_state): - _current_state, _current_time = _loop_state - _noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()} - new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) - return new_state, new_time - - final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time)) - return final_state diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index db5c448d7..b3214c229 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -1,23 +1,31 @@ import numpy as np +import keras +import pytest +from bayesflow.utils import integrate, integrate_stochastic -def test_scheduled_integration(): - import keras - from bayesflow.utils import integrate +TOLERANCE_ADAPTIVE = 1e-6 # Adaptive solvers should be very accurate. +TOLERANCE_EULER = 1e-3 # Euler with fixed steps requires a larger tolerance +# tolerances for SDE tests +TOL_MEAN = 5e-2 +TOL_VAR = 5e-2 + + +@pytest.mark.parametrize("method", ["euler", "rk45", "tsit5"]) +def test_scheduled_integration(method): def fn(t, x): return {"x": t**2} - steps = keras.ops.convert_to_tensor([0.0, 0.5, 1.0]) - approximate_result = 0.0 + 0.5**2 * 0.5 - result = integrate(fn, {"x": 0.0}, steps=steps)["x"] - assert result == approximate_result + def analytical_result(t): + return (t**3) / 3.0 + steps = keras.ops.arange(0.0, 1.0 + 1e-6, 0.01) + result = integrate(fn, {"x": 0.0}, steps=steps, method=method)["x"] + np.testing.assert_allclose(result, analytical_result(steps[-1]), atol=1e-1, rtol=1e-1) -def test_scipy_integration(): - import keras - from bayesflow.utils import integrate +def test_scipy_integration(): def fn(t, x): return {"x": keras.ops.exp(t)} @@ -34,3 +42,300 @@ def fn(t, x): scipy_kwargs={"atol": 1e-6, "rtol": 1e-6}, )["x"] np.testing.assert_allclose(exact_result, result, atol=1e-6, rtol=1e-6) + + +@pytest.mark.parametrize( + "method, atol", [("euler", TOLERANCE_EULER), ("rk45", TOLERANCE_ADAPTIVE), ("tsit5", TOLERANCE_ADAPTIVE)] +) +def test_analytical_integration(method, atol): + def fn(t, x): + return {"x": keras.ops.convert_to_tensor([2.0 * t])} + + initial_state = {"x": keras.ops.convert_to_tensor([1.0])} + T_final = 1.0 + num_steps = 100 + analytical_result = 1.0 + T_final**2 + + result = integrate(fn, initial_state, start_time=0.0, stop_time=T_final, steps=num_steps, method=method)["x"] + if method == "euler": + result_adaptive = result + else: + result_adaptive = integrate( + fn, initial_state, start_time=0.0, stop_time=T_final, steps="adaptive", method=method, max_steps=1_000 + )["x"] + + np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) + np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.1) + + +@pytest.mark.parametrize( + "method, atol", [("euler", TOLERANCE_EULER), ("rk45", TOLERANCE_ADAPTIVE), ("tsit5", TOLERANCE_ADAPTIVE)] +) +def test_analytical_backward_integration(method, atol): + T_final = 1.0 + + def fn(t, x): + return {"x": keras.ops.convert_to_tensor([2.0 * t])} + + num_steps = 100 + analytical_result = 1.0 + initial_state = {"x": keras.ops.convert_to_tensor([1.0 + T_final**2])} + + result = integrate(fn, initial_state, start_time=T_final, stop_time=0.0, steps=num_steps, method=method)["x"] + if method == "euler": + result_adaptive = result + else: + result_adaptive = integrate( + fn, initial_state, start_time=T_final, stop_time=0.0, steps="adaptive", method=method, max_steps=1_000 + )["x"] + + np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) + np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.1) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("euler_maruyama", True), + ("sea", False), + ("shark", False), + ("two_step_adaptive", False), + ("two_step_adaptive", True), + ], +) +def test_forward_additive_ou_weak_means_and_vars(method, use_adapt): + """ + Ornstein-Uhlenbeck with additive noise, integrated FORWARD in time. + This serves as a sanity check that forward integration still works correctly. + + Forward SDE: + dX = a X dt + sigma dW + + Exact at time T starting from X(0) = x_0: + E[X(T)] = x_0 * exp(a T) + Var[X(T)] = sigma^2 * (exp(2 a T) - 1) / (2 a) + """ + # SDE parameters + a = -1.0 + sigma = 0.5 + x_0 = 1.2 # initial condition at time 0 + T = 1.0 + + N = 10000 + seed = keras.random.SeedGenerator(42) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + return {"x": keras.ops.convert_to_tensor([sigma])} + + initial_state = {"x": keras.ops.ones((N,)) * x_0} + steps = 200 if not use_adapt else "adaptive" + + # Expected mean and variance at t=T + exp_mean = x_0 * np.exp(a * T) + exp_var = sigma**2 * (np.exp(2.0 * a * T) - 1.0) / (2.0 * a) + + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=0.0, + stop_time=T, + steps=steps, + seed=seed, + method=method, + ) + + x_T = np.array(out["x"]) + emp_mean = float(x_T.mean()) + emp_var = float(x_T.var()) + + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("euler_maruyama", True), + ("sea", False), + ("shark", False), + ("two_step_adaptive", False), + ("two_step_adaptive", True), + ], +) +def test_backward_additive_ou_weak_means_and_vars(method, use_adapt): + """ + Ornstein-Uhlenbeck with additive noise, integrated BACKWARD in time. + + When integrating from t=T back to t=0 with initial condition X(T) = x_T, + we get X(0) which should satisfy: + E[X(0)] = x_T * exp(-a T) (-a because we go backward) + Var[X(0)] = sigma^2 * (exp(-2 a T) - 1) / (-2 a) + + We verify weak accuracy by matching empirical mean and variance. + """ + # SDE parameters + a = -1.0 + sigma = 0.5 + x_T = 1.2 # initial condition at time T + T = 1.0 + + N = 10000 + seed = keras.random.SeedGenerator(42) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + # additive noise, independent of state + return {"x": keras.ops.convert_to_tensor([sigma])} + + # Start at time T with value x_T + initial_state = {"x": keras.ops.ones((N,)) * x_T} + steps = 200 if not use_adapt else "adaptive" + # Expected mean and variance at t=0 after integrating backward from t=T + # For backward integration, the effective drift coefficient changes sign + exp_mean = x_T * np.exp(-a * T) + exp_var = sigma**2 * (np.exp(-2.0 * a * T) - 1.0) / (-2.0 * a) + + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=T, + stop_time=0.0, + steps=steps, + seed=seed, + method=method, + ) + + x_0 = np.array(out["x"]) + emp_mean = float(x_0.mean()) + emp_var = float(x_0.var()) + + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("euler_maruyama", True), + ("sea", False), + ("shark", False), + ("two_step_adaptive", False), + ("two_step_adaptive", True), + ], +) +def test_zero_noise_reduces_to_deterministic(method, use_adapt): + """ + With zero diffusion the SDE reduces to the ODE + dX = a X dt + """ + a = 0.7 + x0 = 0.9 + T = 1.25 + steps = 200 if not use_adapt else "adaptive" + seed = keras.random.SeedGenerator(0) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + # identically zero diffusion + return {"x": keras.ops.convert_to_tensor([0.0])} + + initial_state = {"x": keras.ops.ones((256,)) * x0} + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=0.0, + stop_time=T, + steps=steps, + seed=seed, + method=method, + max_steps=1_000, + )["x"] + + exact = x0 * np.exp(a * T) + np.testing.assert_allclose(np.array(out).mean(), exact, atol=1e-3, rtol=0.1) + + +@pytest.mark.parametrize("steps", [500]) +def test_langevin_gaussian_sampling(steps): + """ + Test annealed Langevin dynamics on a 1D Gaussian target. + + Target distribution: N(mu, sigma^2), with score + ∇_x log p(x) = -(x - mu) / sigma^2 + + We verify that the empirical mean and variance after Langevin sampling + match the target within a loose tolerance (to allow for Monte Carlo noise). + """ + # target parameters + mu = 0.3 + sigma = 0.7 + + # number of particles + N = 20000 + start_time = 0.0 + stop_time = 1.0 + + # tolerances for mean and variance + tol_mean = 5e-2 + tol_var = 5e-2 + + # initial state: broad Gaussian, independent of target + seed = keras.random.SeedGenerator(42) + x0 = keras.random.normal((N,), dtype="float32", seed=seed) + initial_state = {"x": x0} + + # simple dummy noise schedule: constant alpha + class DummyNoiseSchedule: + def get_log_snr(self, t, training=False): + return keras.ops.zeros_like(t) + + def get_alpha_sigma(self, log_snr_t): + alpha_t = keras.ops.ones_like(log_snr_t) + sigma_t = keras.ops.ones_like(log_snr_t) + return alpha_t, sigma_t + + noise_schedule = DummyNoiseSchedule() + + # score of the target Gaussian + def score_fn(t, x): + s = -(x - mu) / (sigma**2) + return {"x": s} + + # run Langevin + final_state = integrate_stochastic( + drift_fn=None, + diffusion_fn=None, + score_fn=score_fn, + noise_schedule=noise_schedule, + state=initial_state, + start_time=start_time, + stop_time=stop_time, + steps=steps, + seed=seed, + method="langevin", + max_steps=1_000, + corrector_steps=1, + ) + + xT = np.array(final_state["x"]) + emp_mean = float(xT.mean()) + emp_var = float(xT.var()) + + exp_mean = mu + exp_var = sigma**2 + + np.testing.assert_allclose(emp_mean, exp_mean, atol=tol_mean, rtol=0.0) + np.testing.assert_allclose(emp_var, exp_var, atol=tol_var, rtol=0.0)