diff --git a/corrai/sampling.py b/corrai/sampling.py index a609030..771ef1e 100644 --- a/corrai/sampling.py +++ b/corrai/sampling.py @@ -964,7 +964,15 @@ def plot_pcp( results = pd.DataFrame() for config in indicators_configs: col, func, *extra = config - results[f"{func}_{col}"] = self.get_aggregated_time_series( + if extra and isinstance(extra[0], str): + name = extra[0] + elif callable(func): + name = func.__name__ + elif isinstance(func, str): + name = func + else: + raise TypeError(f"Invalid aggregation function: {func}") + results[f"{name}_{col}"] = self.get_aggregated_time_series( col, func, reference_time_series=None if not extra else extra[0] ) else: diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 787e245..9598369 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -378,6 +378,15 @@ def test_plot_pcp_in_sampler(self): assert isinstance(fig, go.Figure) assert len(fig.data) == 1 + def last_val(x): + return x.iloc[-1] + + fig2 = sampler.sample.plot_pcp([("res", last_val)]) + + dims2 = fig2.data[0]["dimensions"] + labels2 = [d["label"] for d in dims2] + assert "last_val_res" in labels2 + def test_lhs_sampler(self): # Dynamic sampler = LHSSampler(