Skip to content

Commit 2e0acb9

Browse files
committed
Simplify ctx usage
1 parent 328cf29 commit 2e0acb9

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

python/examples/mlir/memref_management.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,33 @@
1010
import lighthouse.utils as lh_utils
1111

1212

13-
def create_mlir_module(ctx: ir.Context, shape: list[int]) -> ir.Module:
14-
with ctx, ir.Location.unknown():
15-
module = ir.Module.create()
16-
with ir.InsertionPoint(module.body):
17-
mem_type = ir.MemRefType.get(shape, ir.F32Type.get())
18-
19-
# Return a new buffer initialized with input's data.
20-
@func.func(mem_type)
21-
def copy(input):
22-
new_buf = memref.alloc(mem_type, [], [])
23-
memref.copy(input, new_buf)
24-
return new_buf
25-
26-
# Free given buffer.
27-
@func.func(mem_type)
28-
def module_dealloc(input):
29-
memref.dealloc(input)
13+
def create_mlir_module(shape: list[int]) -> ir.Module:
14+
module = ir.Module.create()
15+
with ir.InsertionPoint(module.body):
16+
mem_type = ir.MemRefType.get(shape, ir.F32Type.get())
17+
18+
# Return a new buffer initialized with input's data.
19+
@func.func(mem_type)
20+
def copy(input):
21+
new_buf = memref.alloc(mem_type, [], [])
22+
memref.copy(input, new_buf)
23+
return new_buf
24+
25+
# Free given buffer.
26+
@func.func(mem_type)
27+
def module_dealloc(input):
28+
memref.dealloc(input)
3029

3130
return module
3231

3332

3433
def lower_to_llvm(operation: ir.Operation) -> None:
35-
with operation.context:
36-
pm = PassManager("builtin.module")
37-
pm.add("func.func(llvm-request-c-wrappers)")
38-
pm.add("convert-to-llvm")
39-
pm.add("reconcile-unrealized-casts")
40-
pm.add("cse")
41-
pm.add("canonicalize")
34+
pm = PassManager("builtin.module")
35+
pm.add("func.func(llvm-request-c-wrappers)")
36+
pm.add("convert-to-llvm")
37+
pm.add("reconcile-unrealized-casts")
38+
pm.add("cse")
39+
pm.add("canonicalize")
4240
pm.run(operation)
4341

4442

@@ -73,8 +71,7 @@ def main():
7371
shape = [16, 32]
7472

7573
# Create and compile test module.
76-
ctx = ir.Context()
77-
kernel = create_mlir_module(ctx, shape)
74+
kernel = create_mlir_module(shape)
7875
lower_to_llvm(kernel.operation)
7976
eng = ExecutionEngine(kernel, opt_level=3)
8077
eng.initialize()
@@ -114,4 +111,5 @@ def main():
114111

115112

116113
if __name__ == "__main__":
117-
main()
114+
with ir.Context(), ir.Location.unknown():
115+
main()

0 commit comments

Comments
 (0)