diff --git a/pyproject.toml b/pyproject.toml index 8fc1d9b..38bb04d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", "License :: OSI Approved :: Apache Software License", ] -dependencies = ["jax>=0.4.0", "jaxlib>=0.4.0", "fmmax>=0.8.0"] +dependencies = ["jax>=0.4.0", "jaxlib>=0.4.0", "fmmax>=1.0.0"] [project.optional-dependencies] test = ["pre-commit", "pytest-cov", "ruff", "optax", "mypy"] @@ -25,4 +25,4 @@ packages = ["thermox"] [tool.ruff] [tool.ruff.lint.per-file-ignores] -"__init__.py" = ["F401", "F821"] \ No newline at end of file +"__init__.py" = ["F401", "F821"] diff --git a/thermox/utils.py b/thermox/utils.py index c08e8a6..c47e038 100644 --- a/thermox/utils.py +++ b/thermox/utils.py @@ -1,7 +1,7 @@ from typing import NamedTuple, Tuple from jax import numpy as jnp from jax import Array -from fmmax.utils import ( +from fmmax.eig import ( eig, ) # differentiable and jit-able eigendecomposition, not yet available in jax, see https://github.com/google/jax/issues/2748