|
4 | 4 | import shutil |
5 | 5 | import subprocess |
6 | 6 | from typing import Dict, Optional, Type, Union |
| 7 | +from uuid import uuid4 |
7 | 8 |
|
8 | 9 | import fsspec |
| 10 | +import mlflow |
9 | 11 | import psutil |
10 | 12 | from jupyter_core.paths import jupyter_data_dir |
11 | 13 | from jupyter_server.transutils import _i18n |
|
42 | 44 | from jupyter_scheduler.orm import Job, JobDefinition, create_session |
43 | 45 | from jupyter_scheduler.utils import create_output_directory, create_output_filename |
44 | 46 |
|
| 47 | +MLFLOW_SERVER_HOST = "127.0.0.1" |
| 48 | +MLFLOW_SERVER_PORT = "5000" |
| 49 | +MLFLOW_SERVER_URI = f"http://{MLFLOW_SERVER_HOST}:{MLFLOW_SERVER_PORT}" |
| 50 | + |
45 | 51 |
|
46 | 52 | class BaseScheduler(LoggingConfigurable): |
47 | 53 | """Base class for schedulers. A default implementation |
@@ -348,16 +354,13 @@ def start_mlflow_server(self): |
348 | 354 | [ |
349 | 355 | "mlflow", |
350 | 356 | "server", |
351 | | - "--backend-store-uri", |
352 | | - "./mlruns", |
353 | | - "--default-artifact-root", |
354 | | - "./mlartifacts", |
355 | 357 | "--host", |
356 | | - "0.0.0.0", |
| 358 | + MLFLOW_SERVER_HOST, |
357 | 359 | "--port", |
358 | | - "5000", |
| 360 | + MLFLOW_SERVER_PORT, |
359 | 361 | ] |
360 | 362 | ) |
| 363 | + mlflow.set_tracking_uri(MLFLOW_SERVER_URI) |
361 | 364 |
|
362 | 365 | def __init__( |
363 | 366 | self, |
@@ -415,6 +418,19 @@ def create_job(self, model: CreateJob) -> str: |
415 | 418 | if not model.output_formats: |
416 | 419 | model.output_formats = [] |
417 | 420 |
|
| 421 | + mlflow_client = mlflow.MlflowClient() |
| 422 | + |
| 423 | + if model.job_definition_id and model.mlflow_experiment_id: |
| 424 | + experiment_id = model.mlflow_experiment_id |
| 425 | + else: |
| 426 | + experiment_id = mlflow_client.create_experiment(f"{model.name}-{uuid4()}") |
| 427 | + model.mlflow_experiment_id = experiment_id |
| 428 | + input_file_path = os.path.join(self.root_dir, model.input_uri) |
| 429 | + mlflow.log_artifact(input_file_path, "input") |
| 430 | + |
| 431 | + mlflow_run = mlflow_client.create_run(experiment_id=experiment_id, run_name=model.name) |
| 432 | + model.mlflow_run_id = mlflow_run.info.run_id |
| 433 | + |
418 | 434 | job = Job(**model.dict(exclude_none=True, exclude={"input_uri"})) |
419 | 435 | session.add(job) |
420 | 436 | session.commit() |
@@ -553,6 +569,12 @@ def create_job_definition(self, model: CreateJobDefinition) -> str: |
553 | 569 | if not self.file_exists(model.input_uri): |
554 | 570 | raise InputUriError(model.input_uri) |
555 | 571 |
|
| 572 | + mlflow_client = mlflow.MlflowClient() |
| 573 | + experiment_id = mlflow_client.create_experiment(f"{model.name}-{uuid4()}") |
| 574 | + model.mlflow_experiment_id = experiment_id |
| 575 | + input_file_path = os.path.join(self.root_dir, model.input_uri) |
| 576 | + mlflow.log_artifact(input_file_path, "input") |
| 577 | + |
556 | 578 | job_definition = JobDefinition(**model.dict(exclude_none=True, exclude={"input_uri"})) |
557 | 579 | session.add(job_definition) |
558 | 580 | session.commit() |
|
0 commit comments