@@ -326,32 +326,34 @@ def get_matmul(a: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value:
326326
327327# torch.nn.functional.linear
328328def get_linear (a : ir .Value , w : ir .Value , b : ir .Value , out : ir .Value ) -> ir .Value :
329+ elty = out .type .element_type
330+ zero = arith .constant (elty , 0.0 )
331+ out_zeroed = linalg .fill (zero , outs = [out ])
332+
329333 # a[i, k] * w[j, k] -> out[i, j]
330334 i , j , k = [ir .AffineDimExpr .get (d ) for d in range (3 )]
331335 a_map = affine_map (3 , [i , k ]) # (batch, in_feat)
332- w_map = affine_map (3 , [j , k ]) # (out_feat, in_feat) - note: we use j for first dim
336+ w_map = affine_map (3 , [j , k ]) # (out_feat, in_feat)
333337 out_map = affine_map (3 , [i , j ]) # (batch, out_feat)
334338
335- # First compute the matmul into out (which will accumulate)
336339 @linalg .generic (
337340 [a , w ],
338- [out ],
341+ [out_zeroed ],
339342 [a_map , w_map , out_map ],
340343 [parallel , parallel , reduction ],
341344 )
342345 def matmul_op (a_elem , w_elem , out_elem ):
343346 prod = arith .MulFOp (a_elem , w_elem ).result
344347 return arith .AddFOp (out_elem , prod ).result
345348
346- # Step 2: Add bias using broadcasting
347349 # b[j] -> out[i, j]
348350 i2 , j2 = [ir .AffineDimExpr .get (d ) for d in range (2 )]
349351 b_map = affine_map (2 , [j2 ]) # (out_feat,)
350352 out_map2 = affine_map (2 , [i2 , j2 ]) # (batch, out_feat)
351353
352354 @linalg .generic (
353355 [matmul_op , b ],
354- [out ],
356+ [out_zeroed ],
355357 [out_map2 , b_map , out_map2 ],
356358 [parallel , parallel ],
357359 )
@@ -1243,3 +1245,93 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out):
12431245
12441246 assert torch .allclose (out1 , xq_out , rtol = 0.01 , atol = 0.01 , equal_nan = True )
12451247 assert torch .allclose (out2 , xk_out , rtol = 0.01 , atol = 0.01 , equal_nan = True )
1248+
1249+
1250+ def test_feed_forward ():
1251+ def generate_module (ctx , elty ):
1252+ with ctx , ir .Location .unknown ():
1253+ module = ir .Module .create ()
1254+ with ir .InsertionPoint (module .body ):
1255+ input_type = ir .RankedTensorType .get ((4 , 16 ), elty )
1256+ hidden_type = ir .RankedTensorType .get ((4 , 64 ), elty )
1257+ output_type = ir .RankedTensorType .get ((4 , 16 ), elty )
1258+ weight1_type = ir .RankedTensorType .get ((64 , 16 ), elty )
1259+ bias1_type = ir .RankedTensorType .get ((64 ,), elty )
1260+ weight2_type = ir .RankedTensorType .get ((16 , 64 ), elty )
1261+ bias2_type = ir .RankedTensorType .get ((16 ,), elty )
1262+ weight3_type = ir .RankedTensorType .get ((64 , 16 ), elty )
1263+ bias3_type = ir .RankedTensorType .get ((64 ,), elty )
1264+
1265+ @func .FuncOp .from_py_func (
1266+ input_type ,
1267+ weight1_type ,
1268+ bias1_type ,
1269+ weight2_type ,
1270+ bias2_type ,
1271+ weight3_type ,
1272+ bias3_type ,
1273+ output_type ,
1274+ name = "feed_forward" ,
1275+ )
1276+ def feed_forward (x , w1 , b1 , w2 , b2 , w3 , b3 , out ):
1277+ # Compute hidden = linear(x, w1, b1)
1278+ hidden_uninit = tensor .EmptyOp (hidden_type .shape , elty ).result
1279+ hidden = get_linear (x , w1 , b1 , hidden_uninit )
1280+
1281+ # Compute hidden_silu = silu(hidden)
1282+ hidden_silu_uninit = tensor .EmptyOp (hidden_type .shape , elty ).result
1283+ hidden_silu = get_silu (hidden , hidden_silu_uninit )
1284+
1285+ # Compute gate = linear(x, w3, b3)
1286+ gate_uninit = tensor .EmptyOp (hidden_type .shape , elty ).result
1287+ gate = get_linear (x , w3 , b3 , gate_uninit )
1288+
1289+ # Compute activated = hidden_silu * gate
1290+ activated_uninit = tensor .EmptyOp (hidden_type .shape , elty ).result
1291+ activated = get_mul (hidden_silu , gate , activated_uninit )
1292+
1293+ # Compute out = linear(activated, w2, b2)
1294+ get_linear (activated , w2 , b2 , out )
1295+
1296+ return module
1297+
1298+ ctx = ir .Context ()
1299+ ir_type = to_ir_type ("f32" , ctx )
1300+ module = generate_module (ctx , ir_type )
1301+ bufferize_module (ctx , module )
1302+ schedule = create_schedule (ctx )
1303+ apply_schedule (module , schedule )
1304+ pm = create_pass_pipeline (ctx )
1305+ pm .run (module .operation )
1306+
1307+ eng = ExecutionEngine (module , opt_level = 2 )
1308+ func_ptr = eng .lookup ("feed_forward" )
1309+
1310+ torch_dtype = lh_utils .mlir_type_to_torch_dtype (ir_type )
1311+ x = torch .randn (4 , 16 , dtype = torch_dtype )
1312+ w1 = torch .randn (64 , 16 , dtype = torch_dtype )
1313+ b1 = torch .randn (64 , dtype = torch_dtype )
1314+ w2 = torch .randn (16 , 64 , dtype = torch_dtype )
1315+ b2 = torch .randn (16 , dtype = torch_dtype )
1316+ w3 = torch .randn (64 , 16 , dtype = torch_dtype )
1317+ b3 = torch .randn (64 , dtype = torch_dtype )
1318+
1319+ hidden_ref = torch .nn .functional .linear (x , w1 , b1 )
1320+ activated_ref = torch .nn .functional .silu (hidden_ref )
1321+ activated_ref *= torch .nn .functional .linear (x , w3 , b3 )
1322+ out_ref = torch .nn .functional .linear (activated_ref , w2 , b2 )
1323+ out = torch .empty_like (out_ref )
1324+ out .zero_ ()
1325+ x_mem = get_ranked_memref_descriptor (x .numpy ())
1326+ w1_mem = get_ranked_memref_descriptor (w1 .numpy ())
1327+ b1_mem = get_ranked_memref_descriptor (b1 .numpy ())
1328+ w2_mem = get_ranked_memref_descriptor (w2 .numpy ())
1329+ b2_mem = get_ranked_memref_descriptor (b2 .numpy ())
1330+ w3_mem = get_ranked_memref_descriptor (w3 .numpy ())
1331+ b3_mem = get_ranked_memref_descriptor (b3 .numpy ())
1332+ out_mem = get_ranked_memref_descriptor (out .numpy ())
1333+ args = lh_utils .memrefs_to_packed_args (
1334+ [x_mem , w1_mem , b1_mem , w2_mem , b2_mem , w3_mem , b3_mem , out_mem ]
1335+ )
1336+ func_ptr (args )
1337+ assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
0 commit comments