From c6ce0d9acd592fcdf00c014b998fc0223044d3cc Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 10 Jan 2025 15:47:30 +0100 Subject: [PATCH 1/5] tests are currently passing --- src/moscot/backends/ott/__init__.py | 3 +- src/moscot/backends/ott/solver.py | 213 +--------------- src/moscot/backends/utils.py | 17 +- src/moscot/neural/backends/__init__.py | 0 .../neural/backends/neural_ott/__init__.py | 3 + .../neural/backends/neural_ott/solver.py | 224 +++++++++++++++++ src/moscot/neural/base/problems/problem.py | 26 +- src/moscot/neural/data/__init__.py | 4 + .../neural/data/_distribution_collection.py | 229 ++++++++++++++++++ src/moscot/neural/data/_policy_loader.py | 169 +++++++++++++ .../neural/problems/generic/_generic.py | 15 +- src/moscot/utils/tagged_array.py | 187 +------------- .../test_conditional_neural_problem.py | 12 +- 13 files changed, 667 insertions(+), 435 deletions(-) create mode 100644 src/moscot/neural/backends/__init__.py create mode 100644 src/moscot/neural/backends/neural_ott/__init__.py create mode 100644 src/moscot/neural/backends/neural_ott/solver.py create mode 100644 src/moscot/neural/data/__init__.py create mode 100644 src/moscot/neural/data/_distribution_collection.py create mode 100644 src/moscot/neural/data/_policy_loader.py diff --git a/src/moscot/backends/ott/__init__.py b/src/moscot/backends/ott/__init__.py index 7fdae526c..19d9292d6 100644 --- a/src/moscot/backends/ott/__init__.py +++ b/src/moscot/backends/ott/__init__.py @@ -2,7 +2,7 @@ from moscot.backends.ott._utils import sinkhorn_divergence from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput -from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver +from moscot.backends.ott.solver import GWSolver, SinkhornSolver from moscot.costs import register_cost __all__ = [ @@ -11,7 +11,6 @@ "SinkhornSolver", "NeuralOutput", "sinkhorn_divergence", - "GENOTLinSolver", "GraphOTTOutput", ] diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index 50ca82fa4..513448e31 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -1,12 +1,9 @@ import abc -import functools import inspect -import math import types from typing import ( Any, Hashable, - List, Literal, Mapping, NamedTuple, @@ -17,21 +14,13 @@ Union, ) -import optax - import jax import jax.numpy as jnp -import numpy as np from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud -from ott.neural.datasets import OTData, OTDataset -from ott.neural.methods.flows import dynamics, genot -from ott.neural.networks.layers import time_encoder -from ott.neural.networks.velocity_field import VelocityField from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr -from ott.solvers.utils import uniform_sampler from moscot._logging import logger from moscot._types import ( @@ -43,23 +32,23 @@ ) from moscot.backends.ott._utils import ( InitializerResolver, - Loader, - MultiLoader, _instantiate_geodesic_cost, alpha_to_fused_penalty, check_shapes, convert_scipy_sparse, - data_match_fn, densify, ensure_2d, ) -from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput +from moscot.backends.ott.output import GraphOTTOutput, OTTOutput from moscot.base.problems._utils import TimeScalesHeatKernel from moscot.base.solver import OTSolver from moscot.costs import get_cost -from moscot.utils.tagged_array import DistributionCollection, TaggedArray +from moscot.utils.tagged_array import TaggedArray -__all__ = ["SinkhornSolver", "GWSolver", "GENOTLinSolver"] +__all__ = [ + "SinkhornSolver", + "GWSolver", +] OTTSolver_t = Union[ sinkhorn.Sinkhorn, @@ -516,193 +505,3 @@ def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: problem_kwargs -= {"geom_xx", "geom_yy", "geom_xy", "fused_penalty"} problem_kwargs |= {"alpha"} return geom_kwargs | problem_kwargs, {"epsilon"} - - -class GENOTLinSolver(OTSolver[OTTOutput]): - """Solver class for genot.GENOT linear :cite:`klein2023generative`.""" - - def __init__(self, **kwargs: Any) -> None: - """Initiate the class with any kwargs passed to the ott-jax class.""" - super().__init__() - self._train_sampler: Optional[MultiLoader] = None - self._valid_sampler: Optional[MultiLoader] = None - self._neural_kwargs = kwargs - - @property - def problem_kind(self) -> ProblemKind_t: # noqa: D102 - return "linear" - - def _prepare( # type: ignore[override] - self, - distributions: DistributionCollection[K], - sample_pairs: List[Tuple[Any, Any]], - train_size: float = 0.9, - batch_size: int = 1024, - is_conditional: bool = True, - **kwargs: Any, - ) -> Tuple[MultiLoader, MultiLoader]: - train_loaders = [] - validate_loaders = [] - seed = kwargs.get("seed") - is_aligned = kwargs.get("is_aligned", False) - if train_size == 1.0: - for sample_pair in sample_pairs: - source_key = sample_pair[0] - target_key = sample_pair[1] - src_data = OTData( - lin=distributions[source_key].xy, - condition=distributions[source_key].conditions if is_conditional else None, - ) - tgt_data = OTData( - lin=distributions[target_key].xy, - condition=distributions[target_key].conditions if is_conditional else None, - ) - dataset = OTDataset(src_data=src_data, tgt_data=tgt_data, seed=seed, is_aligned=is_aligned) - loader = Loader(dataset, batch_size=batch_size, seed=seed) - train_loaders.append(loader) - validate_loaders.append(loader) - else: - if train_size > 1.0 or train_size <= 0.0: - raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1") - - seed = kwargs.get("seed", 0) - for sample_pair in sample_pairs: - source_key = sample_pair[0] - target_key = sample_pair[1] - source_data: ArrayLike = distributions[source_key].xy - target_data: ArrayLike = distributions[target_key].xy - source_split_data = self._split_data( - source_data, - conditions=distributions[source_key].conditions, - train_size=train_size, - seed=seed, - a=distributions[source_key].a, - b=distributions[source_key].b, - ) - target_split_data = self._split_data( - target_data, - conditions=distributions[target_key].conditions, - train_size=train_size, - seed=seed, - a=distributions[target_key].a, - b=distributions[target_key].b, - ) - src_data_train = OTData( - lin=source_split_data.data_train, - condition=source_split_data.conditions_train if is_conditional else None, - ) - tgt_data_train = OTData( - lin=target_split_data.data_train, - condition=target_split_data.conditions_train if is_conditional else None, - ) - train_dataset = OTDataset( - src_data=src_data_train, tgt_data=tgt_data_train, seed=seed, is_aligned=is_aligned - ) - train_loader = Loader(train_dataset, batch_size=batch_size, seed=seed) - src_data_validate = OTData( - lin=source_split_data.data_valid, - condition=source_split_data.conditions_valid if is_conditional else None, - ) - tgt_data_validate = OTData( - lin=target_split_data.data_valid, - condition=target_split_data.conditions_valid if is_conditional else None, - ) - validate_dataset = OTDataset( - src_data=src_data_validate, tgt_data=tgt_data_validate, seed=seed, is_aligned=is_aligned - ) - validate_loader = Loader(validate_dataset, batch_size=batch_size, seed=seed) - train_loaders.append(train_loader) - validate_loaders.append(validate_loader) - source_dim = self._neural_kwargs.get("input_dim", 0) - target_dim = source_dim - condition_dim = self._neural_kwargs.get("cond_dim", 0) - # TODO(ilan-gold): What are reasonable defaults here? - neural_vf = VelocityField( - output_dims=[*self._neural_kwargs.get("velocity_field_output_dims", []), target_dim], - condition_dims=( - self._neural_kwargs.get("velocity_field_condition_dims", [source_dim + condition_dim]) - if is_conditional - else None - ), - hidden_dims=self._neural_kwargs.get("velocity_field_hidden_dims", [1024, 1024, 1024]), - time_dims=self._neural_kwargs.get("velocity_field_time_dims", None), - time_encoder=self._neural_kwargs.get( - "velocity_field_time_encoder", functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024) - ), - ) - seed = self._neural_kwargs.get("seed", 0) - rng = jax.random.PRNGKey(seed) - data_match_fn_kwargs = self._neural_kwargs.get( - "data_match_fn_kwargs", - {} if "data_match_fn" in self._neural_kwargs else {"epsilon": 1e-1, "tau_a": 1.0, "tau_b": 1.0}, - ) - time_sampler = self._neural_kwargs.get("time_sampler", uniform_sampler) - optimizer = self._neural_kwargs.get("optimizer", optax.adam(learning_rate=1e-4)) - self._solver = genot.GENOT( - vf=neural_vf, - flow=self._neural_kwargs.get( - "flow", - dynamics.ConstantNoiseFlow(0.1), - ), - data_match_fn=functools.partial( - self._neural_kwargs.get("data_match_fn", data_match_fn), typ="lin", **data_match_fn_kwargs - ), - source_dim=source_dim, - target_dim=target_dim, - condition_dim=condition_dim if is_conditional else None, - optimizer=optimizer, - time_sampler=time_sampler, - rng=rng, - latent_noise_fn=self._neural_kwargs.get("latent_noise_fn", None), - **self._neural_kwargs.get("velocity_field_train_state_kwargs", {}), - ) - return ( - MultiLoader(datasets=train_loaders, seed=seed), - MultiLoader(datasets=validate_loaders, seed=seed), - ) - - def _split_data( # TODO: adapt for Gromov terms - self, - x: ArrayLike, - conditions: Optional[ArrayLike], - train_size: float, - seed: int, - a: Optional[ArrayLike] = None, - b: Optional[ArrayLike] = None, - ) -> SingleDistributionData: - n_samples_x = x.shape[0] - n_train_x = math.ceil(train_size * n_samples_x) - rng = np.random.default_rng(seed) - x = rng.permutation(x) - if a is not None: - a = rng.permutation(a) - if b is not None: - b = rng.permutation(b) - - return SingleDistributionData( - data_train=x[:n_train_x], - data_valid=x[n_train_x:], - conditions_train=conditions[:n_train_x] if conditions is not None else None, - conditions_valid=conditions[n_train_x:] if conditions is not None else None, - a_train=a[:n_train_x] if a is not None else None, - a_valid=a[n_train_x:] if a is not None else None, - b_train=b[:n_train_x] if b is not None else None, - b_valid=b[n_train_x:] if b is not None else None, - ) - - @property - def solver(self) -> genot.GENOT: - """Underlying optimal transport solver.""" - return self._solver - - @classmethod - def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: - return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value] - - def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> NeuralOutput: # type: ignore[override] - seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests - rng = jax.random.PRNGKey(seed) - logs = self.solver( - data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng - ) # TODO(ilan-gold): validation and figure out defualts - return NeuralOutput(self.solver, logs) diff --git a/src/moscot/backends/utils.py b/src/moscot/backends/utils.py index 988e05413..f603f901c 100644 --- a/src/moscot/backends/utils.py +++ b/src/moscot/backends/utils.py @@ -5,12 +5,13 @@ if TYPE_CHECKING: from moscot.backends import ott + from moscot.neural.backends import neural_ott __all__ = ["get_solver", "register_solver", "get_available_backends"] register_solver_t = Callable[ - [Literal["linear", "quadratic"], Optional[Literal["GENOTLinSolver"]]], - Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"], + [Literal["linear", "quadratic"], Optional[Literal["GENOTSolver"]]], + Union["ott.SinkhornSolver", "ott.GWSolver", "neural_ott.GENOTSolver"], ] @@ -27,7 +28,7 @@ def get_solver(problem_kind: ProblemKind_t, *, backend: str = "ott", return_clas def register_solver( backend: str, -) -> Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"]: +) -> Union["ott.SinkhornSolver", "ott.GWSolver", "neural_ott.GENOTSolver"]: """Register a solver for a specific backend. Parameters @@ -45,13 +46,15 @@ def register_solver( @register_solver("ott") def _( problem_kind: Literal["linear", "quadratic"], - solver_name: Optional[Literal["GENOTLinSolver"]] = None, -) -> Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"]: + solver_name: Optional[Literal["GENOTSolver"]] = None, +) -> Union["ott.SinkhornSolver", "ott.GWSolver", "neural_ott.GENOTSolver"]: from moscot.backends import ott if problem_kind == "linear": - if solver_name == "GENOTLinSolver": - return ott.GENOTLinSolver # type: ignore[return-value] + if solver_name == "GENOTSolver": + from moscot.neural.backends import neural_ott + + return neural_ott.GENOTSolver # type: ignore[return-value] if solver_name is None: return ott.SinkhornSolver # type: ignore[return-value] if problem_kind == "quadratic": diff --git a/src/moscot/neural/backends/__init__.py b/src/moscot/neural/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/moscot/neural/backends/neural_ott/__init__.py b/src/moscot/neural/backends/neural_ott/__init__.py new file mode 100644 index 000000000..b96a92f96 --- /dev/null +++ b/src/moscot/neural/backends/neural_ott/__init__.py @@ -0,0 +1,3 @@ +from moscot.neural.backends.neural_ott.solver import GENOTSolver + +__all__ = ["GENOTSolver"] \ No newline at end of file diff --git a/src/moscot/neural/backends/neural_ott/solver.py b/src/moscot/neural/backends/neural_ott/solver.py new file mode 100644 index 000000000..9b7b312c5 --- /dev/null +++ b/src/moscot/neural/backends/neural_ott/solver.py @@ -0,0 +1,224 @@ +import abc +import functools +import inspect +import math +import types +from typing import ( + Any, + Hashable, + List, + Literal, + Mapping, + NamedTuple, + Optional, + Set, + Tuple, + TypeVar, + Union, +) + +import optax + +import jax +import jax.numpy as jnp +import numpy as np +from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud +from ott.neural.datasets import OTData, OTDataset +from ott.neural.methods.flows import dynamics, genot +from ott.neural.networks.layers import time_encoder +from ott.neural.networks.velocity_field import VelocityField +from ott.problems.linear import linear_problem +from ott.problems.quadratic import quadratic_problem +from ott.solvers.linear import sinkhorn, sinkhorn_lr +from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr +from ott.solvers.utils import uniform_sampler +from moscot.neural.data import PolicyDataLoader + +from moscot._logging import logger +from moscot._types import ( + ArrayLike, + LRInitializer_t, + ProblemKind_t, + QuadInitializer_t, + SinkhornInitializer_t, +) +from moscot.backends.ott._utils import ( + InitializerResolver, + Loader, + MultiLoader, + _instantiate_geodesic_cost, + alpha_to_fused_penalty, + check_shapes, + convert_scipy_sparse, + data_match_fn, + densify, + ensure_2d, +) +from moscot.utils.subset_policy import SubsetPolicy +from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput +from moscot.base.problems._utils import TimeScalesHeatKernel +from moscot.base.solver import BaseSolver +from moscot.neural.data import DistributionCollection +from typing import TypeVar + +K = TypeVar("K", bound=Hashable) + + +__all__ = ["GENOTSolver"] + + +def _split_data( + rng: jax.random.PRNGKey, + distributions: DistributionCollection, + train_size: float = 0.9, +) -> Any: + train_collection: DistributionCollection = {} + val_collection: DistributionCollection = {} + for key, dist in distributions.items(): + n_train_x = math.ceil(train_size * dist.n_samples) + idxs = jax.random.permutation(rng, jnp.arange(dist.n_samples)) + train_collection[key] = dist[idxs[:n_train_x]] + val_collection[key] = dist[idxs[n_train_x:]] + return train_collection, val_collection + + +class GENOTSolver(BaseSolver[NeuralOutput]): + """Solver class for genot.GENOT linear :cite:`klein2023generative`.""" + + def __init__(self, **kwargs: Any) -> None: + """Initiate the class with any kwargs passed to the ott-jax class.""" + super().__init__() + self._neural_kwargs = kwargs + + @property + def problem_kind(self) -> ProblemKind_t: # noqa: D102 + return "linear" + + def _prepare( # type: ignore[override] + self, + distributions: DistributionCollection[K], + policy: SubsetPolicy[K], + train_size: float = 0.9, + batch_size: int = 128, + is_conditional: bool = False, + seed: int = 0, + device: Any = None, + ) -> Tuple[PolicyDataLoader, PolicyDataLoader]: + del device # TODO: ignore for now, but we should handle this properly + + if train_size > 1.0 or train_size <= 0.0: + raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1") + + rng = jax.random.PRNGKey(seed) + + src_renames = tgt_renames = { + "xy": "lin", + "xx": "quad", + } + + if train_size == 1.0: + train_rng, valid_rng, rng = jax.random.split(rng, 3) + train_loader = PolicyDataLoader( + rng=train_rng, + policy=policy, + distributions=distributions, + batch_size=batch_size, + plan=policy.plan(), + src_renames=src_renames, + tgt_renames=tgt_renames, + ) + validate_loader = PolicyDataLoader( + rng=valid_rng, + policy=policy, + distributions=distributions, + batch_size=batch_size, + plan=policy.plan(), + src_renames=src_renames, + tgt_renames=tgt_renames, + ) + + else: + train_rng, valid_rng, split_rng, rng = jax.random.split(rng, 4) + train_dist, valid_dist = _split_data(split_rng, distributions, train_size=train_size) + train_loader = PolicyDataLoader( + rng=train_rng, + policy=policy, + distributions=train_dist, + batch_size=batch_size, + plan=policy.plan(), + src_renames=src_renames, + tgt_renames=tgt_renames, + ) + validate_loader = PolicyDataLoader( + rng=valid_rng, + policy=policy, + distributions=valid_dist, + batch_size=batch_size, + plan=policy.plan(), + src_renames=src_renames, + tgt_renames=tgt_renames, + ) + self.train_loader = train_loader + self.validate_loader = validate_loader + source_dim = self._neural_kwargs.get("input_dim", 0) + target_dim = source_dim + condition_dim = self._neural_kwargs.get("cond_dim", 0) + # TODO(ilan-gold): What are reasonable defaults here? + neural_vf = VelocityField( + output_dims=[*self._neural_kwargs.get("velocity_field_output_dims", []), target_dim], + condition_dims=( + self._neural_kwargs.get("velocity_field_condition_dims", [source_dim + condition_dim]) + if is_conditional + else None + ), + hidden_dims=self._neural_kwargs.get("velocity_field_hidden_dims", [1024, 1024, 1024]), + time_dims=self._neural_kwargs.get("velocity_field_time_dims", None), + time_encoder=self._neural_kwargs.get( + "velocity_field_time_encoder", functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024) + ), + ) + seed = self._neural_kwargs.get("seed", 0) + rng = jax.random.PRNGKey(seed) + data_match_fn_kwargs = self._neural_kwargs.get( + "data_match_fn_kwargs", + {} if "data_match_fn" in self._neural_kwargs else {"epsilon": 1e-1, "tau_a": 1.0, "tau_b": 1.0}, + ) + time_sampler = self._neural_kwargs.get("time_sampler", uniform_sampler) + optimizer = self._neural_kwargs.get("optimizer", optax.adam(learning_rate=1e-4)) + self._solver = genot.GENOT( + vf=neural_vf, + flow=self._neural_kwargs.get( + "flow", + dynamics.ConstantNoiseFlow(0.1), + ), + data_match_fn=functools.partial( + self._neural_kwargs.get("data_match_fn", data_match_fn), typ="lin", **data_match_fn_kwargs + ), + source_dim=source_dim, + target_dim=target_dim, + condition_dim=condition_dim if is_conditional else None, + optimizer=optimizer, + time_sampler=time_sampler, + rng=rng, + latent_noise_fn=self._neural_kwargs.get("latent_noise_fn", None), + **self._neural_kwargs.get("velocity_field_train_state_kwargs", {}), + ) + return train_loader, validate_loader + + + @property + def solver(self) -> genot.GENOT: + """Underlying optimal transport solver.""" + return self._solver + + @classmethod + def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: + return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value] + + def _solve(self, data_samplers: Tuple[PolicyDataLoader, PolicyDataLoader]) -> NeuralOutput: # type: ignore[override] + seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests + rng = jax.random.PRNGKey(seed) + logs = self.solver( + data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng + ) # TODO(ilan-gold): validation and figure out defualts + return NeuralOutput(self.solver, logs) diff --git a/src/moscot/neural/base/problems/problem.py b/src/moscot/neural/base/problems/problem.py index cc142f989..b68ed14d3 100644 --- a/src/moscot/neural/base/problems/problem.py +++ b/src/moscot/neural/base/problems/problem.py @@ -23,6 +23,7 @@ from moscot.base.problems._utils import wrap_prepare, wrap_solve from moscot.base.problems.problem import BaseProblem from moscot.base.solver import OTSolver +from moscot.neural.data import DistributionCollection, DistributionContainer from moscot.utils.subset_policy import ( # type:ignore[attr-defined] ExplicitPolicy, Policy_t, @@ -30,7 +31,6 @@ SubsetPolicy, create_policy, ) -from moscot.utils.tagged_array import DistributionCollection, DistributionContainer K = TypeVar("K", bound=Hashable) @@ -59,7 +59,6 @@ def __init__( self._distributions: Optional[DistributionCollection[K]] = None # type: ignore[valid-type] self._policy: Optional[SubsetPolicy[Any]] = None - self._sample_pairs: Optional[List[Tuple[Any, Any]]] = None self._solver: Optional[OTSolver[BaseNeuralOutput]] = None self._solution: Optional[BaseNeuralOutput] = None @@ -75,11 +74,9 @@ def prepare( xy: Mapping[str, Any], xx: Mapping[str, Any], conditions: Mapping[str, Any], - a: Optional[str] = None, - b: Optional[str] = None, subset: Optional[Sequence[Tuple[K, K]]] = None, + seed: int = 0, reference: K = None, - **kwargs: Any, ) -> "NeuralOTProblem": """Prepare conditional optimal transport problem. @@ -92,10 +89,6 @@ def prepare( Policy defining which pairs of distributions to sample from during training. policy_key %(key)s - a - Source marginals. - b - Target marginals. kwargs Keyword arguments when creating the source/target marginals. @@ -105,6 +98,7 @@ def prepare( Self and modifies the following attributes: TODO. """ + self._seed = seed self._problem_kind = "linear" self._distributions = DistributionCollection() self._solution = None @@ -121,14 +115,12 @@ def prepare( self._policy = self._policy.create_graph(reference=reference) else: _ = self.policy.create_graph() # type: ignore[union-attr] - self._sample_pairs = list(self.policy._graph) # type: ignore[union-attr] for el in self.policy.categories: # type: ignore[union-attr] adata_masked = self.adata[self._create_mask(el)] - a_created = self._create_marginals(adata_masked, data=a, source=True, **kwargs) - b_created = self._create_marginals(adata_masked, data=b, source=False, **kwargs) + # TODO: Marginals self.distributions[el] = DistributionContainer.from_adata( # type: ignore[index] - adata_masked, a=a_created, b=b_created, **xy, **xx, **conditions + adata_masked, **xy, **xx, **conditions ) return self @@ -136,7 +128,7 @@ def prepare( def solve( self, backend: Literal["ott"] = "ott", - solver_name: Literal["GENOTLinSolver"] = "GENOTLinSolver", + solver_name: Literal["GENOTSolver"] = "GENOTSolver", device: Optional[Device_t] = None, **kwargs: Any, ) -> "NeuralOTProblem": @@ -167,18 +159,16 @@ def solve( ) init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) self._solver = solver_class(input_dim=input_dim, cond_dim=cond_dim, **init_kwargs) - # note that the solver call consists of solver._prepare and solver._solve - sample_pairs = self._sample_pairs if self._sample_pairs is not None else [] self._solution = self._solver( # type: ignore[misc] device=device, distributions=self.distributions, - sample_pairs=self._sample_pairs, - is_conditional=len(sample_pairs) > 1, + policy=self.policy, **call_kwargs, ) return self + # TODO: Marginals def _create_marginals( self, adata: AnnData, *, source: bool, data: Optional[str] = None, **kwargs: Any ) -> ArrayLike: diff --git a/src/moscot/neural/data/__init__.py b/src/moscot/neural/data/__init__.py new file mode 100644 index 000000000..83fa7ad2f --- /dev/null +++ b/src/moscot/neural/data/__init__.py @@ -0,0 +1,4 @@ +from moscot.neural.data._policy_loader import PolicyDataLoader +from moscot.neural.data._distribution_collection import DistributionCollection, DistributionContainer + +__all__ = ["PolicyDataLoader", "DistributionCollection", "DistributionContainer"] diff --git a/src/moscot/neural/data/_distribution_collection.py b/src/moscot/neural/data/_distribution_collection.py new file mode 100644 index 000000000..42c49796b --- /dev/null +++ b/src/moscot/neural/data/_distribution_collection.py @@ -0,0 +1,229 @@ +from dataclasses import dataclass +from typing import Any, Literal, Optional, Tuple, Hashable, Union + +import numpy as np +import scipy.sparse as sp +import jax.numpy as jnp +import jax + +from anndata import AnnData + +from moscot._logging import logger +from moscot._types import CostFn_t +from moscot.costs import get_cost + +from typing import TypeVar + +K = TypeVar("K", bound=Hashable) + + +@dataclass(frozen=True, repr=True) +class DistributionContainer: + """Data container for OT problems involving more than two distributions. + + TODO + + Parameters + ---------- + xy + Distribution living in a shared space. + xx + Distribution living in an incomparable space. + conditions + Conditions for the distributions. + cost_xy + Cost function when in the shared space. + cost_xx + Cost function in the incomparable space. + """ + + xy: Optional[jax.Array] + xx: Optional[jax.Array] + conditions: Optional[jax.Array] + cost_xy: Any + cost_xx: Any + + @property + def contains_linear(self) -> bool: + """Whether the distribution contains data corresponding to the linear term.""" + return self.xy is not None + + @property + def contains_quadratic(self) -> bool: + """Whether the distribution contains data corresponding to the quadratic term.""" + return self.xx is not None + + @property + def contains_condition(self) -> bool: + """Whether the distribution contains data corresponding to the condition.""" + return self.conditions is not None + + @property + def n_samples(self) -> int: + """Number of samples in the distribution.""" + return self.xy.shape[0] if self.contains_linear else self.xx.shape[0] + + @staticmethod + def _extract_data( + adata: AnnData, + *, + attr: Literal["X", "obs", "obsp", "obsm", "var", "varm", "layers", "uns"], + key: Optional[str] = None, + ) -> jax.Array: + modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" + data = getattr(adata, attr) + + try: + if key is not None: + data = data[key] + except KeyError: + raise KeyError(f"Unable to fetch data from `{modifier}`.") from None + except IndexError: + raise IndexError(f"Unable to fetch data from `{modifier}`.") from None + + if attr == "obs": + data = np.asarray(data)[:, None] + if sp.issparse(data): + logger.warning(f"Densifying data in `{modifier}`") + data = data.toarray() + if data.ndim != 2: + raise ValueError(f"Expected `{modifier}` to have `2` dimensions, found `{data.ndim}`.") + + return jnp.array(data) + + @staticmethod + def _verify_input( + xy_attr: Optional[Literal["X", "obsp", "obsm", "layers", "uns"]], + xy_key: Optional[str], + xx_attr: Optional[Literal["X", "obsp", "obsm", "layers", "uns"]], + xx_key: Optional[str], + conditions_attr: Optional[Literal["obs", "var", "obsm", "varm", "layers", "uns"]], + conditions_key: Optional[str], + ) -> Tuple[bool, bool, bool]: + if (xy_attr is None and xy_key is not None) or (xy_attr is not None and xy_key is None): + raise ValueError(r"Either both `xy_attr` and `xy_key` must be `None` or none of them.") + if (xx_attr is None and xx_key is not None) or (xx_attr is not None and xx_key is None): + raise ValueError(r"Either both `xy_attr` and `xy_key` must be `None` or none of them.") + if (conditions_attr is None and conditions_key is not None) or ( + conditions_attr is not None and conditions_key is None + ): + raise ValueError(r"Either both `conditions_attr` and `conditions_key` must be `None` or none of them.") + return xy_attr is not None, xx_attr is not None, conditions_attr is not None + + @classmethod + def from_adata( + cls, + adata: AnnData, + xy_attr: Literal["X", "obsp", "obsm", "layers", "uns"] = None, + xy_key: Optional[str] = None, + xy_cost: CostFn_t = "sq_euclidean", + xx_attr: Literal["X", "obsp", "obsm", "layers", "uns"] = None, + xx_key: Optional[str] = None, + xx_cost: CostFn_t = "sq_euclidean", + conditions_attr: Optional[Literal["obs", "var", "obsm", "varm", "layers", "uns"]] = None, + conditions_key: Optional[str] = None, + backend: Literal["ott"] = "ott", + **kwargs: Any, + ) -> "DistributionContainer": + """Create distribution container from :class:`~anndata.AnnData`. + + .. warning:: + Sparse arrays will be always densified. + + Parameters + ---------- + adata + Annotated data object. + a + Marginals when used as source distribution. + b + Marginals when used as target distribution. + xy_attr + Attribute of `adata` containing the data for the shared space. + xy_key + Key of `xy_attr` containing the data for the shared space. + xy_cost + Cost function when in the shared space. + xx_attr + Attribute of `adata` containing the data for the incomparable space. + xx_key + Key of `xx_attr` containing the data for the incomparable space. + xx_cost + Cost function in the incomparable space. + conditions_attr + Attribute of `adata` containing the conditions. + conditions_key + Key of `conditions_attr` containing the conditions. + backend + Backend to use. + kwargs + Keyword arguments to pass to the cost functions. + + Returns + ------- + The distribution container. + """ + contains_linear, contains_quadratic, contains_condition = cls._verify_input( + xy_attr, xy_key, xx_attr, xx_key, conditions_attr, conditions_key + ) + + if contains_linear: + xy_data = cls._extract_data(adata, attr=xy_attr, key=xy_key) + xy_cost_fn = get_cost(xy_cost, backend=backend, **kwargs) + else: + xy_data = None + xy_cost_fn = None + + if contains_quadratic: + xx_data = cls._extract_data(adata, attr=xx_attr, key=xx_key) + xx_cost_fn = get_cost(xx_cost, backend=backend, **kwargs) + else: + xx_data = None + xx_cost_fn = None + + conditions_data = ( + cls._extract_data(adata, attr=conditions_attr, key=conditions_key) if contains_condition else None # type: ignore[arg-type] # noqa:E501 + ) + return cls(xy=xy_data, xx=xx_data, conditions=conditions_data, cost_xy=xy_cost_fn, cost_xx=xx_cost_fn) + + def __getitem__( + self, + idx: Union[int, slice, jnp.ndarray, jax.Array, list, tuple] + ) -> "DistributionContainer": + """ + Return a new DistributionContainer where .xy, .xx, .conditions + are sliced by `idx` (if they are not None). + + This allows usage like: + new_container = distribution_container[train_ixs] + """ # noqa: D205 + # TODO: Normally this is inefficient + # But we first need to separate the slicing of training and validation data + # Before creating this DistributionContainer! + # Slice xy + new_xy = self.xy[idx] if self.xy is not None else None + + # Slice xx + new_xx = self.xx[idx] if self.xx is not None else None + + # Slice conditions + new_conditions = self.conditions[idx] if self.conditions is not None else None + + # Reuse the same cost functions + return DistributionContainer( + xy=new_xy, + xx=new_xx, + conditions=new_conditions, + cost_xy=self.cost_xy, + cost_xx=self.cost_xx, + ) + + +class DistributionCollection(dict[K, DistributionContainer]): + """Collection of distributions.""" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}{list(self.keys())}" + + def __str__(self) -> str: + return repr(self) diff --git a/src/moscot/neural/data/_policy_loader.py b/src/moscot/neural/data/_policy_loader.py new file mode 100644 index 000000000..90f086809 --- /dev/null +++ b/src/moscot/neural/data/_policy_loader.py @@ -0,0 +1,169 @@ +from moscot.utils.subset_policy import SubsetPolicy +from moscot.neural.data._distribution_collection import DistributionCollection + +import jax + +from typing import Any, Dict, Iterator, List, Optional, Tuple + + +import jax +import jax.numpy as jnp +from typing import Any, Dict, Optional, List, Tuple, Iterator, Sequence + +import functools + + +@functools.partial(jax.jit, static_argnums=(3,)) +def _sample_indices( + rng: jax.Array, idx_src: jnp.ndarray, idx_tgt: jnp.ndarray, batch_size: int +) -> Tuple[jax.Array, jnp.ndarray, jnp.ndarray]: + """ + JIT-compiled function: + - Splits RNG into rng_src/rng_tgt. + - Samples with replacement from idx_src, idx_tgt. + - Returns the updated rng, plus arrays of sample positions for src & tgt. + """ # noqa: D205 + rng, rng_src, rng_tgt = jax.random.split(rng, 3) + src_samples = jax.random.randint(rng_src, shape=(batch_size,), minval=0, maxval=idx_src.shape[0]) + tgt_samples = jax.random.randint(rng_tgt, shape=(batch_size,), minval=0, maxval=idx_tgt.shape[0]) + return rng, src_samples, tgt_samples + + +@jax.jit +def _gather_array(arr: jnp.ndarray, idxs: jnp.ndarray) -> jnp.ndarray: + """ + JIT-compiled function to gather rows from arr at idxs. + If arr is shape [N, ...], idxs is shape [K], result is [K, ...]. + """ # noqa: D205 + return jnp.take(arr, idxs, axis=0) + + +class PolicyDataLoader: + """A data loader for handling subset policies and distribution collections. + + A data loader that: + - Takes a SubsetPolicy (with a plan of edges). + - Has a DistributionCollection mapping node -> DistributionContainer. + - For each distribution container, we check that .xy, .xx, .conditions + all share the same shape[0] if they're not None. + - On each iteration: + 1) Randomly pick an edge (src_node, tgt_node) in Python. + 2) Use a small jitted function `_sample_indices` to sample from + the node_indices (with replacement). + 3) Use a small jitted function `_gather_array` to gather data from + .xy, .xx, .conditions. + 4) Build a final dictionary and yield it. + """ + + def __init__( + self, + rng: jax.Array, + policy: SubsetPolicy[Any], + distributions: DistributionCollection, + batch_size: int = 128, + plan: Optional[Sequence[Tuple[Any, Any]]] = None, + src_prefix: str = "src", + tgt_prefix: str = "tgt", + src_renames: Optional[Dict[str, str]] = None, + tgt_renames: Optional[Dict[str, str]] = None, + ): + + self.policy = policy + self.distributions = distributions + self.rng = rng + self.batch_size = batch_size + self.edges = plan if plan is not None else self.policy.plan() + self.src_prefix = src_prefix + self.tgt_prefix = tgt_prefix + self.src_renames = src_renames if src_renames is not None else {} + self.tgt_renames = tgt_renames if tgt_renames is not None else {} + + # Precompute an index array for each node + self.node_indices: Dict[Any, jnp.ndarray] = {} + self._init_indices() + + def _init_indices(self) -> None: + """Verify shape consistency within each DistributionContainer, store jnp.arange(...) as node_indices.""" + for node, container in self.distributions.items(): + # Gather shapes of non-None arrays + shapes = [] + if container.xy is not None: + shapes.append(container.xy.shape[0]) + if container.xx is not None: + shapes.append(container.xx.shape[0]) + if container.conditions is not None: + shapes.append(container.conditions.shape[0]) + + # All must match + if shapes and not all(s == shapes[0] for s in shapes): + raise ValueError(f"Inconsistent shape for node {node}: {shapes}") + + if shapes: + n = shapes[0] + self.node_indices[node] = jnp.arange(n) + + def __iter__(self) -> Iterator[Dict[str, jnp.ndarray]]: + """ + Infinite generator. Each iteration: + 1) Randomly pick an edge (src_node, tgt_node). + 2) Use _sample_indices(...) to get random sample positions from the node's data. + 3) Use _gather_array(...) to gather from .xy, .xx, .conditions. + 4) Build a dict and yield it. + """ + while True: + if not self.edges: + break + + # (A) Pick a random edge in Python + self.rng, rng_edge = jax.random.split(self.rng) + i = jax.random.randint(rng_edge, shape=(), minval=0, maxval=len(self.edges)) + edge = self.edges[int(i)] + src_node, tgt_node = edge + + # Skip if the node doesn't exist in distributions or indices + if src_node not in self.distributions or tgt_node not in self.distributions: + continue + if src_node not in self.node_indices or tgt_node not in self.node_indices: + continue + + src_container = self.distributions[src_node] + tgt_container = self.distributions[tgt_node] + idx_src = self.node_indices[src_node] + idx_tgt = self.node_indices[tgt_node] + + # (B) Sample random positions with a small jitted function + self.rng, src_samples, tgt_samples = _sample_indices(self.rng, idx_src, idx_tgt, self.batch_size) + # Convert to actual indices + src_idxs = jnp.take(idx_src, src_samples) + tgt_idxs = jnp.take(idx_tgt, tgt_samples) + + # (C) Gather data from each relevant array + batch_dict = {} + + src_candidates = [ + ("xy", src_container.xy), + ("xx", src_container.xx), + ("conditions", src_container.conditions), + ] + for key, arr in src_candidates: + if arr is not None: + key_new = self.src_renames.get(key, key) + batch_dict[f"{self.src_prefix}_{key_new}"] = _gather_array(arr, src_idxs) + + tgt_candidates = [ + ("xy", tgt_container.xy), + ("xx", tgt_container.xx), + ("conditions", tgt_container.conditions), + ] + for key, arr in tgt_candidates: + if arr is not None: + key_new = self.tgt_renames.get(key, key) + batch_dict[f"{self.tgt_prefix}_{key_new}"] = _gather_array(arr, tgt_idxs) + if not batch_dict: + continue + + yield batch_dict + + def __len__(self) -> int: + """Optionally define a length if you like, e.g. len(edges).""" + return len(self.edges) diff --git a/src/moscot/neural/problems/generic/_generic.py b/src/moscot/neural/problems/generic/_generic.py index c37a2f75c..abc5bb87f 100644 --- a/src/moscot/neural/problems/generic/_generic.py +++ b/src/moscot/neural/problems/generic/_generic.py @@ -1,6 +1,6 @@ import types from types import MappingProxyType -from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Type, Union +from typing import Any, Dict, Literal, Mapping, Tuple, Type, Union from moscot import _constants from moscot._types import CostKwargs_t, OttCostFn_t, Policy_t @@ -22,9 +22,12 @@ def prepare( key: str, joint_attr: Union[str, Mapping[str, Any]], conditional_attr: Union[str, Mapping[str, Any]], + # src_condition_attr: Union[str, Mapping[str, Any]], + # src_augment_attr: Optional[Union[str, Mapping[str, Any]]] = None, + # src_quad_attr: Optional[Union[str, Mapping[str, Any]]] = None, + # tgt_quad_attr: Optional[Union[str, Mapping[str, Any]]] = None, + # tgt_flow_attr: Optional[Union[str, Mapping[str, Any]]] = None, policy: Literal["sequential", "star", "explicit"] = "sequential", - a: Optional[str] = None, - b: Optional[str] = None, cost: OttCostFn_t = "sq_euclidean", cost_kwargs: CostKwargs_t = types.MappingProxyType({}), **kwargs: Any, @@ -40,8 +43,6 @@ def prepare( xy=xy, xx=xx, conditions=conditions, - a=a, - b=b, **kwargs, ) @@ -58,14 +59,12 @@ def solve( """Solve.""" return super().solve( batch_size=batch_size, - # tau_a=tau_a, # TODO: unbalancedness handler - # tau_b=tau_b, seed=seed, n_iters=iterations, valid_freq=valid_freq, valid_sinkhorn_kwargs=valid_sinkhorn_kwargs, train_size=train_size, - solver_name="GENOTLinSolver", + solver_name="GENOTSolver", **kwargs, ) diff --git a/src/moscot/utils/tagged_array.py b/src/moscot/utils/tagged_array.py index fa6a70f06..4fe681e79 100644 --- a/src/moscot/utils/tagged_array.py +++ b/src/moscot/utils/tagged_array.py @@ -8,12 +8,12 @@ from anndata import AnnData from moscot._logging import logger -from moscot._types import ArrayLike, CostFn_t, OttCostFn_t +from moscot._types import ArrayLike, CostFn_t from moscot.costs import get_cost K = TypeVar("K", bound=Hashable) -__all__ = ["Tag", "TaggedArray", "DistributionContainer", "DistributionCollection"] +__all__ = ["Tag", "TaggedArray"] @enum.unique @@ -188,186 +188,3 @@ def is_point_cloud(self) -> bool: def is_graph(self) -> bool: """Whether :attr:`data_src` is a graph.""" return self.tag == Tag.GRAPH - - -@dataclass(frozen=True, repr=True) -class DistributionContainer: - """Data container for OT problems involving more than two distributions. - - TODO - - Parameters - ---------- - xy - Distribution living in a shared space. - xx - Distribution living in an incomparable space. - a - Marginals when used as source distribution. - b - Marginals when used as target distribution. - conditions - Conditions for the distributions. - cost_xy - Cost function when in the shared space. - cost_xx - Cost function in the incomparable space. - """ - - xy: Optional[ArrayLike] - xx: Optional[ArrayLike] - a: ArrayLike - b: ArrayLike - conditions: Optional[ArrayLike] - cost_xy: OttCostFn_t - cost_xx: OttCostFn_t - - @property - def contains_linear(self) -> bool: - """Whether the distribution contains data corresponding to the linear term.""" - return self.xy is not None - - @property - def contains_quadratic(self) -> bool: - """Whether the distribution contains data corresponding to the quadratic term.""" - return self.xx is not None - - @property - def contains_condition(self) -> bool: - """Whether the distribution contains data corresponding to the condition.""" - return self.conditions is not None - - @staticmethod - def _extract_data( - adata: AnnData, - *, - attr: Literal["X", "obs", "obsp", "obsm", "var", "varm", "layers", "uns"], - key: Optional[str] = None, - ) -> ArrayLike: - modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" - data = getattr(adata, attr) - - try: - if key is not None: - data = data[key] - except KeyError: - raise KeyError(f"Unable to fetch data from `{modifier}`.") from None - except IndexError: - raise IndexError(f"Unable to fetch data from `{modifier}`.") from None - - if attr == "obs": - data = np.asarray(data)[:, None] - if sp.issparse(data): - logger.warning(f"Densifying data in `{modifier}`") - data = data.A - if data.ndim != 2: - raise ValueError(f"Expected `{modifier}` to have `2` dimensions, found `{data.ndim}`.") - - return data - - @staticmethod - def _verify_input( - xy_attr: Optional[Literal["X", "obsp", "obsm", "layers", "uns"]], - xy_key: Optional[str], - xx_attr: Optional[Literal["X", "obsp", "obsm", "layers", "uns"]], - xx_key: Optional[str], - conditions_attr: Optional[Literal["obs", "var", "obsm", "varm", "layers", "uns"]], - conditions_key: Optional[str], - ) -> Tuple[bool, bool, bool]: - if (xy_attr is None and xy_key is not None) or (xy_attr is not None and xy_key is None): - raise ValueError(r"Either both `xy_attr` and `xy_key` must be `None` or none of them.") - if (xx_attr is None and xx_key is not None) or (xx_attr is not None and xx_key is None): - raise ValueError(r"Either both `xy_attr` and `xy_key` must be `None` or none of them.") - if (conditions_attr is None and conditions_key is not None) or ( - conditions_attr is not None and conditions_key is None - ): - raise ValueError(r"Either both `conditions_attr` and `conditions_key` must be `None` or none of them.") - return xy_attr is not None, xx_attr is not None, conditions_attr is not None - - @classmethod - def from_adata( - cls, - adata: AnnData, - a: ArrayLike, - b: ArrayLike, - xy_attr: Literal["X", "obsp", "obsm", "layers", "uns"] = None, - xy_key: Optional[str] = None, - xy_cost: CostFn_t = "sq_euclidean", - xx_attr: Literal["X", "obsp", "obsm", "layers", "uns"] = None, - xx_key: Optional[str] = None, - xx_cost: CostFn_t = "sq_euclidean", - conditions_attr: Optional[Literal["obs", "var", "obsm", "varm", "layers", "uns"]] = None, - conditions_key: Optional[str] = None, - backend: Literal["ott"] = "ott", - **kwargs: Any, - ) -> "DistributionContainer": - """Create distribution container from :class:`~anndata.AnnData`. - - .. warning:: - Sparse arrays will be always densified. - - Parameters - ---------- - adata - Annotated data object. - a - Marginals when used as source distribution. - b - Marginals when used as target distribution. - xy_attr - Attribute of `adata` containing the data for the shared space. - xy_key - Key of `xy_attr` containing the data for the shared space. - xy_cost - Cost function when in the shared space. - xx_attr - Attribute of `adata` containing the data for the incomparable space. - xx_key - Key of `xx_attr` containing the data for the incomparable space. - xx_cost - Cost function in the incomparable space. - conditions_attr - Attribute of `adata` containing the conditions. - conditions_key - Key of `conditions_attr` containing the conditions. - backend - Backend to use. - kwargs - Keyword arguments to pass to the cost functions. - - Returns - ------- - The distribution container. - """ - contains_linear, contains_quadratic, contains_condition = cls._verify_input( - xy_attr, xy_key, xx_attr, xx_key, conditions_attr, conditions_key - ) - - if contains_linear: - xy_data = cls._extract_data(adata, attr=xy_attr, key=xy_key) - xy_cost_fn = get_cost(xy_cost, backend=backend, **kwargs) - else: - xy_data = None - xy_cost_fn = None - - if contains_quadratic: - xx_data = cls._extract_data(adata, attr=xx_attr, key=xx_key) - xx_cost_fn = get_cost(xx_cost, backend=backend, **kwargs) - else: - xx_data = None - xx_cost_fn = None - - conditions_data = ( - cls._extract_data(adata, attr=conditions_attr, key=conditions_key) if contains_condition else None # type: ignore[arg-type] # noqa:E501 - ) - return cls(xy=xy_data, xx=xx_data, a=a, b=b, conditions=conditions_data, cost_xy=xy_cost_fn, cost_xx=xx_cost_fn) - - -class DistributionCollection(dict[K, DistributionContainer]): - """Collection of distributions.""" - - def __repr__(self) -> str: - return f"{self.__class__.__name__}{list(self.keys())}" - - def __str__(self) -> str: - return repr(self) diff --git a/tests/neural/problems/generic/test_conditional_neural_problem.py b/tests/neural/problems/generic/test_conditional_neural_problem.py index e4cd9b832..bd80685a3 100644 --- a/tests/neural/problems/generic/test_conditional_neural_problem.py +++ b/tests/neural/problems/generic/test_conditional_neural_problem.py @@ -9,10 +9,10 @@ from moscot.base.output import BaseSolverOutput from moscot.neural.base.problems import NeuralOTProblem from moscot.neural.problems.generic import GENOTLinProblem # type: ignore[attr-defined] -from moscot.utils.tagged_array import DistributionCollection, DistributionContainer +from moscot.neural.data import DistributionCollection, DistributionContainer from tests._utils import ATOL, RTOL from tests.problems.conftest import neurallin_cond_args_1 - +import jax.numpy as jnp class TestGENOTLinProblem: @pytest.mark.fast @@ -26,15 +26,11 @@ def test_prepare(self, adata_time: ad.AnnData): container = problem.distributions[0] n_obs_0 = adata_time[adata_time.obs["time"] == 0].n_obs assert isinstance(container, DistributionContainer) - assert isinstance(container.xy, np.ndarray) + assert isinstance(container.xy, jnp.ndarray) assert container.xy.shape == (n_obs_0, 50) assert container.xx is None - assert isinstance(container.conditions, np.ndarray) + assert isinstance(container.conditions, jnp.ndarray) assert container.conditions.shape == (n_obs_0, 1) - assert isinstance(container.a, np.ndarray) - assert container.a.shape == (n_obs_0,) - assert isinstance(container.b, np.ndarray) - assert container.b.shape == (n_obs_0,) assert isinstance(container.cost_xy, costs.SqEuclidean) assert container.cost_xx is None From e262d50b10c9198dcfc46fa59b938f430add9e94 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 10 Jan 2025 16:11:07 +0100 Subject: [PATCH 2/5] formatting --- .../neural/backends/neural_ott/__init__.py | 2 +- .../neural/backends/neural_ott/solver.py | 64 +++---------------- src/moscot/neural/base/problems/problem.py | 1 - src/moscot/neural/data/__init__.py | 5 +- .../neural/data/_distribution_collection.py | 13 ++-- src/moscot/neural/data/_policy_loader.py | 19 +++--- .../test_conditional_neural_problem.py | 5 +- 7 files changed, 31 insertions(+), 78 deletions(-) diff --git a/src/moscot/neural/backends/neural_ott/__init__.py b/src/moscot/neural/backends/neural_ott/__init__.py index b96a92f96..bfb3ebed1 100644 --- a/src/moscot/neural/backends/neural_ott/__init__.py +++ b/src/moscot/neural/backends/neural_ott/__init__.py @@ -1,3 +1,3 @@ from moscot.neural.backends.neural_ott.solver import GENOTSolver -__all__ = ["GENOTSolver"] \ No newline at end of file +__all__ = ["GENOTSolver"] diff --git a/src/moscot/neural/backends/neural_ott/solver.py b/src/moscot/neural/backends/neural_ott/solver.py index 9b7b312c5..b624f7e37 100644 --- a/src/moscot/neural/backends/neural_ott/solver.py +++ b/src/moscot/neural/backends/neural_ott/solver.py @@ -1,65 +1,22 @@ -import abc import functools -import inspect import math -import types -from typing import ( - Any, - Hashable, - List, - Literal, - Mapping, - NamedTuple, - Optional, - Set, - Tuple, - TypeVar, - Union, -) +from typing import Any, Hashable, Set, Tuple, TypeVar import optax import jax import jax.numpy as jnp -import numpy as np -from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud -from ott.neural.datasets import OTData, OTDataset from ott.neural.methods.flows import dynamics, genot from ott.neural.networks.layers import time_encoder from ott.neural.networks.velocity_field import VelocityField -from ott.problems.linear import linear_problem -from ott.problems.quadratic import quadratic_problem -from ott.solvers.linear import sinkhorn, sinkhorn_lr -from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr from ott.solvers.utils import uniform_sampler -from moscot.neural.data import PolicyDataLoader - -from moscot._logging import logger -from moscot._types import ( - ArrayLike, - LRInitializer_t, - ProblemKind_t, - QuadInitializer_t, - SinkhornInitializer_t, -) -from moscot.backends.ott._utils import ( - InitializerResolver, - Loader, - MultiLoader, - _instantiate_geodesic_cost, - alpha_to_fused_penalty, - check_shapes, - convert_scipy_sparse, - data_match_fn, - densify, - ensure_2d, -) -from moscot.utils.subset_policy import SubsetPolicy -from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput -from moscot.base.problems._utils import TimeScalesHeatKernel + +from moscot._types import ProblemKind_t +from moscot.backends.ott._utils import data_match_fn +from moscot.backends.ott.output import NeuralOutput from moscot.base.solver import BaseSolver -from moscot.neural.data import DistributionCollection -from typing import TypeVar +from moscot.neural.data import DistributionCollection, PolicyDataLoader +from moscot.utils.subset_policy import SubsetPolicy K = TypeVar("K", bound=Hashable) @@ -69,11 +26,11 @@ def _split_data( rng: jax.random.PRNGKey, - distributions: DistributionCollection, + distributions: DistributionCollection[K], train_size: float = 0.9, ) -> Any: - train_collection: DistributionCollection = {} - val_collection: DistributionCollection = {} + train_collection: DistributionCollection[K] = {} + val_collection: DistributionCollection[K] = {} for key, dist in distributions.items(): n_train_x = math.ceil(train_size * dist.n_samples) idxs = jax.random.permutation(rng, jnp.arange(dist.n_samples)) @@ -205,7 +162,6 @@ def _prepare( # type: ignore[override] ) return train_loader, validate_loader - @property def solver(self) -> genot.GENOT: """Underlying optimal transport solver.""" diff --git a/src/moscot/neural/base/problems/problem.py b/src/moscot/neural/base/problems/problem.py index b68ed14d3..1c74572df 100644 --- a/src/moscot/neural/base/problems/problem.py +++ b/src/moscot/neural/base/problems/problem.py @@ -2,7 +2,6 @@ Any, Hashable, Iterable, - List, Literal, Mapping, Optional, diff --git a/src/moscot/neural/data/__init__.py b/src/moscot/neural/data/__init__.py index 83fa7ad2f..a3260be83 100644 --- a/src/moscot/neural/data/__init__.py +++ b/src/moscot/neural/data/__init__.py @@ -1,4 +1,7 @@ +from moscot.neural.data._distribution_collection import ( + DistributionCollection, + DistributionContainer, +) from moscot.neural.data._policy_loader import PolicyDataLoader -from moscot.neural.data._distribution_collection import DistributionCollection, DistributionContainer __all__ = ["PolicyDataLoader", "DistributionCollection", "DistributionContainer"] diff --git a/src/moscot/neural/data/_distribution_collection.py b/src/moscot/neural/data/_distribution_collection.py index 42c49796b..a07e37bc1 100644 --- a/src/moscot/neural/data/_distribution_collection.py +++ b/src/moscot/neural/data/_distribution_collection.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import Any, Literal, Optional, Tuple, Hashable, Union +from typing import Any, Hashable, Literal, Optional, Tuple, TypeVar, Union +import jax +import jax.numpy as jnp import numpy as np import scipy.sparse as sp -import jax.numpy as jnp -import jax from anndata import AnnData @@ -12,8 +12,6 @@ from moscot._types import CostFn_t from moscot.costs import get_cost -from typing import TypeVar - K = TypeVar("K", bound=Hashable) @@ -61,7 +59,7 @@ def contains_condition(self) -> bool: @property def n_samples(self) -> int: """Number of samples in the distribution.""" - return self.xy.shape[0] if self.contains_linear else self.xx.shape[0] + return self.xy.shape[0] if self.contains_linear else self.xx.shape[0] # type: ignore[union-attr] @staticmethod def _extract_data( @@ -187,8 +185,7 @@ def from_adata( return cls(xy=xy_data, xx=xx_data, conditions=conditions_data, cost_xy=xy_cost_fn, cost_xx=xx_cost_fn) def __getitem__( - self, - idx: Union[int, slice, jnp.ndarray, jax.Array, list, tuple] + self, idx: Union[int, slice, jnp.ndarray, jax.Array, list[Any], tuple[Any]] ) -> "DistributionContainer": """ Return a new DistributionContainer where .xy, .xx, .conditions diff --git a/src/moscot/neural/data/_policy_loader.py b/src/moscot/neural/data/_policy_loader.py index 90f086809..3fef84034 100644 --- a/src/moscot/neural/data/_policy_loader.py +++ b/src/moscot/neural/data/_policy_loader.py @@ -1,16 +1,13 @@ -from moscot.utils.subset_policy import SubsetPolicy -from moscot.neural.data._distribution_collection import DistributionCollection - -import jax - -from typing import Any, Dict, Iterator, List, Optional, Tuple - +import functools +from typing import Any, Dict, Hashable, Iterator, Optional, Sequence, Tuple, TypeVar import jax import jax.numpy as jnp -from typing import Any, Dict, Optional, List, Tuple, Iterator, Sequence -import functools +from moscot.neural.data._distribution_collection import DistributionCollection +from moscot.utils.subset_policy import SubsetPolicy + +K = TypeVar("K", bound=Hashable) @functools.partial(jax.jit, static_argnums=(3,)) @@ -59,7 +56,7 @@ def __init__( self, rng: jax.Array, policy: SubsetPolicy[Any], - distributions: DistributionCollection, + distributions: DistributionCollection[K], batch_size: int = 128, plan: Optional[Sequence[Tuple[Any, Any]]] = None, src_prefix: str = "src", @@ -109,7 +106,7 @@ def __iter__(self) -> Iterator[Dict[str, jnp.ndarray]]: 2) Use _sample_indices(...) to get random sample positions from the node's data. 3) Use _gather_array(...) to gather from .xy, .xx, .conditions. 4) Build a dict and yield it. - """ + """ # noqa: D205 while True: if not self.edges: break diff --git a/tests/neural/problems/generic/test_conditional_neural_problem.py b/tests/neural/problems/generic/test_conditional_neural_problem.py index bd80685a3..33b87d749 100644 --- a/tests/neural/problems/generic/test_conditional_neural_problem.py +++ b/tests/neural/problems/generic/test_conditional_neural_problem.py @@ -1,6 +1,7 @@ import optax import pytest +import jax.numpy as jnp import numpy as np from ott.geometry import costs @@ -8,11 +9,11 @@ from moscot.base.output import BaseSolverOutput from moscot.neural.base.problems import NeuralOTProblem -from moscot.neural.problems.generic import GENOTLinProblem # type: ignore[attr-defined] from moscot.neural.data import DistributionCollection, DistributionContainer +from moscot.neural.problems.generic import GENOTLinProblem # type: ignore[attr-defined] from tests._utils import ATOL, RTOL from tests.problems.conftest import neurallin_cond_args_1 -import jax.numpy as jnp + class TestGENOTLinProblem: @pytest.mark.fast From c028c1497a2050c99d8a3e8d9b1d244144ebd41f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 10 Jan 2025 16:35:18 +0100 Subject: [PATCH 3/5] linting and adding genotsolver as a new backend --- src/moscot/backends/utils.py | 42 +++++++++++++--------- src/moscot/neural/base/problems/problem.py | 2 +- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/moscot/backends/utils.py b/src/moscot/backends/utils.py index f603f901c..96ce8e476 100644 --- a/src/moscot/backends/utils.py +++ b/src/moscot/backends/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Literal, Tuple, TypeVar, Union from moscot import _registry from moscot._types import ProblemKind_t @@ -7,16 +7,16 @@ from moscot.backends import ott from moscot.neural.backends import neural_ott -__all__ = ["get_solver", "register_solver", "get_available_backends"] -register_solver_t = Callable[ - [Literal["linear", "quadratic"], Optional[Literal["GENOTSolver"]]], - Union["ott.SinkhornSolver", "ott.GWSolver", "neural_ott.GENOTSolver"], -] +__all__ = ["get_solver", "register_solver", "get_available_backends"] _REGISTRY = _registry.Registry() +GWSolver = TypeVar("GWSolver", bound="ott.GWSolver") +SinkhornSolver = TypeVar("SinkhornSolver", bound="ott.SinkhornSolver") +GENOTSolver = TypeVar("GENOTSolver", bound="neural_ott.GENOTSolver") + def get_solver(problem_kind: ProblemKind_t, *, backend: str = "ott", return_class: bool = False, **kwargs: Any) -> Any: """TODO.""" @@ -28,7 +28,7 @@ def get_solver(problem_kind: ProblemKind_t, *, backend: str = "ott", return_clas def register_solver( backend: str, -) -> Union["ott.SinkhornSolver", "ott.GWSolver", "neural_ott.GENOTSolver"]: +) -> Union[SinkhornSolver, GWSolver, GENOTSolver]: """Register a solver for a specific backend. Parameters @@ -43,25 +43,33 @@ def register_solver( return _REGISTRY.register(backend) # type: ignore[return-value] -@register_solver("ott") -def _( +@register_solver("ott") # type: ignore[misc] +def create_ott_solver( problem_kind: Literal["linear", "quadratic"], - solver_name: Optional[Literal["GENOTSolver"]] = None, -) -> Union["ott.SinkhornSolver", "ott.GWSolver", "neural_ott.GENOTSolver"]: + solver_name: Any = None, +) -> Union[SinkhornSolver, GWSolver]: from moscot.backends import ott if problem_kind == "linear": - if solver_name == "GENOTSolver": - from moscot.neural.backends import neural_ott - - return neural_ott.GENOTSolver # type: ignore[return-value] - if solver_name is None: - return ott.SinkhornSolver # type: ignore[return-value] + return ott.SinkhornSolver # type: ignore[return-value] if problem_kind == "quadratic": return ott.GWSolver # type: ignore[return-value] raise NotImplementedError(f"Unable to create solver for `{problem_kind!r}`, {solver_name} problem.") +@register_solver("neural_ott") +def create_neural_ott_solver( + problem_kind: Literal["linear", "quadratic"], + solver_name: Any = None, +) -> GENOTSolver: + from moscot.neural.backends import neural_ott + + if solver_name == "GENOTSolver": + return neural_ott.GENOTSolver + + raise NotImplementedError(f"Unable to create solver for `{problem_kind!r}`, {solver_name} problem.") + + def get_available_backends() -> Tuple[str, ...]: """Return all available backends.""" return tuple(backend for backend in _REGISTRY) diff --git a/src/moscot/neural/base/problems/problem.py b/src/moscot/neural/base/problems/problem.py index 1c74572df..7c1691d08 100644 --- a/src/moscot/neural/base/problems/problem.py +++ b/src/moscot/neural/base/problems/problem.py @@ -126,7 +126,7 @@ def prepare( @wrap_solve def solve( self, - backend: Literal["ott"] = "ott", + backend: Literal["neural_ott"] = "neural_ott", solver_name: Literal["GENOTSolver"] = "GENOTSolver", device: Optional[Device_t] = None, **kwargs: Any, From 522b9616cef5d80ebe7e4514bc18a688fb1d72bb Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 10 Jan 2025 20:45:50 +0100 Subject: [PATCH 4/5] tests passing and formatting handled --- .../neural/backends/neural_ott/solver.py | 8 +- src/moscot/neural/base/problems/problem.py | 132 +++++++++- src/moscot/neural/data/__init__.py | 4 +- .../neural/data/_distribution_collection.py | 228 ++++-------------- src/moscot/neural/data/_policy_loader.py | 52 ++-- .../neural/problems/generic/_generic.py | 55 +++-- src/moscot/problems/_utils.py | 1 + .../test_conditional_neural_problem.py | 27 +-- 8 files changed, 233 insertions(+), 274 deletions(-) diff --git a/src/moscot/neural/backends/neural_ott/solver.py b/src/moscot/neural/backends/neural_ott/solver.py index b624f7e37..725b4bebd 100644 --- a/src/moscot/neural/backends/neural_ott/solver.py +++ b/src/moscot/neural/backends/neural_ott/solver.py @@ -69,8 +69,8 @@ def _prepare( # type: ignore[override] rng = jax.random.PRNGKey(seed) src_renames = tgt_renames = { - "xy": "lin", - "xx": "quad", + "shared_space": "lin", + "incomparable_space": "quad", } if train_size == 1.0: @@ -80,7 +80,6 @@ def _prepare( # type: ignore[override] policy=policy, distributions=distributions, batch_size=batch_size, - plan=policy.plan(), src_renames=src_renames, tgt_renames=tgt_renames, ) @@ -89,7 +88,6 @@ def _prepare( # type: ignore[override] policy=policy, distributions=distributions, batch_size=batch_size, - plan=policy.plan(), src_renames=src_renames, tgt_renames=tgt_renames, ) @@ -102,7 +100,6 @@ def _prepare( # type: ignore[override] policy=policy, distributions=train_dist, batch_size=batch_size, - plan=policy.plan(), src_renames=src_renames, tgt_renames=tgt_renames, ) @@ -111,7 +108,6 @@ def _prepare( # type: ignore[override] policy=policy, distributions=valid_dist, batch_size=batch_size, - plan=policy.plan(), src_renames=src_renames, tgt_renames=tgt_renames, ) diff --git a/src/moscot/neural/base/problems/problem.py b/src/moscot/neural/base/problems/problem.py index 7c1691d08..38ebf4b48 100644 --- a/src/moscot/neural/base/problems/problem.py +++ b/src/moscot/neural/base/problems/problem.py @@ -1,5 +1,6 @@ from typing import ( Any, + Generic, Hashable, Iterable, Literal, @@ -11,18 +12,22 @@ Union, ) +import jax +import jax.numpy as jnp import numpy as np import pandas as pd +import scipy.sparse as sp from anndata import AnnData from moscot import backends +from moscot._logging import logger from moscot._types import ArrayLike, Device_t from moscot.base.output import BaseNeuralOutput from moscot.base.problems._utils import wrap_prepare, wrap_solve from moscot.base.problems.problem import BaseProblem from moscot.base.solver import OTSolver -from moscot.neural.data import DistributionCollection, DistributionContainer +from moscot.neural.data import DistributionCollection, NeuralDistribution from moscot.utils.subset_policy import ( # type:ignore[attr-defined] ExplicitPolicy, Policy_t, @@ -36,7 +41,7 @@ __all__ = ["NeuralOTProblem"] -class NeuralOTProblem(BaseProblem): # TODO(@MUCDK) check generic types, save and load +class NeuralOTProblem(BaseProblem, Generic[K]): # TODO(@MUCDK) check generic types, save and load """ Base class for all conditional (nerual) optimal transport problems. @@ -56,7 +61,7 @@ def __init__( super().__init__(**kwargs) self._adata = adata - self._distributions: Optional[DistributionCollection[K]] = None # type: ignore[valid-type] + self._distributions: Optional[DistributionCollection[K]] = None self._policy: Optional[SubsetPolicy[Any]] = None self._solver: Optional[OTSolver[BaseNeuralOutput]] = None @@ -70,13 +75,14 @@ def prepare( self, policy_key: str, policy: Policy_t, - xy: Mapping[str, Any], - xx: Mapping[str, Any], - conditions: Mapping[str, Any], + lin: Mapping[str, Any], + src_quad: Optional[Mapping[str, Any]] = None, + tgt_quad: Optional[Mapping[str, Any]] = None, + condition: Optional[Mapping[str, Any]] = None, subset: Optional[Sequence[Tuple[K, K]]] = None, seed: int = 0, reference: K = None, - ) -> "NeuralOTProblem": + ) -> "NeuralOTProblem[K]": """Prepare conditional optimal transport problem. Parameters @@ -115,11 +121,33 @@ def prepare( else: _ = self.policy.create_graph() # type: ignore[union-attr] + if src_quad is None and tgt_quad is not None: + raise ValueError("If `tgt_quad` is provided, `src_quad` must also be provided.") + if src_quad is not None and tgt_quad is None: + raise ValueError("If `src_quad` is provided, `tgt_quad` must also be provided.") + if src_quad is not None: + # which edges will be always source + source_nodes = {el[0] for el in self.policy.plan()} # type: ignore[union-attr] + target_nodes = {el[1] for el in self.policy.plan()} # type: ignore[union-attr] + # if there aren't nodes that are always source or target, we will warn the user + # that we will choose source quad attributes + tgt_quad_nodes = target_nodes - source_nodes + if not source_nodes.isdisjoint(target_nodes): + logger.warning( + "Some nodes are both source and target in the policy plan, " + "we will choose source quad attributes for such nodes." + ) for el in self.policy.categories: # type: ignore[union-attr] adata_masked = self.adata[self._create_mask(el)] # TODO: Marginals - self.distributions[el] = DistributionContainer.from_adata( # type: ignore[index] - adata_masked, **xy, **xx, **conditions + quad = None + if src_quad is not None: + quad = tgt_quad if el in tgt_quad_nodes else src_quad + self.distributions[el] = NeuralOTProblem._create_neural_distribution( # type: ignore[index] + adata_masked, + lin=lin, + quad=quad, + condition=condition, ) return self @@ -130,7 +158,7 @@ def solve( solver_name: Literal["GENOTSolver"] = "GENOTSolver", device: Optional[Device_t] = None, **kwargs: Any, - ) -> "NeuralOTProblem": + ) -> "NeuralOTProblem[K]": """Solve optimal transport problem. Parameters @@ -149,9 +177,14 @@ def solve( - :attr:`solver`: optimal transport solver. - :attr:`solution`: optimal transport solution. """ - tmp = next(iter(self.distributions)) # type: ignore[arg-type] - input_dim = self.distributions[tmp].xy.shape[1] # type: ignore[union-attr, index] - cond_dim = self.distributions[tmp].conditions.shape[1] # type: ignore[union-attr, index] + assert self.distributions is not None + distributions: DistributionCollection[K] = self.distributions + assert next(iter(self.distributions.keys())) is not None + tmp_key: K = next(iter(self.distributions.keys())) + input_dim = distributions[tmp_key].shared_space.shape[1] + cond_dim = 0 + if distributions[tmp_key].condition is not None: + cond_dim = distributions[tmp_key].condition.shape[1] # type: ignore[union-attr] solver_class = backends.get_solver( self.problem_kind, solver_name=solver_name, backend=backend, return_class=True @@ -230,3 +263,76 @@ def solver(self) -> Optional[OTSolver[BaseNeuralOutput]]: def policy(self) -> Optional[SubsetPolicy[Any]]: """Policy used to subset the data.""" return self._policy + + @staticmethod + def _extract_data( + adata: AnnData, + *, + attr: Literal["X", "obs", "obsp", "obsm", "var", "varm", "layers", "uns"], + key: Optional[str] = None, + ) -> jax.Array: + modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" + data = getattr(adata, attr) + + try: + if key is not None: + data = data[key] + except KeyError: + raise KeyError(f"Unable to fetch data from `{modifier}`.") from None + except IndexError: + raise IndexError(f"Unable to fetch data from `{modifier}`.") from None + + if attr == "obs": + data = np.asarray(data)[:, None] + if sp.issparse(data): + logger.warning(f"Densifying data in `{modifier}`") + data = data.toarray() + if data.ndim != 2: + raise ValueError(f"Expected `{modifier}` to have `2` dimensions, found `{data.ndim}`.") + + return jnp.array(data) + + @staticmethod + def _create_neural_distribution( + adata: AnnData, + lin: Optional[Mapping[str, Any]] = None, + quad: Optional[Mapping[str, Any]] = None, + condition: Optional[Mapping[str, Any]] = None, + ) -> NeuralDistribution: + fields = [ + ("shared_space", lin), + ("incomparable_space", quad), + ("condition", condition), + ] + return NeuralDistribution( + **{ + field_name: NeuralOTProblem._extract_data(adata, **field) + for field_name, field in fields + if field is not None + } + ) + + @staticmethod + def _handle_attr(elem: Union[str, Mapping[str, Any]]) -> dict[str, Any]: + if isinstance(elem, str): + return { + "attr": "obsm", + "key": elem, + } + if isinstance(elem, Mapping): + attr = dict(elem) + if "attr" not in attr: + raise KeyError("`attr` must be provided when `attr` is a mapping.") + if elem["attr"] == "X": + return { + "attr": "X", + } + if elem["attr"] in ("obsm", "obsp", "obs", "uns"): + if "key" not in elem: + raise KeyError("`key` must be provided when `attr` is `obsm`, `obsp`, `obs`, or `uns`.") + return { + "attr": elem["attr"], + "key": elem["key"], + } + + raise TypeError(f"Unrecognized `attr` format: {elem}.") diff --git a/src/moscot/neural/data/__init__.py b/src/moscot/neural/data/__init__.py index a3260be83..6143d8652 100644 --- a/src/moscot/neural/data/__init__.py +++ b/src/moscot/neural/data/__init__.py @@ -1,7 +1,7 @@ from moscot.neural.data._distribution_collection import ( DistributionCollection, - DistributionContainer, + NeuralDistribution, ) from moscot.neural.data._policy_loader import PolicyDataLoader -__all__ = ["PolicyDataLoader", "DistributionCollection", "DistributionContainer"] +__all__ = ["PolicyDataLoader", "DistributionCollection", "NeuralDistribution"] diff --git a/src/moscot/neural/data/_distribution_collection.py b/src/moscot/neural/data/_distribution_collection.py index a07e37bc1..fd0109d6c 100644 --- a/src/moscot/neural/data/_distribution_collection.py +++ b/src/moscot/neural/data/_distribution_collection.py @@ -1,192 +1,67 @@ from dataclasses import dataclass -from typing import Any, Hashable, Literal, Optional, Tuple, TypeVar, Union +from typing import Any, ClassVar, Hashable, Optional, TypeVar, Union import jax import jax.numpy as jnp -import numpy as np -import scipy.sparse as sp - -from anndata import AnnData - -from moscot._logging import logger -from moscot._types import CostFn_t -from moscot.costs import get_cost K = TypeVar("K", bound=Hashable) @dataclass(frozen=True, repr=True) -class DistributionContainer: - """Data container for OT problems involving more than two distributions. +class NeuralDistribution: + """Data container representing a distribution to be used in OT-based flow models. - TODO + Can be either a source or target distribution. + Keep in mind that if a OT-based flow model is used, + sizes such as `n_shared_features`, `n_flow` should be same + in general for both source and target distributions. Parameters ---------- - xy + shared_space (n_samples, n_shared_features) Distribution living in a shared space. - xx + Used for the linear term of matching step. + I.e., given to the matching function of OT Based Flow Models + incomparable_space (n_samples, n_incomparable_features) Distribution living in an incomparable space. - conditions - Conditions for the distributions. - cost_xy - Cost function when in the shared space. - cost_xx - Cost function in the incomparable space. - """ + Used for the quadratic term of matching step. + condition (n_samples, n_conditions) + Condition for the distributions. + augment (n_samples, n_augment) + Augmentation to be used in the flow model. + flow (n_samples, n_flow) + Often equal to `shared_space` but can be different. + This will either be given to the flow model as primary input + (not the case for GENOT as GENOT uses noise instead) or + the output of the flow model. - xy: Optional[jax.Array] - xx: Optional[jax.Array] - conditions: Optional[jax.Array] - cost_xy: Any - cost_xx: Any + """ - @property - def contains_linear(self) -> bool: - """Whether the distribution contains data corresponding to the linear term.""" - return self.xy is not None + shared_space: jax.Array + incomparable_space: Optional[jax.Array] = None + condition: Optional[jax.Array] = None + augment: Optional[jax.Array] = None + flow: Optional[jax.Array] = None - @property - def contains_quadratic(self) -> bool: - """Whether the distribution contains data corresponding to the quadratic term.""" - return self.xx is not None + FIELDS: ClassVar[tuple[str]] = ["shared_space", "incomparable_space", "condition", "augment", "flow"] - @property - def contains_condition(self) -> bool: - """Whether the distribution contains data corresponding to the condition.""" - return self.conditions is not None + def __post_init__(self) -> None: + fields = ["shared_space", "incomparable_space", "condition", "augment", "flow"] + if all(getattr(self, field) is None for field in fields): + raise ValueError(f"At least one of the fields `{fields}` must be provided.") + given_fields = [field for field in fields if getattr(self, field) is not None] + # if all number of samples are not equal + if len({getattr(self, field).shape[0] for field in given_fields}) > 1: + raise ValueError("All fields must have the same number of samples.") @property def n_samples(self) -> int: """Number of samples in the distribution.""" - return self.xy.shape[0] if self.contains_linear else self.xx.shape[0] # type: ignore[union-attr] - - @staticmethod - def _extract_data( - adata: AnnData, - *, - attr: Literal["X", "obs", "obsp", "obsm", "var", "varm", "layers", "uns"], - key: Optional[str] = None, - ) -> jax.Array: - modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" - data = getattr(adata, attr) - - try: - if key is not None: - data = data[key] - except KeyError: - raise KeyError(f"Unable to fetch data from `{modifier}`.") from None - except IndexError: - raise IndexError(f"Unable to fetch data from `{modifier}`.") from None - - if attr == "obs": - data = np.asarray(data)[:, None] - if sp.issparse(data): - logger.warning(f"Densifying data in `{modifier}`") - data = data.toarray() - if data.ndim != 2: - raise ValueError(f"Expected `{modifier}` to have `2` dimensions, found `{data.ndim}`.") - - return jnp.array(data) - - @staticmethod - def _verify_input( - xy_attr: Optional[Literal["X", "obsp", "obsm", "layers", "uns"]], - xy_key: Optional[str], - xx_attr: Optional[Literal["X", "obsp", "obsm", "layers", "uns"]], - xx_key: Optional[str], - conditions_attr: Optional[Literal["obs", "var", "obsm", "varm", "layers", "uns"]], - conditions_key: Optional[str], - ) -> Tuple[bool, bool, bool]: - if (xy_attr is None and xy_key is not None) or (xy_attr is not None and xy_key is None): - raise ValueError(r"Either both `xy_attr` and `xy_key` must be `None` or none of them.") - if (xx_attr is None and xx_key is not None) or (xx_attr is not None and xx_key is None): - raise ValueError(r"Either both `xy_attr` and `xy_key` must be `None` or none of them.") - if (conditions_attr is None and conditions_key is not None) or ( - conditions_attr is not None and conditions_key is None - ): - raise ValueError(r"Either both `conditions_attr` and `conditions_key` must be `None` or none of them.") - return xy_attr is not None, xx_attr is not None, conditions_attr is not None - - @classmethod - def from_adata( - cls, - adata: AnnData, - xy_attr: Literal["X", "obsp", "obsm", "layers", "uns"] = None, - xy_key: Optional[str] = None, - xy_cost: CostFn_t = "sq_euclidean", - xx_attr: Literal["X", "obsp", "obsm", "layers", "uns"] = None, - xx_key: Optional[str] = None, - xx_cost: CostFn_t = "sq_euclidean", - conditions_attr: Optional[Literal["obs", "var", "obsm", "varm", "layers", "uns"]] = None, - conditions_key: Optional[str] = None, - backend: Literal["ott"] = "ott", - **kwargs: Any, - ) -> "DistributionContainer": - """Create distribution container from :class:`~anndata.AnnData`. - - .. warning:: - Sparse arrays will be always densified. - - Parameters - ---------- - adata - Annotated data object. - a - Marginals when used as source distribution. - b - Marginals when used as target distribution. - xy_attr - Attribute of `adata` containing the data for the shared space. - xy_key - Key of `xy_attr` containing the data for the shared space. - xy_cost - Cost function when in the shared space. - xx_attr - Attribute of `adata` containing the data for the incomparable space. - xx_key - Key of `xx_attr` containing the data for the incomparable space. - xx_cost - Cost function in the incomparable space. - conditions_attr - Attribute of `adata` containing the conditions. - conditions_key - Key of `conditions_attr` containing the conditions. - backend - Backend to use. - kwargs - Keyword arguments to pass to the cost functions. - - Returns - ------- - The distribution container. - """ - contains_linear, contains_quadratic, contains_condition = cls._verify_input( - xy_attr, xy_key, xx_attr, xx_key, conditions_attr, conditions_key - ) - - if contains_linear: - xy_data = cls._extract_data(adata, attr=xy_attr, key=xy_key) - xy_cost_fn = get_cost(xy_cost, backend=backend, **kwargs) - else: - xy_data = None - xy_cost_fn = None - - if contains_quadratic: - xx_data = cls._extract_data(adata, attr=xx_attr, key=xx_key) - xx_cost_fn = get_cost(xx_cost, backend=backend, **kwargs) - else: - xx_data = None - xx_cost_fn = None - - conditions_data = ( - cls._extract_data(adata, attr=conditions_attr, key=conditions_key) if contains_condition else None # type: ignore[arg-type] # noqa:E501 - ) - return cls(xy=xy_data, xx=xx_data, conditions=conditions_data, cost_xy=xy_cost_fn, cost_xx=xx_cost_fn) + return self.shared_space.shape[0] def __getitem__( self, idx: Union[int, slice, jnp.ndarray, jax.Array, list[Any], tuple[Any]] - ) -> "DistributionContainer": + ) -> "NeuralDistribution": """ Return a new DistributionContainer where .xy, .xx, .conditions are sliced by `idx` (if they are not None). @@ -198,27 +73,20 @@ def __getitem__( # But we first need to separate the slicing of training and validation data # Before creating this DistributionContainer! # Slice xy - new_xy = self.xy[idx] if self.xy is not None else None - - # Slice xx - new_xx = self.xx[idx] if self.xx is not None else None + given_fields = [field for field in self.FIELDS if getattr(self, field) is not None] + return NeuralDistribution(**{field: getattr(self, field)[idx] for field in given_fields}) - # Slice conditions - new_conditions = self.conditions[idx] if self.conditions is not None else None - # Reuse the same cost functions - return DistributionContainer( - xy=new_xy, - xx=new_xx, - conditions=new_conditions, - cost_xy=self.cost_xy, - cost_xx=self.cost_xx, - ) - - -class DistributionCollection(dict[K, DistributionContainer]): +@dataclass +class DistributionCollection(dict[K, NeuralDistribution]): """Collection of distributions.""" + def __post_init__(self) -> None: + # check if all the shared spaces have the same shape[1] + shared_spaces = {dist.shared_space.shape[1] for dist in self.values()} + if len(shared_spaces) > 1: + raise ValueError("All shared spaces must have the same number of features.") + def __repr__(self) -> str: return f"{self.__class__.__name__}{list(self.keys())}" diff --git a/src/moscot/neural/data/_policy_loader.py b/src/moscot/neural/data/_policy_loader.py index 3fef84034..4db83b27b 100644 --- a/src/moscot/neural/data/_policy_loader.py +++ b/src/moscot/neural/data/_policy_loader.py @@ -1,10 +1,13 @@ import functools -from typing import Any, Dict, Hashable, Iterator, Optional, Sequence, Tuple, TypeVar +from typing import Any, Dict, Hashable, Iterator, Optional, Tuple, TypeVar import jax import jax.numpy as jnp -from moscot.neural.data._distribution_collection import DistributionCollection +from moscot.neural.data._distribution_collection import ( + DistributionCollection, + NeuralDistribution, +) from moscot.utils.subset_policy import SubsetPolicy K = TypeVar("K", bound=Hashable) @@ -58,46 +61,30 @@ def __init__( policy: SubsetPolicy[Any], distributions: DistributionCollection[K], batch_size: int = 128, - plan: Optional[Sequence[Tuple[Any, Any]]] = None, src_prefix: str = "src", tgt_prefix: str = "tgt", src_renames: Optional[Dict[str, str]] = None, tgt_renames: Optional[Dict[str, str]] = None, ): - self.policy = policy self.distributions = distributions self.rng = rng self.batch_size = batch_size - self.edges = plan if plan is not None else self.policy.plan() + self.edges = self.policy.plan() self.src_prefix = src_prefix self.tgt_prefix = tgt_prefix self.src_renames = src_renames if src_renames is not None else {} self.tgt_renames = tgt_renames if tgt_renames is not None else {} - + self.fields = NeuralDistribution.FIELDS # Precompute an index array for each node self.node_indices: Dict[Any, jnp.ndarray] = {} self._init_indices() def _init_indices(self) -> None: - """Verify shape consistency within each DistributionContainer, store jnp.arange(...) as node_indices.""" + """Precompute an index array for each node.""" for node, container in self.distributions.items(): - # Gather shapes of non-None arrays - shapes = [] - if container.xy is not None: - shapes.append(container.xy.shape[0]) - if container.xx is not None: - shapes.append(container.xx.shape[0]) - if container.conditions is not None: - shapes.append(container.conditions.shape[0]) - - # All must match - if shapes and not all(s == shapes[0] for s in shapes): - raise ValueError(f"Inconsistent shape for node {node}: {shapes}") - - if shapes: - n = shapes[0] - self.node_indices[node] = jnp.arange(n) + idx = jnp.arange(container.n_samples) + self.node_indices[node] = idx def __iter__(self) -> Iterator[Dict[str, jnp.ndarray]]: """ @@ -138,24 +125,19 @@ def __iter__(self) -> Iterator[Dict[str, jnp.ndarray]]: batch_dict = {} src_candidates = [ - ("xy", src_container.xy), - ("xx", src_container.xx), - ("conditions", src_container.conditions), + (f, getattr(src_container, f)) for f in self.fields if getattr(src_container, f) is not None ] for key, arr in src_candidates: - if arr is not None: - key_new = self.src_renames.get(key, key) - batch_dict[f"{self.src_prefix}_{key_new}"] = _gather_array(arr, src_idxs) + key_new = self.src_renames.get(key, key) + batch_dict[f"{self.src_prefix}_{key_new}"] = _gather_array(arr, src_idxs) tgt_candidates = [ - ("xy", tgt_container.xy), - ("xx", tgt_container.xx), - ("conditions", tgt_container.conditions), + (f, getattr(tgt_container, f)) for f in self.fields if getattr(tgt_container, f) is not None ] for key, arr in tgt_candidates: - if arr is not None: - key_new = self.tgt_renames.get(key, key) - batch_dict[f"{self.tgt_prefix}_{key_new}"] = _gather_array(arr, tgt_idxs) + key_new = self.tgt_renames.get(key, key) + batch_dict[f"{self.tgt_prefix}_{key_new}"] = _gather_array(arr, tgt_idxs) + if not batch_dict: continue diff --git a/src/moscot/neural/problems/generic/_generic.py b/src/moscot/neural/problems/generic/_generic.py index abc5bb87f..3fcf6f2c2 100644 --- a/src/moscot/neural/problems/generic/_generic.py +++ b/src/moscot/neural/problems/generic/_generic.py @@ -1,48 +1,57 @@ import types from types import MappingProxyType -from typing import Any, Dict, Literal, Mapping, Tuple, Type, Union +from typing import ( + Any, + Dict, + Hashable, + Literal, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from moscot import _constants from moscot._types import CostKwargs_t, OttCostFn_t, Policy_t from moscot.neural.base.problems.problem import NeuralOTProblem -from moscot.problems._utils import ( - handle_conditional_attr, - handle_cost_tmp, - handle_joint_attr_tmp, -) __all__ = ["GENOTLinProblem"] +K = TypeVar("K", bound=Hashable) + -class GENOTLinProblem(NeuralOTProblem): +class GENOTLinProblem(NeuralOTProblem[K]): """Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems.""" def prepare( self, key: str, joint_attr: Union[str, Mapping[str, Any]], - conditional_attr: Union[str, Mapping[str, Any]], - # src_condition_attr: Union[str, Mapping[str, Any]], - # src_augment_attr: Optional[Union[str, Mapping[str, Any]]] = None, - # src_quad_attr: Optional[Union[str, Mapping[str, Any]]] = None, - # tgt_quad_attr: Optional[Union[str, Mapping[str, Any]]] = None, - # tgt_flow_attr: Optional[Union[str, Mapping[str, Any]]] = None, + condition_attr: Union[str, Mapping[str, Any]] = None, + src_quad_attr: Optional[Union[str, Mapping[str, Any]]] = None, + tgt_quad_attr: Optional[Union[str, Mapping[str, Any]]] = None, policy: Literal["sequential", "star", "explicit"] = "sequential", cost: OttCostFn_t = "sq_euclidean", cost_kwargs: CostKwargs_t = types.MappingProxyType({}), **kwargs: Any, - ) -> "GENOTLinProblem": + ) -> "GENOTLinProblem[K]": """Prepare the :class:`moscot.problems.generic.GENOTLinProblem`.""" self.batch_key = key - xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs) - conditions = handle_conditional_attr(conditional_attr) - xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs) + # TODO: These cost functions should be going to GENOT match_function somehow + del cost, cost_kwargs + lin = GENOTLinProblem._handle_attr(joint_attr) + src_quad = GENOTLinProblem._handle_attr(src_quad_attr) if src_quad_attr is not None else None + tgt_quad = GENOTLinProblem._handle_attr(tgt_quad_attr) if tgt_quad_attr is not None else None + condition = GENOTLinProblem._handle_attr(condition_attr) if condition_attr is not None else None return super().prepare( policy_key=key, policy=policy, - xy=xy, - xx=xx, - conditions=conditions, + lin=lin, + src_quad=src_quad, + tgt_quad=tgt_quad, + condition=condition, **kwargs, ) @@ -55,7 +64,7 @@ def solve( valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), train_size: float = 1.0, **kwargs: Any, - ) -> "GENOTLinProblem": + ) -> "GENOTLinProblem[K]": """Solve.""" return super().solve( batch_size=batch_size, @@ -69,8 +78,8 @@ def solve( ) @property - def _base_problem_type(self) -> Type[NeuralOTProblem]: - return NeuralOTProblem + def _base_problem_type(self) -> Type[NeuralOTProblem[K]]: + return NeuralOTProblem[K] @property def _valid_policies(self) -> Tuple[Policy_t, ...]: diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index a405a3da6..6965cbccf 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -40,6 +40,7 @@ def _handle_mapping_joint_attr( ) -> Tuple[Dict[str, Any], Union[Literal["local-pca"], Callback_t], Dict[str, Any]]: joint_attr = dict(joint_attr) xy_callback_kwargs = dict(xy_callback_kwargs) + if "attr" in joint_attr and joint_attr["attr"] == "X": return {"x_attr": "X", "y_attr": "X"}, xy_callback, xy_callback_kwargs # type: ignore[return-value] if "attr" in joint_attr and joint_attr["attr"] == "obsm": diff --git a/tests/neural/problems/generic/test_conditional_neural_problem.py b/tests/neural/problems/generic/test_conditional_neural_problem.py index 33b87d749..28a9dfc18 100644 --- a/tests/neural/problems/generic/test_conditional_neural_problem.py +++ b/tests/neural/problems/generic/test_conditional_neural_problem.py @@ -3,13 +3,12 @@ import jax.numpy as jnp import numpy as np -from ott.geometry import costs import anndata as ad from moscot.base.output import BaseSolverOutput from moscot.neural.base.problems import NeuralOTProblem -from moscot.neural.data import DistributionCollection, DistributionContainer +from moscot.neural.data import DistributionCollection, NeuralDistribution from moscot.neural.problems.generic import GENOTLinProblem # type: ignore[attr-defined] from tests._utils import ATOL, RTOL from tests.problems.conftest import neurallin_cond_args_1 @@ -19,26 +18,24 @@ class TestGENOTLinProblem: @pytest.mark.fast def test_prepare(self, adata_time: ad.AnnData): problem = GENOTLinProblem(adata=adata_time) - problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) + problem = problem.prepare(key="time", joint_attr="X_pca", condition_attr={"attr": "obs", "key": "time"}) assert isinstance(problem, NeuralOTProblem) assert isinstance(problem.distributions, DistributionCollection) assert list(problem.distributions.keys()) == [0, 1, 2] container = problem.distributions[0] n_obs_0 = adata_time[adata_time.obs["time"] == 0].n_obs - assert isinstance(container, DistributionContainer) - assert isinstance(container.xy, jnp.ndarray) - assert container.xy.shape == (n_obs_0, 50) - assert container.xx is None - assert isinstance(container.conditions, jnp.ndarray) - assert container.conditions.shape == (n_obs_0, 1) - assert isinstance(container.cost_xy, costs.SqEuclidean) - assert container.cost_xx is None + assert isinstance(container, NeuralDistribution) + assert isinstance(container.shared_space, jnp.ndarray) + assert isinstance(container.condition, jnp.ndarray) + assert container.shared_space.shape == (n_obs_0, 50) + assert container.incomparable_space is None + assert container.condition.shape == (n_obs_0, 1) @pytest.mark.parametrize("train_size", [0.9, 1.0]) def test_solve_balanced_no_baseline(self, adata_time: ad.AnnData, train_size: float): # type: ignore[no-untyped-def] # noqa: E501 problem = GENOTLinProblem(adata=adata_time) - problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) + problem = problem.prepare(key="time", joint_attr="X_pca", condition_attr={"attr": "obs", "key": "time"}) problem = problem.solve(train_size=train_size, **neurallin_cond_args_1) assert isinstance(problem.solution, BaseSolverOutput) @@ -47,12 +44,12 @@ def test_reproducibility(self, adata_time: ad.AnnData): pc_tzero = adata_time[cond_zero_mask].obsm["X_pca"] problem_one = GENOTLinProblem(adata=adata_time) problem_one = problem_one.prepare( - key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}, seed=0 + key="time", joint_attr="X_pca", condition_attr={"attr": "obs", "key": "time"}, seed=0 ) problem_one = problem_one.solve(**neurallin_cond_args_1) problem_two = GENOTLinProblem(adata=adata_time) problem_two = problem_two.prepare( - key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}, seed=0 + key="time", joint_attr="X_pca", condition_attr={"attr": "obs", "key": "time"}, seed=0 ) problem_two = problem_two.solve(**neurallin_cond_args_1) assert np.allclose( @@ -77,7 +74,7 @@ def test_reproducibility(self, adata_time: ad.AnnData): def test_pass_custom_optimizers(self, adata_time: ad.AnnData): problem = GENOTLinProblem(adata=adata_time) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))] - problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) + problem = problem.prepare(key="time", joint_attr="X_pca", condition_attr={"attr": "obs", "key": "time"}) custom_opt = optax.adagrad(1e-4) problem = problem.solve(iterations=2, optimizer=custom_opt) From be67c6ed19625e999004302e157af2afcd2c99ba Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 10 Jan 2025 20:59:10 +0100 Subject: [PATCH 5/5] update doc --- docs/developer.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/developer.rst b/docs/developer.rst index 68e9b110d..110af5080 100644 --- a/docs/developer.rst +++ b/docs/developer.rst @@ -99,8 +99,6 @@ Miscellaneous data.apoptosis_markers tagged_array.TaggedArray tagged_array.Tag - tagged_array.DistributionCollection - tagged_array.DistributionContainer .. currentmodule:: moscot.base.problems .. autosummary::