Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/skillmodels/kalman_filters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import functools

import jax
import jax.numpy as jnp

Expand All @@ -16,7 +14,6 @@
# ======================================================================================


@functools.partial(jax.checkpoint, prevent_cse=False)
def kalman_update(
states,
upper_chols,
Expand Down Expand Up @@ -160,7 +157,6 @@ def calculate_sigma_scaling_factor_and_weights(n_states, kappa=2):
return scaling_factor, weights


@functools.partial(jax.checkpoint, static_argnums=0, prevent_cse=False)
def kalman_predict(
transition_func,
states,
Expand Down Expand Up @@ -232,7 +228,6 @@ def kalman_predict(
return predicted_states, predicted_covs


@functools.partial(jax.checkpoint, prevent_cse=False)
def _calculate_sigma_points(states, upper_chols, scaling_factor, observed_factors):
"""Calculate the array of sigma_points for the unscented transform.

Expand Down
4 changes: 2 additions & 2 deletions src/skillmodels/likelihood_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def log_likelihood_obs(
transition_func=transition_func,
observed_factors=observed_factors,
)

static_out = jax.lax.scan(_body, carry, loop_args)[1]
_body = jax.checkpoint(_body, prevent_cse=False)
static_out = jax.lax.scan(_body, carry, loop_args, unroll=False)[1]

# clip contributions before aggregation to preserve as much information as
# possible.
Expand Down
26 changes: 24 additions & 2 deletions src/skillmodels/maximization_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
jax.config.update("jax_enable_x64", True) # noqa: FBT003


def get_maximization_inputs(model_dict, data):
def get_maximization_inputs(model_dict, data, split_dataset=1):
"""Create inputs for optimagic's maximize function.

Args:
model_dict (dict): The model specification. See: :ref:`model_specs`
data (DataFrame): dataset in long format.
split_dataset(Int): Controls into how many sclices to split the dataset
during the gradient computation.

Returns a dictionary with keys:
loglike (function): A jax jitted function that takes an optimagic-style
Expand Down Expand Up @@ -121,7 +123,27 @@ def loglikeobs(params):
def loglike_and_gradient(params):
params_vec = partialed_get_jnp_params_vec(params)
crit = float(_jitted_loglike(params_vec))
grad = _to_numpy(_gradient(params_vec))
n_obs = processed_data["measurements"].shape[1]
_grad = jnp.zeros_like(params_vec)
start = 0
stop = int(n_obs / split_dataset)
step = int(n_obs / split_dataset)
for i in range(split_dataset):
stop = n_obs if i == split_dataset - 1 else stop
measurements_slice = processed_data["measurements"][:, start:stop]
controls_slice = processed_data["controls"][:, start:stop, :]
observed_factors_slice = processed_data["observed_factors"][
:, start:stop, :
]
_grad += _gradient(
params_vec,
measurements=measurements_slice,
controls=controls_slice,
observed_factors=observed_factors_slice,
)
start += step
stop += step
grad = _to_numpy(_grad)
return crit, grad

def debug_loglike(params):
Expand Down
2 changes: 1 addition & 1 deletion src/skillmodels/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _householder(r: jax.Array, tau: jax.Array):
vi = jnp.expand_dims(r[:, i], 1)
vi = vi.at[0:i].set(0)
vi = vi.at[i].set(1)
h = h @ (jnp.eye(m) - tau[i] * (vi @ jnp.transpose(vi)))
h = h - tau[i] * (h @ vi) @ jnp.transpose(vi)
return h[:, :n]


Expand Down
13 changes: 13 additions & 0 deletions tests/test_likelihood_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,19 @@ def test_likelihood_values_have_not_changed(model2, model2_data, model_name, fun
old_loglike = np.array(json.load(j)).sum()
aaae(new_loglike, old_loglike)

def test_splitting_does_not_change_gradient(model2, model2_data):

inputs = get_maximization_inputs(model2, model2_data)
inputs_split = get_maximization_inputs(model2, model2_data, 13)

params = inputs["params_template"]
params["value"] = 0.1

_, gradient = inputs["loglike_and_gradient"](params)
_, gradient_split = inputs_split["loglike_and_gradient"](params)

aaae(gradient, gradient_split)


@pytest.mark.parametrize(
("model_name", "fun_key"), product(MODEL_NAMES, ["loglikeobs"])
Expand Down
Loading