diff --git a/src/skillmodels/kalman_filters.py b/src/skillmodels/kalman_filters.py index defe35c..f9cfae9 100644 --- a/src/skillmodels/kalman_filters.py +++ b/src/skillmodels/kalman_filters.py @@ -1,5 +1,3 @@ -import functools - import jax import jax.numpy as jnp @@ -16,7 +14,6 @@ # ====================================================================================== -@functools.partial(jax.checkpoint, prevent_cse=False) def kalman_update( states, upper_chols, @@ -160,7 +157,6 @@ def calculate_sigma_scaling_factor_and_weights(n_states, kappa=2): return scaling_factor, weights -@functools.partial(jax.checkpoint, static_argnums=0, prevent_cse=False) def kalman_predict( transition_func, states, @@ -232,7 +228,6 @@ def kalman_predict( return predicted_states, predicted_covs -@functools.partial(jax.checkpoint, prevent_cse=False) def _calculate_sigma_points(states, upper_chols, scaling_factor, observed_factors): """Calculate the array of sigma_points for the unscented transform. diff --git a/src/skillmodels/likelihood_function.py b/src/skillmodels/likelihood_function.py index 61ffee6..4e7eec2 100644 --- a/src/skillmodels/likelihood_function.py +++ b/src/skillmodels/likelihood_function.py @@ -132,8 +132,8 @@ def log_likelihood_obs( transition_func=transition_func, observed_factors=observed_factors, ) - - static_out = jax.lax.scan(_body, carry, loop_args)[1] + _body = jax.checkpoint(_body, prevent_cse=False) + static_out = jax.lax.scan(_body, carry, loop_args, unroll=False)[1] # clip contributions before aggregation to preserve as much information as # possible. diff --git a/src/skillmodels/maximization_inputs.py b/src/skillmodels/maximization_inputs.py index 2e0dbfa..ae4d6cf 100644 --- a/src/skillmodels/maximization_inputs.py +++ b/src/skillmodels/maximization_inputs.py @@ -23,12 +23,14 @@ jax.config.update("jax_enable_x64", True) # noqa: FBT003 -def get_maximization_inputs(model_dict, data): +def get_maximization_inputs(model_dict, data, split_dataset=1): """Create inputs for optimagic's maximize function. Args: model_dict (dict): The model specification. See: :ref:`model_specs` data (DataFrame): dataset in long format. + split_dataset(Int): Controls into how many sclices to split the dataset + during the gradient computation. Returns a dictionary with keys: loglike (function): A jax jitted function that takes an optimagic-style @@ -121,7 +123,27 @@ def loglikeobs(params): def loglike_and_gradient(params): params_vec = partialed_get_jnp_params_vec(params) crit = float(_jitted_loglike(params_vec)) - grad = _to_numpy(_gradient(params_vec)) + n_obs = processed_data["measurements"].shape[1] + _grad = jnp.zeros_like(params_vec) + start = 0 + stop = int(n_obs / split_dataset) + step = int(n_obs / split_dataset) + for i in range(split_dataset): + stop = n_obs if i == split_dataset - 1 else stop + measurements_slice = processed_data["measurements"][:, start:stop] + controls_slice = processed_data["controls"][:, start:stop, :] + observed_factors_slice = processed_data["observed_factors"][ + :, start:stop, : + ] + _grad += _gradient( + params_vec, + measurements=measurements_slice, + controls=controls_slice, + observed_factors=observed_factors_slice, + ) + start += step + stop += step + grad = _to_numpy(_grad) return crit, grad def debug_loglike(params): diff --git a/src/skillmodels/qr.py b/src/skillmodels/qr.py index 7e95b8a..1cf835d 100644 --- a/src/skillmodels/qr.py +++ b/src/skillmodels/qr.py @@ -29,7 +29,7 @@ def _householder(r: jax.Array, tau: jax.Array): vi = jnp.expand_dims(r[:, i], 1) vi = vi.at[0:i].set(0) vi = vi.at[i].set(1) - h = h @ (jnp.eye(m) - tau[i] * (vi @ jnp.transpose(vi))) + h = h - tau[i] * (h @ vi) @ jnp.transpose(vi) return h[:, :n] diff --git a/tests/test_likelihood_regression.py b/tests/test_likelihood_regression.py index f973acf..000e4c0 100644 --- a/tests/test_likelihood_regression.py +++ b/tests/test_likelihood_regression.py @@ -91,6 +91,19 @@ def test_likelihood_values_have_not_changed(model2, model2_data, model_name, fun old_loglike = np.array(json.load(j)).sum() aaae(new_loglike, old_loglike) +def test_splitting_does_not_change_gradient(model2, model2_data): + + inputs = get_maximization_inputs(model2, model2_data) + inputs_split = get_maximization_inputs(model2, model2_data, 13) + + params = inputs["params_template"] + params["value"] = 0.1 + + _, gradient = inputs["loglike_and_gradient"](params) + _, gradient_split = inputs_split["loglike_and_gradient"](params) + + aaae(gradient, gradient_split) + @pytest.mark.parametrize( ("model_name", "fun_key"), product(MODEL_NAMES, ["loglikeobs"])