From 8b1903327b37cac0c278de3cd230c3b84b2292bb Mon Sep 17 00:00:00 2001 From: mj023 Date: Tue, 11 Mar 2025 20:30:11 +0100 Subject: [PATCH 1/6] Implement QR Factorization --- src/skillmodels/kalman_filters.py | 64 +++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/src/skillmodels/kalman_filters.py b/src/skillmodels/kalman_filters.py index 37d22ef..7418a6d 100644 --- a/src/skillmodels/kalman_filters.py +++ b/src/skillmodels/kalman_filters.py @@ -1,9 +1,66 @@ import functools - import jax import jax.numpy as jnp -array_qr_jax = jax.vmap(jax.vmap(jnp.linalg.qr)) + +@jax.custom_jvp +def qr(A): + """Custom implementation of the QR Decomposition""" + r,tau = jnp.linalg.qr(A, mode='raw') + + q = _householder(r.mT,tau) + return q,jnp.triu(r.mT[:tau.shape[0]]) + + +def _householder(r,tau): + """Custom implementation of the Householder Product to calculate Q from the outputs of + jnp.linalg.qr with mode = "raw". This is needed because the JAX implementation is extremely slow + for a batch of small matrices. + """ + m = r.shape[0] + n = tau.shape[0] + v1 = jnp.expand_dims(r[:,0], 1) + v1 = v1.at[0:0].set(0) + v1 = v1.at[0].set(1) + H = jnp.eye(m) - tau[0] * (v1 @ jnp.transpose(v1)) + for i in range(1, n): + vi = jnp.expand_dims(r[:,i], 1) + vi = vi.at[0:i].set(0) + vi = vi.at[i].set(1) + H = H @ (jnp.eye(m) - tau[i] * (vi @ jnp.transpose(vi))) + return H[:,:n] + +def _T(x: jax.Array) -> jax.Array: + return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) + +def _H(x: jax.Array) -> jax.Array: + return _T(x).conj() + +def _tril(m: jax.Array, k:int = 0) -> jax.Array: + *_, N, M = m.shape + mask = jnp.tri(N, M, k,bool) + return jax.lax.select(jax.lax.broadcast(mask, m.shape[:-2]), m, jax.lax.zeros_like_array(m)) + +@qr.defjvp +def qr_jvp_rule(primals, tangents): + """Calculates the derivative of the custom QR composition.""" + # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. + x, = primals + dx, = tangents + q, r, = qr(x) + *_, m, n = x.shape + dx_rinv = jax.lax.linalg.triangular_solve(r, dx) # Right side solve by default + qt_dx_rinv = _H(q) @ dx_rinv + qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) + do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric + # The following correction is necessary for complex inputs + I = jax.lax.expand_dims(jnp.eye(n, n), range(qt_dx_rinv.ndim - 2)) + do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) + dq = q @ (do - qt_dx_rinv) + dx_rinv + dr = (qt_dx_rinv - do) @ r + return (q, r), (dq, dr) + +array_qr_jax = jax.vmap(jax.vmap(qr)) # ====================================================================================== @@ -11,7 +68,7 @@ # ====================================================================================== -@functools.partial(jax.checkpoint, prevent_cse=False) + def kalman_update( states, upper_chols, @@ -227,6 +284,7 @@ def kalman_predict( return predicted_states, predicted_covs + @functools.partial(jax.checkpoint, prevent_cse=False) def _calculate_sigma_points(states, upper_chols, scaling_factor, observed_factors): """Calculate the array of sigma_points for the unscented transform. From a003a89dba8c922954ab6b900b83db5b49ccca04 Mon Sep 17 00:00:00 2001 From: mj023 Date: Thu, 13 Mar 2025 16:36:10 +0100 Subject: [PATCH 2/6] Move QR and add tests --- src/skillmodels/kalman_filters.py | 63 ++----------------------------- src/skillmodels/qr.py | 63 +++++++++++++++++++++++++++++++ tests/test_qr.py | 25 ++++++++++++ 3 files changed, 91 insertions(+), 60 deletions(-) create mode 100644 src/skillmodels/qr.py create mode 100644 tests/test_qr.py diff --git a/src/skillmodels/kalman_filters.py b/src/skillmodels/kalman_filters.py index 7418a6d..fd7dbc6 100644 --- a/src/skillmodels/kalman_filters.py +++ b/src/skillmodels/kalman_filters.py @@ -1,74 +1,17 @@ import functools import jax import jax.numpy as jnp +from skillmodels.qr import qr_gpu -@jax.custom_jvp -def qr(A): - """Custom implementation of the QR Decomposition""" - r,tau = jnp.linalg.qr(A, mode='raw') - - q = _householder(r.mT,tau) - return q,jnp.triu(r.mT[:tau.shape[0]]) - - -def _householder(r,tau): - """Custom implementation of the Householder Product to calculate Q from the outputs of - jnp.linalg.qr with mode = "raw". This is needed because the JAX implementation is extremely slow - for a batch of small matrices. - """ - m = r.shape[0] - n = tau.shape[0] - v1 = jnp.expand_dims(r[:,0], 1) - v1 = v1.at[0:0].set(0) - v1 = v1.at[0].set(1) - H = jnp.eye(m) - tau[0] * (v1 @ jnp.transpose(v1)) - for i in range(1, n): - vi = jnp.expand_dims(r[:,i], 1) - vi = vi.at[0:i].set(0) - vi = vi.at[i].set(1) - H = H @ (jnp.eye(m) - tau[i] * (vi @ jnp.transpose(vi))) - return H[:,:n] - -def _T(x: jax.Array) -> jax.Array: - return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) - -def _H(x: jax.Array) -> jax.Array: - return _T(x).conj() - -def _tril(m: jax.Array, k:int = 0) -> jax.Array: - *_, N, M = m.shape - mask = jnp.tri(N, M, k,bool) - return jax.lax.select(jax.lax.broadcast(mask, m.shape[:-2]), m, jax.lax.zeros_like_array(m)) - -@qr.defjvp -def qr_jvp_rule(primals, tangents): - """Calculates the derivative of the custom QR composition.""" - # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. - x, = primals - dx, = tangents - q, r, = qr(x) - *_, m, n = x.shape - dx_rinv = jax.lax.linalg.triangular_solve(r, dx) # Right side solve by default - qt_dx_rinv = _H(q) @ dx_rinv - qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) - do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric - # The following correction is necessary for complex inputs - I = jax.lax.expand_dims(jnp.eye(n, n), range(qt_dx_rinv.ndim - 2)) - do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) - dq = q @ (do - qt_dx_rinv) + dx_rinv - dr = (qt_dx_rinv - do) @ r - return (q, r), (dq, dr) - -array_qr_jax = jax.vmap(jax.vmap(qr)) - +array_qr_jax = jax.vmap(jax.vmap(qr_gpu)) if jax.default_backend()=='gpu' else jax.vmap(jax.vmap(jnp.linalg.qr)) # ====================================================================================== # Update Step # ====================================================================================== - +@functools.partial(jax.checkpoint, prevent_cse=False) def kalman_update( states, upper_chols, diff --git a/src/skillmodels/qr.py b/src/skillmodels/qr.py new file mode 100644 index 0000000..ef51f82 --- /dev/null +++ b/src/skillmodels/qr.py @@ -0,0 +1,63 @@ +import jax +import jax.numpy as jnp + +@jax.custom_jvp +def qr_gpu(A: jax.Array): + """Custom implementation of the QR Decomposition""" + r,tau = jnp.linalg.qr(A, mode='raw') + + q = _householder(r.mT,tau) + return q,jnp.triu(r.mT[:tau.shape[0]]) + + +def _householder(r: jax.Array,tau: jax.Array): + """Custom implementation of the Householder Product to calculate Q from the outputs of + jnp.linalg.qr with mode = "raw". This is needed because the JAX implementation is extremely slow + for a batch of small matrices. + """ + m = r.shape[0] + n = tau.shape[0] + # Calculate Householder Vector which is saved in the lower triangle + v1 = jnp.expand_dims(r[:,0], 1) + v1 = v1.at[0:0].set(0) + v1 = v1.at[0].set(1) + H = jnp.eye(m) - tau[0] * (v1 @ jnp.transpose(v1)) + for i in range(1, n): + vi = jnp.expand_dims(r[:,i], 1) + vi = vi.at[0:i].set(0) + vi = vi.at[i].set(1) + H = H @ (jnp.eye(m) - tau[i] * (vi @ jnp.transpose(vi))) + return H[:,:n] + +def _T(x: jax.Array) -> jax.Array: + """Transpose batched Matrix.""" + return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) + +def _H(x: jax.Array) -> jax.Array: + """Hermitian Transpose of a Matrix.""" + return _T(x).conj() + +def _tril(m: jax.Array, k:int = 0) -> jax.Array: + """Select lower Triangle of a Matrix.""" + *_, N, M = m.shape + mask = jnp.tri(N, M, k,bool) + return jax.lax.select(jax.lax.broadcast(mask, m.shape[:-2]), m, jax.lax.zeros_like_array(m)) + +@qr_gpu.defjvp +def qr_jvp_rule(primals, tangents): + """Calculates the derivative of the custom QR composition.""" + # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. + x, = primals + dx, = tangents + q, r, = qr_gpu(x) + *_, m, n = x.shape + dx_rinv = jax.lax.linalg.triangular_solve(r, dx) # Right side solve by default + qt_dx_rinv = _H(q) @ dx_rinv + qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) + do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric + # The following correction is necessary for complex inputs + I = jax.lax.expand_dims(jnp.eye(n, n), range(qt_dx_rinv.ndim - 2)) + do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) + dq = q @ (do - qt_dx_rinv) + dx_rinv + dr = (qt_dx_rinv - do) @ r + return (q, r), (dq, dr) \ No newline at end of file diff --git a/tests/test_qr.py b/tests/test_qr.py new file mode 100644 index 0000000..10b5805 --- /dev/null +++ b/tests/test_qr.py @@ -0,0 +1,25 @@ +import jax +import jax.numpy as jnp +import numpy as np +from numpy.testing import assert_array_almost_equal as aaae +from skillmodels.qr import qr_gpu + +def test_qr(): + factorized = np.random.uniform(low=-1, high=3, size=(7, 7)) + cov = factorized @ factorized.T * 0.5 + np.eye(7) + q_gpu, r_gpu = qr_gpu(cov) + q_jax, r_jax = jnp.linalg.qr(cov) + def f_jax(A): + q,r = jnp.linalg.qr(A) + return jnp.sum(r) + jnp.sum(q) + def f_gpu(A): + q,r = qr_gpu(A) + return jnp.sum(r) + jnp.sum(q) + grad_qr_jax = jax.grad(f_jax) + grad_qr_gpu = jax.grad(f_gpu) + grad_gpu = grad_qr_gpu(cov) + grad_jax = grad_qr_jax(cov) + aaae(q_gpu, q_jax) + aaae(r_gpu, r_jax) + aaae(grad_gpu, grad_jax) + aaae(grad_gpu, grad_jax) \ No newline at end of file From f4701e4cd2fb5862d42554594e46e5b471dacf2a Mon Sep 17 00:00:00 2001 From: mj023 Date: Mon, 17 Mar 2025 16:19:37 +0100 Subject: [PATCH 3/6] Fix Seed, Split tests --- src/skillmodels/kalman_filters.py | 10 ++-- src/skillmodels/qr.py | 83 +++++++++++++++++-------------- tests/test_qr.py | 46 ++++++++++++----- 3 files changed, 86 insertions(+), 53 deletions(-) diff --git a/src/skillmodels/kalman_filters.py b/src/skillmodels/kalman_filters.py index fd7dbc6..defe35c 100644 --- a/src/skillmodels/kalman_filters.py +++ b/src/skillmodels/kalman_filters.py @@ -1,10 +1,15 @@ import functools + import jax import jax.numpy as jnp -from skillmodels.qr import qr_gpu +from skillmodels.qr import qr_gpu -array_qr_jax = jax.vmap(jax.vmap(qr_gpu)) if jax.default_backend()=='gpu' else jax.vmap(jax.vmap(jnp.linalg.qr)) +array_qr_jax = ( + jax.vmap(jax.vmap(qr_gpu)) + if jax.default_backend() == "gpu" + else jax.vmap(jax.vmap(jnp.linalg.qr)) +) # ====================================================================================== # Update Step @@ -227,7 +232,6 @@ def kalman_predict( return predicted_states, predicted_covs - @functools.partial(jax.checkpoint, prevent_cse=False) def _calculate_sigma_points(states, upper_chols, scaling_factor, observed_factors): """Calculate the array of sigma_points for the unscented transform. diff --git a/src/skillmodels/qr.py b/src/skillmodels/qr.py index ef51f82..7e95b8a 100644 --- a/src/skillmodels/qr.py +++ b/src/skillmodels/qr.py @@ -1,63 +1,72 @@ import jax import jax.numpy as jnp + @jax.custom_jvp -def qr_gpu(A: jax.Array): - """Custom implementation of the QR Decomposition""" - r,tau = jnp.linalg.qr(A, mode='raw') - - q = _householder(r.mT,tau) - return q,jnp.triu(r.mT[:tau.shape[0]]) - - -def _householder(r: jax.Array,tau: jax.Array): - """Custom implementation of the Householder Product to calculate Q from the outputs of - jnp.linalg.qr with mode = "raw". This is needed because the JAX implementation is extremely slow - for a batch of small matrices. +def qr_gpu(a: jax.Array): + """Custom implementation of the QR Decomposition.""" + r, tau = jnp.linalg.qr(a, mode="raw") + + q = _householder(r.mT, tau) + return q, jnp.triu(r.mT[: tau.shape[0]]) + + +def _householder(r: jax.Array, tau: jax.Array): + """Custom implementation of the Householder Product. + + Uses the outputs of jnp.linalg.qr with mode = "raw" to calculate Q. This is needed + because the JAX implementation is extremely slow for a batch of small matrices. """ m = r.shape[0] n = tau.shape[0] - # Calculate Householder Vector which is saved in the lower triangle - v1 = jnp.expand_dims(r[:,0], 1) + # Calculate Householder Vector which is saved in the lower triangle of R + v1 = jnp.expand_dims(r[:, 0], 1) v1 = v1.at[0:0].set(0) v1 = v1.at[0].set(1) - H = jnp.eye(m) - tau[0] * (v1 @ jnp.transpose(v1)) + h = jnp.eye(m) - tau[0] * (v1 @ jnp.transpose(v1)) + # Multiply all Householder Vectors Q = H(1)*H(2)...*H(n) for i in range(1, n): - vi = jnp.expand_dims(r[:,i], 1) + vi = jnp.expand_dims(r[:, i], 1) vi = vi.at[0:i].set(0) vi = vi.at[i].set(1) - H = H @ (jnp.eye(m) - tau[i] * (vi @ jnp.transpose(vi))) - return H[:,:n] + h = h @ (jnp.eye(m) - tau[i] * (vi @ jnp.transpose(vi))) + return h[:, :n] + + +def _t(x: jax.Array) -> jax.Array: + """Transpose batched Matrix.""" + return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) + + +def _h(x: jax.Array) -> jax.Array: + """Hermitian Transpose of a Matrix.""" + return _t(x).conj() -def _T(x: jax.Array) -> jax.Array: - """Transpose batched Matrix.""" - return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) -def _H(x: jax.Array) -> jax.Array: - """Hermitian Transpose of a Matrix.""" - return _T(x).conj() +def _tril(m: jax.Array, k: int = 0) -> jax.Array: + """Select lower Triangle of a Matrix.""" + *_, dim_n, dim_m = m.shape + mask = jnp.tri(dim_n, dim_m, k, bool) + return jax.lax.select( + jax.lax.broadcast(mask, m.shape[:-2]), m, jax.lax.zeros_like_array(m) + ) -def _tril(m: jax.Array, k:int = 0) -> jax.Array: - """Select lower Triangle of a Matrix.""" - *_, N, M = m.shape - mask = jnp.tri(N, M, k,bool) - return jax.lax.select(jax.lax.broadcast(mask, m.shape[:-2]), m, jax.lax.zeros_like_array(m)) @qr_gpu.defjvp def qr_jvp_rule(primals, tangents): """Calculates the derivative of the custom QR composition.""" # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. - x, = primals - dx, = tangents - q, r, = qr_gpu(x) + (x,) = primals + (dx,) = tangents + q, r = qr_gpu(x) *_, m, n = x.shape dx_rinv = jax.lax.linalg.triangular_solve(r, dx) # Right side solve by default - qt_dx_rinv = _H(q) @ dx_rinv + qt_dx_rinv = _h(q) @ dx_rinv qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) - do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric + do = qt_dx_rinv_lower - _h(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs - I = jax.lax.expand_dims(jnp.eye(n, n), range(qt_dx_rinv.ndim - 2)) - do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) + i = jax.lax.expand_dims(jnp.eye(n, n), range(qt_dx_rinv.ndim - 2)) + do = do + i * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) dq = q @ (do - qt_dx_rinv) + dx_rinv dr = (qt_dx_rinv - do) @ r - return (q, r), (dq, dr) \ No newline at end of file + return (q, r), (dq, dr) diff --git a/tests/test_qr.py b/tests/test_qr.py index 10b5805..5a35475 100644 --- a/tests/test_qr.py +++ b/tests/test_qr.py @@ -1,25 +1,45 @@ import jax import jax.numpy as jnp import numpy as np +import pytest from numpy.testing import assert_array_almost_equal as aaae + from skillmodels.qr import qr_gpu -def test_qr(): - factorized = np.random.uniform(low=-1, high=3, size=(7, 7)) +SEED = 20 + + +@pytest.fixture +def cov_matrix(): + fixedrng = np.random.default_rng(SEED) + factorized = fixedrng.uniform(low=-1, high=3, size=(7, 7)) cov = factorized @ factorized.T * 0.5 + np.eye(7) - q_gpu, r_gpu = qr_gpu(cov) - q_jax, r_jax = jnp.linalg.qr(cov) - def f_jax(A): - q,r = jnp.linalg.qr(A) + return cov + + +def test_q(cov_matrix): + q_gpu, _ = qr_gpu(cov_matrix) + q_jax, _ = jnp.linalg.qr(cov_matrix) + aaae(q_gpu, q_jax) + + +def test_r(cov_matrix): + _, r_gpu = qr_gpu(cov_matrix) + _, r_jax = jnp.linalg.qr(cov_matrix) + aaae(r_gpu, r_jax) + + +def test_grad_qr(cov_matrix): + def f_jax(a): + q, r = jnp.linalg.qr(a) return jnp.sum(r) + jnp.sum(q) - def f_gpu(A): - q,r = qr_gpu(A) + + def f_gpu(a): + q, r = qr_gpu(a) return jnp.sum(r) + jnp.sum(q) + grad_qr_jax = jax.grad(f_jax) grad_qr_gpu = jax.grad(f_gpu) - grad_gpu = grad_qr_gpu(cov) - grad_jax = grad_qr_jax(cov) - aaae(q_gpu, q_jax) - aaae(r_gpu, r_jax) + grad_gpu = grad_qr_gpu(cov_matrix) + grad_jax = grad_qr_jax(cov_matrix) aaae(grad_gpu, grad_jax) - aaae(grad_gpu, grad_jax) \ No newline at end of file From 720345a3ff01620e50521e488a5a712c679fce5c Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 17 Mar 2025 17:54:40 +0100 Subject: [PATCH 4/6] Run gpu tests on GHA. --- .github/workflows/main.yml | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0e11d88..cd37515 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - uses: actions/checkout@v4 - uses: prefix-dev/setup-pixi@v0.8.3 with: - pixi-version: v0.42.1 + pixi-version: v0.42.4 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: test-cpu @@ -41,6 +41,23 @@ jobs: uses: codecov/codecov-action@v4 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + run-tests-cuda: + name: Run tests with CUDA for ubuntu-latest on Python 3.13 + runs-on: ubuntu-latest + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v4 + - uses: prefix-dev/setup-pixi@v0.8.3 + with: + pixi-version: v0.42.4 + cache: true + cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} + environments: test-gpu + activate-environment: true + - name: Run pytest + shell: bash -l {0} + run: pixi run -e test-gpu # run-mypy: # name: Run mypy on Python 3.13 # runs-on: ubuntu-latest From 16a5f63ce9effc5b256dcf6dd65123a0c2dc84d2 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 17 Mar 2025 18:08:18 +0100 Subject: [PATCH 5/6] Had misremembered pixi version. --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cd37515..fa0da08 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - uses: actions/checkout@v4 - uses: prefix-dev/setup-pixi@v0.8.3 with: - pixi-version: v0.42.4 + pixi-version: v0.42.1 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: test-cpu @@ -50,7 +50,7 @@ jobs: - uses: actions/checkout@v4 - uses: prefix-dev/setup-pixi@v0.8.3 with: - pixi-version: v0.42.4 + pixi-version: v0.42.1 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: test-gpu From 20976e5834d31b9f3d6ea04b8ac4f0e2a2728288 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 17 Mar 2025 18:13:00 +0100 Subject: [PATCH 6/6] Revert the last set of changes, no GPU support on GHA... --- .github/workflows/main.yml | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fa0da08..0e11d88 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -41,23 +41,6 @@ jobs: uses: codecov/codecov-action@v4 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - run-tests-cuda: - name: Run tests with CUDA for ubuntu-latest on Python 3.13 - runs-on: ubuntu-latest - strategy: - fail-fast: false - steps: - - uses: actions/checkout@v4 - - uses: prefix-dev/setup-pixi@v0.8.3 - with: - pixi-version: v0.42.1 - cache: true - cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} - environments: test-gpu - activate-environment: true - - name: Run pytest - shell: bash -l {0} - run: pixi run -e test-gpu # run-mypy: # name: Run mypy on Python 3.13 # runs-on: ubuntu-latest