Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
034b2c4
correct rk45 and add tsit5
arrjon Nov 24, 2025
3cae883
add predictor corrector
arrjon Nov 24, 2025
39682b1
add adaptive sampler SDE
arrjon Nov 24, 2025
770abc7
add shark
arrjon Nov 24, 2025
eba6892
rm warn
arrjon Nov 24, 2025
e901b73
fix dt
arrjon Nov 24, 2025
e8be555
fix adaptive step size
arrjon Nov 24, 2025
36a16b3
refactor stochastic integrator
arrjon Nov 24, 2025
de57eaf
refactor stochastic integrator
arrjon Nov 24, 2025
dde5451
refactor stochastic integrator
arrjon Nov 24, 2025
3d2c80e
fix adaptive
arrjon Nov 24, 2025
ed5e89f
fix Tsit5
arrjon Nov 25, 2025
c4b52a7
fix sampler
arrjon Nov 25, 2025
ac22af5
updated stochastic solvers
arrjon Nov 27, 2025
44570cf
add Langevin
arrjon Nov 27, 2025
531c610
add Langevin
arrjon Nov 28, 2025
fdeeb2f
add adaptive step size
arrjon Nov 28, 2025
f45c2dc
tune adaptive step size
arrjon Nov 29, 2025
9fd7707
add Gotta Go Fast SDE sampler
arrjon Nov 30, 2025
5c5abd3
improve adaptive ODE samplers
arrjon Nov 30, 2025
dd021bb
fix schedule test
arrjon Nov 30, 2025
1fe2c60
improved defaults
arrjon Dec 1, 2025
a771e32
improved defaults
arrjon Dec 1, 2025
a7adea2
improved initial step size
arrjon Dec 1, 2025
08853fb
improved initial step size
arrjon Dec 1, 2025
23a69ea
check nan in integrate
arrjon Dec 2, 2025
b9e8c96
set default
arrjon Dec 2, 2025
be78470
update model defaults
arrjon Dec 3, 2025
ad27606
make loop jax compatible
arrjon Dec 3, 2025
5a1a3fa
filter kwargs
arrjon Dec 3, 2025
e585708
fix density computation
arrjon Dec 3, 2025
ac07af2
fix jax all nans
arrjon Dec 3, 2025
4c9d44b
fix jax all nans
arrjon Dec 3, 2025
8729745
fix jax all nans
arrjon Dec 3, 2025
adefa7b
relax tols in tests
arrjon Dec 3, 2025
f9823f8
enable density computation with adaptive step size solvers
arrjon Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 72 additions & 16 deletions bayesflow/networks/diffusion_model/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
integrate_stochastic,
logging,
tensor_utils,
STOCHASTIC_METHODS,
)
from bayesflow.utils.serialization import serialize, deserialize, serializable

Expand All @@ -39,13 +40,13 @@ class DiffusionModel(InferenceNetwork):
"activation": "mish",
"kernel_initializer": "he_normal",
"residual": True,
"dropout": 0.0,
"dropout": 0.05,
"spectral_normalization": False,
}

INTEGRATE_DEFAULT_CONFIG = {
"method": "rk45",
"steps": 100,
"method": "two_step_adaptive",
"steps": "adaptive",
}

def __init__(
Expand Down Expand Up @@ -243,6 +244,55 @@ def _apply_subnet(
else:
return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training)

def score(
self,
xz: Tensor,
time: float | Tensor = None,
log_snr_t: Tensor = None,
conditions: Tensor = None,
training: bool = False,
) -> Tensor:
"""
Computes the score of the target or latent variable `xz`.

Parameters
----------
xz : Tensor
The current state of the latent variable `z`, typically of shape (..., D),
where D is the dimensionality of the latent space.
time : float or Tensor
Scalar or tensor representing the time (or noise level) at which the velocity
should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided.
log_snr_t : Tensor
The log signal-to-noise ratio at time `t`. If None, time must be provided.
conditions : Tensor, optional
Conditional inputs to the network, such as conditioning variables
or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
training : bool, optional
Whether the model is in training mode. Affects behavior of dropout, batch norm,
or other stochastic layers. Default is False.

Returns
-------
Tensor
The velocity tensor of the same shape as `xz`, representing the right-hand
side of the SDE or ODE at the given `time`.
"""
if log_snr_t is None:
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)

subnet_out = self._apply_subnet(
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
)
pred = self.output_projector(subnet_out, training=training)

x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)

score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
return score

def velocity(
self,
xz: Tensor,
Expand Down Expand Up @@ -279,19 +329,10 @@ def velocity(
The velocity tensor of the same shape as `xz`, representing the right-hand
side of the SDE or ODE at the given `time`.
"""
# calculate the current noise level and transform into correct shape
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)

subnet_out = self._apply_subnet(
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
)
pred = self.output_projector(subnet_out, training=training)

x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)

score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training)

# compute velocity f, g of the SDE or ODE
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training)
Expand Down Expand Up @@ -368,7 +409,7 @@ def _forward(
integrate_kwargs = integrate_kwargs | self.integrate_kwargs
integrate_kwargs = integrate_kwargs | kwargs

if integrate_kwargs["method"] == "euler_maruyama":
if integrate_kwargs["method"] in STOCHASTIC_METHODS:
raise ValueError("Stochastic methods are not supported for forward integration.")

if density:
Expand Down Expand Up @@ -418,7 +459,7 @@ def _inverse(
integrate_kwargs = integrate_kwargs | self.integrate_kwargs
integrate_kwargs = integrate_kwargs | kwargs
if density:
if integrate_kwargs["method"] == "euler_maruyama":
if integrate_kwargs["method"] in STOCHASTIC_METHODS:
raise ValueError("Stochastic methods are not supported for density computation.")

def deltas(time, xz):
Expand All @@ -437,7 +478,7 @@ def deltas(time, xz):
return x, log_density

state = {"xz": z}
if integrate_kwargs["method"] == "euler_maruyama":
if integrate_kwargs["method"] in STOCHASTIC_METHODS:

def deltas(time, xz):
return {
Expand All @@ -447,9 +488,24 @@ def deltas(time, xz):
def diffusion(time, xz):
return {"xz": self.diffusion_term(xz, time=time, training=training)}

score_fn = None
if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin":

def score_fn(time, xz):
return {
"xz": self.score(
xz,
time=time,
conditions=conditions,
training=training,
)
}

state = integrate_stochastic(
drift_fn=deltas,
diffusion_fn=diffusion,
score_fn=score_fn,
noise_schedule=self.noise_schedule,
state=state,
seed=self.seed_generator,
**integrate_kwargs,
Expand Down
14 changes: 8 additions & 6 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class FlowMatching(InferenceNetwork):
}

INTEGRATE_DEFAULT_CONFIG = {
"method": "rk45",
"steps": 100,
"method": "tsit5",
"steps": "adaptive",
}

def __init__(
Expand Down Expand Up @@ -236,14 +236,15 @@ def f(x):
def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
integrate_kwargs = self.integrate_kwargs | kwargs
if density:

def deltas(time, xz):
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))}
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs)

z = state["xz"]
log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1)
Expand All @@ -254,7 +255,7 @@ def deltas(time, xz):
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}

state = {"xz": x}
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs)

z = state["xz"]

Expand All @@ -263,14 +264,15 @@ def deltas(time, xz):
def _inverse(
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
integrate_kwargs = self.integrate_kwargs | kwargs
if density:

def deltas(time, xz):
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))}
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs)

x = state["xz"]
log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1)
Expand All @@ -281,7 +283,7 @@ def deltas(time, xz):
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}

state = {"xz": z}
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs)

x = state["xz"]

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)

from .hparam_utils import find_batch_size, find_memory_budget
from .integrate import integrate, integrate_stochastic
from .integrate import integrate, integrate_stochastic, DETERMINISTIC_METHODS, STOCHASTIC_METHODS

from .io import (
pickle_load,
Expand Down
Loading