Skip to content

Commit bd23099

Browse files
committed
Add feed forward
1 parent 2f80f17 commit bd23099

File tree

1 file changed

+97
-5
lines changed

1 file changed

+97
-5
lines changed

ingress/mlir-gen/mlir_gen/test/test_core.py

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -326,32 +326,34 @@ def get_matmul(a: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value:
326326

327327
# torch.nn.functional.linear
328328
def get_linear(a: ir.Value, w: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value:
329+
elty = out.type.element_type
330+
zero = arith.constant(elty, 0.0)
331+
out_zeroed = linalg.fill(zero, outs=[out])
332+
329333
# a[i, k] * w[j, k] -> out[i, j]
330334
i, j, k = [ir.AffineDimExpr.get(d) for d in range(3)]
331335
a_map = affine_map(3, [i, k]) # (batch, in_feat)
332-
w_map = affine_map(3, [j, k]) # (out_feat, in_feat) - note: we use j for first dim
336+
w_map = affine_map(3, [j, k]) # (out_feat, in_feat)
333337
out_map = affine_map(3, [i, j]) # (batch, out_feat)
334338

335-
# First compute the matmul into out (which will accumulate)
336339
@linalg.generic(
337340
[a, w],
338-
[out],
341+
[out_zeroed],
339342
[a_map, w_map, out_map],
340343
[parallel, parallel, reduction],
341344
)
342345
def matmul_op(a_elem, w_elem, out_elem):
343346
prod = arith.MulFOp(a_elem, w_elem).result
344347
return arith.AddFOp(out_elem, prod).result
345348

346-
# Step 2: Add bias using broadcasting
347349
# b[j] -> out[i, j]
348350
i2, j2 = [ir.AffineDimExpr.get(d) for d in range(2)]
349351
b_map = affine_map(2, [j2]) # (out_feat,)
350352
out_map2 = affine_map(2, [i2, j2]) # (batch, out_feat)
351353

352354
@linalg.generic(
353355
[matmul_op, b],
354-
[out],
356+
[out_zeroed],
355357
[out_map2, b_map, out_map2],
356358
[parallel, parallel],
357359
)
@@ -1243,3 +1245,93 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out):
12431245

12441246
assert torch.allclose(out1, xq_out, rtol=0.01, atol=0.01, equal_nan=True)
12451247
assert torch.allclose(out2, xk_out, rtol=0.01, atol=0.01, equal_nan=True)
1248+
1249+
1250+
def test_feed_forward():
1251+
def generate_module(ctx, elty):
1252+
with ctx, ir.Location.unknown():
1253+
module = ir.Module.create()
1254+
with ir.InsertionPoint(module.body):
1255+
input_type = ir.RankedTensorType.get((4, 16), elty)
1256+
hidden_type = ir.RankedTensorType.get((4, 64), elty)
1257+
output_type = ir.RankedTensorType.get((4, 16), elty)
1258+
weight1_type = ir.RankedTensorType.get((64, 16), elty)
1259+
bias1_type = ir.RankedTensorType.get((64,), elty)
1260+
weight2_type = ir.RankedTensorType.get((16, 64), elty)
1261+
bias2_type = ir.RankedTensorType.get((16,), elty)
1262+
weight3_type = ir.RankedTensorType.get((64, 16), elty)
1263+
bias3_type = ir.RankedTensorType.get((64,), elty)
1264+
1265+
@func.FuncOp.from_py_func(
1266+
input_type,
1267+
weight1_type,
1268+
bias1_type,
1269+
weight2_type,
1270+
bias2_type,
1271+
weight3_type,
1272+
bias3_type,
1273+
output_type,
1274+
name="feed_forward",
1275+
)
1276+
def feed_forward(x, w1, b1, w2, b2, w3, b3, out):
1277+
# Compute hidden = linear(x, w1, b1)
1278+
hidden_uninit = tensor.EmptyOp(hidden_type.shape, elty).result
1279+
hidden = get_linear(x, w1, b1, hidden_uninit)
1280+
1281+
# Compute hidden_silu = silu(hidden)
1282+
hidden_silu_uninit = tensor.EmptyOp(hidden_type.shape, elty).result
1283+
hidden_silu = get_silu(hidden, hidden_silu_uninit)
1284+
1285+
# Compute gate = linear(x, w3, b3)
1286+
gate_uninit = tensor.EmptyOp(hidden_type.shape, elty).result
1287+
gate = get_linear(x, w3, b3, gate_uninit)
1288+
1289+
# Compute activated = hidden_silu * gate
1290+
activated_uninit = tensor.EmptyOp(hidden_type.shape, elty).result
1291+
activated = get_mul(hidden_silu, gate, activated_uninit)
1292+
1293+
# Compute out = linear(activated, w2, b2)
1294+
get_linear(activated, w2, b2, out)
1295+
1296+
return module
1297+
1298+
ctx = ir.Context()
1299+
ir_type = to_ir_type("f32", ctx)
1300+
module = generate_module(ctx, ir_type)
1301+
bufferize_module(ctx, module)
1302+
schedule = create_schedule(ctx)
1303+
apply_schedule(module, schedule)
1304+
pm = create_pass_pipeline(ctx)
1305+
pm.run(module.operation)
1306+
1307+
eng = ExecutionEngine(module, opt_level=2)
1308+
func_ptr = eng.lookup("feed_forward")
1309+
1310+
torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
1311+
x = torch.randn(4, 16, dtype=torch_dtype)
1312+
w1 = torch.randn(64, 16, dtype=torch_dtype)
1313+
b1 = torch.randn(64, dtype=torch_dtype)
1314+
w2 = torch.randn(16, 64, dtype=torch_dtype)
1315+
b2 = torch.randn(16, dtype=torch_dtype)
1316+
w3 = torch.randn(64, 16, dtype=torch_dtype)
1317+
b3 = torch.randn(64, dtype=torch_dtype)
1318+
1319+
hidden_ref = torch.nn.functional.linear(x, w1, b1)
1320+
activated_ref = torch.nn.functional.silu(hidden_ref)
1321+
activated_ref *= torch.nn.functional.linear(x, w3, b3)
1322+
out_ref = torch.nn.functional.linear(activated_ref, w2, b2)
1323+
out = torch.empty_like(out_ref)
1324+
out.zero_()
1325+
x_mem = get_ranked_memref_descriptor(x.numpy())
1326+
w1_mem = get_ranked_memref_descriptor(w1.numpy())
1327+
b1_mem = get_ranked_memref_descriptor(b1.numpy())
1328+
w2_mem = get_ranked_memref_descriptor(w2.numpy())
1329+
b2_mem = get_ranked_memref_descriptor(b2.numpy())
1330+
w3_mem = get_ranked_memref_descriptor(w3.numpy())
1331+
b3_mem = get_ranked_memref_descriptor(b3.numpy())
1332+
out_mem = get_ranked_memref_descriptor(out.numpy())
1333+
args = lh_utils.memrefs_to_packed_args(
1334+
[x_mem, w1_mem, b1_mem, w2_mem, b2_mem, w3_mem, b3_mem, out_mem]
1335+
)
1336+
func_ptr(args)
1337+
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)

0 commit comments

Comments
 (0)