[Mlir-commits] [mlir] [mlir][memref] Support ignoring ValueRange in foldMemrefCast (PR #171337)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 9 01:50:33 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Shay Kleiman (shay-kl)

<details>
<summary>Changes</summary>

Currently foldMemrefCast allows passing a single operand that should be ignored and not folded. Added support for passing ValueRange instead. Since Value can be implicitly converted to ValueRange, this shouldn't affect existing usage of the function.

---
Full diff: https://github.com/llvm/llvm-project/pull/171337.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRef.h (+3-2) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+5-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index b7abcdea10a2a..c4f2cf2413165 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -50,8 +50,9 @@ namespace memref {
 
 /// This is a common utility used for patterns of the form
 /// "someop(memref.cast) -> someop". It folds the source of any memref.cast
-/// into the root operation directly.
-LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr);
+/// into the root operation directly. Operands in `ignoredOperands` are excluded
+/// from folding.
+LogicalResult foldMemRefCast(Operation *op, ValueRange ignoredOperands = {});
 
 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
 /// type.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..6b82a550668b2 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -41,12 +41,14 @@ Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
 
 /// This is a common class used for patterns of the form
 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
-/// into the root operation directly.
-LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
+/// into the root operation directly. Operands in `ignoredOperands` are excluded
+/// from folding.
+LogicalResult mlir::memref::foldMemRefCast(Operation *op,
+                                           ValueRange ignoredOperands) {
   bool folded = false;
   for (OpOperand &operand : op->getOpOperands()) {
     auto cast = operand.get().getDefiningOp<CastOp>();
-    if (cast && operand.get() != inner &&
+    if (cast && !llvm::is_contained(ignoredOperands, operand.get()) &&
         !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
       operand.set(cast.getOperand());
       folded = true;

``````````

</details>


https://github.com/llvm/llvm-project/pull/171337


More information about the Mlir-commits mailing list