|
10 | 10 | import lighthouse.utils as lh_utils |
11 | 11 |
|
12 | 12 |
|
13 | | -def create_mlir_module(ctx: ir.Context, shape: list[int]) -> ir.Module: |
14 | | - with ctx, ir.Location.unknown(): |
15 | | - module = ir.Module.create() |
16 | | - with ir.InsertionPoint(module.body): |
17 | | - mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) |
18 | | - |
19 | | - # Return a new buffer initialized with input's data. |
20 | | - @func.func(mem_type) |
21 | | - def copy(input): |
22 | | - new_buf = memref.alloc(mem_type, [], []) |
23 | | - memref.copy(input, new_buf) |
24 | | - return new_buf |
25 | | - |
26 | | - # Free given buffer. |
27 | | - @func.func(mem_type) |
28 | | - def module_dealloc(input): |
29 | | - memref.dealloc(input) |
| 13 | +def create_mlir_module(shape: list[int]) -> ir.Module: |
| 14 | + module = ir.Module.create() |
| 15 | + with ir.InsertionPoint(module.body): |
| 16 | + mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) |
| 17 | + |
| 18 | + # Return a new buffer initialized with input's data. |
| 19 | + @func.func(mem_type) |
| 20 | + def copy(input): |
| 21 | + new_buf = memref.alloc(mem_type, [], []) |
| 22 | + memref.copy(input, new_buf) |
| 23 | + return new_buf |
| 24 | + |
| 25 | + # Free given buffer. |
| 26 | + @func.func(mem_type) |
| 27 | + def module_dealloc(input): |
| 28 | + memref.dealloc(input) |
30 | 29 |
|
31 | 30 | return module |
32 | 31 |
|
33 | 32 |
|
34 | 33 | def lower_to_llvm(operation: ir.Operation) -> None: |
35 | | - with operation.context: |
36 | | - pm = PassManager("builtin.module") |
37 | | - pm.add("func.func(llvm-request-c-wrappers)") |
38 | | - pm.add("convert-to-llvm") |
39 | | - pm.add("reconcile-unrealized-casts") |
40 | | - pm.add("cse") |
41 | | - pm.add("canonicalize") |
| 34 | + pm = PassManager("builtin.module") |
| 35 | + pm.add("func.func(llvm-request-c-wrappers)") |
| 36 | + pm.add("convert-to-llvm") |
| 37 | + pm.add("reconcile-unrealized-casts") |
| 38 | + pm.add("cse") |
| 39 | + pm.add("canonicalize") |
42 | 40 | pm.run(operation) |
43 | 41 |
|
44 | 42 |
|
@@ -73,8 +71,7 @@ def main(): |
73 | 71 | shape = [16, 32] |
74 | 72 |
|
75 | 73 | # Create and compile test module. |
76 | | - ctx = ir.Context() |
77 | | - kernel = create_mlir_module(ctx, shape) |
| 74 | + kernel = create_mlir_module(shape) |
78 | 75 | lower_to_llvm(kernel.operation) |
79 | 76 | eng = ExecutionEngine(kernel, opt_level=3) |
80 | 77 | eng.initialize() |
@@ -114,4 +111,5 @@ def main(): |
114 | 111 |
|
115 | 112 |
|
116 | 113 | if __name__ == "__main__": |
117 | | - main() |
| 114 | + with ir.Context(), ir.Location.unknown(): |
| 115 | + main() |
0 commit comments