Skip to content
Draft
583 changes: 583 additions & 0 deletions notebooks/structural_components_dataclass.ipynb

Large diffs are not rendered by default.

241 changes: 241 additions & 0 deletions pymc_extras/statespace/core/properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from __future__ import annotations

import warnings

from collections.abc import Iterator
from copy import deepcopy
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Generic, Self, TypeVar

from pymc_extras.statespace.core import PyMCStateSpace
from pymc_extras.statespace.utils.constants import (
ALL_STATE_AUX_DIM,
ALL_STATE_DIM,
OBS_STATE_AUX_DIM,
OBS_STATE_DIM,
SHOCK_AUX_DIM,
SHOCK_DIM,
)

if TYPE_CHECKING:
from pymc_extras.statespace.models.structural.core import Component


@dataclass(frozen=True)
class Property:
def __str__(self) -> str:
return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))


T = TypeVar("T", bound=Property)


@dataclass(frozen=True)
class Info(Generic[T]):
items: tuple[T, ...]
key_field: str = "name"
_index: dict[str, T] | None = None

def __post_init__(self):
index = {}
missing_attr = []
for item in self.items:
if not hasattr(item, self.key_field):
missing_attr.append(item)
continue
key = getattr(item, self.key_field)
# if key in index:
# raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") # This needs to be possible for shared states
Comment on lines +47 to +48
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That shouldn't happen here though, it should come up in merge or add right? And we handle it there with the allow_duplicates flag

index[key] = item
if missing_attr:
raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}")
object.__setattr__(self, "_index", index)

def _key(self, item: T) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used?

return getattr(item, self.key_field)

def get(self, key: str, default=None) -> T | None:
return self._index.get(key, default)

def __getitem__(self, key: str) -> T:
try:
return self._index[key]
except KeyError as e:
available = ", ".join(self._index.keys())
raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e

def __contains__(self, key: object) -> bool:
return key in self._index

def __iter__(self) -> Iterator[T]:
return iter(self.items)

def __len__(self) -> int:
return len(self.items)

def __str__(self) -> str:
return f"{self.key_field}s: {list(self._index.keys())}"

def add(self, new_item: T):
return type(self)([*self.items, new_item])

def merge(self, other: Self, allow_duplicates: bool = False) -> Self:
if not isinstance(other, type(self)):
raise TypeError(f"Cannot merge {type(other).__name__} with {type(self).__name__}")

overlapping = set(self.names) & set(other.names)
if overlapping and not allow_duplicates:
raise ValueError(f"Duplicate names found: {overlapping}")

return type(self)(list(self.items) + list(other.items))

@property
def names(self) -> tuple[str, ...]:
return tuple(self._index.keys())

def copy(self) -> Info[T]:
return deepcopy(self)


@dataclass(frozen=True)
class Parameter(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
constraints: str | None = None


@dataclass(frozen=True)
class ParameterInfo(Info[Parameter]):
def __init__(self, parameters: list[Parameter]):
super().__init__(items=tuple(parameters), key_field="name")


@dataclass(frozen=True)
class Data(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
is_exogenous: bool


@dataclass(frozen=True)
class DataInfo(Info[Data]):
def __init__(self, data: list[Data]):
super().__init__(items=tuple(data), key_field="name")

@property
def needs_exogenous_data(self) -> bool:
return any(d.is_exogenous for d in self.items)

@property
def exogenous_names(self) -> tuple[str, ...]:
return tuple(d.name for d in self.items if d.is_exogenous)

def __str__(self) -> str:
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"


@dataclass(frozen=True)
class Coord(Property):
dimension: str
labels: tuple[str, ...]


@dataclass(frozen=True)
class CoordInfo(Info[Coord]):
def __init__(self, coords: list[Coord]):
super().__init__(items=tuple(coords), key_field="dimension")

def __str__(self) -> str:
base = "coordinates:"
for coord in self.items:
coord_str = str(coord)
indented = "\n".join(" " + line for line in coord_str.splitlines())
base += "\n" + indented + "\n"
return base

@classmethod
def default_coords_from_model(
cls, model: PyMCStateSpace | Component
) -> (
Self
): # TODO: Need to figure out how to include Component type was causing circular import issues
states = tuple(model.state_names)
obs_states = tuple(model.observed_states)
shocks = tuple(model.shock_names)

dim_to_labels = (
(ALL_STATE_DIM, states),
(ALL_STATE_AUX_DIM, states),
(OBS_STATE_DIM, obs_states),
(OBS_STATE_AUX_DIM, obs_states),
(SHOCK_DIM, shocks),
(SHOCK_AUX_DIM, shocks),
)

coords = [Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels]
return cls(coords)

def to_dict(self):
return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0}


@dataclass(frozen=True)
class State(Property):
name: str
observed: bool
shared: bool


@dataclass(frozen=True)
class StateInfo(Info[State]):
def __init__(self, states: list[State]):
super().__init__(items=tuple(states), key_field="name")

def __str__(self) -> str:
return (
f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}"
)

@property
def observed_states(self) -> tuple[State, ...]: # Is this needed??
return tuple(s for s in self.items if s.observed)
Comment on lines +201 to +203
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to keep it (as an alias for observed_state_names), then pick one to be the "canonical" name and just return that one from the other ones, rather than re-writing the loop in several places


@property
def observed_state_names(self) -> tuple[State, ...]:
return tuple(s.name for s in self.items if s.observed)

@property
def unobserved_state_names(self) -> tuple[State, ...]:
return tuple(s.name for s in self.items if not s.observed)

def merge(self, other: StateInfo, allow_duplicates: bool = False) -> StateInfo:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't the base class version work?

"""Combine states from two StateInfo objects."""
if not isinstance(other, StateInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo")

overlapping = set(self.names) & set(other.names)
if overlapping and not allow_duplicates:
# This is necessary for shared states
warnings.warn(
f"Duplicate state names found: {overlapping}. Merge will ONLY retain unique states",
UserWarning,
)
return StateInfo(
states=list(self.items)
+ [item for item in other.items if item.name not in overlapping]
)

return StateInfo(states=list(self.items) + list(other.items))


@dataclass(frozen=True)
class Shock(Property):
name: str


@dataclass(frozen=True)
class ShockInfo(Info[Shock]):
def __init__(self, shocks: list[Shock]):
super().__init__(items=tuple(shocks), key_field="name")
Loading
Loading