Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import os
import re
from collections import defaultdict
from datetime import date, datetime, timedelta, timezone

import boto3
Expand Down Expand Up @@ -95,7 +96,7 @@ def _get_instance_type_parameters(): # noqa: C901
return _get_instance_type_parameters._cache

result = {}
excluded_instance_type_prefixes = [
excluded_instance_type_prefixes = (
"m1",
"m2",
"m3",
Expand All @@ -117,70 +118,75 @@ def _get_instance_type_parameters(): # noqa: C901
"g3",
"p2",
"p3",
]
)

for region in ["us-east-1", "us-west-2"]: # Only populate instance type for big regions
ec2_client = boto3.client("ec2", region_name=region)
# The following conversion is required becase Python jinja doesn't like "-"
region_jinja = region.replace("-", "_").upper()
try:
xlarge_instances = []
instance_type_availability_zones = {}
# Use describe_instance_types with pagination
xlarge_instances = set()
all_gpu_instances = set()
instance_type_availability_zones = defaultdict(list)
# Get instance type offerings and build AZ mapping
paginator = ec2_client.get_paginator("describe_instance_type_offerings")

for page in paginator.paginate(LocationType="availability-zone"):
for instance_type in page["InstanceTypeOfferings"]:
for offering in page["InstanceTypeOfferings"]:
instance_type_name = offering["InstanceType"]
instance_type_availability_zones[instance_type_name].append(offering["Location"])
# Check if instance type ends with '.xlarge'
if instance_type["InstanceType"].endswith(".xlarge") and _is_current_instance_type_generation(
excluded_instance_type_prefixes, instance_type
if instance_type_name.endswith(".xlarge") and _is_current_instance_type_generation(
excluded_instance_type_prefixes, offering
):
xlarge_instances.append(instance_type["InstanceType"])
if instance_type_availability_zones.get(instance_type["InstanceType"]):
instance_type_availability_zones[instance_type["InstanceType"]].append(
instance_type["Location"]
)
else:
instance_type_availability_zones[instance_type["InstanceType"]] = [instance_type["Location"]]
xlarge_instances.add(instance_type_name)
# Get a list of only GPU instances of any size available in the region
if (
instance_type_name.startswith("p") or instance_type_name.startswith("g")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure only p/g types are GPU instance? What about trn instances?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trn has NeuronInfo instead of GPuInfo, which describe's the in-house inference chip we create.

for GPU so far I have found only these p and g family

) and _is_current_instance_type_generation(excluded_instance_type_prefixes, offering):
all_gpu_instances.add(instance_type_name)

xlarge_instances = list(set(xlarge_instances)) # Remove redundancy.
# Get GPU instance details in batches of 100
all_gpu_list = list(all_gpu_instances)
gpu_instances = []
paginator = ec2_client.get_paginator("describe_instance_types")
for page in paginator.paginate(InstanceTypes=xlarge_instances):
for instance_type in page["InstanceTypes"]:
if _is_nvidia_gpu_instance_type(instance_type) and "g6f" not in instance_type["InstanceType"]:
gpu_instances.append(instance_type["InstanceType"])
# DescribeInstanceType API Limit of 100 instances
batch_size = 100

for i in range(0, len(all_gpu_list), batch_size):
gpu_instance_type_batch = all_gpu_list[i : i + batch_size] # noqa: E203
for page in paginator.paginate(InstanceTypes=gpu_instance_type_batch):
for instance_type in page["InstanceTypes"]:
if _is_nvidia_gpu_instance_type(instance_type) and "g6f" not in instance_type["InstanceType"]:
if instance_type.get("GpuInfo").get("Gpus")[0].get(
"Count"
) >= 4 and _is_current_instance_type_generation(
excluded_instance_type_prefixes, instance_type
):
# Find instance types with 4 or more GPUs. Number of GPUs can change test behavior.
# For example, it takes longer for DCGM health check to diagnose multiple GPUs.
instance_size = instance_type["InstanceType"].split(".")[1][: -len("xlarge")]
if instance_size and int(instance_size) < 20:
# Avoid using very expensive instance types
gpu_instances.append(instance_type["InstanceType"])
else:
gpu_instances.append(instance_type["InstanceType"])

for page in paginator.paginate():
for instance_type in page["InstanceTypes"]:
if (
_is_nvidia_gpu_instance_type(instance_type)
and instance_type.get("GpuInfo").get("Gpus")[0].get("Count") >= 4
and _is_current_instance_type_generation(excluded_instance_type_prefixes, instance_type)
):
# Find instance types with 4 or more GPUs. Number of GPUs can change test behavior.
# For example, it takes longer for DCGM health check to diagnose multiple GPUs.
instance_size = instance_type["InstanceType"].split(".")[1][: -len("xlarge")]
if instance_size and int(instance_size) < 20:
# Avoid using very expensive instance types
gpu_instances.append(instance_type["InstanceType"])

xlarge_instances.sort()
gpu_instances.sort()
xlarge_sorted = sorted(xlarge_instances)
gpu_sorted = sorted(gpu_instances)
today_number = (date.today() - date(2020, 1, 1)).days
for index in range(len(xlarge_instances)):
instance_type = xlarge_instances[(today_number + index) % len(xlarge_instances)]
for index, _ in enumerate(xlarge_sorted):
instance_type = xlarge_sorted[(today_number + index) % len(xlarge_sorted)]
azs = instance_type_availability_zones[instance_type]
result[f"{region_jinja}_INSTANCE_TYPE_{index}"] = instance_type[: -len(".xlarge")]
availability_zones = instance_type_availability_zones[instance_type]
result[f"{region_jinja}_INSTANCE_TYPE_{index}_AZ"] = (
availability_zones[0] if len(availability_zones) <= 2 else region
)
for index in range(len(gpu_instances)):
instance_type = gpu_instances[(today_number + index) % len(gpu_instances)]
result[f"{region_jinja}_INSTANCE_TYPE_{index}_AZ"] = azs[0] if len(azs) <= 2 else region

for index, _ in enumerate(gpu_sorted):
instance_type = gpu_sorted[(today_number + index) % len(gpu_sorted)]
azs = instance_type_availability_zones[instance_type]
result[f"{region_jinja}_GPU_INSTANCE_TYPE_{index}"] = instance_type
availability_zones = instance_type_availability_zones[instance_type]
result[f"{region_jinja}_GPU_INSTANCE_TYPE_{index}_AZ"] = (
availability_zones[0] if len(availability_zones) <= 2 else region
)
result[f"{region_jinja}_GPU_INSTANCE_TYPE_{index}_AZ"] = azs[0] if len(azs) <= 2 else region

except Exception as e:
print(f"Error getting instance types: {str(e)}. Using c5 and g4dn as the default instance type")
for index in range(100):
Expand Down
Loading