From 2fb2d22503a3ab9fbb4d3be6ec0f3b3040c0a1e6 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 14 Nov 2025 12:38:03 +0100 Subject: [PATCH 1/3] [python][utils] MemRef Manager Adds a utility for manual memory management of memref buffers across Python and jitted MLIR modules. Explicit memory management becomes required when an MLIR function returns a newly allocated buffer e.g., results of a computation. This can become a complex task due to difference in memory models between Python and the MLIR runtime allocators. By default, returned MLIR buffers' lifetime cannot be automatically managed by the Python environment. The Python memref manager aims to address the following challenges: - use of the same runtime allocators as a jitted MLIR module for consistent memory management - lean abstraction using memref descriptors directly - buffers usable both by Python and jitted MLIR modules Current implementation assumes that memref allocation ops are lowered to standard C functions, like 'malloc' and 'free', which are preloaded together with the Python process. --- python/examples/mlir/memref_management.py | 119 ++++++++++++++++++++++ python/lighthouse/utils/__init__.py | 2 + python/lighthouse/utils/memref_manager.py | 98 ++++++++++++++++++ 3 files changed, 219 insertions(+) create mode 100644 python/examples/mlir/memref_management.py create mode 100644 python/lighthouse/utils/memref_manager.py diff --git a/python/examples/mlir/memref_management.py b/python/examples/mlir/memref_management.py new file mode 100644 index 0000000..9cf4b4e --- /dev/null +++ b/python/examples/mlir/memref_management.py @@ -0,0 +1,119 @@ +# RUN: %PYTHON %s + +import torch +import ctypes + +from mlir import ir +from mlir.dialects import func, memref +from mlir.runtime import np_to_memref +from mlir.execution_engine import ExecutionEngine +from mlir.passmanager import PassManager + +import lighthouse.utils as lh_utils + + +def create_mlir_module(ctx: ir.Context, shape: list[int]) -> ir.Module: + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) + + # Return a new buffer initialized with input's data. + @func.func(mem_type) + def copy(input): + new_buf = memref.alloc(mem_type, [], []) + memref.copy(input, new_buf) + return new_buf + + # Free given buffer. + @func.func(mem_type) + def module_dealloc(input): + memref.dealloc(input) + + return module + + +def lower_to_llvm(operation: ir.Operation) -> None: + with operation.context: + pm = PassManager("builtin.module") + pm.add("func.func(llvm-request-c-wrappers)") + pm.add("convert-to-llvm") + pm.add("reconcile-unrealized-casts") + pm.add("cse") + pm.add("canonicalize") + pm.run(operation) + + +def main(): + # Validate basic functionality. + print("Testing memref allocator...") + mem = lh_utils.MemRefManager() + # Check allocation. + buf = mem.alloc(32, 8, 16, ctype=ctypes.c_float) + assert buf.allocated != 0, "Invalid allocation" + assert list(buf.shape) == [32, 8, 16], "Invalid shape" + assert list(buf.strides) == [128, 16, 1], "Invalid strides" + # Check deallocation. + mem.dealloc(buf) + assert buf.allocated == 0, "Failed deallocation" + # Double free must not crash. + mem.dealloc(buf) + + # Zero rank buffer. + buf = mem.alloc(ctype=ctypes.c_float) + mem.dealloc(buf) + # Small buffer. + buf = mem.alloc(8, ctype=ctypes.c_int8) + mem.dealloc(buf) + # Large buffer. + buf = mem.alloc(1024, 1024, ctype=ctypes.c_int32) + mem.dealloc(buf) + + # Validate functionality across Python-MLIR boundary. + print("Testing JIT module memory management...") + # Buffer shape for testing. + shape = [16, 32] + + # Create and compile test module. + ctx = ir.Context() + kernel = create_mlir_module(ctx, shape) + lower_to_llvm(kernel.operation) + eng = ExecutionEngine(kernel, opt_level=3) + eng.initialize() + + # Validate passing memrefs between Python and jitted module. + print("...copy test...") + fn_copy = eng.lookup("copy") + + # Alloc buffer in Python and initialize it. + in_mem = mem.alloc(*shape, ctype=ctypes.c_float) + in_np = np_to_memref.ranked_memref_to_numpy([in_mem]) + assert not in_np.flags.owndata, "Expected non-owning memref conversion" + in_tensor = torch.from_numpy(in_np) + torch.randn(in_tensor.shape, out=in_tensor) + + out_mem = np_to_memref.make_nd_memref_descriptor(in_tensor.dim(), ctypes.c_float)() + out_mem.allocated = 0 + + args = lh_utils.memrefs_to_packed_args([out_mem, in_mem]) + fn_copy(args) + assert out_mem.allocated != 0, "Invalid buffer returned" + + out_tensor = torch.from_numpy(np_to_memref.ranked_memref_to_numpy([out_mem])) + torch.testing.assert_close(out_tensor, in_tensor) + + mem.dealloc(out_mem) + assert out_mem.allocated == 0, "Failed to dealloc returned buffer" + mem.dealloc(in_mem) + + # Validate external allocation with deallocation from within jitted module. + print("...dealloc test...") + fn_mlir_dealloc = eng.lookup("module_dealloc") + buf_mem = mem.alloc(*shape, ctype=ctypes.c_float) + fn_mlir_dealloc(lh_utils.memrefs_to_packed_args([buf_mem])) + + print("SUCCESS") + + +if __name__ == "__main__": + main() diff --git a/python/lighthouse/utils/__init__.py b/python/lighthouse/utils/__init__.py index 4cff9a5..9714f3b 100644 --- a/python/lighthouse/utils/__init__.py +++ b/python/lighthouse/utils/__init__.py @@ -1,5 +1,7 @@ """A collection of utility tools""" +from .memref_manager import MemRefManager + from .runtime_args import ( get_packed_arg, memref_to_ctype, diff --git a/python/lighthouse/utils/memref_manager.py b/python/lighthouse/utils/memref_manager.py new file mode 100644 index 0000000..e10243d --- /dev/null +++ b/python/lighthouse/utils/memref_manager.py @@ -0,0 +1,98 @@ +import ctypes + +from itertools import accumulate +from functools import reduce +import operator + +import mlir.runtime.np_to_memref as np_mem + + +class MemRefManager: + """ + A utility class for manual management of MLIR memrefs. + + When used together with memref operation from within a jitted MLIR module, + it is assumed that Memref dialect allocations and deallocation are performed + through standard runtime `malloc` and `free` functions. + + Custom allocators are currently not supported. For more details, see: + https://mlir.llvm.org/docs/TargetLLVMIR/#generic-alloction-and-deallocation-functions + """ + + def __init__(self) -> None: + # Library name is left unspecified to allow for symbol search + # in the global symbol table of the current process. + # For more details, see: + # https://github.com/python/cpython/issues/78773 + self.dll = ctypes.CDLL(name=None) + self.fn_malloc = self.dll.malloc + self.fn_malloc.argtypes = [ctypes.c_size_t] + self.fn_malloc.restype = ctypes.c_void_p + self.fn_free = self.dll.free + self.fn_free.argtypes = [ctypes.c_void_p] + self.fn_free.restype = None + + def alloc(self, *shape: int, ctype: ctypes._SimpleCData) -> ctypes.Structure: + """ + Allocate an empty memory buffer. + Returns an MLIR ranked memref descriptor. + + Args: + shape: A sequence of integers defining the buffer's shape. + ctype: A C type of buffer's elements. + """ + assert issubclass(ctype, ctypes._SimpleCData), "Expected a simple data ctype" + size_bytes = reduce(operator.mul, shape, ctypes.sizeof(ctype)) + buf = self.fn_malloc(size_bytes) + assert buf, "Failed to allocate memory" + + rank = len(shape) + if rank == 0: + desc = np_mem.make_zero_d_memref_descriptor(ctype)() + desc.allocated = buf + desc.aligned = ctypes.cast(buf, ctypes.POINTER(ctype)) + desc.offset = ctypes.c_longlong(0) + return desc + + desc = np_mem.make_nd_memref_descriptor(rank, ctype)() + desc.allocated = buf + desc.aligned = ctypes.cast(buf, ctypes.POINTER(ctype)) + desc.offset = ctypes.c_longlong(0) + shape_ctype_t = ctypes.c_longlong * rank + desc.shape = shape_ctype_t(*shape) + + strides = list(accumulate(reversed(shape[1:]), func=operator.mul)) + strides.reverse() + strides.append(1) + desc.strides = shape_ctype_t(*strides) + return desc + + def dealloc(self, memref_desc: ctypes.Structure) -> None: + """ + Free underlying memory buffer. + + Args: + memref_desc: An MLIR memref descriptor. + """ + # TODO: Expose upstream MemrefDescriptor classes for easier handling + assert memref_desc.__class__.__name__ == "MemRefDescriptor" or isinstance( + memref_desc, np_mem.UnrankedMemRefDescriptor + ), "Invalid memref descriptor" + + if isinstance(memref_desc, np_mem.UnrankedMemRefDescriptor): + # Unranked memref holds the underlying descriptor as an opaque pointer. + # Cast the descriptor to a zero ranked memref with an arbitrary type to + # access the base allocated memory pointer. + ranked_desc_type = np_mem.make_zero_d_memref_descriptor(ctypes.c_char) + ranked_desc = ctypes.cast( + memref_desc.descriptor, ctypes.POINTER(ranked_desc_type) + ) + memref_desc = ranked_desc[0] + + alloc_ptr = memref_desc.allocated + if alloc_ptr == 0: + return + + c_ptr = ctypes.cast(alloc_ptr, ctypes.c_void_p) + self.fn_free(c_ptr) + memref_desc.allocated = 0 From be45a896976555ac2d0d410162557fc811448f84 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 20 Nov 2025 10:40:27 +0100 Subject: [PATCH 2/3] Simplify ctx usage --- python/examples/mlir/memref_management.py | 52 +++++++++++------------ 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/python/examples/mlir/memref_management.py b/python/examples/mlir/memref_management.py index 9cf4b4e..2e6e859 100644 --- a/python/examples/mlir/memref_management.py +++ b/python/examples/mlir/memref_management.py @@ -12,35 +12,33 @@ import lighthouse.utils as lh_utils -def create_mlir_module(ctx: ir.Context, shape: list[int]) -> ir.Module: - with ctx, ir.Location.unknown(): - module = ir.Module.create() - with ir.InsertionPoint(module.body): - mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) - - # Return a new buffer initialized with input's data. - @func.func(mem_type) - def copy(input): - new_buf = memref.alloc(mem_type, [], []) - memref.copy(input, new_buf) - return new_buf - - # Free given buffer. - @func.func(mem_type) - def module_dealloc(input): - memref.dealloc(input) +def create_mlir_module(shape: list[int]) -> ir.Module: + module = ir.Module.create() + with ir.InsertionPoint(module.body): + mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) + + # Return a new buffer initialized with input's data. + @func.func(mem_type) + def copy(input): + new_buf = memref.alloc(mem_type, [], []) + memref.copy(input, new_buf) + return new_buf + + # Free given buffer. + @func.func(mem_type) + def module_dealloc(input): + memref.dealloc(input) return module def lower_to_llvm(operation: ir.Operation) -> None: - with operation.context: - pm = PassManager("builtin.module") - pm.add("func.func(llvm-request-c-wrappers)") - pm.add("convert-to-llvm") - pm.add("reconcile-unrealized-casts") - pm.add("cse") - pm.add("canonicalize") + pm = PassManager("builtin.module") + pm.add("func.func(llvm-request-c-wrappers)") + pm.add("convert-to-llvm") + pm.add("reconcile-unrealized-casts") + pm.add("cse") + pm.add("canonicalize") pm.run(operation) @@ -75,8 +73,7 @@ def main(): shape = [16, 32] # Create and compile test module. - ctx = ir.Context() - kernel = create_mlir_module(ctx, shape) + kernel = create_mlir_module(shape) lower_to_llvm(kernel.operation) eng = ExecutionEngine(kernel, opt_level=3) eng.initialize() @@ -116,4 +113,5 @@ def main(): if __name__ == "__main__": - main() + with ir.Context(), ir.Location.unknown(): + main() From 62cce0d98032a6125aed62ada9cfaaead2b4de38 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 25 Nov 2025 13:24:53 +0100 Subject: [PATCH 3/3] Apply formatting --- python/lighthouse/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/lighthouse/utils/__init__.py b/python/lighthouse/utils/__init__.py index 9714f3b..5ea3d0c 100644 --- a/python/lighthouse/utils/__init__.py +++ b/python/lighthouse/utils/__init__.py @@ -11,6 +11,7 @@ ) __all__ = [ + "MemRefManager", "get_packed_arg", "memref_to_ctype", "memrefs_to_packed_args",