11import torch
2- < << << << HEAD
3- import argparse
4- == == == =
52import os
6- > >> >> >> 9647 a7f ([examples ][mlir ] Basic MLIR compilation and execution example )
73
84from mlir import ir
95from mlir .dialects import transform
106from mlir .dialects .transform import structured
117from mlir .dialects .transform import interpreter
128from mlir .execution_engine import ExecutionEngine
13- < << << << HEAD
149from mlir .passmanager import PassManager
15- == == == =
1610from mlir .runtime .np_to_memref import (
1711 get_ranked_memref_descriptor ,
1812)
19- > >> >> >> 9647 a7f ([examples ][mlir ] Basic MLIR compilation and execution example )
2013
2114from lighthouse import utils as lh_utils
2215
@@ -45,7 +38,7 @@ def create_kernel(ctx: ir.Context) -> ir.Module:
4538def create_schedule (ctx : ir .Context ) -> ir .Module :
4639 """
4740 Create an MLIR module containing transformation schedule.
48- The schedule provides necessary steps to lower the kernel to LLVM IR .
41+ The schedule provides partial lowering to scalar operations .
4942
5043 Args:
5144 ctx: MLIR context.
@@ -92,12 +85,6 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
9285 transform .ApplyCommonSubexpressionEliminationOp (mod )
9386 with ir .InsertionPoint (transform .ApplyPatternsOp (mod ).patterns ):
9487 transform .ApplyCanonicalizationPatternsOp ()
95- # Lower to LLVM.
96- mod = transform .apply_registered_pass (anytype , mod , "convert-scf-to-cf" )
97- mod = transform .apply_registered_pass (anytype , mod , "convert-to-llvm" )
98- mod = transform .apply_registered_pass (
99- anytype , mod , "reconcile-unrealized-casts"
100- )
10188
10289 # Terminate the schedule.
10390 transform .YieldOp ()
@@ -120,6 +107,27 @@ def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
120107 )
121108
122109
110+ def create_pass_pipeline (ctx : ir .Context ) -> PassManager :
111+ """
112+ Create an MLIR pass pipeline.
113+ The pipeline lowers operations further down to LLVM dialect.
114+
115+ Args:
116+ ctx: MLIR context.
117+ """
118+ with ctx :
119+ # Create a pass manager that applies passes to the whole module.
120+ pm = PassManager ("builtin.module" )
121+ # Lower to LLVM.
122+ pm .add ("convert-scf-to-cf" )
123+ pm .add ("convert-to-llvm" )
124+ pm .add ("reconcile-unrealized-casts" )
125+ # Cleanup
126+ pm .add ("cse" )
127+ pm .add ("canonicalize" )
128+ return pm
129+
130+
123131# The example's entry point.
124132def main ():
125133 ### Baseline computation ###
@@ -131,20 +139,24 @@ def main():
131139 out_ref = torch .add (a , b )
132140
133141 ### MLIR payload preparation ###
134- # Create payload kernel and lowering schedule .
142+ # Create payload kernel.
135143 ctx = ir .Context ()
136144 kernel = create_kernel (ctx )
137- schedule = create_schedule (ctx )
138145
139- # Lower the kernel to LLVM dialect.
146+ # Create a transform schedule and apply initial lowering.
147+ schedule = create_schedule (ctx )
140148 apply_schedule (kernel , schedule )
141149
150+ # Create a pass pipeline and lower the kernel to LLVM dialect.
151+ pm = create_pass_pipeline (ctx )
152+ pm .run (kernel .operation )
153+
142154 ### Compilation ###
143155 # External shared libraries, containing MLIR runner utilities, are are generally
144156 # required to execute the compiled module.
145157 #
146158 # Get paths to MLIR runner shared libraries through an environment variable.
147- mlir_libs = os .environ .get ("LIGHTHOUSE_SHARED_LIBS" ).split (":" )
159+ mlir_libs = os .environ .get ("LIGHTHOUSE_SHARED_LIBS" , default = "" ).split (":" )
148160
149161 # JIT the kernel.
150162 eng = ExecutionEngine (kernel , opt_level = 2 , shared_libs = mlir_libs )
0 commit comments