Skip to content

Commit e75c434

Browse files
authored
Add dict comprehension support (#1191)
1 parent 4755284 commit e75c434

File tree

4 files changed

+267
-2
lines changed

4 files changed

+267
-2
lines changed

helion/_compiler/device_ir.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from .roll_reduction import ReductionRoller
5555
from .source_location import current_location
5656
from .type_propagation import CallableType
57+
from .type_propagation import DictType
5758
from .type_propagation import GridIndexType
5859
from .type_propagation import IterType
5960
from .type_propagation import LiteralType
@@ -1087,6 +1088,36 @@ def evaluate_expression() -> object:
10871088
# Return as tuple to match the expected type for tuple unrolling
10881089
return tuple(results)
10891090

1091+
def visit_DictComp(self, node: ast.DictComp) -> dict[object, object]:
1092+
"""Handle dict comprehension unrolling."""
1093+
assert isinstance(node, ExtendedAST)
1094+
1095+
if len(node.generators) != 1 or node.generators[0].ifs:
1096+
raise exc.StatementNotSupported(
1097+
"Complex dict comprehensions are not supported"
1098+
)
1099+
1100+
generator = node.generators[0]
1101+
assert isinstance(generator.iter, ExtendedAST)
1102+
iter_type = generator.iter._type_info
1103+
1104+
if not isinstance(iter_type, SequenceType):
1105+
raise exc.StatementNotSupported(
1106+
"Dict comprehensions over non-sequence types are not supported"
1107+
)
1108+
1109+
result: dict[object, object] = {}
1110+
1111+
def evaluate_key_value() -> None:
1112+
key = self.visit(node.key)
1113+
value = self.visit(node.value)
1114+
result[key] = value
1115+
1116+
self._handle_sequence_unrolling(
1117+
generator.iter, generator.target, evaluate_key_value, preserve_scope=False
1118+
)
1119+
return result
1120+
10901121
def visit_Dict(self, node: ast.Dict) -> dict[object, object]:
10911122
keys = [self.visit(key) if key is not None else None for key in node.keys]
10921123
values = [self.visit(value) for value in node.values]
@@ -1224,9 +1255,18 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
12241255
# pyrefly: ignore [bad-index]
12251256
return self.visit(value)[index_value]
12261257
raise exc.InvalidSequenceSubscription(node.slice)
1258+
# Check StackTensorType before DictType since StackTensorType inherits from DictType
12271259
if isinstance(type_info, StackTensorType):
12281260
# pyrefly: ignore [bad-argument-type]
12291261
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice))
1262+
if isinstance(type_info, DictType):
1263+
key_value = self.visit(node.slice)
1264+
if isinstance(key_value, (str, int)):
1265+
# pyrefly: ignore [bad-index]
1266+
return self.visit(value)[key_value]
1267+
raise exc.TypeInferenceError(
1268+
f"Dict subscript must be a literal str or int, got {type(key_value).__name__}"
1269+
)
12301270
if type_info is not None and type_info.origin.is_host():
12311271
# pyrefly: ignore [bad-argument-type]
12321272
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice))

helion/_compiler/type_propagation.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,11 +2445,64 @@ def visit_ListComp(self, node: ast.ListComp) -> TypeInfo:
24452445
def visit_GeneratorExp(self, node: ast.GeneratorExp) -> TypeInfo:
24462446
return self._visit_comprehension(node, "generator expression")
24472447

2448+
def visit_DictComp(self, node: ast.DictComp) -> TypeInfo:
2449+
"""Type propagation for dict comprehensions."""
2450+
if len(node.generators) != 1:
2451+
raise exc.StatementNotSupported(
2452+
"Dict comprehensions with multiple generators are not supported"
2453+
)
2454+
2455+
generator = node.generators[0]
2456+
iter_type = self.visit(generator.iter)
2457+
2458+
# Try to unpack the iterable
2459+
try:
2460+
iterable_elements = iter_type.unpack()
2461+
except NotImplementedError:
2462+
raise exc.StatementNotSupported(
2463+
"Dict comprehensions over non-unpackable iterables are not supported"
2464+
) from None
2465+
2466+
result_elements: dict[str | int, TypeInfo] = {}
2467+
2468+
def clear_type_info(n: ast.AST) -> None:
2469+
"""Clear _type_info on AST nodes to allow re-visiting with different values."""
2470+
if isinstance(n, ExtendedAST):
2471+
n._type_info = None
2472+
for child in ast.iter_child_nodes(n):
2473+
clear_type_info(child)
2474+
2475+
for element_type in iterable_elements:
2476+
self.push_scope()
2477+
try:
2478+
self._assign(generator.target, element_type)
2479+
for if_clause in generator.ifs:
2480+
self.visit(if_clause)
2481+
# Clear type info before visiting to avoid merging with previous iteration
2482+
clear_type_info(node.key)
2483+
clear_type_info(node.value)
2484+
key_type = self.visit(node.key)
2485+
value_type = self.visit(node.value)
2486+
# Get the literal key value by evaluating with proxy
2487+
try:
2488+
key = key_type.proxy()
2489+
except (NotImplementedError, TypeError):
2490+
raise exc.StatementNotSupported(
2491+
"Dict comprehension keys must evaluate to literals"
2492+
) from None
2493+
if not isinstance(key, (str, int)):
2494+
raise exc.StatementNotSupported(
2495+
f"Dict comprehension keys must be str or int, got {type(key).__name__}"
2496+
)
2497+
result_elements[key] = value_type
2498+
finally:
2499+
self.pop_scope()
2500+
2501+
return DictType(self.origin(), result_elements)
2502+
24482503
# TODO(jansel): need to implement these
24492504
# pyrefly: ignore [bad-assignment, bad-param-name-override]
24502505
visit_SetComp: _VisitMethod = _not_supported
2451-
# pyrefly: ignore [bad-assignment, bad-param-name-override]
2452-
visit_DictComp: _VisitMethod = _not_supported
24532506

24542507
# TODO(jansel): support closure functions defined on host
24552508
# pyrefly: ignore [bad-assignment, bad-param-name-override]

test/test_unroll_tuples.expected

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,111 @@ def kernel_constants_iteration(x: torch.Tensor, *, _launcher=_default_launcher):
8888
# src[test_unroll_tuples.py:N]: return result
8989
return result
9090

91+
--- assertExpectedJournal(TestUnrollTuples.test_dict_comprehension)
92+
from __future__ import annotations
93+
94+
import torch
95+
import triton
96+
import triton.language as tl
97+
from helion.runtime import default_launcher as _default_launcher
98+
99+
@triton.jit
100+
def _helion_kernel_dict_comprehension(x, result, _BLOCK_SIZE_0: tl.constexpr):
101+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
102+
pid_0 = tl.program_id(0)
103+
offset_0 = pid_0 * _BLOCK_SIZE_0
104+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
105+
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
106+
acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
107+
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[1]
108+
load = tl.load(x + indices_0 * 1, None)
109+
v_0 = 2.0
110+
v_1 = load * v_0
111+
v_2 = acc + v_1
112+
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[2]
113+
load_1 = tl.load(x + indices_0 * 1, None)
114+
v_3 = 4.0
115+
v_4 = load_1 * v_3
116+
v_5 = v_2 + v_4
117+
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[3]
118+
load_2 = tl.load(x + indices_0 * 1, None)
119+
v_6 = 6.0
120+
v_7 = load_2 * v_6
121+
v_8 = v_5 + v_7
122+
# src[test_unroll_tuples.py:N]: result[tile_idx] = acc
123+
tl.store(result + indices_0 * 1, v_8, None)
124+
125+
def kernel_dict_comprehension(x: torch.Tensor, *, _launcher=_default_launcher):
126+
"""Test dict comprehension with constants."""
127+
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(x)
128+
result = torch.zeros_like(x)
129+
# src[test_unroll_tuples.py:N]: multipliers = {k: k * 2 for k in (1, 2, 3)}
130+
multipliers = {k: k * 2 for k in (1, 2, 3)}
131+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
132+
_BLOCK_SIZE_0 = 16
133+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
134+
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
135+
# src[test_unroll_tuples.py:N]: # Access dict with literal keys
136+
# src[test_unroll_tuples.py:N-N]: ...
137+
_launcher(_helion_kernel_dict_comprehension, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
138+
# src[test_unroll_tuples.py:N]: return result
139+
return result
140+
141+
--- assertExpectedJournal(TestUnrollTuples.test_dict_comprehension_with_range)
142+
from __future__ import annotations
143+
144+
import torch
145+
import triton
146+
import triton.language as tl
147+
from helion.runtime import default_launcher as _default_launcher
148+
149+
@triton.jit
150+
def _helion_kernel_dict_comprehension_with_range(x, result, _BLOCK_SIZE_0: tl.constexpr):
151+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
152+
pid_0 = tl.program_id(0)
153+
offset_0 = pid_0 * _BLOCK_SIZE_0
154+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
155+
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
156+
acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
157+
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[0]
158+
load = tl.load(x + indices_0 * 1, None)
159+
v_0 = 2.0
160+
v_1 = load * v_0
161+
v_2 = acc + v_1
162+
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[1]
163+
load_1 = tl.load(x + indices_0 * 1, None)
164+
v_3 = 4.0
165+
v_4 = load_1 * v_3
166+
v_5 = v_2 + v_4
167+
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[2]
168+
load_2 = tl.load(x + indices_0 * 1, None)
169+
v_6 = 6.0
170+
v_7 = load_2 * v_6
171+
v_8 = v_5 + v_7
172+
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[3]
173+
load_3 = tl.load(x + indices_0 * 1, None)
174+
v_9 = 8.0
175+
v_10 = load_3 * v_9
176+
v_11 = v_8 + v_10
177+
# src[test_unroll_tuples.py:N]: result[tile_idx] = acc
178+
tl.store(result + indices_0 * 1, v_11, None)
179+
180+
def kernel_dict_comprehension_with_range(x: torch.Tensor, *, _launcher=_default_launcher):
181+
"""Test dict comprehension with range for key generation."""
182+
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(x)
183+
result = torch.zeros_like(x)
184+
# src[test_unroll_tuples.py:N]: multipliers = {i: (i + 1) * 2 for i in range(4)}
185+
multipliers = {i: (i + 1) * 2 for i in range(4)}
186+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
187+
_BLOCK_SIZE_0 = 16
188+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
189+
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
190+
# src[test_unroll_tuples.py:N]: # Access dict with literal keys
191+
# src[test_unroll_tuples.py:N-N]: ...
192+
_launcher(_helion_kernel_dict_comprehension_with_range, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
193+
# src[test_unroll_tuples.py:N]: return result
194+
return result
195+
91196
--- assertExpectedJournal(TestUnrollTuples.test_enumerate_constants)
92197
from __future__ import annotations
93198

test/test_unroll_tuples.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,43 @@ def kernel_tuple_comprehension_with_tensors(
278278
return result
279279

280280

281+
@helion.kernel(autotune_effort="none")
282+
def kernel_dict_comprehension(
283+
x: torch.Tensor,
284+
) -> torch.Tensor:
285+
"""Test dict comprehension with constants."""
286+
result = torch.zeros_like(x)
287+
# Create dict using comprehension
288+
multipliers = {k: k * 2 for k in (1, 2, 3)}
289+
for tile_idx in hl.tile(result.size(0)):
290+
acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
291+
# Access dict with literal keys
292+
acc += x[tile_idx] * multipliers[1]
293+
acc += x[tile_idx] * multipliers[2]
294+
acc += x[tile_idx] * multipliers[3]
295+
result[tile_idx] = acc
296+
return result
297+
298+
299+
@helion.kernel(autotune_effort="none")
300+
def kernel_dict_comprehension_with_range(
301+
x: torch.Tensor,
302+
) -> torch.Tensor:
303+
"""Test dict comprehension with range for key generation."""
304+
result = torch.zeros_like(x)
305+
# Create dict using comprehension with range
306+
multipliers = {i: (i + 1) * 2 for i in range(4)}
307+
for tile_idx in hl.tile(result.size(0)):
308+
acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
309+
# Access dict with literal keys
310+
acc += x[tile_idx] * multipliers[0]
311+
acc += x[tile_idx] * multipliers[1]
312+
acc += x[tile_idx] * multipliers[2]
313+
acc += x[tile_idx] * multipliers[3]
314+
result[tile_idx] = acc
315+
return result
316+
317+
281318
@helion.kernel(autotune_effort="none")
282319
def kernel_list_comprehension_with_function(
283320
x: torch.Tensor,
@@ -725,6 +762,36 @@ def test_tuple_comprehension_with_tensors(self):
725762
expected = tensor1 * 0.5 + tensor2 * 1.0 + tensor3 * 1.5
726763
torch.testing.assert_close(result, expected)
727764

765+
def test_dict_comprehension(self):
766+
"""Test dict comprehension with constants."""
767+
size = (16,)
768+
x = torch.randn(size, device=DEVICE)
769+
770+
code, result = code_and_output(kernel_dict_comprehension, (x,))
771+
772+
# Validate generated code
773+
self.assertExpectedJournal(code)
774+
775+
# Test correctness - multipliers = {1: 2, 2: 4, 3: 6}
776+
# should be x * (2 + 4 + 6) = x * 12
777+
expected = x * 12
778+
torch.testing.assert_close(result, expected)
779+
780+
def test_dict_comprehension_with_range(self):
781+
"""Test dict comprehension with range for key generation."""
782+
size = (16,)
783+
x = torch.randn(size, device=DEVICE)
784+
785+
code, result = code_and_output(kernel_dict_comprehension_with_range, (x,))
786+
787+
# Validate generated code
788+
self.assertExpectedJournal(code)
789+
790+
# Test correctness - multipliers = {0: 2, 1: 4, 2: 6, 3: 8}
791+
# should be x * (2 + 4 + 6 + 8) = x * 20
792+
expected = x * 20
793+
torch.testing.assert_close(result, expected)
794+
728795
def test_list_comprehension_with_function(self):
729796
"""Test list comprehension with expressions."""
730797
size = (14,)

0 commit comments

Comments
 (0)