-
Notifications
You must be signed in to change notification settings - Fork 3
mlflow: Add version tags for registered models #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add the following model version tags when logging a model to MLflow: * model_uri: The URI of the model artifact * model_type: The type of the model (e.g. 'medcat_snomed') * validation_status: The validation status of the model (e.g. 'pending') Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just make sure that versions[0] will always return the most recently saved model version.
(versions[0] had been used prior to this PR and should have been tested)
app/management/tracker_client.py
Outdated
| model_name: str, | ||
| model_manager: ModelManager, | ||
| validation_status: str = "pending", | ||
| model_type: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All CMS models have the ModelType, hence there's no need to make the argument optional and mlflow.set_tag() will not set a None value.
|
@baixiac so |
Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds version tags to registered models in MLflow to improve model tracking and discovery. The changes introduce three new tags: model_uri, model_type, and validation_status that are attached to model versions when logging models.
Key changes:
- Added
_set_model_version_tagshelper method to standardize tag setting across model registration flows - Updated
save_modelandsave_pretrained_modelmethods to acceptmodel_typeparameter and set version tags - Modified all trainer implementations to pass model type information when saving models
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| app/management/tracker_client.py | Added _set_model_version_tags static method and updated save_model/save_pretrained_model to set version tags including model_uri, model_type, and validation_status |
| app/trainers/metacat_trainer.py | Updated save_model call to include model type from model service |
| app/trainers/medcat_trainer.py | Updated save_model calls (2 locations) to include model type from model service |
| app/trainers/medcat_deid_trainer.py | Updated save_model call to include model type from model service |
| app/trainers/huggingface_ner_trainer.py | Updated save_model calls (2 locations) to include model type from model service |
| app/trainers/huggingface_llm_trainer.py | Updated save_model call to include model type from model service |
| app/cli/cli.py | Updated save_pretrained_model call to pass model_type parameter |
| tests/app/monitoring/test_tracker_client.py | Enhanced tests to verify version tags are set correctly; added mock setup for pretrained model test |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| mlflow.set_tag.has_calls( | ||
| [ | ||
| call("training.output.package", "file.zip"), | ||
| call("training.output.model_uri", artifact_uri), | ||
| call("training.output.model_type", "model_type"), | ||
| ], | ||
| any_order=False, | ||
| ) |
Copilot
AI
Dec 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect assertion method. Should be assert_has_calls instead of has_calls. The current code will not actually perform the assertion, allowing the test to pass even if the calls were not made.
Add the following model version tags when logging a model to MLflow: