diff --git a/src/skillmodels/kalman_filters.py b/src/skillmodels/kalman_filters.py index 37d22ef..defe35c 100644 --- a/src/skillmodels/kalman_filters.py +++ b/src/skillmodels/kalman_filters.py @@ -3,8 +3,13 @@ import jax import jax.numpy as jnp -array_qr_jax = jax.vmap(jax.vmap(jnp.linalg.qr)) +from skillmodels.qr import qr_gpu +array_qr_jax = ( + jax.vmap(jax.vmap(qr_gpu)) + if jax.default_backend() == "gpu" + else jax.vmap(jax.vmap(jnp.linalg.qr)) +) # ====================================================================================== # Update Step diff --git a/src/skillmodels/qr.py b/src/skillmodels/qr.py new file mode 100644 index 0000000..7e95b8a --- /dev/null +++ b/src/skillmodels/qr.py @@ -0,0 +1,72 @@ +import jax +import jax.numpy as jnp + + +@jax.custom_jvp +def qr_gpu(a: jax.Array): + """Custom implementation of the QR Decomposition.""" + r, tau = jnp.linalg.qr(a, mode="raw") + + q = _householder(r.mT, tau) + return q, jnp.triu(r.mT[: tau.shape[0]]) + + +def _householder(r: jax.Array, tau: jax.Array): + """Custom implementation of the Householder Product. + + Uses the outputs of jnp.linalg.qr with mode = "raw" to calculate Q. This is needed + because the JAX implementation is extremely slow for a batch of small matrices. + """ + m = r.shape[0] + n = tau.shape[0] + # Calculate Householder Vector which is saved in the lower triangle of R + v1 = jnp.expand_dims(r[:, 0], 1) + v1 = v1.at[0:0].set(0) + v1 = v1.at[0].set(1) + h = jnp.eye(m) - tau[0] * (v1 @ jnp.transpose(v1)) + # Multiply all Householder Vectors Q = H(1)*H(2)...*H(n) + for i in range(1, n): + vi = jnp.expand_dims(r[:, i], 1) + vi = vi.at[0:i].set(0) + vi = vi.at[i].set(1) + h = h @ (jnp.eye(m) - tau[i] * (vi @ jnp.transpose(vi))) + return h[:, :n] + + +def _t(x: jax.Array) -> jax.Array: + """Transpose batched Matrix.""" + return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) + + +def _h(x: jax.Array) -> jax.Array: + """Hermitian Transpose of a Matrix.""" + return _t(x).conj() + + +def _tril(m: jax.Array, k: int = 0) -> jax.Array: + """Select lower Triangle of a Matrix.""" + *_, dim_n, dim_m = m.shape + mask = jnp.tri(dim_n, dim_m, k, bool) + return jax.lax.select( + jax.lax.broadcast(mask, m.shape[:-2]), m, jax.lax.zeros_like_array(m) + ) + + +@qr_gpu.defjvp +def qr_jvp_rule(primals, tangents): + """Calculates the derivative of the custom QR composition.""" + # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. + (x,) = primals + (dx,) = tangents + q, r = qr_gpu(x) + *_, m, n = x.shape + dx_rinv = jax.lax.linalg.triangular_solve(r, dx) # Right side solve by default + qt_dx_rinv = _h(q) @ dx_rinv + qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) + do = qt_dx_rinv_lower - _h(qt_dx_rinv_lower) # This is skew-symmetric + # The following correction is necessary for complex inputs + i = jax.lax.expand_dims(jnp.eye(n, n), range(qt_dx_rinv.ndim - 2)) + do = do + i * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) + dq = q @ (do - qt_dx_rinv) + dx_rinv + dr = (qt_dx_rinv - do) @ r + return (q, r), (dq, dr) diff --git a/tests/test_qr.py b/tests/test_qr.py new file mode 100644 index 0000000..5a35475 --- /dev/null +++ b/tests/test_qr.py @@ -0,0 +1,45 @@ +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from numpy.testing import assert_array_almost_equal as aaae + +from skillmodels.qr import qr_gpu + +SEED = 20 + + +@pytest.fixture +def cov_matrix(): + fixedrng = np.random.default_rng(SEED) + factorized = fixedrng.uniform(low=-1, high=3, size=(7, 7)) + cov = factorized @ factorized.T * 0.5 + np.eye(7) + return cov + + +def test_q(cov_matrix): + q_gpu, _ = qr_gpu(cov_matrix) + q_jax, _ = jnp.linalg.qr(cov_matrix) + aaae(q_gpu, q_jax) + + +def test_r(cov_matrix): + _, r_gpu = qr_gpu(cov_matrix) + _, r_jax = jnp.linalg.qr(cov_matrix) + aaae(r_gpu, r_jax) + + +def test_grad_qr(cov_matrix): + def f_jax(a): + q, r = jnp.linalg.qr(a) + return jnp.sum(r) + jnp.sum(q) + + def f_gpu(a): + q, r = qr_gpu(a) + return jnp.sum(r) + jnp.sum(q) + + grad_qr_jax = jax.grad(f_jax) + grad_qr_gpu = jax.grad(f_gpu) + grad_gpu = grad_qr_gpu(cov_matrix) + grad_jax = grad_qr_jax(cov_matrix) + aaae(grad_gpu, grad_jax)