[Mlir-commits] [mlir] [mlir][memref] Canonicalize memref.reinterpret_cast when offset/sizes/strides are constants. (PR #163505)

Ming Yan llvmlistbot at llvm.org
Fri Oct 17 03:07:53 PDT 2025


https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/163505

>From f62bd0f649b55750c1d61b9732b99ad15f65092d Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Wed, 15 Oct 2025 14:10:24 +0800
Subject: [PATCH 1/3] [mlir][memref] Canonicalize memref.reinterpret_cast when
 offset/sizes/strides are constants.

Implement folding logic to canonicalize memref.reinterpret_cast ops when
offset, sizes and strides are compile-time constants. This removes dynamic
shape annotations and produces a static memref form, allowing further
lowering and backend optimizations.
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 24 ++++++++++++++++-
 mlir/test/Dialect/MemRef/canonicalize.mlir | 30 +++++++++++++++-------
 2 files changed, 44 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..de797c4789480 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,33 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
     return success();
   }
 };
+
+struct ReinterpretCastOpConstantFolder
+    : public OpRewritePattern<ReinterpretCastOp> {
+public:
+  using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReinterpretCastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getOffsets(), op.getSizes(),
+                                                 op.getStrides()),
+                      getConstantIntValue))
+      return failure();
+
+    auto newReinterpretCast = ReinterpretCastOp::create(
+        rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
+        op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+
+    rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+    return success();
+  }
+};
 } // namespace
 
 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                     MLIRContext *context) {
-  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+  results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+              ReinterpretCastOpConstantFolder>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c8bcb08..7160b52af6353 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @reinterpret_constant_fold
+//  CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_of_reinterpret
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
 //       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 // when the strides don't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
 // when the offset doesn't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>

>From efc5e8264776ca5685e562517befce3f25c0e8fe Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Fri, 17 Oct 2025 17:34:52 +0800
Subject: [PATCH 2/3] Ensure that success() is returned only if the IR has been
 modified.

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 18 +++++++++++++-----
 1 file changed, 13 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index de797c4789480..dbfe9988533ce 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2166,14 +2166,22 @@ struct ReinterpretCastOpConstantFolder
 
   LogicalResult matchAndRewrite(ReinterpretCastOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getOffsets(), op.getSizes(),
-                                                 op.getStrides()),
-                      getConstantIntValue))
+    unsigned srcStaticCount = llvm::count_if(
+        llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
+                                   op.getMixedStrides()),
+        [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+
+    SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+    SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
+    SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
+
+    if (srcStaticCount ==
+        llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
+                       [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
       return failure();
 
     auto newReinterpretCast = ReinterpretCastOp::create(
-        rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
-        op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+        rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
 
     rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
     return success();

>From 2f5f8207730ff05cb392e6900100a0a3ae64df47 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Fri, 17 Oct 2025 18:06:37 +0800
Subject: [PATCH 3/3] Add comments

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index dbfe9988533ce..36b995354c38c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2166,6 +2166,10 @@ struct ReinterpretCastOpConstantFolder
 
   LogicalResult matchAndRewrite(ReinterpretCastOp op,
                                 PatternRewriter &rewriter) const override {
+    // TODO: Using counting comparison instead of direct comparison because
+    // getMixedValues (and consequently ReinterpretCastOp::getMixed...) returns
+    // IntegerAttrs, while constifyIndexValues (and consequently
+    // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
     unsigned srcStaticCount = llvm::count_if(
         llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
                                    op.getMixedStrides()),



More information about the Mlir-commits mailing list