Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions corrai/base/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime as dt
from typing import Union

import pandas as pd
Expand Down Expand Up @@ -346,3 +347,39 @@ def simulate(
setattr(self, prop, val)

return pd.Series({"res": self.prop_1 * self.prop_2 + self.prop_3})


class Sine(PyModel):
def __init__(self):
super().__init__(is_dynamic=True)
self.omega = 2
self.amplitude = 5

def simulate(
self,
property_dict: dict[str, str | int | float] = None,
simulation_options: dict = None,
**simulation_kwargs,
) -> pd.DataFrame | pd.Series:
self.set_property_values(property_dict)

start = simulation_options.get("start", "2009-01-01 00:00:00")
stop = simulation_options.get("stop", "2009-01-02 00:00:00")
output_freq = simulation_options.get("freq", "h")

index = pd.date_range(start, stop, freq=output_freq, tz="UTC")
cumsum_second = np.arange(
0, (index[-1] - index[0]).total_seconds() + 1, step=3600
)

return pd.DataFrame(
data=self.amplitude
* np.sin(
self.omega
* np.pi
/ dt.timedelta(days=1).total_seconds()
* cumsum_second
),
columns=["res"],
index=index,
)
26 changes: 26 additions & 0 deletions corrai/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import datetime as dt
from collections.abc import Iterable
from typing import Callable


def _reshape_1d(sample):
Expand Down Expand Up @@ -258,3 +259,28 @@ def get_reversed_dict(dictionary, values=None):
values = [values]

return {val: key for key, val in dictionary.items() if val in values}


def check_indicators_configs(
is_dynamic: bool,
indicators_configs: list[str]
| list[tuple[str, str | Callable] | tuple[str, str | Callable, pd.Series]]
| None,
):
if is_dynamic:
if indicators_configs is None:
raise ValueError(
"Model is dynamic. At least one indicators and its aggregation "
"method must be provided"
)
if isinstance(indicators_configs[0], str):
raise ValueError(
"Invalid 'indicators_configs'. Model is dynamic"
"At least 'method' is required"
)
else:
if indicators_configs is not None and isinstance(indicators_configs[0], tuple):
raise ValueError(
"Invalid 'indicators_configs'. Model is static. "
"'indicators_configs' must be a list of string"
)
67 changes: 48 additions & 19 deletions corrai/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from corrai.base.math import METHODS
from corrai.base.model import Model
from corrai.base.utils import check_indicators_configs
from corrai.sampling import Sample
from corrai.base.parameter import Parameter

Expand Down Expand Up @@ -209,35 +210,18 @@ def evaluate(
np.array([[val[1] for val in parameter_value_pairs]]), [res]
)

check_indicators_configs(self.model.is_dynamic, indicators_configs)

if self.model.is_dynamic:
if indicators_configs is None:
raise ValueError(
"Model is dynamic. At least one indicators and its aggregation "
"method must be provided"
)
if isinstance(indicators_configs[0], str):
raise ValueError(
"Invalid 'indicators_configs'. Model is dynamic"
"At least 'method' is required"
)
results = pd.Series()
for config in indicators_configs:
col, func, *extra = config
series = res[col]

if isinstance(func, str):
func = METHODS[func]

results[col] = func(series, *extra)
return pd.Series(results)
else:
if indicators_configs is not None and isinstance(
indicators_configs[0], tuple
):
raise ValueError(
"Invalid 'indicators_configs'. Model is static. "
"'indicators_configs' must be a list of string"
)
return res[indicators_configs] if indicators_configs is not None else res

def scipy_obj_function(self, x: np.ndarray, *args) -> float:
Expand Down Expand Up @@ -830,3 +814,48 @@ def plot_sample(
quantile_band=quantile_band,
type_graph=type_graph,
)

@wraps(Sample.plot_pcp)
def plot_pcp(
self,
indicators_configs: list[str]
| list[tuple[str, str | Callable] | tuple[str, str | Callable, pd.Series]],
color_by: str | None = None,
title: str | None = "Parallel Coordinates — Samples",
html_file_path: str | None = None,
) -> go.Figure:
return self.model_evaluator.sample.plot_pcp(
indicators_configs=indicators_configs,
color_by=color_by,
title=title,
html_file_path=html_file_path,
)

@wraps(Sample.plot_hist)
def plot_hist(
self,
indicator: str,
method: str = "mean",
unit: str = "",
agg_method_kwarg: dict = None,
reference_time_series: pd.Series = None,
bins: int = 30,
colors: str = "orange",
reference_value: int | float = None,
reference_label: str = "Reference",
show_rug: bool = False,
title: str = None,
):
return self.model_evaluator.sample.plot_hist(
indicator=indicator,
method=method,
unit=unit,
agg_method_kwarg=agg_method_kwarg,
reference_time_series=reference_time_series,
bins=bins,
colors=colors,
reference_value=reference_value,
reference_label=reference_label,
show_rug=show_rug,
title=title,
)
157 changes: 104 additions & 53 deletions corrai/sampling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import wraps
from typing import Union
from typing import Union, Callable

import numpy as np
import pandas as pd
Expand All @@ -15,48 +15,42 @@
from SALib.sample import sobol as sobol_sampler
from SALib.sample import fast_sampler, latin

from corrai.base.utils import check_indicators_configs
from corrai.base.parameter import Parameter
from corrai.base.model import Model
from corrai.base.math import aggregate_time_series
from corrai.base.simulate import run_simulations


def plot_pcp(
parameter_values: np.ndarray,
parameter_names: list[str],
parameter_values: pd.DataFrame,
aggregated_results: pd.DataFrame,
*,
bounds: list[tuple[float, float]] | None = None,
color_by: str | None = None,
title: str | None = "Parallel Coordinates — Samples",
html_file_path: str | None = None,
) -> go.Figure:
"""
Creates a Parallel Coordinates Plot (PCP) for parameter samples and aggregated indicators.
Each vertical axis corresponds to a parameter or an aggregated indicator,
and each polyline represents one simulation.
Creates a Parallel Coordinates Plot (PCP) for parameter samples and aggregated
indicators. Each vertical axis corresponds to a parameter or an aggregated
indicator, and each polyline represents one simulation.
"""

if parameter_values.shape[0] != len(aggregated_results):
raise ValueError("Mismatch between number of samples and aggregated results.")
if len(parameter_names) != parameter_values.shape[1]:
if parameter_values.shape[0] != aggregated_results.shape[0]:
raise ValueError(
"`parameter_names` length must match parameter_values.shape[1]."
"Shape mismatch between parameter_values and aggregated_results"
)

df = pd.DataFrame(
parameter_values, columns=parameter_names, index=aggregated_results.index
)
df = pd.concat([df, aggregated_results], axis=1)
df = pd.concat([parameter_values, aggregated_results], axis=1)

if color_by is None:
if not aggregated_results.empty:
color_by = aggregated_results.columns[0]
else:
color_by = parameter_names[0]
color_by = parameter_values.columns[0]

dimensions = []
for j, pname in enumerate(parameter_names):
for j, pname in enumerate(parameter_values.columns):
dim = {"label": pname, "values": df[pname].to_numpy()}
if bounds is not None:
lb, ub = bounds[j]
Expand Down Expand Up @@ -712,6 +706,94 @@ def _legend_for(i: int) -> str:
)
return fig

def plot_pcp(
self,
indicators_configs: list[str]
| list[tuple[str, str | Callable] | tuple[str, str | Callable, pd.Series]],
color_by: str | None = None,
title: str | None = "Parallel Coordinates — Samples",
html_file_path: str | None = None,
) -> go.Figure:
"""
This method produces an interactive PCP visualization that allows comparison
of model parameters against aggregated indicators from simulation results.
It supports both dynamic and static models.

For dynamic models, the specified indicators are aggregated across time using
the provided functions (e.g., "mean", "sum", error metrics). For static models,
the indicators are taken directly from the stored results.

Parameters
----------
indicators_configs : list of str or list of tuple
Configuration of indicators to include in the plot.

- For dynamic models, each element must be a tuple of the form:
``(indicator_name, method)`` or
``(indicator_name, method, reference_series)``.

Here:
* `indicator_name` : str
Column name in the simulation results to aggregate.
* `method` : str or Callable
Aggregation function or metric to apply.
* `reference_series` : pandas.Series, optional
Reference time series required for error-based methods
(e.g., mean absolute error).

- For static models, a simple list of indicator names (str) is sufficient.

color_by : str, optional
Name of a parameter or result column to use for coloring the PCP lines.
If None, all lines are plotted in the same color.

title : str, default="Parallel Coordinates — Samples"
Title of the plot.

html_file_path : str, optional
If provided, saves the interactive plot as an HTML file at the specified
path.

Returns
-------
plotly.graph_objects.Figure
The generated parallel coordinates figure. The figure can be displayed
interactively in a Jupyter notebook, web browser, or exported to HTML.

Raises
------
ValueError
If the `indicators_configs` are incompatible with the model type
(dynamic vs static).

See Also
--------
get_aggregated_time_series :
For details on supported aggregation methods and how indicator values
are computed for dynamic models.
"""

check_indicators_configs(self.is_dynamic, indicators_configs)

if self.is_dynamic:
results = pd.DataFrame()
for config in indicators_configs:
col, func, *extra = config
results[f"{func}_{col}"] = self.get_aggregated_time_series(
col, func, reference_time_series=None if not extra else extra[0]
)
else:
results = self.get_static_results_as_df()[indicators_configs]

return plot_pcp(
parameter_values=self.values,
aggregated_results=results,
bounds=self.get_parameters_intervals().tolist(),
color_by=color_by,
title=title,
html_file_path=html_file_path,
)


class Sampler(ABC):
"""
Expand Down Expand Up @@ -862,48 +944,17 @@ def get_sample_aggregated_time_series(
indicator, method, agg_method_kwarg, reference_time_series, freq, prefix
)

@wraps(Sample.plot_pcp)
def plot_pcp(
self,
indicator: str | None = None,
method: str | list[str] = "mean",
agg_method_kwarg: dict = None,
reference_time_series: pd.Series = None,
freq: str | pd.Timedelta | dt.timedelta = None,
prefix: str | None = None,
bounds: list[tuple[float, float]] | None = None,
indicators_configs: list[str]
| list[tuple[str, str | Callable] | tuple[str, str | Callable, pd.Series]],
color_by: str | None = None,
title: str | None = "Parallel Coordinates — Samples",
html_file_path: str | None = None,
) -> go.Figure:
if indicator is None:
aggregated = pd.DataFrame(index=range(len(self.values)))
else:
methods = [method] if isinstance(method, str) else method
dfs = []
for m in methods:
this_prefix = prefix if prefix is not None else m
agg = aggregate_time_series(
results=self.results,
indicator=indicator,
method=m,
agg_method_kwarg=agg_method_kwarg,
reference_time_series=reference_time_series,
freq=freq,
prefix=this_prefix,
)
if agg is not None and not agg.empty:
dfs.append(agg)
aggregated = (
pd.concat(dfs, axis=1)
if dfs
else pd.DataFrame(index=range(len(self.values)))
)

return plot_pcp(
parameter_values=self.values,
parameter_names=[p.name for p in self.parameters],
aggregated_results=aggregated,
bounds=bounds,
return self.sample.plot_pcp(
indicators_configs=indicators_configs,
color_by=color_by,
title=title,
html_file_path=html_file_path,
Expand Down
Loading