[Mlir-commits] [mlir] [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (PR #189533)

Ming Yan llvmlistbot at llvm.org
Mon Mar 30 22:34:06 PDT 2026


https://github.com/NexMing created https://github.com/llvm/llvm-project/pull/189533

When encountering an invalid offset or size, we only skip the current invalid value and continue attempting to fold other valid offsets or sizes.

>From 4e76f6cdefa2b7c363e12cfe81bb8670e450c404 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 31 Mar 2026 13:17:41 +0800
Subject: [PATCH] [mlir][memref] Fold memref.reinterpret_cast operations with
 valid offset or size constants. When encountering an invalid offset or size,
 we only skip the current invalid value and continue attempting to fold other
 valid offsets or sizes.

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 82 +++++++++++++++-------
 mlir/test/Dialect/MemRef/canonicalize.mlir |  8 +--
 2 files changed, 57 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f31811ad7b98e..8aef3d38aeb2d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2273,41 +2273,69 @@ struct ReinterpretCastOpConstantFolder
 
   LogicalResult matchAndRewrite(ReinterpretCastOp op,
                                 PatternRewriter &rewriter) const override {
-    unsigned srcStaticCount = llvm::count_if(
-        llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
-                                   op.getMixedStrides()),
-        [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+    MemRefType srcType = op.getType();
 
-    SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+    OpFoldResult offset = op.getConstifiedMixedOffset();
     SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
     SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
 
-    // TODO: Using counting comparison instead of direct comparison because
-    // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
-    // IntegerAttrs, while constifyIndexValues (and therefore
-    // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
-    if (srcStaticCount ==
-        llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
-                       [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
-      return failure();
+    int64_t layoutOffset = ShapedType::kDynamic;
 
-    // Do not fold if the offset is a negative constant; ViewLikeInterface
-    // verifies that static offsets are non-negative.
-    if (auto cst = getConstantIntValue(offsets[0]))
+    if (auto cst = getConstantIntValue(offset)) {
+      // If the offset is a negative constant, we can't fold it because the
+      // resulting memref type would be invalid. In that case, we keep the
+      // original offset.
       if (*cst < 0)
-        return rewriter.notifyMatchFailure(
-            op, "negative constant offset is invalid");
+        offset = op.getMixedOffsets()[0];
+      else
+        layoutOffset = *cst;
+    }
 
-    // Do not fold if any size is a negative constant; MemRefType::get asserts
-    // non-negative static sizes.
-    for (OpFoldResult sizeOfr : sizes)
-      if (auto cst = getConstantIntValue(sizeOfr))
-        if (*cst < 0)
-          return rewriter.notifyMatchFailure(
-              op, "negative constant size is invalid");
+    int64_t lastStride = 1;
+    bool isContiguousMemrefType = (layoutOffset == 0);
+    SmallVector<int64_t> layoutStrides, shapes;
+
+    for (auto [stride, size, srcSize] :
+         llvm::zip(strides, sizes, op.getMixedSizes())) {
+      int64_t layoutStride = ShapedType::kDynamic;
+      if (auto cstStride = getConstantIntValue(stride)) {
+        layoutStride = *cstStride;
+        isContiguousMemrefType &= (layoutStride == lastStride);
+      }
+      layoutStrides.push_back(layoutStride);
+
+      int64_t layoutSize = ShapedType::kDynamic;
+      if (auto cstSize = getConstantIntValue(size)) {
+        // If the size is a negative constant, we can't fold it because the
+        // resulting memref type would be invalid. In that case, we keep the
+        // original size.
+        if (*cstSize < 0)
+          size = srcSize;
+        else
+          layoutSize = *cstSize;
+      }
+      shapes.push_back(layoutSize);
+
+      if (ShapedType::isStatic(lastStride) && ShapedType::isStatic(layoutSize))
+        lastStride = lastStride * layoutSize;
+      else
+        lastStride = ShapedType::kDynamic;
+    }
+
+    MemRefType dstType = MemRefType::get(
+        shapes, srcType.getElementType(),
+        isContiguousMemrefType
+            ? nullptr
+            : StridedLayoutAttr::get(srcType.getContext(), layoutOffset,
+                                     layoutStrides),
+        srcType.getMemorySpace());
+
+    if (dstType == srcType)
+      return failure();
 
-    auto newReinterpretCast = ReinterpretCastOp::create(
-        rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
+    auto newReinterpretCast =
+        ReinterpretCastOp::create(rewriter, op->getLoc(), dstType,
+                                  op.getSource(), offset, sizes, strides);
 
     rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
     return success();
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index fb1e7d00feb47..ca415ce4f0483 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1292,10 +1292,8 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : me
 // which triggers an assertion in MemRefType::get (issue #188407).
 // CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-//       CHECK: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: %[[SZ:.*]] = arith.constant -1 : index
-//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[SZ]]], strides: [%[[SZ]], %[[C1]]]
+//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, %[[SZ]]], strides: [-1, 1]
 func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -1313,10 +1311,8 @@ func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> mem
 // ViewLikeInterface constraint that offsets must be non-negative.
 // CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-//       CHECK: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[C2:.*]] = arith.constant 2 : index
 //       CHECK: %[[NEG:.*]] = arith.constant -1 : index
-//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[C1]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
+//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, 2], strides: [2, 1]
 func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index



More information about the Mlir-commits mailing list