[Mlir-commits] [mlir] [mlir][memref] memref.view canonicalizations fixes (PR #173237)

Ivan Butygin llvmlistbot at llvm.org
Mon Dec 22 02:18:56 PST 2025


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/173237

* Do not fold if offset is not zero
* Remove unnecessary alloc check
* Convert to cast if offset is zero and types are compatible

>From c166908205dec2f59cf0922e70b46074869e1c52 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 22 Dec 2025 11:04:11 +0100
Subject: [PATCH] [mlir][memref] memref.view canonicalizations fixes

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 48 ++++++++++++++-----
 mlir/test/Dialect/MemRef/canonicalize.mlir | 54 ++++++++++++++++++++--
 2 files changed, 85 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 620cc97b9e3a2..c872bd7805af5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3675,7 +3675,8 @@ OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
   MemRefType sourceMemrefType = getSource().getType();
   MemRefType resultMemrefType = getResult().getType();
 
-  if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
+  if (resultMemrefType == sourceMemrefType &&
+      resultMemrefType.hasStaticShape() && isZeroInteger(getByteShift()))
     return getViewSource();
 
   return {};
@@ -3684,7 +3685,7 @@ OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
 namespace {
 
 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
-  using OpRewritePattern<ViewOp>::OpRewritePattern;
+  using Base::Base;
 
   LogicalResult matchAndRewrite(ViewOp viewOp,
                                 PatternRewriter &rewriter) const override {
@@ -3751,22 +3752,43 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
   }
 };
 
+/// view(memref.cast(%source)) -> view(%source).
 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
-  using OpRewritePattern<ViewOp>::OpRewritePattern;
+  using Base::Base;
 
   LogicalResult matchAndRewrite(ViewOp viewOp,
                                 PatternRewriter &rewriter) const override {
-    Value memrefOperand = viewOp.getOperand(0);
-    CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
+    auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
     if (!memrefCastOp)
       return failure();
-    Value allocOperand = memrefCastOp.getOperand();
-    AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
-    if (!allocOp)
+
+    rewriter.replaceOpWithNewOp<ViewOp>(
+        viewOp, viewOp.getType(), memrefCastOp.getSource(),
+        viewOp.getByteShift(), viewOp.getSizes());
+    return success();
+  }
+};
+
+/// view %source[0] -> cast(%source) if static shapes.
+struct ViewOpZeroOffsetFolder : public OpRewritePattern<ViewOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(ViewOp viewOp,
+                                PatternRewriter &rewriter) const override {
+    if (!isZeroInteger(viewOp.getByteShift()))
+      return failure();
+
+    Value source = viewOp.getSource();
+    auto sourceMemrefType = cast<MemRefType>(source.getType());
+    auto resultMemrefType = cast<MemRefType>(viewOp.getType());
+    if (sourceMemrefType == resultMemrefType)
+      return failure(); // Handled by folder
+
+    if (!resultMemrefType.hasStaticShape() ||
+        !CastOp::areCastCompatible(sourceMemrefType, resultMemrefType))
       return failure();
-    rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
-                                        viewOp.getByteShift(),
-                                        viewOp.getSizes());
+
+    rewriter.replaceOpWithNewOp<CastOp>(viewOp, resultMemrefType, source);
     return success();
   }
 };
@@ -3775,7 +3797,9 @@ struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
 
 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
+  results
+      .add<ViewOpShapeFolder, ViewOpMemrefCastFolder, ViewOpZeroOffsetFolder>(
+          context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 60311306b984d..cd7a18a27e634 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1336,21 +1336,65 @@ func.func @fold_assume_alignment_chain(%0: memref<128xf32>) -> memref<128xf32> {
 
 // -----
 
+// CHECK-LABEL: func @fold_view_cast
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<128xi8>)
+func.func @fold_view_cast(%0: memref<128xi8>) -> memref<i32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[RES:.*]] = memref.view %[[ARG]][%[[C0]]][] : memref<128xi8> to memref<i32>
+  // CHECK: return %[[RES]]
+  %1 = memref.cast %0 : memref<128xi8> to memref<?xi8>
+  %res = memref.view %1[%c0][] : memref<?xi8> to memref<i32>
+  return %res : memref<i32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_view_same_source_result_types
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<128xi8>)
 func.func @fold_view_same_source_result_types(%0: memref<128xi8>) -> memref<128xi8> {
-  %c0 = arith.constant 0: index
+  %c0 = arith.constant 0 : index
   // CHECK-NOT: memref.view
+  // CHECK: return %[[ARG]]
   %res = memref.view %0[%c0][] : memref<128xi8> to memref<128xi8>
   return %res : memref<128xi8>
 }
 
 // -----
 
-// CHECK-LABEL: func @non_fold_view_same_source_res_types
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index) -> memref<?xi8> {
+// CHECK-LABEL: func @fold_view_compatible_source_result_types
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<?xi8>)
+func.func @fold_view_compatible_source_result_types(%0: memref<?xi8>) -> memref<128xi8> {
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT: memref.view
+  // CHECK: %[[RES:.*]] = memref.cast %[[ARG]] : memref<?xi8> to memref<128xi8>
+  // CHECK: return %[[RES]]
+  %res = memref.view %0[%c0][] : memref<?xi8> to memref<128xi8>
+  return %res : memref<128xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_fold_view_non_zero_offset
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<128xi8>)
+func.func @non_fold_view_non_zero_offset(%0: memref<128xi8>) -> memref<128xi8> {
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[RES:.*]] = memref.view %[[ARG]][%[[C1]]][] : memref<128xi8> to memref<128xi8>
+  // CHECK: return %[[RES]]
+  %res = memref.view %0[%c1][] : memref<128xi8> to memref<128xi8>
+  return %res : memref<128xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_fold_view_same_source_dynamic_size
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<?xi8>, %[[SIZE:.*]]: index)
+func.func @non_fold_view_same_source_dynamic_size(%0: memref<?xi8>, %arg0 : index) -> memref<?xi8> {
   %c0 = arith.constant 0: index
-  // CHECK: memref.view
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[RES:.*]] = memref.view %[[ARG]][%[[C0]]][%[[SIZE]]] : memref<?xi8> to memref<?xi8>
+  // CHECK: return %[[RES]]
   %res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
   return %res : memref<?xi8>
 }



More information about the Mlir-commits mailing list