Skip to content

Commit 4f32d3f

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Support optional specification of reference dataset on inference pipeline creation
1 parent 6ef02f0 commit 4f32d3f

File tree

1 file changed

+66
-7
lines changed

1 file changed

+66
-7
lines changed

openlayer/__init__.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,6 +1393,9 @@ def create_inference_pipeline(
13931393
task_type: TaskType,
13941394
name: Optional[str] = None,
13951395
description: Optional[str] = None,
1396+
reference_df: Optional[pd.DataFrame] = None,
1397+
reference_dataset_file_path: Optional[str] = None,
1398+
reference_dataset_config_file_path: Optional[str] = None,
13961399
) -> InferencePipeline:
13971400
"""Creates an inference pipeline in an Openlayer project.
13981401
@@ -1441,6 +1444,19 @@ def create_inference_pipeline(
14411444
platform. Refer to :obj:`upload_reference_dataset` and
14421445
:obj:`publish_batch_data` for detailed examples.
14431446
"""
1447+
if (reference_df is None) ^ (reference_dataset_config_file_path is None) or (
1448+
reference_dataset_file_path is None
1449+
) ^ (reference_dataset_config_file_path is None):
1450+
raise ValueError(
1451+
"You must specify both a reference dataset and"
1452+
" its config or none of them."
1453+
)
1454+
if reference_df is not None and reference_dataset_file_path is not None:
1455+
raise ValueError(
1456+
"Please specify either a reference dataset or a reference dataset"
1457+
" file path."
1458+
)
1459+
14441460
# Validate inference pipeline
14451461
inference_pipeline_config = {
14461462
"name": name or "Production",
@@ -1452,19 +1468,62 @@ def create_inference_pipeline(
14521468
)
14531469
)
14541470
failed_validations = inference_pipeline_validator.validate()
1455-
14561471
if failed_validations:
14571472
raise exceptions.OpenlayerValidationError(
14581473
"There are issues with the inference pipeline. \n"
14591474
"Make sure to fix all of the issues listed above before creating it.",
14601475
) from None
14611476

1462-
endpoint = f"projects/{project_id}/inference-pipelines"
1463-
payload = {
1464-
"name": name,
1465-
"description": description,
1466-
}
1467-
inference_pipeline_data = self.api.post_request(endpoint, body=payload)
1477+
# Validate reference dataset and augment config
1478+
if reference_dataset_config_file_path is not None:
1479+
dataset_validator = dataset_validators.get_validator(
1480+
task_type=task_type,
1481+
dataset_config_file_path=reference_dataset_config_file_path,
1482+
dataset_df=reference_df,
1483+
)
1484+
failed_validations = dataset_validator.validate()
1485+
1486+
if failed_validations:
1487+
raise exceptions.OpenlayerValidationError(
1488+
"There are issues with the reference dataset and its config. \n"
1489+
"Make sure to fix all of the issues listed above before the upload.",
1490+
) from None
1491+
1492+
# Load dataset config and augment with defaults
1493+
reference_dataset_config = utils.read_yaml(
1494+
reference_dataset_config_file_path
1495+
)
1496+
reference_dataset_data = DatasetSchema().load(
1497+
{"task_type": task_type.value, **reference_dataset_config}
1498+
)
1499+
1500+
with tempfile.TemporaryDirectory() as tmp_dir:
1501+
# Copy relevant files to tmp dir if reference dataset is provided
1502+
if reference_dataset_config_file_path is not None:
1503+
utils.write_yaml(
1504+
reference_dataset_data, f"{tmp_dir}/dataset_config.yaml"
1505+
)
1506+
if reference_df is not None:
1507+
reference_df.to_csv(f"{tmp_dir}/dataset.csv", index=False)
1508+
else:
1509+
shutil.copy(
1510+
reference_dataset_file_path,
1511+
f"{tmp_dir}/dataset.csv",
1512+
)
1513+
1514+
tar_file_path = os.path.join(tmp_dir, "tarfile")
1515+
with tarfile.open(tar_file_path, mode="w:gz") as tar:
1516+
tar.add(tmp_dir, arcname=os.path.basename("reference_dataset"))
1517+
1518+
endpoint = f"projects/{project_id}/inference-pipelines"
1519+
inference_pipeline_data = self.api.upload(
1520+
endpoint=endpoint,
1521+
file_path=tar_file_path,
1522+
object_name="tarfile",
1523+
body=inference_pipeline_config,
1524+
storage_uri_key="referenceDatasetUri",
1525+
method="POST",
1526+
)
14681527
inference_pipeline = InferencePipeline(
14691528
inference_pipeline_data, self.api.upload, self, task_type
14701529
)

0 commit comments

Comments
 (0)