Skip to content

Conversation

@shay-kl
Copy link
Contributor

@shay-kl shay-kl commented Dec 9, 2025

Currently foldMemrefCast allows passing a single operand that should be ignored and not folded. Added support for passing ValueRange instead. Since Value can be implicitly converted to ValueRange, this shouldn't affect existing usage of the function.

Currently foldMemrefCast allows passing a single operand that should
be ignored and not folded. Added support for passing ValueRange instead.
Since Value can be implicitly converted to ValueRange this shouldn't
affect existing usage of the function.
@llvmbot
Copy link
Member

llvmbot commented Dec 9, 2025

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Shay Kleiman (shay-kl)

Changes

Currently foldMemrefCast allows passing a single operand that should be ignored and not folded. Added support for passing ValueRange instead. Since Value can be implicitly converted to ValueRange, this shouldn't affect existing usage of the function.


Full diff: https://github.com/llvm/llvm-project/pull/171337.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRef.h (+3-2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+5-3)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index b7abcdea10a2a..c4f2cf2413165 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -50,8 +50,9 @@ namespace memref {
 
 /// This is a common utility used for patterns of the form
 /// "someop(memref.cast) -> someop". It folds the source of any memref.cast
-/// into the root operation directly.
-LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr);
+/// into the root operation directly. Operands in `ignoredOperands` are excluded
+/// from folding.
+LogicalResult foldMemRefCast(Operation *op, ValueRange ignoredOperands = {});
 
 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
 /// type.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..6b82a550668b2 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -41,12 +41,14 @@ Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
 
 /// This is a common class used for patterns of the form
 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
-/// into the root operation directly.
-LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
+/// into the root operation directly. Operands in `ignoredOperands` are excluded
+/// from folding.
+LogicalResult mlir::memref::foldMemRefCast(Operation *op,
+                                           ValueRange ignoredOperands) {
   bool folded = false;
   for (OpOperand &operand : op->getOpOperands()) {
     auto cast = operand.get().getDefiningOp<CastOp>();
-    if (cast && operand.get() != inner &&
+    if (cast && !llvm::is_contained(ignoredOperands, operand.get()) &&
         !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
       operand.set(cast.getOperand());
       folded = true;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants