Skip to content

Commit a92e202

Browse files
Merge pull request #1641 from roboflow/feature/model-variant-in-roboflow-api
Add ability for inference-exp to parse modelVariant from RFAPI
2 parents 8c1bab3 + 50ce008 commit a92e202

File tree

5 files changed

+9
-0
lines changed

5 files changed

+9
-0
lines changed

inference_experimental/inference_exp/models/auto_loaders/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def describe_model(
120120
model_id=model_metadata.model_id,
121121
requested_model_id=model_id,
122122
model_architecture=model_metadata.model_architecture,
123+
model_variant=model_metadata.model_variant,
123124
task_type=model_metadata.task_type,
124125
weights_provider=weights_provider,
125126
registered_packages=len(model_metadata.model_packages),

inference_experimental/inference_exp/models/auto_loaders/presentation_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def render_table_with_model_overview(
2020
model_id: str,
2121
requested_model_id: str,
2222
model_architecture: str,
23+
model_variant: Optional[str],
2324
task_type: Optional[str],
2425
weights_provider: str,
2526
registered_packages: int,
@@ -32,6 +33,7 @@ def render_table_with_model_overview(
3233
model_id_str = f"{model_id_str} (alias: {requested_model_id})"
3334
table.add_row("Model ID:", model_id_str)
3435
table.add_row("Architecture:", model_architecture)
36+
table.add_row("Variant:", model_variant or "N/A")
3537
table.add_row("Task:", task_type or "N/A")
3638
table.add_row("Weights provider:", weights_provider)
3739
table.add_row("Number of packages:", str(registered_packages))

inference_experimental/inference_exp/weights_providers/entities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,4 @@ class ModelMetadata:
156156
model_architecture: str
157157
model_packages: List[ModelPackageMetadata]
158158
task_type: Optional[str] = field(default=None)
159+
model_variant: Optional[str] = field(default=None)

inference_experimental/inference_exp/weights_providers/roboflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class RoboflowModelMetadata(BaseModel):
6161
type: Literal["external-model-metadata-v1"]
6262
model_id: str = Field(alias="modelId")
6363
model_architecture: str = Field(alias="modelArchitecture")
64+
model_variant: Optional[str] = Field(alias="modelVariant", default=None)
6465
task_type: Optional[str] = Field(alias="taskType", default=None)
6566
model_packages: List[Union[RoboflowModelPackageV1, dict]] = Field(
6667
alias="modelPackages",
@@ -81,6 +82,7 @@ def get_roboflow_model(model_id: str, api_key: Optional[str] = None) -> ModelMet
8182
model_architecture=model_metadata.model_architecture,
8283
model_packages=parsed_model_packages,
8384
task_type=model_metadata.task_type,
85+
model_variant=model_metadata.model_variant,
8486
)
8587

8688

inference_experimental/tests/unit_tests/weights_providers/test_roboflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,6 +1317,7 @@ def test_get_roboflow_model(requests_mock: Mocker) -> None:
13171317
"type": "external-model-metadata-v1",
13181318
"modelId": "my-model",
13191319
"modelArchitecture": "yolov8",
1320+
"modelVariant": "yolov8-n",
13201321
"taskType": "object-detection",
13211322
"modelPackages": [
13221323
{
@@ -1352,6 +1353,7 @@ def test_get_roboflow_model(requests_mock: Mocker) -> None:
13521353
"type": "external-model-metadata-v1",
13531354
"modelId": "my-model",
13541355
"modelArchitecture": "yolov8",
1356+
"modelVariant": "yolov8-n",
13551357
"taskType": "object-detection",
13561358
"modelPackages": [
13571359
{
@@ -1401,5 +1403,6 @@ def test_get_roboflow_model(requests_mock: Mocker) -> None:
14011403
# then
14021404
assert result.model_id == "my-model"
14031405
assert result.model_architecture == "yolov8"
1406+
assert result.model_variant == "yolov8-n"
14041407
assert result.task_type == "object-detection"
14051408
assert len(result.model_packages) == 2

0 commit comments

Comments
 (0)