Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/skillmodels/kalman_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import jax
import jax.numpy as jnp

array_qr_jax = jax.vmap(jax.vmap(jnp.linalg.qr))
from skillmodels.qr import qr_gpu

array_qr_jax = (
jax.vmap(jax.vmap(qr_gpu))
if jax.default_backend() == "gpu"
else jax.vmap(jax.vmap(jnp.linalg.qr))
)

# ======================================================================================
# Update Step
Expand Down
72 changes: 72 additions & 0 deletions src/skillmodels/qr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import jax
import jax.numpy as jnp


@jax.custom_jvp
def qr_gpu(a: jax.Array):
"""Custom implementation of the QR Decomposition."""
r, tau = jnp.linalg.qr(a, mode="raw")

q = _householder(r.mT, tau)
return q, jnp.triu(r.mT[: tau.shape[0]])


def _householder(r: jax.Array, tau: jax.Array):
"""Custom implementation of the Householder Product.

Uses the outputs of jnp.linalg.qr with mode = "raw" to calculate Q. This is needed
because the JAX implementation is extremely slow for a batch of small matrices.
"""
m = r.shape[0]
n = tau.shape[0]
# Calculate Householder Vector which is saved in the lower triangle of R
v1 = jnp.expand_dims(r[:, 0], 1)
v1 = v1.at[0:0].set(0)
v1 = v1.at[0].set(1)
h = jnp.eye(m) - tau[0] * (v1 @ jnp.transpose(v1))
# Multiply all Householder Vectors Q = H(1)*H(2)...*H(n)
for i in range(1, n):
vi = jnp.expand_dims(r[:, i], 1)
vi = vi.at[0:i].set(0)
vi = vi.at[i].set(1)
h = h @ (jnp.eye(m) - tau[i] * (vi @ jnp.transpose(vi)))
return h[:, :n]


def _t(x: jax.Array) -> jax.Array:
"""Transpose batched Matrix."""
return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))


def _h(x: jax.Array) -> jax.Array:
"""Hermitian Transpose of a Matrix."""
return _t(x).conj()


def _tril(m: jax.Array, k: int = 0) -> jax.Array:
"""Select lower Triangle of a Matrix."""
*_, dim_n, dim_m = m.shape
mask = jnp.tri(dim_n, dim_m, k, bool)
return jax.lax.select(
jax.lax.broadcast(mask, m.shape[:-2]), m, jax.lax.zeros_like_array(m)
)


@qr_gpu.defjvp
def qr_jvp_rule(primals, tangents):
"""Calculates the derivative of the custom QR composition."""
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
(x,) = primals
(dx,) = tangents
q, r = qr_gpu(x)
*_, m, n = x.shape
dx_rinv = jax.lax.linalg.triangular_solve(r, dx) # Right side solve by default
qt_dx_rinv = _h(q) @ dx_rinv
qt_dx_rinv_lower = _tril(qt_dx_rinv, -1)
do = qt_dx_rinv_lower - _h(qt_dx_rinv_lower) # This is skew-symmetric
# The following correction is necessary for complex inputs
i = jax.lax.expand_dims(jnp.eye(n, n), range(qt_dx_rinv.ndim - 2))
do = do + i * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype))
dq = q @ (do - qt_dx_rinv) + dx_rinv
dr = (qt_dx_rinv - do) @ r
return (q, r), (dq, dr)
45 changes: 45 additions & 0 deletions tests/test_qr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal as aaae

from skillmodels.qr import qr_gpu

SEED = 20


@pytest.fixture
def cov_matrix():
fixedrng = np.random.default_rng(SEED)
factorized = fixedrng.uniform(low=-1, high=3, size=(7, 7))
cov = factorized @ factorized.T * 0.5 + np.eye(7)
return cov


def test_q(cov_matrix):
q_gpu, _ = qr_gpu(cov_matrix)
q_jax, _ = jnp.linalg.qr(cov_matrix)
aaae(q_gpu, q_jax)


def test_r(cov_matrix):
_, r_gpu = qr_gpu(cov_matrix)
_, r_jax = jnp.linalg.qr(cov_matrix)
aaae(r_gpu, r_jax)


def test_grad_qr(cov_matrix):
def f_jax(a):
q, r = jnp.linalg.qr(a)
return jnp.sum(r) + jnp.sum(q)

def f_gpu(a):
q, r = qr_gpu(a)
return jnp.sum(r) + jnp.sum(q)

grad_qr_jax = jax.grad(f_jax)
grad_qr_gpu = jax.grad(f_gpu)
grad_gpu = grad_qr_gpu(cov_matrix)
grad_jax = grad_qr_jax(cov_matrix)
aaae(grad_gpu, grad_jax)
Loading