@@ -88,6 +88,111 @@ def kernel_constants_iteration(x: torch.Tensor, *, _launcher=_default_launcher):
8888 # src[test_unroll_tuples.py:N]: return result
8989 return result
9090
91+ --- assertExpectedJournal(TestUnrollTuples.test_dict_comprehension)
92+ from __future__ import annotations
93+
94+ import torch
95+ import triton
96+ import triton.language as tl
97+ from helion.runtime import default_launcher as _default_launcher
98+
99+ @triton.jit
100+ def _helion_kernel_dict_comprehension(x, result, _BLOCK_SIZE_0: tl.constexpr):
101+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
102+ pid_0 = tl.program_id(0)
103+ offset_0 = pid_0 * _BLOCK_SIZE_0
104+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
105+ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
106+ acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
107+ # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[1]
108+ load = tl.load(x + indices_0 * 1, None)
109+ v_0 = 2.0
110+ v_1 = load * v_0
111+ v_2 = acc + v_1
112+ # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[2]
113+ load_1 = tl.load(x + indices_0 * 1, None)
114+ v_3 = 4.0
115+ v_4 = load_1 * v_3
116+ v_5 = v_2 + v_4
117+ # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[3]
118+ load_2 = tl.load(x + indices_0 * 1, None)
119+ v_6 = 6.0
120+ v_7 = load_2 * v_6
121+ v_8 = v_5 + v_7
122+ # src[test_unroll_tuples.py:N]: result[tile_idx] = acc
123+ tl.store(result + indices_0 * 1, v_8, None)
124+
125+ def kernel_dict_comprehension(x: torch.Tensor, *, _launcher=_default_launcher):
126+ """Test dict comprehension with constants."""
127+ # src[test_unroll_tuples.py:N]: result = torch.zeros_like(x)
128+ result = torch.zeros_like(x)
129+ # src[test_unroll_tuples.py:N]: multipliers = {k: k * 2 for k in (1, 2, 3)}
130+ multipliers = {k: k * 2 for k in (1, 2, 3)}
131+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
132+ _BLOCK_SIZE_0 = 16
133+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
134+ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
135+ # src[test_unroll_tuples.py:N]: # Access dict with literal keys
136+ # src[test_unroll_tuples.py:N-N]: ...
137+ _launcher(_helion_kernel_dict_comprehension, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
138+ # src[test_unroll_tuples.py:N]: return result
139+ return result
140+
141+ --- assertExpectedJournal(TestUnrollTuples.test_dict_comprehension_with_range)
142+ from __future__ import annotations
143+
144+ import torch
145+ import triton
146+ import triton.language as tl
147+ from helion.runtime import default_launcher as _default_launcher
148+
149+ @triton.jit
150+ def _helion_kernel_dict_comprehension_with_range(x, result, _BLOCK_SIZE_0: tl.constexpr):
151+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
152+ pid_0 = tl.program_id(0)
153+ offset_0 = pid_0 * _BLOCK_SIZE_0
154+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
155+ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
156+ acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
157+ # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[0]
158+ load = tl.load(x + indices_0 * 1, None)
159+ v_0 = 2.0
160+ v_1 = load * v_0
161+ v_2 = acc + v_1
162+ # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[1]
163+ load_1 = tl.load(x + indices_0 * 1, None)
164+ v_3 = 4.0
165+ v_4 = load_1 * v_3
166+ v_5 = v_2 + v_4
167+ # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[2]
168+ load_2 = tl.load(x + indices_0 * 1, None)
169+ v_6 = 6.0
170+ v_7 = load_2 * v_6
171+ v_8 = v_5 + v_7
172+ # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[3]
173+ load_3 = tl.load(x + indices_0 * 1, None)
174+ v_9 = 8.0
175+ v_10 = load_3 * v_9
176+ v_11 = v_8 + v_10
177+ # src[test_unroll_tuples.py:N]: result[tile_idx] = acc
178+ tl.store(result + indices_0 * 1, v_11, None)
179+
180+ def kernel_dict_comprehension_with_range(x: torch.Tensor, *, _launcher=_default_launcher):
181+ """Test dict comprehension with range for key generation."""
182+ # src[test_unroll_tuples.py:N]: result = torch.zeros_like(x)
183+ result = torch.zeros_like(x)
184+ # src[test_unroll_tuples.py:N]: multipliers = {i: (i + 1) * 2 for i in range(4)}
185+ multipliers = {i: (i + 1) * 2 for i in range(4)}
186+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
187+ _BLOCK_SIZE_0 = 16
188+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
189+ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
190+ # src[test_unroll_tuples.py:N]: # Access dict with literal keys
191+ # src[test_unroll_tuples.py:N-N]: ...
192+ _launcher(_helion_kernel_dict_comprehension_with_range, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
193+ # src[test_unroll_tuples.py:N]: return result
194+ return result
195+
91196--- assertExpectedJournal(TestUnrollTuples.test_enumerate_constants)
92197from __future__ import annotations
93198
0 commit comments