@@ -1393,6 +1393,9 @@ def create_inference_pipeline(
13931393 task_type : TaskType ,
13941394 name : Optional [str ] = None ,
13951395 description : Optional [str ] = None ,
1396+ reference_df : Optional [pd .DataFrame ] = None ,
1397+ reference_dataset_file_path : Optional [str ] = None ,
1398+ reference_dataset_config_file_path : Optional [str ] = None ,
13961399 ) -> InferencePipeline :
13971400 """Creates an inference pipeline in an Openlayer project.
13981401
@@ -1441,6 +1444,19 @@ def create_inference_pipeline(
14411444 platform. Refer to :obj:`upload_reference_dataset` and
14421445 :obj:`publish_batch_data` for detailed examples.
14431446 """
1447+ if (reference_df is None ) ^ (reference_dataset_config_file_path is None ) or (
1448+ reference_dataset_file_path is None
1449+ ) ^ (reference_dataset_config_file_path is None ):
1450+ raise ValueError (
1451+ "You must specify both a reference dataset and"
1452+ " its config or none of them."
1453+ )
1454+ if reference_df is not None and reference_dataset_file_path is not None :
1455+ raise ValueError (
1456+ "Please specify either a reference dataset or a reference dataset"
1457+ " file path."
1458+ )
1459+
14441460 # Validate inference pipeline
14451461 inference_pipeline_config = {
14461462 "name" : name or "Production" ,
@@ -1452,19 +1468,62 @@ def create_inference_pipeline(
14521468 )
14531469 )
14541470 failed_validations = inference_pipeline_validator .validate ()
1455-
14561471 if failed_validations :
14571472 raise exceptions .OpenlayerValidationError (
14581473 "There are issues with the inference pipeline. \n "
14591474 "Make sure to fix all of the issues listed above before creating it." ,
14601475 ) from None
14611476
1462- endpoint = f"projects/{ project_id } /inference-pipelines"
1463- payload = {
1464- "name" : name ,
1465- "description" : description ,
1466- }
1467- inference_pipeline_data = self .api .post_request (endpoint , body = payload )
1477+ # Validate reference dataset and augment config
1478+ if reference_dataset_config_file_path is not None :
1479+ dataset_validator = dataset_validators .get_validator (
1480+ task_type = task_type ,
1481+ dataset_config_file_path = reference_dataset_config_file_path ,
1482+ dataset_df = reference_df ,
1483+ )
1484+ failed_validations = dataset_validator .validate ()
1485+
1486+ if failed_validations :
1487+ raise exceptions .OpenlayerValidationError (
1488+ "There are issues with the reference dataset and its config. \n "
1489+ "Make sure to fix all of the issues listed above before the upload." ,
1490+ ) from None
1491+
1492+ # Load dataset config and augment with defaults
1493+ reference_dataset_config = utils .read_yaml (
1494+ reference_dataset_config_file_path
1495+ )
1496+ reference_dataset_data = DatasetSchema ().load (
1497+ {"task_type" : task_type .value , ** reference_dataset_config }
1498+ )
1499+
1500+ with tempfile .TemporaryDirectory () as tmp_dir :
1501+ # Copy relevant files to tmp dir if reference dataset is provided
1502+ if reference_dataset_config_file_path is not None :
1503+ utils .write_yaml (
1504+ reference_dataset_data , f"{ tmp_dir } /dataset_config.yaml"
1505+ )
1506+ if reference_df is not None :
1507+ reference_df .to_csv (f"{ tmp_dir } /dataset.csv" , index = False )
1508+ else :
1509+ shutil .copy (
1510+ reference_dataset_file_path ,
1511+ f"{ tmp_dir } /dataset.csv" ,
1512+ )
1513+
1514+ tar_file_path = os .path .join (tmp_dir , "tarfile" )
1515+ with tarfile .open (tar_file_path , mode = "w:gz" ) as tar :
1516+ tar .add (tmp_dir , arcname = os .path .basename ("reference_dataset" ))
1517+
1518+ endpoint = f"projects/{ project_id } /inference-pipelines"
1519+ inference_pipeline_data = self .api .upload (
1520+ endpoint = endpoint ,
1521+ file_path = tar_file_path ,
1522+ object_name = "tarfile" ,
1523+ body = inference_pipeline_config ,
1524+ storage_uri_key = "referenceDatasetUri" ,
1525+ method = "POST" ,
1526+ )
14681527 inference_pipeline = InferencePipeline (
14691528 inference_pipeline_data , self .api .upload , self , task_type
14701529 )
0 commit comments