Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/run_loading_benchmark_on_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pandas as pd
from torch.utils.data import DataLoader

from arrayloader_benchmarks import benchmark_loader
from arrayloader_benchmarks import benchmark_loader, compute_spec

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -254,6 +254,7 @@ def run(
"num_workers": num_workers,
"batch_size": batch_size,
"chunk_size": chunk_size,
"compute_spec": compute_spec.get_aws_sagemaker_instance_type(),
"run_uid": ln.context.run.uid,
"timestamp": datetime.datetime.now(datetime.UTC),
"user": ln.setup.settings.user.handle,
Expand Down
30 changes: 30 additions & 0 deletions src/arrayloader_benchmarks/compute_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

import json

import boto3


def get_aws_sagemaker_instance_type() -> str:
"""Get the instance type of the current SageMaker Studio instance."""
try:
# Read the metadata
with open("/opt/ml/metadata/resource-metadata.json") as f: # noqa
metadata = json.load(f)

sagemaker = boto3.client("sagemaker", region_name="us-west-2")

# Try to describe the space
space_response = sagemaker.describe_space(
DomainId=metadata["DomainId"], SpaceName=metadata["SpaceName"]
)

# Navigate through the nested settings to find instance type
space_settings = space_response.get("SpaceSettings", {})
jupyter_settings = space_settings.get("JupyterLabAppSettings", {})
default_resource_spec = jupyter_settings.get("DefaultResourceSpec", {})

return default_resource_spec.get("InstanceType", "unknown")

except Exception: # noqa: BLE001
return "unknown"
Loading