[Mlir-commits] [mlir] [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (PR #189533)
Ming Yan
llvmlistbot at llvm.org
Tue Mar 31 00:01:05 PDT 2026
https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/189533
>From 1419cc52876c48f7d0b50681e31a31252001ac13 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 31 Mar 2026 14:11:50 +0800
Subject: [PATCH 1/2] Add a testcase.
---
mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index fb1e7d00feb47..a13d51b8a54cb 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1350,6 +1350,25 @@ func.func @reinterpret_cast_fold_negative_stride(%arg0: memref<2x3xf32>) -> memr
// -----
+// CHECK-LABEL: func.func @reinterpret_cast_contiguous(
+// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [3, 7], strides: [7, 1] : memref<*xf32> to memref<3x7xf32, strided<[7, 1]>>
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_0]] : memref<3x7xf32, strided<[7, 1]>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK: return %[[CAST_0]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK: }
+func.func @reinterpret_cast_contiguous(%arg0: memref<*xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c7 = arith.constant 7 : index
+ %output = memref.reinterpret_cast %arg0 to
+ offset: [%c0], sizes: [%c3, %c7], strides: [%c7, %c1]
+ : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %output : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>,
%arg1 : index) -> memref<?xf32, strided<[?], offset: ?>> {
%c0 = arith.constant 0 : index
>From c360c8fe7eee1654b34eaef32054eee9249f1d0c Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 31 Mar 2026 14:47:33 +0800
Subject: [PATCH 2/2] [mlir][memref] Fold memref.reinterpret_cast operations
with valid offset or size constants.
When encountering an invalid offset or size, we only skip the current invalid value and continue attempting to fold other valid offsets or sizes.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 83 +++++++++++++++-------
mlir/test/Dialect/MemRef/canonicalize.mlir | 12 ++--
2 files changed, 60 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f31811ad7b98e..5d6dd65624fc7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2273,41 +2273,70 @@ struct ReinterpretCastOpConstantFolder
LogicalResult matchAndRewrite(ReinterpretCastOp op,
PatternRewriter &rewriter) const override {
- unsigned srcStaticCount = llvm::count_if(
- llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
- op.getMixedStrides()),
- [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+ MemRefType resultType = op.getType();
+ SmallVector<OpFoldResult> srcSizes = op.getMixedSizes();
- SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+ OpFoldResult offset = op.getConstifiedMixedOffset();
SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
- // TODO: Using counting comparison instead of direct comparison because
- // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
- // IntegerAttrs, while constifyIndexValues (and therefore
- // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
- if (srcStaticCount ==
- llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
- [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
- return failure();
+ int64_t layoutOffset = ShapedType::kDynamic;
- // Do not fold if the offset is a negative constant; ViewLikeInterface
- // verifies that static offsets are non-negative.
- if (auto cst = getConstantIntValue(offsets[0]))
+ if (auto cst = getConstantIntValue(offset)) {
+ // If the offset is a negative constant, we can't fold it because the
+ // resulting memref type would be invalid. In that case, we keep the
+ // original offset.
if (*cst < 0)
- return rewriter.notifyMatchFailure(
- op, "negative constant offset is invalid");
+ offset = op.getMixedOffsets()[0];
+ else
+ layoutOffset = *cst;
+ }
- // Do not fold if any size is a negative constant; MemRefType::get asserts
- // non-negative static sizes.
- for (OpFoldResult sizeOfr : sizes)
- if (auto cst = getConstantIntValue(sizeOfr))
- if (*cst < 0)
- return rewriter.notifyMatchFailure(
- op, "negative constant size is invalid");
+ int64_t rank = resultType.getRank();
+ int64_t lastStride = 1;
+ bool isContiguousMemrefType = (layoutOffset == 0);
+ SmallVector<int64_t> layoutStrides(rank), shapes(rank);
+
+ for (int64_t dim = rank - 1; dim >= 0; --dim) {
+ int64_t layoutStride = ShapedType::kDynamic;
+ if (auto cstStride = getConstantIntValue(strides[dim])) {
+ layoutStride = *cstStride;
+ isContiguousMemrefType &= (layoutStride == lastStride);
+ }
+ layoutStrides[dim] = layoutStride;
+
+ int64_t layoutSize = ShapedType::kDynamic;
+ if (auto cstSize = getConstantIntValue(sizes[dim])) {
+ // If the size is a negative constant, we can't fold it because the
+ // resulting memref type would be invalid. In that case, we keep the
+ // original size.
+ if (*cstSize < 0)
+ sizes[dim] = srcSizes[dim];
+ else
+ layoutSize = *cstSize;
+ }
+ shapes[dim] = layoutSize;
+
+ if (ShapedType::isStatic(lastStride) && ShapedType::isStatic(layoutSize))
+ lastStride = lastStride * layoutSize;
+ else
+ lastStride = ShapedType::kDynamic;
+ }
+
+ MemRefType newResultType = MemRefType::get(
+ shapes, resultType.getElementType(),
+ isContiguousMemrefType
+ ? nullptr
+ : StridedLayoutAttr::get(resultType.getContext(), layoutOffset,
+ layoutStrides),
+ resultType.getMemorySpace());
+
+ if (newResultType == resultType)
+ return failure();
- auto newReinterpretCast = ReinterpretCastOp::create(
- rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
+ auto newReinterpretCast =
+ ReinterpretCastOp::create(rewriter, op->getLoc(), newResultType,
+ op.getSource(), offset, sizes, strides);
rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
return success();
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a13d51b8a54cb..167971cac42b7 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1292,10 +1292,8 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : me
// which triggers an assertion in MemRefType::get (issue #188407).
// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[SZ:.*]] = arith.constant -1 : index
-// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[SZ]]], strides: [%[[SZ]], %[[C1]]]
+// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, %[[SZ]]], strides: [-1, 1]
func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -1313,10 +1311,8 @@ func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> mem
// ViewLikeInterface constraint that offsets must be non-negative.
// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[NEG:.*]] = arith.constant -1 : index
-// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[C1]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
+// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, 2], strides: [2, 1]
func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
@@ -1352,8 +1348,8 @@ func.func @reinterpret_cast_fold_negative_stride(%arg0: memref<2x3xf32>) -> memr
// CHECK-LABEL: func.func @reinterpret_cast_contiguous(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
-// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [3, 7], strides: [7, 1] : memref<*xf32> to memref<3x7xf32, strided<[7, 1]>>
-// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_0]] : memref<3x7xf32, strided<[7, 1]>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [3, 7], strides: [7, 1] : memref<*xf32> to memref<3x7xf32>
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_0]] : memref<3x7xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK: return %[[CAST_0]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK: }
func.func @reinterpret_cast_contiguous(%arg0: memref<*xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
More information about the Mlir-commits
mailing list