11import torch
2+ < << << << HEAD
23import argparse
4+ == == == =
5+ import os
6+ > >> >> >> 9647 a7f ([examples ][mlir ] Basic MLIR compilation and execution example )
37
48from mlir import ir
59from mlir .dialects import transform
610from mlir .dialects .transform import structured
711from mlir .dialects .transform import interpreter
812from mlir .execution_engine import ExecutionEngine
13+ < << << << HEAD
914from mlir .passmanager import PassManager
15+ == == == =
16+ from mlir .runtime .np_to_memref import (
17+ get_ranked_memref_descriptor ,
18+ )
19+ > >> >> >> 9647 a7f ([examples ][mlir ] Basic MLIR compilation and execution example )
1020
1121from lighthouse import utils as lh_utils
1222
@@ -35,7 +45,7 @@ def create_kernel(ctx: ir.Context) -> ir.Module:
3545def create_schedule (ctx : ir .Context ) -> ir .Module :
3646 """
3747 Create an MLIR module containing transformation schedule.
38- The schedule provides partial lowering to scalar operations .
48+ The schedule provides necessary steps to lower the kernel to LLVM IR .
3949
4050 Args:
4151 ctx: MLIR context.
@@ -47,26 +57,25 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
4757 ir .UnitAttr .get ()
4858 )
4959
50- # For simplicity, use generic matchers without requiring specific types.
51- anytype = transform .any_op_t ()
52-
5360 # Create entry point transformation sequence.
5461 with ir .InsertionPoint (schedule .body ):
5562 named_seq = transform .NamedSequenceOp (
56- sym_name = "__transform_main" ,
57- input_types = [ anytype ],
58- result_types = [],
63+ "__transform_main" ,
64+ [ transform . AnyOpType . get () ],
65+ [],
5966 arg_attrs = [{"transform.readonly" : ir .UnitAttr .get ()}],
6067 )
6168
6269 # Create the schedule.
6370 with ir .InsertionPoint (named_seq .body ):
71+ # For simplicity, use generic transform matchers.
72+ anytype = transform .AnyOpType .get ()
73+
6474 # Find the kernel's function op.
6575 func = structured .MatchOp .match_op_names (
6676 named_seq .bodyTarget , ["func.func" ]
6777 )
68- # Use C interface wrappers - required to make function executable
69- # after jitting.
78+ # Use C interface wrappers - required to make function executable after jitting.
7079 func = transform .apply_registered_pass (
7180 anytype , func , "llvm-request-c-wrappers"
7281 )
@@ -80,16 +89,22 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
8089 anytype , mod , "convert-linalg-to-loops"
8190 )
8291 # Cleanup.
83- transform .apply_cse (mod )
92+ transform .ApplyCommonSubexpressionEliminationOp (mod )
8493 with ir .InsertionPoint (transform .ApplyPatternsOp (mod ).patterns ):
85- transform .apply_patterns_canonicalization ()
94+ transform .ApplyCanonicalizationPatternsOp ()
95+ # Lower to LLVM.
96+ mod = transform .apply_registered_pass (anytype , mod , "convert-scf-to-cf" )
97+ mod = transform .apply_registered_pass (anytype , mod , "convert-to-llvm" )
98+ mod = transform .apply_registered_pass (
99+ anytype , mod , "reconcile-unrealized-casts"
100+ )
86101
87102 # Terminate the schedule.
88- transform .yield_ ([] )
103+ transform .YieldOp ( )
89104 return schedule
90105
91106
92- def apply_schedule (kernel : ir .Module , schedule : ir .Module ) -> None :
107+ def apply_schedule (kernel : ir .Module , schedule : ir .Module ) -> ir . Module :
93108 """
94109 Apply transformation schedule to a kernel module.
95110 The kernel is modified in-place.
@@ -105,29 +120,8 @@ def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
105120 )
106121
107122
108- def create_pass_pipeline (ctx : ir .Context ) -> PassManager :
109- """
110- Create an MLIR pass pipeline.
111- The pipeline lowers operations further down to LLVM dialect.
112-
113- Args:
114- ctx: MLIR context.
115- """
116- with ctx :
117- # Create a pass manager that applies passes to the whole module.
118- pm = PassManager ("builtin.module" )
119- # Lower to LLVM.
120- pm .add ("convert-scf-to-cf" )
121- pm .add ("convert-to-llvm" )
122- pm .add ("reconcile-unrealized-casts" )
123- # Cleanup
124- pm .add ("cse" )
125- pm .add ("canonicalize" )
126- return pm
127-
128-
129123# The example's entry point.
130- def main (args ):
124+ def main ():
131125 ### Baseline computation ###
132126 # Create inputs.
133127 a = torch .randn (16 , 32 , dtype = torch .float32 )
@@ -137,50 +131,36 @@ def main(args):
137131 out_ref = torch .add (a , b )
138132
139133 ### MLIR payload preparation ###
140- # Create payload kernel.
134+ # Create payload kernel and lowering schedule .
141135 ctx = ir .Context ()
142136 kernel = create_kernel (ctx )
143-
144- # Create a transform schedule and apply initial lowering.
145137 schedule = create_schedule (ctx )
138+ # Lower the kernel to LLVM dialect.
146139 apply_schedule (kernel , schedule )
147140
148- # Create a pass pipeline and lower the kernel to LLVM dialect.
149- pm = create_pass_pipeline (ctx )
150- pm .run (kernel .operation )
151-
152141 ### Compilation ###
153- # Parse additional libraries if present.
142+ # External shared libraries, containing MLIR runner utilities, are are generally
143+ # required to execute the compiled module.
154144 #
155- # External shared libraries, runtime utilities, might be needed to execute
156- # the compiled module.
157- # The execution engine requires full paths to the libraries.
158- mlir_libs = []
159- if args .shared_libs :
160- mlir_libs += args .shared_libs .split ("," )
145+ # Get paths to MLIR runner shared libraries through an environment variable.
146+ mlir_libs = os .environ .get ("LIGHTHOUSE_SHARED_LIBS" ).split (":" )
161147
162148 # JIT the kernel.
163149 eng = ExecutionEngine (kernel , opt_level = 2 , shared_libs = mlir_libs )
164-
165- # Initialize the JIT engine.
166- #
167- # The deferred initialization executes global constructors that might
168- # have been created by the module during engine creation (for example,
169- # when `gpu.module` is present) or registered afterwards.
170- #
171- # Initialization is not strictly necessary in this case.
172- # However, it is a good practice to perform it regardless.
173- eng .initialize ()
174-
175150 # Get the kernel function.
176151 add_func = eng .lookup ("add" )
177152
178153 ### Execution ###
154+ # Create corresponding memref descriptors containing input data.
155+ a_mem = get_ranked_memref_descriptor (a .numpy ())
156+ b_mem = get_ranked_memref_descriptor (b .numpy ())
157+
179158 # Create an empty buffer to hold results.
180159 out = torch .empty_like (out_ref )
160+ out_mem = get_ranked_memref_descriptor (out .numpy ())
181161
182162 # Execute the kernel.
183- args = lh_utils .torch_to_packed_args ([ a , b , out ])
163+ args = lh_utils .memrefs_to_packed_args ([ a_mem , b_mem , out_mem ])
184164 add_func (args )
185165
186166 ### Verification ###
@@ -192,21 +172,4 @@ def main(args):
192172
193173
194174if __name__ == "__main__" :
195- parser = argparse .ArgumentParser ()
196-
197- # External shared libraries, runtime utilities, might be needed to
198- # execute the compiled module.
199- # For example, MLIR runner utils libraries such as:
200- # - libmlir_runner_utils.so
201- # - libmlir_c_runner_utils.so
202- #
203- # Full paths to the libraries should be provided.
204- # For example:
205- # --shared-libs=$LLVM_BUILD/lib/lib1.so,$LLVM_BUILD/lib/lib2.so
206- parser .add_argument (
207- "--shared-libs" ,
208- type = str ,
209- help = "Comma-separated list of libraries to link dynamically" ,
210- )
211- args = parser .parse_args ()
212- main (args )
175+ main ()
0 commit comments