Skip to content

Commit 601d7dd

Browse files
authored
Add setup for Helion to compile on MTIA with basic test (#1169)
1 parent 58cbc67 commit 601d7dd

File tree

3 files changed

+55
-7
lines changed

3 files changed

+55
-7
lines changed

helion/_testing.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ def is_cpu() -> bool:
6161
)
6262

6363

64+
def is_mtia() -> bool:
65+
"""Return True if running on MTIA."""
66+
return _get_triton_backend() == "mtia"
67+
68+
69+
def skipIfMTIA(reason: str) -> Callable[[Callable], Callable]:
70+
return unittest.skipIf(is_mtia(), reason)
71+
72+
6473
class _LogCapture(logging.Handler):
6574
"""Simple logging handler to capture log records."""
6675

@@ -101,11 +110,14 @@ def is_cuda() -> bool:
101110

102111
PROJECT_ROOT: Path = Path(__file__).parent.parent
103112
EXAMPLES_DIR: Path = PROJECT_ROOT / "examples"
113+
DEVICE = None
104114

105115
if is_cpu():
106116
DEVICE = torch.device("cpu")
107117
elif torch.xpu.is_available():
108118
DEVICE = torch.device("xpu")
119+
elif is_mtia():
120+
DEVICE = torch.device("mtia")
109121
else:
110122
DEVICE = torch.device("cuda")
111123

@@ -1006,6 +1018,21 @@ class TestCase(unittest.TestCase):
10061018
@classmethod
10071019
def setUpClass(cls) -> None:
10081020
cls._expected_journal = AssertExpectedJournal(cls)
1021+
1022+
if is_mtia():
1023+
# pyrefly: ignore [missing-import]
1024+
import mtia.host_runtime.torch_mtia.dynamic_library # noqa: F401
1025+
1026+
# pyrefly: ignore [missing-import]
1027+
from mtia.re.re_unittest_lib import MTIAUnittest
1028+
1029+
# pyrefly: ignore [missing-import]
1030+
from triton_mtia.python.mtia.eager import mtia_triton_launcher
1031+
1032+
# Call MTIAUnittest.setUpClass for MTIA initialization
1033+
MTIAUnittest.setUpClass.__func__(cls)
1034+
# Initialize MTIA properly
1035+
mtia_triton_launcher.init()
10091036
super().setUpClass()
10101037

10111038
@classmethod

helion/runtime/__init__.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import contextvars
44
import os
5-
from typing import TYPE_CHECKING
65

76
import torch
7+
import triton
88

99
from .. import _compat as _compat # ensure Triton compatibility patches run
1010
from .config import Config as Config
@@ -14,12 +14,14 @@
1414
from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal
1515
from .triton_helpers import triton_wait_signal as triton_wait_signal
1616

17-
if TYPE_CHECKING:
18-
import triton
19-
2017

2118
def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor:
22-
return torch.empty(size, device="cuda", dtype=torch.int8)
19+
# Dynamically get device from Triton backend
20+
current_target = triton.runtime.driver.active.get_current_target()
21+
if current_target is None:
22+
raise RuntimeError("No active Triton target available")
23+
backend = current_target.backend
24+
return torch.empty(size, device=backend, dtype=torch.int8)
2325

2426

2527
def set_triton_allocator() -> None:
@@ -51,8 +53,13 @@ def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int:
5153
Grid size to use for a persistent kernel on the device after accounting
5254
for any reserved SMs. Always at least 1.
5355
"""
54-
assert device.type in ["cuda", "xpu", "cpu"], "TODO: implement for other devices"
5556
available_sms: int
57+
assert device.type in [
58+
"cuda",
59+
"xpu",
60+
"cpu",
61+
"mtia",
62+
], "TODO: implement for other devices"
5663
if device.type == "cpu":
5764
try:
5865
num_threads = int(torch.get_num_threads())
@@ -66,8 +73,19 @@ def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int:
6673
# TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number.
6774
elif device.type == "xpu":
6875
available_sms = torch.xpu.get_device_properties(device.index).gpu_subslice_count
76+
elif device.type == "mtia":
77+
device_props = torch.mtia.get_device_properties(device.index)
78+
if "maxGridHeight" in device_props and "maxGridWidth" in device_props:
79+
available_sms = device_props["maxGridHeight"] * device_props["maxGridWidth"]
80+
else:
81+
raise RuntimeError(
82+
f"Unable to determine SM count for MTIA device. "
83+
f"Available properties: {list(device_props.keys())}"
84+
)
6985
else:
70-
raise AssertionError("TODO: implement for other devices")
86+
raise NotImplementedError(
87+
f"get_num_sm not implemented for device type: {device.type}"
88+
)
7189

7290
if reserved_sms <= 0:
7391
return available_sms
@@ -83,6 +101,7 @@ def default_launcher(
83101
**kwargs: dict,
84102
) -> object:
85103
"""Default launcher function that executes the kernel immediately."""
104+
# For both CUDA and MTIA, use the same kernel execution
86105
return triton_kernel.run(
87106
*args,
88107
grid=grid,

test/test_constexpr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from helion._testing import RefEagerTestBase
1111
from helion._testing import TestCase
1212
from helion._testing import code_and_output
13+
from helion._testing import skipIfMTIA
1314
from helion._testing import skipIfRefEager
1415
import helion.language as hl
1516

@@ -95,6 +96,7 @@ def fn(x: torch.Tensor, mode: str) -> torch.Tensor:
9596
self.assertExpectedJournal(code)
9697

9798
@skipIfRefEager("Triton codegen does not work in ref eager mode")
99+
@skipIfMTIA('Not supported on MTIA. Error: "Expected IntList but got GenericList"')
98100
def test_block_size_constexpr_assignment_in_host_code(self) -> None:
99101
@helion.kernel(
100102
config=helion.Config(

0 commit comments

Comments
 (0)