Skip to content

Commit 0a25818

Browse files
adam-smnkkurapov-peter
authored andcommitted
[examples][mlir] Basic MLIR compilation and execution example
Adds a simple end-to-end example demonstrating programatic transform schedule creation, MLIR JIT compilation, execution, and numerical verification of the result. Additionally, 'utils' submodule is added with basic tools to simplify creation of ctype arguments in format accepted by jitted function.
1 parent ba8aa78 commit 0a25818

File tree

2 files changed

+45
-79
lines changed

2 files changed

+45
-79
lines changed
Lines changed: 42 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
import torch
2+
<<<<<<< HEAD
23
import argparse
4+
=======
5+
import os
6+
>>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example)
37

48
from mlir import ir
59
from mlir.dialects import transform
610
from mlir.dialects.transform import structured
711
from mlir.dialects.transform import interpreter
812
from mlir.execution_engine import ExecutionEngine
13+
<<<<<<< HEAD
914
from mlir.passmanager import PassManager
15+
=======
16+
from mlir.runtime.np_to_memref import (
17+
get_ranked_memref_descriptor,
18+
)
19+
>>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example)
1020

1121
from lighthouse import utils as lh_utils
1222

@@ -35,7 +45,7 @@ def create_kernel(ctx: ir.Context) -> ir.Module:
3545
def create_schedule(ctx: ir.Context) -> ir.Module:
3646
"""
3747
Create an MLIR module containing transformation schedule.
38-
The schedule provides partial lowering to scalar operations.
48+
The schedule provides necessary steps to lower the kernel to LLVM IR.
3949
4050
Args:
4151
ctx: MLIR context.
@@ -47,26 +57,25 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
4757
ir.UnitAttr.get()
4858
)
4959

50-
# For simplicity, use generic matchers without requiring specific types.
51-
anytype = transform.any_op_t()
52-
5360
# Create entry point transformation sequence.
5461
with ir.InsertionPoint(schedule.body):
5562
named_seq = transform.NamedSequenceOp(
56-
sym_name="__transform_main",
57-
input_types=[anytype],
58-
result_types=[],
63+
"__transform_main",
64+
[transform.AnyOpType.get()],
65+
[],
5966
arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}],
6067
)
6168

6269
# Create the schedule.
6370
with ir.InsertionPoint(named_seq.body):
71+
# For simplicity, use generic transform matchers.
72+
anytype = transform.AnyOpType.get()
73+
6474
# Find the kernel's function op.
6575
func = structured.MatchOp.match_op_names(
6676
named_seq.bodyTarget, ["func.func"]
6777
)
68-
# Use C interface wrappers - required to make function executable
69-
# after jitting.
78+
# Use C interface wrappers - required to make function executable after jitting.
7079
func = transform.apply_registered_pass(
7180
anytype, func, "llvm-request-c-wrappers"
7281
)
@@ -80,16 +89,22 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
8089
anytype, mod, "convert-linalg-to-loops"
8190
)
8291
# Cleanup.
83-
transform.apply_cse(mod)
92+
transform.ApplyCommonSubexpressionEliminationOp(mod)
8493
with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns):
85-
transform.apply_patterns_canonicalization()
94+
transform.ApplyCanonicalizationPatternsOp()
95+
# Lower to LLVM.
96+
mod = transform.apply_registered_pass(anytype, mod, "convert-scf-to-cf")
97+
mod = transform.apply_registered_pass(anytype, mod, "convert-to-llvm")
98+
mod = transform.apply_registered_pass(
99+
anytype, mod, "reconcile-unrealized-casts"
100+
)
86101

87102
# Terminate the schedule.
88-
transform.yield_([])
103+
transform.YieldOp()
89104
return schedule
90105

91106

92-
def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
107+
def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> ir.Module:
93108
"""
94109
Apply transformation schedule to a kernel module.
95110
The kernel is modified in-place.
@@ -105,29 +120,8 @@ def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
105120
)
106121

107122

108-
def create_pass_pipeline(ctx: ir.Context) -> PassManager:
109-
"""
110-
Create an MLIR pass pipeline.
111-
The pipeline lowers operations further down to LLVM dialect.
112-
113-
Args:
114-
ctx: MLIR context.
115-
"""
116-
with ctx:
117-
# Create a pass manager that applies passes to the whole module.
118-
pm = PassManager("builtin.module")
119-
# Lower to LLVM.
120-
pm.add("convert-scf-to-cf")
121-
pm.add("convert-to-llvm")
122-
pm.add("reconcile-unrealized-casts")
123-
# Cleanup
124-
pm.add("cse")
125-
pm.add("canonicalize")
126-
return pm
127-
128-
129123
# The example's entry point.
130-
def main(args):
124+
def main():
131125
### Baseline computation ###
132126
# Create inputs.
133127
a = torch.randn(16, 32, dtype=torch.float32)
@@ -137,50 +131,36 @@ def main(args):
137131
out_ref = torch.add(a, b)
138132

139133
### MLIR payload preparation ###
140-
# Create payload kernel.
134+
# Create payload kernel and lowering schedule.
141135
ctx = ir.Context()
142136
kernel = create_kernel(ctx)
143-
144-
# Create a transform schedule and apply initial lowering.
145137
schedule = create_schedule(ctx)
138+
# Lower the kernel to LLVM dialect.
146139
apply_schedule(kernel, schedule)
147140

148-
# Create a pass pipeline and lower the kernel to LLVM dialect.
149-
pm = create_pass_pipeline(ctx)
150-
pm.run(kernel.operation)
151-
152141
### Compilation ###
153-
# Parse additional libraries if present.
142+
# External shared libraries, containing MLIR runner utilities, are are generally
143+
# required to execute the compiled module.
154144
#
155-
# External shared libraries, runtime utilities, might be needed to execute
156-
# the compiled module.
157-
# The execution engine requires full paths to the libraries.
158-
mlir_libs = []
159-
if args.shared_libs:
160-
mlir_libs += args.shared_libs.split(",")
145+
# Get paths to MLIR runner shared libraries through an environment variable.
146+
mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS").split(":")
161147

162148
# JIT the kernel.
163149
eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs)
164-
165-
# Initialize the JIT engine.
166-
#
167-
# The deferred initialization executes global constructors that might
168-
# have been created by the module during engine creation (for example,
169-
# when `gpu.module` is present) or registered afterwards.
170-
#
171-
# Initialization is not strictly necessary in this case.
172-
# However, it is a good practice to perform it regardless.
173-
eng.initialize()
174-
175150
# Get the kernel function.
176151
add_func = eng.lookup("add")
177152

178153
### Execution ###
154+
# Create corresponding memref descriptors containing input data.
155+
a_mem = get_ranked_memref_descriptor(a.numpy())
156+
b_mem = get_ranked_memref_descriptor(b.numpy())
157+
179158
# Create an empty buffer to hold results.
180159
out = torch.empty_like(out_ref)
160+
out_mem = get_ranked_memref_descriptor(out.numpy())
181161

182162
# Execute the kernel.
183-
args = lh_utils.torch_to_packed_args([a, b, out])
163+
args = lh_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem])
184164
add_func(args)
185165

186166
### Verification ###
@@ -192,21 +172,4 @@ def main(args):
192172

193173

194174
if __name__ == "__main__":
195-
parser = argparse.ArgumentParser()
196-
197-
# External shared libraries, runtime utilities, might be needed to
198-
# execute the compiled module.
199-
# For example, MLIR runner utils libraries such as:
200-
# - libmlir_runner_utils.so
201-
# - libmlir_c_runner_utils.so
202-
#
203-
# Full paths to the libraries should be provided.
204-
# For example:
205-
# --shared-libs=$LLVM_BUILD/lib/lib1.so,$LLVM_BUILD/lib/lib2.so
206-
parser.add_argument(
207-
"--shared-libs",
208-
type=str,
209-
help="Comma-separated list of libraries to link dynamically",
210-
)
211-
args = parser.parse_args()
212-
main(args)
175+
main()

python/lighthouse/utils/runtime_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import ctypes
2+
<<<<<<< HEAD
23
import torch
34

45
from mlir.runtime.np_to_memref import (
56
get_ranked_memref_descriptor,
67
)
8+
=======
9+
>>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example)
710

811

912
def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]:

0 commit comments

Comments
 (0)