Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 112 additions & 3 deletions corrai/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from corrai.base.metrics import nmbe, cv_rmse
from corrai.base.utils import as_1_column_dataframe
from corrai.base.model import Model

MODEL_MAP = {
"TREE_REGRESSOR": RandomForestRegressor(),
Expand Down Expand Up @@ -196,7 +197,7 @@ class MultiModelSO(BaseEstimator, RegressorMixin):
>>> import pandas as pd
>>> from sklearn.datasets import load_diabetes
>>> from sklearn.model_selection import train_test_split
>>> from corrai.learning.model_selection import MultiModelSO
>>> from corrai.surrogate import MultiModelSO
>>>
>>> data = load_diabetes(as_frame=True)
>>> X = data.data
Expand Down Expand Up @@ -242,6 +243,11 @@ def __init__(
self.models = models if models is not None else list(MODEL_MAP.keys())
self.model_map = {mod: clone(MODEL_MAP[mod]) for mod in self.models}

@property
def feature_names_in_(self):
check_is_fitted(self, ["_is_fitted"])
return self.get_model().feature_names_in_

def __sklearn_is_fitted__(self):
"""
Check fitted status and return a Boolean value.
Expand Down Expand Up @@ -279,6 +285,7 @@ def fit(self, X: pd.DataFrame, y: pd.DataFrame | pd.Series, verbose=True):

self.fine_tune(X, y, self.best_model_key)

self.target_name_ = y.name
self._is_fitted = True

def predict(
Expand All @@ -290,9 +297,15 @@ def predict(
X = as_1_column_dataframe(X)

if isinstance(X, pd.DataFrame):
return pd.DataFrame(model_for_prediction.predict(X), index=X.index)
return pd.DataFrame(
model_for_prediction.predict(X),
index=X.index,
columns=[self.target_name_],
)
else:
return pd.DataFrame(model_for_prediction.predict(X))
return pd.DataFrame(
model_for_prediction.predict(X), columns=[self.target_name_]
)

def get_model(self, model: str = None):
if model is None:
Expand Down Expand Up @@ -331,3 +344,99 @@ def fine_tune(
cvres = grid_search.cv_results_
for mean_score, params in zip(cvres["mean_test_score"], cvres["params"]):
print(np.sqrt(-mean_score), params)


class StaticScikitModel(Model):
"""
Wrapper class for static surrogate MultiModelSingleOutput class and scikit-learn
regressors within the Corrai framework.

This class adapts corrai's `MultiModelSO` and scikit-learn models
to the :class:`Model` interface, enabling parameter-to-property mapping and
simulation execution. It is intended for non-dynamic (static) models where
outputs are single values or vectors rather than time-dependent series.

Parameters
----------
scikit_model : MultiModelSO or RegressorMixin
The underlying scikit-learn model or a Corrai
:class:`MultiModelSO` meta-estimator.
target_name : str, optional
Name of the output variable. Required when ``scikit_model``
is not an instance of :class:`MultiModelSO`.

Attributes
----------
is_dynamic : bool
Always ``False`` for this wrapper, since it represents static models.
scikit_model : MultiModelSO or RegressorMixin
The wrapped scikit-learn model used for predictions.
target_name : str or None
Output variable name when applicable.
"""

def __init__(
self, scikit_model: MultiModelSO | RegressorMixin, target_name: str = None
):
super().__init__(is_dynamic=False)
self.scikit_model = scikit_model
self.target_name = target_name

def simulate(
self,
property_dict: dict[str, str | int | float] = None,
simulation_options: dict = None,
**simulation_kwargs,
) -> pd.DataFrame | pd.Series:
"""
Run the scikit-learn model prediction.

Combines provided parameter values and simulation options into a
feature vector, validates compatibility with the underlying model,
and returns predictions.

Parameters
----------
property_dict : dict of {str: int, float, or str}, optional
Mapping from feature names to values to use for prediction.
simulation_options : dict, optional
Additional feature overrides or configuration parameters to
include in the feature vector.
**simulation_kwargs
Extra keyword arguments for future extensions (currently unused).

Returns
-------
pd.Series or pd.DataFrame
- If ``scikit_model`` is an instance of :class:`MultiModelSO`:
returns a :class:`pandas.Series` of predictions.
- Otherwise, returns a :class:`pandas.DataFrame` with column
name ``target_name``.

Raises
------
ValueError
If unknown feature names are provided or if ``target_name``
is missing when required.
"""

param_df = pd.Series(property_dict)
if simulation_options is not None:
param_df = pd.concat([param_df, pd.Series(simulation_options)])
param_df = param_df.to_frame().T

missing = set(param_df.columns) - set(self.scikit_model.feature_names_in_)
if missing:
raise ValueError(f"Unknown features: {missing}")

if isinstance(self.scikit_model, MultiModelSO):
return self.scikit_model.predict(param_df)
elif self.target_name is not None:
return pd.DataFrame(
data=self.scikit_model.predict(param_df), columns=[self.target_name]
)
else:
raise ValueError(
"scikit_model is not an instance of MultiModelSO"
"target_name must be specified"
)
31 changes: 29 additions & 2 deletions tests/test_surrogate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import itertools

import pandas as pd
import numpy as np

from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import make_pipeline
from corrai.surrogate import ModelTrainer, MultiModelSO
from sklearn.linear_model import LinearRegression

from corrai.surrogate import ModelTrainer, MultiModelSO, StaticScikitModel


class TestLearning:
Expand Down Expand Up @@ -41,4 +45,27 @@ def test_mumoso_and_trainer(self):

model_pipe.predict(x_df)

assert True

class TestScikitWrapper:
def test_scikit_wrapper(self):
ds = pd.DataFrame(
{
"x_1": np.arange(0.0, 10.0, 1),
"x_2": np.arange(10.0, 20.0, 1),
"y": 4.0 * np.arange(10.0, 20.0, 1) + 2.0 * np.arange(0.0, 10.0, 1),
}
)

in_df = {"x_1": 2.0, "x_2": 4.0}

ref_df = pd.DataFrame({"y": 28.0}, index=[0])

mumoso = MultiModelSO()
mumoso.fit(ds[["x_1", "x_2"]], ds["y"])
stat_mod = StaticScikitModel(mumoso)
pd.testing.assert_frame_equal(stat_mod.simulate(in_df), ref_df)

line_reg = LinearRegression()
line_reg.fit(ds[["x_1", "x_2"]], ds["y"])
scikit_mod = StaticScikitModel(line_reg, target_name="y")
pd.testing.assert_frame_equal(scikit_mod.simulate(in_df), ref_df)