diff --git a/app/cli/cli.py b/app/cli/cli.py index 4358338..4fa7b40 100644 --- a/app/cli/cli.py +++ b/app/cli/cli.py @@ -299,6 +299,7 @@ def register_model( model_name=model_name, model_path=model_path, model_manager=ModelManager(model_service_type, config), + model_type=model_type.value, training_type=t_type, run_name=run_name, model_config=m_config, diff --git a/app/management/tracker_client.py b/app/management/tracker_client.py index 61686a2..a63a18c 100644 --- a/app/management/tracker_client.py +++ b/app/management/tracker_client.py @@ -346,6 +346,36 @@ def log_model_config(config: Dict[str, str]) -> None: mlflow.log_params(config) + @staticmethod + def _set_model_version_tags( + client: MlflowClient, + model_name: str, + version: str, + model_type: str, + validation_status: Optional[str] = None, + ) -> None: + """ + Sets standard tags on a model version for serving and discovery. + + Args: + client (MlflowClient): The MLflow client to use for setting tags. + model_name (str): The name of the registered model. + version (str): The version of the model. + model_type (str): The type of the model (e.g., "medcat_snomed"). + validation_status (Optional[str]): The status of the model validation (e.g., "pending"). + """ + try: + client.set_model_version_tag( + name=model_name, version=version, key="model_uri", value=f"models:/{model_name}/{version}" + ) + client.set_model_version_tag(name=model_name, version=version, key="model_type", value=model_type) + if validation_status is not None: + client.set_model_version_tag( + name=model_name, version=version, key="validation_status", value=validation_status + ) + except Exception: + logger.warning("Failed to set tags on version %s of model %s", version, model_name) + @staticmethod def log_model( model_name: str, @@ -381,6 +411,7 @@ def save_pretrained_model( model_name: str, model_path: str, model_manager: ModelManager, + model_type: str, training_type: Optional[str] = "", run_name: Optional[str] = "", model_config: Optional[Dict] = None, @@ -394,6 +425,7 @@ def save_pretrained_model( model_name (str): The name of the model. model_path (str): The path to the pretrained model. model_manager (ModelManager): The instance of ModelManager used for model saving. + model_type (str): The type of the model (e.g., "medcat_snomed"). training_type (Optional[str]): The type of training used for the model. run_name (Optional[str]): The name of the run for identification purposes. model_config (Optional[Dict]): The configuration of the model to save. @@ -423,6 +455,10 @@ def save_pretrained_model( mlflow.set_tags(tags) model_name = model_name.replace(" ", "_") TrackerClient.log_model(model_name, model_path, model_manager, model_name) + client = MlflowClient() + versions = client.search_model_versions(f"name='{model_name}'", order_by=["version_number DESC"]) + if versions: + TrackerClient._set_model_version_tags(client, model_name, versions[0].version, model_type) TrackerClient.end_with_success() except KeyboardInterrupt: TrackerClient.end_with_interruption() @@ -502,6 +538,7 @@ def save_model( filepath: str, model_name: str, model_manager: ModelManager, + model_type: str, validation_status: str = "pending", ) -> str: """ @@ -511,6 +548,7 @@ def save_model( filepath (str): The artifact path of the model to save. model_name (str): The name of the model. model_manager (ModelManager): The instance of ModelManager used for model saving. + model_type (str): The type of the model (e.g., "medcat_snomed"). validation_status (str): The status of the model validation (default: "pending"). Returns: @@ -523,18 +561,19 @@ def save_model( if not mlflow.get_tracking_uri().startswith("file:/"): TrackerClient.log_model(model_name, filepath, model_manager, model_name) - versions = self.mlflow_client.search_model_versions(f"name='{model_name}'") - self.mlflow_client.set_model_version_tag( - name=model_name, - version=versions[0].version, - key="validation_status", - value=validation_status, + versions = self.mlflow_client.search_model_versions( + f"name='{model_name}'", order_by=["version_number DESC"] ) + if versions: + TrackerClient._set_model_version_tags( + self.mlflow_client, model_name, versions[0].version, model_type, validation_status + ) else: TrackerClient.log_model(model_name, filepath, model_manager) artifact_uri = mlflow.get_artifact_uri(model_name) mlflow.set_tag("training.output.model_uri", artifact_uri) + mlflow.set_tag("training.output.model_type", model_type) return artifact_uri diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py index b85f44b..6084c46 100644 --- a/app/trainers/huggingface_llm_trainer.py +++ b/app/trainers/huggingface_llm_trainer.py @@ -436,6 +436,7 @@ def run( retrained_model_pack_path, self._model_name, self._model_manager, + self._model_service.info().model_type.value, ) logger.info(f"Retrained model saved: {model_uri}") else: diff --git a/app/trainers/huggingface_ner_trainer.py b/app/trainers/huggingface_ner_trainer.py index 2aa44aa..e4fc0b6 100644 --- a/app/trainers/huggingface_ner_trainer.py +++ b/app/trainers/huggingface_ner_trainer.py @@ -254,6 +254,7 @@ def run( retrained_model_pack_path, self._model_name, self._model_manager, + self._model_service.info().model_type.value, ) logger.info(f"Retrained model saved: {model_uri}") else: @@ -739,6 +740,7 @@ def _compute_loss( retrained_model_pack_path, self._model_name, self._model_manager, + self._model_service.info().model_type.value, ) logger.info(f"Retrained model saved: {model_uri}") else: diff --git a/app/trainers/medcat_deid_trainer.py b/app/trainers/medcat_deid_trainer.py index 65ac7be..34d2749 100644 --- a/app/trainers/medcat_deid_trainer.py +++ b/app/trainers/medcat_deid_trainer.py @@ -185,7 +185,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + self._model_service.info().model_type.value, + ) logger.info("Retrained model saved: %s", model_uri) self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: diff --git a/app/trainers/medcat_trainer.py b/app/trainers/medcat_trainer.py index e49068f..6ef7fe2 100644 --- a/app/trainers/medcat_trainer.py +++ b/app/trainers/medcat_trainer.py @@ -211,7 +211,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + self._model_service.info().model_type.value, + ) logger.info("Retrained model saved: %s", model_uri) self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: @@ -472,7 +477,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + self._model_service.info().model_type.value, + ) logger.info(f"Retrained model saved: {model_uri}") self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: diff --git a/app/trainers/metacat_trainer.py b/app/trainers/metacat_trainer.py index 5cce3b9..49a47b8 100644 --- a/app/trainers/metacat_trainer.py +++ b/app/trainers/metacat_trainer.py @@ -159,7 +159,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + self._model_service.info().model_type.value, + ) logger.info("Retrained model saved: %s", model_uri) self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: diff --git a/tests/app/monitoring/test_tracker_client.py b/tests/app/monitoring/test_tracker_client.py index edcbaee..1029a0f 100644 --- a/tests/app/monitoring/test_tracker_client.py +++ b/tests/app/monitoring/test_tracker_client.py @@ -3,7 +3,7 @@ import datasets import pytest import pandas as pd -from unittest.mock import Mock, call, ANY +from unittest.mock import Mock, call, patch, ANY from app.management.tracker_client import TrackerClient from app.data import doc_dataset from app.domain import TrainerBackend @@ -161,15 +161,30 @@ def test_save_model(mlflow_fixture): mlflow_client.search_model_versions.return_value = [version] tracker_client.mlflow_client = mlflow_client - artifact_uri = tracker_client.save_model("path/to/file.zip", "model_name", model_manager, "validation_status") + artifact_uri = tracker_client.save_model( + "path/to/file.zip", "model_name", model_manager, "model_type", "validation_status" + ) assert "artifacts/model_name" in artifact_uri model_manager.log_model.assert_called_once_with("model_name", "path/to/file.zip", "model_name") - mlflow_client.set_model_version_tag.assert_called_once_with(name="model_name", version="1", key="validation_status", value="validation_status") + mlflow_client.search_model_versions.assert_called_once_with( + "name='model_name'", order_by=["version_number DESC"] + ) + assert mlflow_client.set_model_version_tag.call_count == 3 + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_uri", value="models:/model_name/1" + ) + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_type", value="model_type" + ) + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="validation_status", value="validation_status" + ) mlflow.set_tag.has_calls( [ call("training.output.package", "file.zip"), call("training.output.model_uri", artifact_uri), + call("training.output.model_type", "model_type"), ], any_order=False, ) @@ -184,14 +199,21 @@ def test_save_model_local(mlflow_fixture): model_manager.save_model.assert_called_once_with("local_dir", "filepath") -def test_save_pretrained_model(mlflow_fixture): +@patch("app.management.tracker_client.MlflowClient") +def test_save_pretrained_model(mock_mlflow_client_class, mlflow_fixture): tracker_client = TrackerClient("") model_manager = Mock() + mlflow_client = Mock() + version = Mock() + version.version = "1" + mlflow_client.search_model_versions.return_value = [version] + mock_mlflow_client_class.return_value = mlflow_client tracker_client.save_pretrained_model( "model_name", "model_path", model_manager, + "model_type", "training_type", "run_name", {"param": "value"}, @@ -212,6 +234,17 @@ def test_save_pretrained_model(mlflow_fixture): assert len(mlflow.set_tags.call_args.args[0]["mlflow.source.name"]) > 0 assert mlflow.set_tags.call_args.args[0]["tag_name"] == "tag_value" + mlflow_client.search_model_versions.assert_called_once_with( + "name='model_name'", order_by=["version_number DESC"] + ) + assert mlflow_client.set_model_version_tag.call_count == 2 + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_uri", value="models:/model_name/1" + ) + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_type", value="model_type" + ) + def test_log_single_exception(mlflow_fixture): tracker_client = TrackerClient("")