[Mlir-commits] [mlir] [mlir][memref] Canonicalize memref.reinterpret_cast when offset/sizes/strides are constants. (PR #163505)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 14 23:26:49 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ming Yan (NexMing)

<details>
<summary>Changes</summary>

Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.

---
Full diff: https://github.com/llvm/llvm-project/pull/163505.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+26-1) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+21-9) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda296da5..f914b292eba83 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
     return success();
   }
 };
+
+struct ReinterpretCastOpConstantFolder
+    : public OpRewritePattern<ReinterpretCastOp> {
+public:
+  using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReinterpretCastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(),
+                                                 op.getMixedSizes(),
+                                                 op.getMixedStrides()),
+                      [](OpFoldResult ofr) {
+                        return isa<Value>(ofr) && getConstantIntValue(ofr);
+                      }))
+      return failure();
+
+    auto newReinterpretCast = ReinterpretCastOp::create(
+        rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(),
+        op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides());
+
+    rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+    return success();
+  }
+};
 } // namespace
 
 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                     MLIRContext *context) {
-  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+  results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+              ReinterpretCastOpConstantFolder>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c8bcb08..7160b52af6353 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @reinterpret_constant_fold
+//  CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_of_reinterpret
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
 //       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 // when the strides don't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
 // when the offset doesn't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-//       CHECK: return %[[RES]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+//       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
   %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>

``````````

</details>


https://github.com/llvm/llvm-project/pull/163505


More information about the Mlir-commits mailing list