Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 158 additions & 17 deletions tests/test_extensions/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# License: BSD 3-Clause
from __future__ import annotations

import inspect
from collections import OrderedDict

import inspect
import numpy as np
import pytest

import openml.testing
from openml.extensions import get_extension_by_flow, get_extension_by_model, register_extension
from openml.extensions import Extension, get_extension_by_flow, get_extension_by_model, register_extension


class DummyFlow:
Expand Down Expand Up @@ -40,23 +42,76 @@ def can_handle_model(model):
return False


class DummyExtension(Extension):
@classmethod
def can_handle_flow(cls, flow):
return isinstance(flow, DummyFlow)

@classmethod
def can_handle_model(cls, model):
return isinstance(model, DummyModel)

def flow_to_model(
self,
flow,
initialize_with_defaults=False,
strict_version=True,
):
if not isinstance(flow, DummyFlow):
raise ValueError("Invalid flow")

model = DummyModel()
model.defaults = initialize_with_defaults
model.strict_version = strict_version
return model

def model_to_flow(self, model):
if not isinstance(model, DummyModel):
raise ValueError("Invalid model")
return DummyFlow()

def get_version_information(self):
return ["dummy==1.0"]

def create_setup_string(self, model):
return "DummyModel()"

def is_estimator(self, model):
return isinstance(model, DummyModel)

def seed_model(self, model, seed):
model.seed = seed
return model

def _run_model_on_fold(
self,
model,
task,
X_train,
rep_no,
fold_no,
y_train=None,
X_test=None,
):
preds = np.zeros(len(X_train))
probs = None
measures = OrderedDict()
trace = None
return preds, probs, measures, trace

def obtain_parameter_values(self, flow, model=None):
return []

def check_if_model_fitted(self, model):
return False

def instantiate_model_from_hpo_class(self, model, trace_iteration):
return DummyModel()


def _unregister():
# "Un-register" the test extensions
while True:
rem_dum_ext1 = False
rem_dum_ext2 = False
try:
openml.extensions.extensions.remove(DummyExtension1)
rem_dum_ext1 = True
except ValueError:
pass
try:
openml.extensions.extensions.remove(DummyExtension2)
rem_dum_ext2 = True
except ValueError:
pass
if not rem_dum_ext1 and not rem_dum_ext2:
break
openml.extensions.extensions.clear()


class TestInit(openml.testing.TestBase):
Expand Down Expand Up @@ -91,3 +146,89 @@ def test_get_extension_by_model(self):
ValueError, match="Multiple extensions registered which can handle model:"
):
get_extension_by_model(DummyModel())


def test_flow_to_model_with_defaults():
"""Test flow_to_model with initialize_with_defaults=True."""
ext = DummyExtension()
flow = DummyFlow()

model = ext.flow_to_model(flow, initialize_with_defaults=True)

assert isinstance(model, DummyModel)
assert model.defaults is True

def test_flow_to_model_strict_version():
"""Test flow_to_model with strict_version parameter."""
ext = DummyExtension()
flow = DummyFlow()

model_strict = ext.flow_to_model(flow, strict_version=True)
model_non_strict = ext.flow_to_model(flow, strict_version=False)

assert isinstance(model_strict, DummyModel)
assert model_strict.strict_version is True

assert isinstance(model_non_strict, DummyModel)
assert model_non_strict.strict_version is False

def test_model_to_flow_conversion():
"""Test converting a model back to flow representation."""
ext = DummyExtension()
model = DummyModel()

flow = ext.model_to_flow(model)

assert isinstance(flow, DummyFlow)


def test_invalid_flow_raises_error():
"""Test that invalid flow raises appropriate error."""
class InvalidFlow:
pass

ext = DummyExtension()
flow = InvalidFlow()

with pytest.raises(ValueError, match="Invalid flow"):
ext.flow_to_model(flow)


def test_extension_not_found_error_message():
"""Test error message contains helpful information."""
class UnknownModel:
pass

_unregister()

with pytest.raises(ValueError, match="No extension registered"):
get_extension_by_model(UnknownModel(), raise_if_no_extension=True)


def test_register_same_extension_twice():
"""Test behavior when registering same extension twice."""
register_extension(DummyExtension)
register_extension(DummyExtension)

matches = [
ext for ext in openml.extensions.extensions
if ext is DummyExtension
]

assert len(matches) == 2


def test_extension_priority_order():
"""Test that extensions are checked in registration order."""
_unregister()

class DummyExtensionA(DummyExtension):
pass
class DummyExtensionB(DummyExtension):
pass

register_extension(DummyExtensionA)
register_extension(DummyExtensionB)

assert openml.extensions.extensions[0] is DummyExtensionA
assert openml.extensions.extensions[1] is DummyExtensionB