From 8ff573c0bae2439ca2960190bce4448b2f22aae8 Mon Sep 17 00:00:00 2001 From: Shay Kleiman Date: Tue, 9 Dec 2025 11:47:11 +0200 Subject: [PATCH] [mlir][memref] Support ignoring ValueRange in foldMemrefCast Currently foldMemrefCast allows passing a single operand that should be ignored and not folded. Added support for passing ValueRange instead. Since Value can be implicitly converted to ValueRange this shouldn't affect existing usage of the function. --- mlir/include/mlir/Dialect/MemRef/IR/MemRef.h | 5 +++-- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h index b7abcdea10a2a..c4f2cf2413165 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -50,8 +50,9 @@ namespace memref { /// This is a common utility used for patterns of the form /// "someop(memref.cast) -> someop". It folds the source of any memref.cast -/// into the root operation directly. -LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr); +/// into the root operation directly. Operands in `ignoredOperands` are excluded +/// from folding. +LogicalResult foldMemRefCast(Operation *op, ValueRange ignoredOperands = {}); /// Return an unranked/ranked tensor type for the given unranked/ranked memref /// type. diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1035d7cb46e6e..6b82a550668b2 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -41,12 +41,14 @@ Operation *MemRefDialect::materializeConstant(OpBuilder &builder, /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref.cast -/// into the root operation directly. -LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { +/// into the root operation directly. Operands in `ignoredOperands` are excluded +/// from folding. +LogicalResult mlir::memref::foldMemRefCast(Operation *op, + ValueRange ignoredOperands) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); - if (cast && operand.get() != inner && + if (cast && !llvm::is_contained(ignoredOperands, operand.get()) && !llvm::isa(cast.getOperand().getType())) { operand.set(cast.getOperand()); folded = true;