[Mlir-commits] [mlir] [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (PR #189533)
Ming Yan
llvmlistbot at llvm.org
Tue Mar 31 03:26:47 PDT 2026
https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/189533
>From 9247e7e58204c11df8b2280e61a080cb338f48c6 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 31 Mar 2026 14:11:50 +0800
Subject: [PATCH 1/4] [mlir][memref] Fold memref.reinterpret_cast operations
with valid offset or size constants.
When encountering an invalid offset or size, we only skip the current invalid value and continue attempting to fold other valid offsets or sizes.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 27 +++++++++++-----------
mlir/test/Dialect/MemRef/canonicalize.mlir | 8 ++-----
2 files changed, 15 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f31811ad7b98e..f3681f38246c3 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2282,29 +2282,28 @@ struct ReinterpretCastOpConstantFolder
SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
- // TODO: Using counting comparison instead of direct comparison because
- // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
- // IntegerAttrs, while constifyIndexValues (and therefore
- // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
- if (srcStaticCount ==
- llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
- [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
- return failure();
-
// Do not fold if the offset is a negative constant; ViewLikeInterface
// verifies that static offsets are non-negative.
if (auto cst = getConstantIntValue(offsets[0]))
if (*cst < 0)
- return rewriter.notifyMatchFailure(
- op, "negative constant offset is invalid");
+ offsets[0] = op.getMixedOffsets()[0];
// Do not fold if any size is a negative constant; MemRefType::get asserts
// non-negative static sizes.
- for (OpFoldResult sizeOfr : sizes)
+ for (auto [srcSizeOfr, sizeOfr] : llvm::zip(op.getMixedSizes(), sizes)) {
if (auto cst = getConstantIntValue(sizeOfr))
if (*cst < 0)
- return rewriter.notifyMatchFailure(
- op, "negative constant size is invalid");
+ sizeOfr = srcSizeOfr;
+ }
+
+ // TODO: Using counting comparison instead of direct comparison because
+ // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
+ // IntegerAttrs, while constifyIndexValues (and therefore
+ // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
+ if (srcStaticCount ==
+ llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
+ return failure();
auto newReinterpretCast = ReinterpretCastOp::create(
rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index fb1e7d00feb47..ca415ce4f0483 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1292,10 +1292,8 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : me
// which triggers an assertion in MemRefType::get (issue #188407).
// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[SZ:.*]] = arith.constant -1 : index
-// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[SZ]]], strides: [%[[SZ]], %[[C1]]]
+// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, %[[SZ]]], strides: [-1, 1]
func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -1313,10 +1311,8 @@ func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> mem
// ViewLikeInterface constraint that offsets must be non-negative.
// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[NEG:.*]] = arith.constant -1 : index
-// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[C1]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
+// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, 2], strides: [2, 1]
func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
>From 8f219c495d06993b3d4c0440864795a4fe21e3c0 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 31 Mar 2026 17:50:38 +0800
Subject: [PATCH 2/4] Update comment.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f3681f38246c3..5493ad6c013f5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2282,14 +2282,16 @@ struct ReinterpretCastOpConstantFolder
SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
- // Do not fold if the offset is a negative constant; ViewLikeInterface
- // verifies that static offsets are non-negative.
+ // If the offset is a negative constant, we can't fold it because the
+ // resulting memref type would be invalid. In that case, we keep the
+ // original offset.
if (auto cst = getConstantIntValue(offsets[0]))
if (*cst < 0)
offsets[0] = op.getMixedOffsets()[0];
- // Do not fold if any size is a negative constant; MemRefType::get asserts
- // non-negative static sizes.
+ // If the size is a negative constant, we can't fold it because the
+ // resulting memref type would be invalid. In that case, we keep the
+ // original size.
for (auto [srcSizeOfr, sizeOfr] : llvm::zip(op.getMixedSizes(), sizes)) {
if (auto cst = getConstantIntValue(sizeOfr))
if (*cst < 0)
>From e3c91bb99e15b1787fbdd939e4709b0c965102bc Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 31 Mar 2026 18:25:31 +0800
Subject: [PATCH 3/4] Update code.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5493ad6c013f5..5897423cdae87 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2292,7 +2292,9 @@ struct ReinterpretCastOpConstantFolder
// If the size is a negative constant, we can't fold it because the
// resulting memref type would be invalid. In that case, we keep the
// original size.
- for (auto [srcSizeOfr, sizeOfr] : llvm::zip(op.getMixedSizes(), sizes)) {
+ for (auto it : llvm::zip(op.getMixedSizes(), sizes)) {
+ auto &srcSizeOfr = std::get<0>(it);
+ auto &sizeOfr = std::get<1>(it);
if (auto cst = getConstantIntValue(sizeOfr))
if (*cst < 0)
sizeOfr = srcSizeOfr;
>From 61d3028f499c806f2216129f597a25679f1644e0 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 31 Mar 2026 18:25:50 +0800
Subject: [PATCH 4/4] Update tests.
---
mlir/test/Dialect/MemRef/canonicalize.mlir | 45 +++++++++++++++++++---
1 file changed, 39 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index ca415ce4f0483..6c4fd6f8f58d6 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1287,14 +1287,14 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : me
// -----
-// Check that reinterpret_cast with a negative constant size is not folded.
+// Check that reinterpret_cast with a negative constant size.
// Folding would attempt to create a MemRefType with a negative static dimension,
// which triggers an assertion in MemRefType::get (issue #188407).
-// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size
+// CHECK-LABEL: func @reinterpret_cast_with_negative_size
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
// CHECK: %[[SZ:.*]] = arith.constant -1 : index
// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, %[[SZ]]], strides: [-1, 1]
-func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+func.func @reinterpret_cast_with_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%sz = arith.constant -1 : index
@@ -1306,14 +1306,14 @@ func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> mem
// -----
-// Check that reinterpret_cast with a negative constant offset is not folded.
+// Check that reinterpret_cast with a negative constant offset.
// Folding would create an op with a static negative offset, which violates the
// ViewLikeInterface constraint that offsets must be non-negative.
-// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset
+// CHECK-LABEL: func @reinterpret_cast_with_negative_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
// CHECK: %[[NEG:.*]] = arith.constant -1 : index
// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, 2], strides: [2, 1]
-func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+func.func @reinterpret_cast_with_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%neg = arith.constant -1 : index
@@ -1325,6 +1325,39 @@ func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> m
// -----
+// Check that reinterpret_cast with a negative constant size and offset.
+// CHECK-LABEL: func @reinterpret_cast_with_negative_size_and_offset
+// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
+// CHECK: %[[NEG:.*]] = arith.constant -1 : index
+// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, %[[NEG]]], strides: [2, 1]
+func.func @reinterpret_cast_with_negative_size_and_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %neg = arith.constant -1 : index
+ %output = memref.reinterpret_cast %arg0 to
+ offset: [%neg], sizes: [%c1, %neg], strides: [%c2, %c1]
+ : memref<2x3xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %output : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
+// Check that reinterpret_cast with all negative constant size and offset is not
+// folded.
+// CHECK-LABEL: func @reinterpret_cast_no_fold_with_all_negative_size_and_offset
+// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
+// CHECK: %[[NEG:.*]] = arith.constant -1 : index
+// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[NEG]], %[[NEG]]], strides: [2, 1]
+func.func @reinterpret_cast_no_fold_with_all_negative_size_and_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ %neg = arith.constant -1 : index
+ %output = memref.reinterpret_cast %arg0 to
+ offset: [%neg], sizes: [%neg, %neg], strides: [2, 1]
+ : memref<2x3xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %output : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
// Check that reinterpret_cast with a negative constant stride IS folded.
// Negative strides are valid in MemRef layouts (e.g. reverse iteration),
// and the ViewLikeInterface places no non-negativity constraint on strides.
More information about the Mlir-commits
mailing list