@@ -174,6 +174,42 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
174174 result .to (torch .float32 ), x .to (torch .float32 ), atol = 1e-2 , rtol = 1e-2
175175 )
176176
177+ def test_load_2d_indexing_without_extra_mask (self ):
178+ """Test that hl.load with two 1D tensor indices produces 2D output in ref eager mode."""
179+
180+ @helion .kernel (ref_mode = helion .RefMode .EAGER )
181+ def kernel (mask : torch .Tensor ) -> torch .Tensor :
182+ n = mask .size (0 )
183+ out = torch .zeros_like (mask )
184+ for tile_i , tile_j in hl .tile ([n , n ]):
185+ # Load with two 1D tensor indices - should produce [tile_I, tile_J] output
186+ vals = hl .load (mask , [tile_i .index , tile_j .index ])
187+ out [tile_i , tile_j ] = vals
188+ return out
189+
190+ with assert_ref_eager_mode ():
191+ mask = torch .tril (torch .ones (4 , 4 , device = DEVICE , dtype = torch .float32 ))
192+ result = kernel (mask )
193+ torch .testing .assert_close (result , mask )
194+
195+ def test_load_3d_indexing_without_extra_mask (self ):
196+ """Test that hl.load with three 1D tensor indices produces 3D output in ref eager mode."""
197+
198+ @helion .kernel (ref_mode = helion .RefMode .EAGER )
199+ def kernel (x : torch .Tensor ) -> torch .Tensor :
200+ d0 , d1 , d2 = x .shape
201+ out = torch .zeros_like (x )
202+ for tile_i , tile_j , tile_k in hl .tile ([d0 , d1 , d2 ]):
203+ # Load with three 1D tensor indices - should produce [tile_I, tile_J, tile_K] output
204+ vals = hl .load (x , [tile_i .index , tile_j .index , tile_k .index ])
205+ out [tile_i , tile_j , tile_k ] = vals
206+ return out
207+
208+ with assert_ref_eager_mode ():
209+ x = torch .arange (24 , device = DEVICE , dtype = torch .float32 ).reshape (2 , 3 , 4 )
210+ result = kernel (x )
211+ torch .testing .assert_close (result , x )
212+
177213
178214if __name__ == "__main__" :
179215 unittest .main ()
0 commit comments