[Mlir-commits] [mlir] [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (PR #189533)

Ming Yan llvmlistbot at llvm.org
Tue Mar 31 02:03:03 PDT 2026


================
@@ -2282,29 +2282,28 @@ struct ReinterpretCastOpConstantFolder
     SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
     SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
 
-    // TODO: Using counting comparison instead of direct comparison because
-    // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
-    // IntegerAttrs, while constifyIndexValues (and therefore
-    // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
-    if (srcStaticCount ==
-        llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
-                       [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
-      return failure();
-
     // Do not fold if the offset is a negative constant; ViewLikeInterface
     // verifies that static offsets are non-negative.
     if (auto cst = getConstantIntValue(offsets[0]))
       if (*cst < 0)
-        return rewriter.notifyMatchFailure(
-            op, "negative constant offset is invalid");
+        offsets[0] = op.getMixedOffsets()[0];
 
     // Do not fold if any size is a negative constant; MemRefType::get asserts
     // non-negative static sizes.
-    for (OpFoldResult sizeOfr : sizes)
+    for (auto [srcSizeOfr, sizeOfr] : llvm::zip(op.getMixedSizes(), sizes)) {
       if (auto cst = getConstantIntValue(sizeOfr))
         if (*cst < 0)
-          return rewriter.notifyMatchFailure(
-              op, "negative constant size is invalid");
+          sizeOfr = srcSizeOfr;
+    }
+
+    // TODO: Using counting comparison instead of direct comparison because
+    // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
+    // IntegerAttrs, while constifyIndexValues (and therefore
+    // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
+    if (srcStaticCount ==
+        llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
+                       [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
+      return failure();
 
----------------
NexMing wrote:

Right.

https://github.com/llvm/llvm-project/pull/189533


More information about the Mlir-commits mailing list