Skip to content

Commit 7416cbf

Browse files
adam-smnkkurapov-peter
authored andcommitted
Split lowering and add pass pipeline example
1 parent 4fb023f commit 7416cbf

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

python/examples/mlir/compile_and_run.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,15 @@
11
import torch
2-
<<<<<<< HEAD
3-
import argparse
4-
=======
52
import os
6-
>>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example)
73

84
from mlir import ir
95
from mlir.dialects import transform
106
from mlir.dialects.transform import structured
117
from mlir.dialects.transform import interpreter
128
from mlir.execution_engine import ExecutionEngine
13-
<<<<<<< HEAD
149
from mlir.passmanager import PassManager
15-
=======
1610
from mlir.runtime.np_to_memref import (
1711
get_ranked_memref_descriptor,
1812
)
19-
>>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example)
2013

2114
from lighthouse import utils as lh_utils
2215

@@ -45,7 +38,7 @@ def create_kernel(ctx: ir.Context) -> ir.Module:
4538
def create_schedule(ctx: ir.Context) -> ir.Module:
4639
"""
4740
Create an MLIR module containing transformation schedule.
48-
The schedule provides necessary steps to lower the kernel to LLVM IR.
41+
The schedule provides partial lowering to scalar operations.
4942
5043
Args:
5144
ctx: MLIR context.
@@ -92,12 +85,6 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
9285
transform.ApplyCommonSubexpressionEliminationOp(mod)
9386
with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns):
9487
transform.ApplyCanonicalizationPatternsOp()
95-
# Lower to LLVM.
96-
mod = transform.apply_registered_pass(anytype, mod, "convert-scf-to-cf")
97-
mod = transform.apply_registered_pass(anytype, mod, "convert-to-llvm")
98-
mod = transform.apply_registered_pass(
99-
anytype, mod, "reconcile-unrealized-casts"
100-
)
10188

10289
# Terminate the schedule.
10390
transform.YieldOp()
@@ -120,6 +107,27 @@ def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
120107
)
121108

122109

110+
def create_pass_pipeline(ctx: ir.Context) -> PassManager:
111+
"""
112+
Create an MLIR pass pipeline.
113+
The pipeline lowers operations further down to LLVM dialect.
114+
115+
Args:
116+
ctx: MLIR context.
117+
"""
118+
with ctx:
119+
# Create a pass manager that applies passes to the whole module.
120+
pm = PassManager("builtin.module")
121+
# Lower to LLVM.
122+
pm.add("convert-scf-to-cf")
123+
pm.add("convert-to-llvm")
124+
pm.add("reconcile-unrealized-casts")
125+
# Cleanup
126+
pm.add("cse")
127+
pm.add("canonicalize")
128+
return pm
129+
130+
123131
# The example's entry point.
124132
def main():
125133
### Baseline computation ###
@@ -131,20 +139,24 @@ def main():
131139
out_ref = torch.add(a, b)
132140

133141
### MLIR payload preparation ###
134-
# Create payload kernel and lowering schedule.
142+
# Create payload kernel.
135143
ctx = ir.Context()
136144
kernel = create_kernel(ctx)
137-
schedule = create_schedule(ctx)
138145

139-
# Lower the kernel to LLVM dialect.
146+
# Create a transform schedule and apply initial lowering.
147+
schedule = create_schedule(ctx)
140148
apply_schedule(kernel, schedule)
141149

150+
# Create a pass pipeline and lower the kernel to LLVM dialect.
151+
pm = create_pass_pipeline(ctx)
152+
pm.run(kernel.operation)
153+
142154
### Compilation ###
143155
# External shared libraries, containing MLIR runner utilities, are are generally
144156
# required to execute the compiled module.
145157
#
146158
# Get paths to MLIR runner shared libraries through an environment variable.
147-
mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS").split(":")
159+
mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS", default="").split(":")
148160

149161
# JIT the kernel.
150162
eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs)

0 commit comments

Comments
 (0)