[Mlir-commits] [mlir] [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (PR #189533)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 30 22:34:39 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ming Yan (NexMing)

<details>
<summary>Changes</summary>

When encountering an invalid offset or size, we only skip the current invalid value and continue attempting to fold other valid offsets or sizes.

---
Full diff: https://github.com/llvm/llvm-project/pull/189533.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+55-27) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+2-6) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f31811ad7b98e..8aef3d38aeb2d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2273,41 +2273,69 @@ struct ReinterpretCastOpConstantFolder
 
   LogicalResult matchAndRewrite(ReinterpretCastOp op,
                                 PatternRewriter &rewriter) const override {
-    unsigned srcStaticCount = llvm::count_if(
-        llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
-                                   op.getMixedStrides()),
-        [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+    MemRefType srcType = op.getType();
 
-    SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+    OpFoldResult offset = op.getConstifiedMixedOffset();
     SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
     SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
 
-    // TODO: Using counting comparison instead of direct comparison because
-    // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
-    // IntegerAttrs, while constifyIndexValues (and therefore
-    // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
-    if (srcStaticCount ==
-        llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
-                       [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
-      return failure();
+    int64_t layoutOffset = ShapedType::kDynamic;
 
-    // Do not fold if the offset is a negative constant; ViewLikeInterface
-    // verifies that static offsets are non-negative.
-    if (auto cst = getConstantIntValue(offsets[0]))
+    if (auto cst = getConstantIntValue(offset)) {
+      // If the offset is a negative constant, we can't fold it because the
+      // resulting memref type would be invalid. In that case, we keep the
+      // original offset.
       if (*cst < 0)
-        return rewriter.notifyMatchFailure(
-            op, "negative constant offset is invalid");
+        offset = op.getMixedOffsets()[0];
+      else
+        layoutOffset = *cst;
+    }
 
-    // Do not fold if any size is a negative constant; MemRefType::get asserts
-    // non-negative static sizes.
-    for (OpFoldResult sizeOfr : sizes)
-      if (auto cst = getConstantIntValue(sizeOfr))
-        if (*cst < 0)
-          return rewriter.notifyMatchFailure(
-              op, "negative constant size is invalid");
+    int64_t lastStride = 1;
+    bool isContiguousMemrefType = (layoutOffset == 0);
+    SmallVector<int64_t> layoutStrides, shapes;
+
+    for (auto [stride, size, srcSize] :
+         llvm::zip(strides, sizes, op.getMixedSizes())) {
+      int64_t layoutStride = ShapedType::kDynamic;
+      if (auto cstStride = getConstantIntValue(stride)) {
+        layoutStride = *cstStride;
+        isContiguousMemrefType &= (layoutStride == lastStride);
+      }
+      layoutStrides.push_back(layoutStride);
+
+      int64_t layoutSize = ShapedType::kDynamic;
+      if (auto cstSize = getConstantIntValue(size)) {
+        // If the size is a negative constant, we can't fold it because the
+        // resulting memref type would be invalid. In that case, we keep the
+        // original size.
+        if (*cstSize < 0)
+          size = srcSize;
+        else
+          layoutSize = *cstSize;
+      }
+      shapes.push_back(layoutSize);
+
+      if (ShapedType::isStatic(lastStride) && ShapedType::isStatic(layoutSize))
+        lastStride = lastStride * layoutSize;
+      else
+        lastStride = ShapedType::kDynamic;
+    }
+
+    MemRefType dstType = MemRefType::get(
+        shapes, srcType.getElementType(),
+        isContiguousMemrefType
+            ? nullptr
+            : StridedLayoutAttr::get(srcType.getContext(), layoutOffset,
+                                     layoutStrides),
+        srcType.getMemorySpace());
+
+    if (dstType == srcType)
+      return failure();
 
-    auto newReinterpretCast = ReinterpretCastOp::create(
-        rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
+    auto newReinterpretCast =
+        ReinterpretCastOp::create(rewriter, op->getLoc(), dstType,
+                                  op.getSource(), offset, sizes, strides);
 
     rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
     return success();
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index fb1e7d00feb47..ca415ce4f0483 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1292,10 +1292,8 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : me
 // which triggers an assertion in MemRefType::get (issue #188407).
 // CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-//       CHECK: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: %[[SZ:.*]] = arith.constant -1 : index
-//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[SZ]]], strides: [%[[SZ]], %[[C1]]]
+//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, %[[SZ]]], strides: [-1, 1]
 func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -1313,10 +1311,8 @@ func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> mem
 // ViewLikeInterface constraint that offsets must be non-negative.
 // CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-//       CHECK: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[C2:.*]] = arith.constant 2 : index
 //       CHECK: %[[NEG:.*]] = arith.constant -1 : index
-//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[C1]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
+//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, 2], strides: [2, 1]
 func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/189533


More information about the Mlir-commits mailing list