22
33import contextvars
44import os
5- from typing import TYPE_CHECKING
65
76import torch
7+ import triton
88
99from .. import _compat as _compat # ensure Triton compatibility patches run
1010from .config import Config as Config
1414from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal
1515from .triton_helpers import triton_wait_signal as triton_wait_signal
1616
17- if TYPE_CHECKING :
18- import triton
19-
2017
2118def _alloc_fn (size : int , alignment : int , stream : int | None ) -> torch .Tensor :
22- return torch .empty (size , device = "cuda" , dtype = torch .int8 )
19+ # Dynamically get device from Triton backend
20+ current_target = triton .runtime .driver .active .get_current_target ()
21+ if current_target is None :
22+ raise RuntimeError ("No active Triton target available" )
23+ backend = current_target .backend
24+ return torch .empty (size , device = backend , dtype = torch .int8 )
2325
2426
2527def set_triton_allocator () -> None :
@@ -51,8 +53,13 @@ def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int:
5153 Grid size to use for a persistent kernel on the device after accounting
5254 for any reserved SMs. Always at least 1.
5355 """
54- assert device .type in ["cuda" , "xpu" , "cpu" ], "TODO: implement for other devices"
5556 available_sms : int
57+ assert device .type in [
58+ "cuda" ,
59+ "xpu" ,
60+ "cpu" ,
61+ "mtia" ,
62+ ], "TODO: implement for other devices"
5663 if device .type == "cpu" :
5764 try :
5865 num_threads = int (torch .get_num_threads ())
@@ -66,8 +73,19 @@ def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int:
6673 # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number.
6774 elif device .type == "xpu" :
6875 available_sms = torch .xpu .get_device_properties (device .index ).gpu_subslice_count
76+ elif device .type == "mtia" :
77+ device_props = torch .mtia .get_device_properties (device .index )
78+ if "maxGridHeight" in device_props and "maxGridWidth" in device_props :
79+ available_sms = device_props ["maxGridHeight" ] * device_props ["maxGridWidth" ]
80+ else :
81+ raise RuntimeError (
82+ f"Unable to determine SM count for MTIA device. "
83+ f"Available properties: { list (device_props .keys ())} "
84+ )
6985 else :
70- raise AssertionError ("TODO: implement for other devices" )
86+ raise NotImplementedError (
87+ f"get_num_sm not implemented for device type: { device .type } "
88+ )
7189
7290 if reserved_sms <= 0 :
7391 return available_sms
@@ -83,6 +101,7 @@ def default_launcher(
83101 ** kwargs : dict ,
84102) -> object :
85103 """Default launcher function that executes the kernel immediately."""
104+ # For both CUDA and MTIA, use the same kernel execution
86105 return triton_kernel .run (
87106 * args ,
88107 grid = grid ,
0 commit comments