Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ Miscellaneous
data.apoptosis_markers
tagged_array.TaggedArray
tagged_array.Tag
tagged_array.DistributionCollection
tagged_array.DistributionContainer

.. currentmodule:: moscot.base.problems
.. autosummary::
Expand Down
3 changes: 1 addition & 2 deletions src/moscot/backends/ott/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from moscot.backends.ott._utils import sinkhorn_divergence
from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
from moscot.costs import register_cost

__all__ = [
Expand All @@ -11,7 +11,6 @@
"SinkhornSolver",
"NeuralOutput",
"sinkhorn_divergence",
"GENOTLinSolver",
"GraphOTTOutput",
]

Expand Down
213 changes: 6 additions & 207 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import abc
import functools
import inspect
import math
import types
from typing import (
Any,
Hashable,
List,
Literal,
Mapping,
NamedTuple,
Expand All @@ -17,21 +14,13 @@
Union,
)

import optax

import jax
import jax.numpy as jnp
import numpy as np
from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
from ott.neural.datasets import OTData, OTDataset
from ott.neural.methods.flows import dynamics, genot
from ott.neural.networks.layers import time_encoder
from ott.neural.networks.velocity_field import VelocityField
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
from ott.solvers.utils import uniform_sampler

from moscot._logging import logger
from moscot._types import (
Expand All @@ -43,23 +32,23 @@
)
from moscot.backends.ott._utils import (
InitializerResolver,
Loader,
MultiLoader,
_instantiate_geodesic_cost,
alpha_to_fused_penalty,
check_shapes,
convert_scipy_sparse,
data_match_fn,
densify,
ensure_2d,
)
from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
from moscot.base.problems._utils import TimeScalesHeatKernel
from moscot.base.solver import OTSolver
from moscot.costs import get_cost
from moscot.utils.tagged_array import DistributionCollection, TaggedArray
from moscot.utils.tagged_array import TaggedArray

__all__ = ["SinkhornSolver", "GWSolver", "GENOTLinSolver"]
__all__ = [
"SinkhornSolver",
"GWSolver",
]

OTTSolver_t = Union[
sinkhorn.Sinkhorn,
Expand Down Expand Up @@ -516,193 +505,3 @@ def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
problem_kwargs -= {"geom_xx", "geom_yy", "geom_xy", "fused_penalty"}
problem_kwargs |= {"alpha"}
return geom_kwargs | problem_kwargs, {"epsilon"}


class GENOTLinSolver(OTSolver[OTTOutput]):
"""Solver class for genot.GENOT linear :cite:`klein2023generative`."""

def __init__(self, **kwargs: Any) -> None:
"""Initiate the class with any kwargs passed to the ott-jax class."""
super().__init__()
self._train_sampler: Optional[MultiLoader] = None
self._valid_sampler: Optional[MultiLoader] = None
self._neural_kwargs = kwargs

@property
def problem_kind(self) -> ProblemKind_t: # noqa: D102
return "linear"

def _prepare( # type: ignore[override]
self,
distributions: DistributionCollection[K],
sample_pairs: List[Tuple[Any, Any]],
train_size: float = 0.9,
batch_size: int = 1024,
is_conditional: bool = True,
**kwargs: Any,
) -> Tuple[MultiLoader, MultiLoader]:
train_loaders = []
validate_loaders = []
seed = kwargs.get("seed")
is_aligned = kwargs.get("is_aligned", False)
if train_size == 1.0:
for sample_pair in sample_pairs:
source_key = sample_pair[0]
target_key = sample_pair[1]
src_data = OTData(
lin=distributions[source_key].xy,
condition=distributions[source_key].conditions if is_conditional else None,
)
tgt_data = OTData(
lin=distributions[target_key].xy,
condition=distributions[target_key].conditions if is_conditional else None,
)
dataset = OTDataset(src_data=src_data, tgt_data=tgt_data, seed=seed, is_aligned=is_aligned)
loader = Loader(dataset, batch_size=batch_size, seed=seed)
train_loaders.append(loader)
validate_loaders.append(loader)
else:
if train_size > 1.0 or train_size <= 0.0:
raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1")

seed = kwargs.get("seed", 0)
for sample_pair in sample_pairs:
source_key = sample_pair[0]
target_key = sample_pair[1]
source_data: ArrayLike = distributions[source_key].xy
target_data: ArrayLike = distributions[target_key].xy
source_split_data = self._split_data(
source_data,
conditions=distributions[source_key].conditions,
train_size=train_size,
seed=seed,
a=distributions[source_key].a,
b=distributions[source_key].b,
)
target_split_data = self._split_data(
target_data,
conditions=distributions[target_key].conditions,
train_size=train_size,
seed=seed,
a=distributions[target_key].a,
b=distributions[target_key].b,
)
src_data_train = OTData(
lin=source_split_data.data_train,
condition=source_split_data.conditions_train if is_conditional else None,
)
tgt_data_train = OTData(
lin=target_split_data.data_train,
condition=target_split_data.conditions_train if is_conditional else None,
)
train_dataset = OTDataset(
src_data=src_data_train, tgt_data=tgt_data_train, seed=seed, is_aligned=is_aligned
)
train_loader = Loader(train_dataset, batch_size=batch_size, seed=seed)
src_data_validate = OTData(
lin=source_split_data.data_valid,
condition=source_split_data.conditions_valid if is_conditional else None,
)
tgt_data_validate = OTData(
lin=target_split_data.data_valid,
condition=target_split_data.conditions_valid if is_conditional else None,
)
validate_dataset = OTDataset(
src_data=src_data_validate, tgt_data=tgt_data_validate, seed=seed, is_aligned=is_aligned
)
validate_loader = Loader(validate_dataset, batch_size=batch_size, seed=seed)
train_loaders.append(train_loader)
validate_loaders.append(validate_loader)
source_dim = self._neural_kwargs.get("input_dim", 0)
target_dim = source_dim
condition_dim = self._neural_kwargs.get("cond_dim", 0)
# TODO(ilan-gold): What are reasonable defaults here?
neural_vf = VelocityField(
output_dims=[*self._neural_kwargs.get("velocity_field_output_dims", []), target_dim],
condition_dims=(
self._neural_kwargs.get("velocity_field_condition_dims", [source_dim + condition_dim])
if is_conditional
else None
),
hidden_dims=self._neural_kwargs.get("velocity_field_hidden_dims", [1024, 1024, 1024]),
time_dims=self._neural_kwargs.get("velocity_field_time_dims", None),
time_encoder=self._neural_kwargs.get(
"velocity_field_time_encoder", functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024)
),
)
seed = self._neural_kwargs.get("seed", 0)
rng = jax.random.PRNGKey(seed)
data_match_fn_kwargs = self._neural_kwargs.get(
"data_match_fn_kwargs",
{} if "data_match_fn" in self._neural_kwargs else {"epsilon": 1e-1, "tau_a": 1.0, "tau_b": 1.0},
)
time_sampler = self._neural_kwargs.get("time_sampler", uniform_sampler)
optimizer = self._neural_kwargs.get("optimizer", optax.adam(learning_rate=1e-4))
self._solver = genot.GENOT(
vf=neural_vf,
flow=self._neural_kwargs.get(
"flow",
dynamics.ConstantNoiseFlow(0.1),
),
data_match_fn=functools.partial(
self._neural_kwargs.get("data_match_fn", data_match_fn), typ="lin", **data_match_fn_kwargs
),
source_dim=source_dim,
target_dim=target_dim,
condition_dim=condition_dim if is_conditional else None,
optimizer=optimizer,
time_sampler=time_sampler,
rng=rng,
latent_noise_fn=self._neural_kwargs.get("latent_noise_fn", None),
**self._neural_kwargs.get("velocity_field_train_state_kwargs", {}),
)
return (
MultiLoader(datasets=train_loaders, seed=seed),
MultiLoader(datasets=validate_loaders, seed=seed),
)

def _split_data( # TODO: adapt for Gromov terms
self,
x: ArrayLike,
conditions: Optional[ArrayLike],
train_size: float,
seed: int,
a: Optional[ArrayLike] = None,
b: Optional[ArrayLike] = None,
) -> SingleDistributionData:
n_samples_x = x.shape[0]
n_train_x = math.ceil(train_size * n_samples_x)
rng = np.random.default_rng(seed)
x = rng.permutation(x)
if a is not None:
a = rng.permutation(a)
if b is not None:
b = rng.permutation(b)

return SingleDistributionData(
data_train=x[:n_train_x],
data_valid=x[n_train_x:],
conditions_train=conditions[:n_train_x] if conditions is not None else None,
conditions_valid=conditions[n_train_x:] if conditions is not None else None,
a_train=a[:n_train_x] if a is not None else None,
a_valid=a[n_train_x:] if a is not None else None,
b_train=b[:n_train_x] if b is not None else None,
b_valid=b[n_train_x:] if b is not None else None,
)

@property
def solver(self) -> genot.GENOT:
"""Underlying optimal transport solver."""
return self._solver

@classmethod
def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value]

def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> NeuralOutput: # type: ignore[override]
seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests
rng = jax.random.PRNGKey(seed)
logs = self.solver(
data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng
) # TODO(ilan-gold): validation and figure out defualts
return NeuralOutput(self.solver, logs)
41 changes: 26 additions & 15 deletions src/moscot/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Literal, Tuple, TypeVar, Union

from moscot import _registry
from moscot._types import ProblemKind_t

if TYPE_CHECKING:
from moscot.backends import ott
from moscot.neural.backends import neural_ott

__all__ = ["get_solver", "register_solver", "get_available_backends"]

register_solver_t = Callable[
[Literal["linear", "quadratic"], Optional[Literal["GENOTLinSolver"]]],
Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"],
]
__all__ = ["get_solver", "register_solver", "get_available_backends"]


_REGISTRY = _registry.Registry()

GWSolver = TypeVar("GWSolver", bound="ott.GWSolver")
SinkhornSolver = TypeVar("SinkhornSolver", bound="ott.SinkhornSolver")
GENOTSolver = TypeVar("GENOTSolver", bound="neural_ott.GENOTSolver")


def get_solver(problem_kind: ProblemKind_t, *, backend: str = "ott", return_class: bool = False, **kwargs: Any) -> Any:
"""TODO."""
Expand All @@ -27,7 +28,7 @@ def get_solver(problem_kind: ProblemKind_t, *, backend: str = "ott", return_clas

def register_solver(
backend: str,
) -> Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"]:
) -> Union[SinkhornSolver, GWSolver, GENOTSolver]:
"""Register a solver for a specific backend.

Parameters
Expand All @@ -42,23 +43,33 @@ def register_solver(
return _REGISTRY.register(backend) # type: ignore[return-value]


@register_solver("ott")
def _(
@register_solver("ott") # type: ignore[misc]
def create_ott_solver(
problem_kind: Literal["linear", "quadratic"],
solver_name: Optional[Literal["GENOTLinSolver"]] = None,
) -> Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"]:
solver_name: Any = None,
) -> Union[SinkhornSolver, GWSolver]:
from moscot.backends import ott

if problem_kind == "linear":
if solver_name == "GENOTLinSolver":
return ott.GENOTLinSolver # type: ignore[return-value]
if solver_name is None:
return ott.SinkhornSolver # type: ignore[return-value]
return ott.SinkhornSolver # type: ignore[return-value]
if problem_kind == "quadratic":
return ott.GWSolver # type: ignore[return-value]
raise NotImplementedError(f"Unable to create solver for `{problem_kind!r}`, {solver_name} problem.")


@register_solver("neural_ott")
def create_neural_ott_solver(
problem_kind: Literal["linear", "quadratic"],
solver_name: Any = None,
) -> GENOTSolver:
from moscot.neural.backends import neural_ott

if solver_name == "GENOTSolver":
return neural_ott.GENOTSolver

raise NotImplementedError(f"Unable to create solver for `{problem_kind!r}`, {solver_name} problem.")


def get_available_backends() -> Tuple[str, ...]:
"""Return all available backends."""
return tuple(backend for backend in _REGISTRY)
Empty file.
3 changes: 3 additions & 0 deletions src/moscot/neural/backends/neural_ott/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from moscot.neural.backends.neural_ott.solver import GENOTSolver

__all__ = ["GENOTSolver"]
Loading
Loading