[Mlir-commits] [mlir] [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (PR #189533)

Ming Yan llvmlistbot at llvm.org
Tue Mar 31 01:05:02 PDT 2026


https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/189533

>From 9247e7e58204c11df8b2280e61a080cb338f48c6 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 31 Mar 2026 14:11:50 +0800
Subject: [PATCH] [mlir][memref] Fold memref.reinterpret_cast operations with
 valid offset or size constants.

When encountering an invalid offset or size, we only skip the current invalid value and continue attempting to fold other valid offsets or sizes.
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 27 +++++++++++-----------
 mlir/test/Dialect/MemRef/canonicalize.mlir |  8 ++-----
 2 files changed, 15 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f31811ad7b98e..f3681f38246c3 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2282,29 +2282,28 @@ struct ReinterpretCastOpConstantFolder
     SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
     SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
 
-    // TODO: Using counting comparison instead of direct comparison because
-    // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
-    // IntegerAttrs, while constifyIndexValues (and therefore
-    // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
-    if (srcStaticCount ==
-        llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
-                       [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
-      return failure();
-
     // Do not fold if the offset is a negative constant; ViewLikeInterface
     // verifies that static offsets are non-negative.
     if (auto cst = getConstantIntValue(offsets[0]))
       if (*cst < 0)
-        return rewriter.notifyMatchFailure(
-            op, "negative constant offset is invalid");
+        offsets[0] = op.getMixedOffsets()[0];
 
     // Do not fold if any size is a negative constant; MemRefType::get asserts
     // non-negative static sizes.
-    for (OpFoldResult sizeOfr : sizes)
+    for (auto [srcSizeOfr, sizeOfr] : llvm::zip(op.getMixedSizes(), sizes)) {
       if (auto cst = getConstantIntValue(sizeOfr))
         if (*cst < 0)
-          return rewriter.notifyMatchFailure(
-              op, "negative constant size is invalid");
+          sizeOfr = srcSizeOfr;
+    }
+
+    // TODO: Using counting comparison instead of direct comparison because
+    // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
+    // IntegerAttrs, while constifyIndexValues (and therefore
+    // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
+    if (srcStaticCount ==
+        llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
+                       [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
+      return failure();
 
     auto newReinterpretCast = ReinterpretCastOp::create(
         rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index fb1e7d00feb47..ca415ce4f0483 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1292,10 +1292,8 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : me
 // which triggers an assertion in MemRefType::get (issue #188407).
 // CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-//       CHECK: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: %[[SZ:.*]] = arith.constant -1 : index
-//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[SZ]]], strides: [%[[SZ]], %[[C1]]]
+//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, %[[SZ]]], strides: [-1, 1]
 func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -1313,10 +1311,8 @@ func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> mem
 // ViewLikeInterface constraint that offsets must be non-negative.
 // CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
-//       CHECK: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[C2:.*]] = arith.constant 2 : index
 //       CHECK: %[[NEG:.*]] = arith.constant -1 : index
-//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[C1]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
+//       CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, 2], strides: [2, 1]
 func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index



More information about the Mlir-commits mailing list