[Mlir-commits] [mlir] 158f10f - [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (#189533)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 1 06:30:17 PDT 2026


Author: Ming Yan
Date: 2026-04-01T15:30:12+02:00
New Revision: 158f10fe24a39208e45d6039dfc6d605967ade2a

URL: https://github.com/llvm/llvm-project/commit/158f10fe24a39208e45d6039dfc6d605967ade2a
DIFF: https://github.com/llvm/llvm-project/commit/158f10fe24a39208e45d6039dfc6d605967ade2a.diff

LOG: [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (#189533)

When encountering an invalid offset or size, we only skip the current
invalid value and continue attempting to fold other valid offsets or
sizes.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 31546f123b512..27c1649ee4ed3 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2331,6 +2331,24 @@ struct ReinterpretCastOpConstantFolder
     SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
     SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
 
+    // If the offset is a negative constant, we can't fold it because the
+    // resulting memref type would be invalid. In that case, we keep the
+    // original offset.
+    if (auto cst = getConstantIntValue(offsets[0]))
+      if (*cst < 0)
+        offsets[0] = op.getMixedOffsets()[0];
+
+    // If the size is a negative constant, we can't fold it because the
+    // resulting memref type would be invalid. In that case, we keep the
+    // original size.
+    for (auto it : llvm::zip(op.getMixedSizes(), sizes)) {
+      auto &srcSizeOfr = std::get<0>(it);
+      auto &sizeOfr = std::get<1>(it);
+      if (auto cst = getConstantIntValue(sizeOfr))
+        if (*cst < 0)
+          sizeOfr = srcSizeOfr;
+    }
+
     // TODO: Using counting comparison instead of direct comparison because
     // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
     // IntegerAttrs, while constifyIndexValues (and therefore
@@ -2340,21 +2358,6 @@ struct ReinterpretCastOpConstantFolder
                        [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
       return failure();
 
-    // Do not fold if the offset is a negative constant; ViewLikeInterface
-    // verifies that static offsets are non-negative.
-    if (auto cst = getConstantIntValue(offsets[0]))
-      if (*cst < 0)
-        return rewriter.notifyMatchFailure(
-            op, "negative constant offset is invalid");
-
-    // Do not fold if any size is a negative constant; MemRefType::get asserts
-    // non-negative static sizes.
-    for (OpFoldResult sizeOfr : sizes)
-      if (auto cst = getConstantIntValue(sizeOfr))
-        if (*cst < 0)
-          return rewriter.notifyMatchFailure(
-              op, "negative constant size is invalid");
-
     auto newReinterpretCast = ReinterpretCastOp::create(
         rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
 

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index fb1e7d00feb47..6c4fd6f8f58d6 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1287,16 +1287,14 @@ func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_offset(%arg0 : me
 
 // -----
 
-// Check that reinterpret_cast with a negative constant size is not folded.
+// Check that reinterpret_cast with a negative constant size.
 // Folding would attempt to create a MemRefType with a negative static dimension,
 // which triggers an assertion in MemRefType::get (issue #188407).
-// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size
+// CHECK-LABEL: func @reinterpret_cast_with_negative_size
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-//       CHECK: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: %[[SZ:.*]] = arith.constant -1 : index
-//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[SZ]]], strides: [%[[SZ]], %[[C1]]]
-func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, %[[SZ]]], strides: [-1, 1]
+func.func @reinterpret_cast_with_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %sz = arith.constant -1 : index
@@ -1308,16 +1306,14 @@ func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> mem
 
 // -----
 
-// Check that reinterpret_cast with a negative constant offset is not folded.
+// Check that reinterpret_cast with a negative constant offset.
 // Folding would create an op with a static negative offset, which violates the
 // ViewLikeInterface constraint that offsets must be non-negative.
-// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset
+// CHECK-LABEL: func @reinterpret_cast_with_negative_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-//       CHECK: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[C2:.*]] = arith.constant 2 : index
 //       CHECK: %[[NEG:.*]] = arith.constant -1 : index
-//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[C1]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, 2], strides: [2, 1]
+func.func @reinterpret_cast_with_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
   %neg = arith.constant -1 : index
@@ -1329,6 +1325,39 @@ func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> m
 
 // -----
 
+// Check that reinterpret_cast with a negative constant size and offset.
+// CHECK-LABEL: func @reinterpret_cast_with_negative_size_and_offset
+//  CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
+//       CHECK: %[[NEG:.*]] = arith.constant -1 : index
+//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, %[[NEG]]], strides: [2, 1]
+func.func @reinterpret_cast_with_negative_size_and_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %neg = arith.constant -1 : index
+  %output = memref.reinterpret_cast %arg0 to
+            offset: [%neg], sizes: [%c1, %neg], strides: [%c2, %c1]
+            : memref<2x3xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %output : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
+// Check that reinterpret_cast with all negative constant size and offset is not
+// folded.
+// CHECK-LABEL: func @reinterpret_cast_no_fold_with_all_negative_size_and_offset
+//  CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
+//       CHECK: %[[NEG:.*]] = arith.constant -1 : index
+//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[NEG]], %[[NEG]]], strides: [2, 1]
+func.func @reinterpret_cast_no_fold_with_all_negative_size_and_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+  %neg = arith.constant -1 : index
+  %output = memref.reinterpret_cast %arg0 to
+            offset: [%neg], sizes: [%neg, %neg], strides: [2, 1]
+            : memref<2x3xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %output : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
 // Check that reinterpret_cast with a negative constant stride IS folded.
 // Negative strides are valid in MemRef layouts (e.g. reverse iteration),
 // and the ViewLikeInterface places no non-negativity constraint on strides.


        


More information about the Mlir-commits mailing list