[Mlir-commits] [mlir] [mlir] Allow folding dynamic full size subviews (PR #140619)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 19 13:58:10 PDT 2025


https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/140619

Supports folding subviews with dynamic sizes in `TrivialSubViewOpFolder` using the `ReifyRankedShapedTypeOpInterface`.

>From dc0b38f99cb68ea07dcf55e2910098d785193387 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Mon, 19 May 2025 20:54:07 +0000
Subject: [PATCH] [mlir] Allow folding dynamic full size subviews

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 27 ++++++++++++++--------
 mlir/test/Dialect/MemRef/canonicalize.mlir | 14 +++++++++++
 2 files changed, 32 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 82702789c2913..3ccdcbf8c3be5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3105,7 +3105,7 @@ FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
 /// is the case if the all offsets are zero, all strides are 1, and the source
 /// shape is same as the size of the subview. In such cases, the subview can
 /// be folded into its source.
-static bool isTrivialSubViewOp(SubViewOp subViewOp) {
+static bool isTrivialSubViewOp(OpBuilder &b, SubViewOp subViewOp) {
   if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
     return false;
 
@@ -3127,15 +3127,24 @@ static bool isTrivialSubViewOp(SubViewOp subViewOp) {
       }))
     return false;
 
-  // Check all size values are static and matches the (static) source shape.
+  // Check all size values match the source shape.
   ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
-  for (const auto &size : llvm::enumerate(mixedSizes)) {
-    std::optional<int64_t> intValue = getConstantIntValue(size.value());
-    if (!intValue || *intValue != sourceShape[size.index()])
-      return false;
+  if (llvm::all_of_zip(mixedSizes, sourceShape,
+                       [](OpFoldResult mixedSize, int64_t staticSize) {
+                         std::optional<int64_t> constSize =
+                             getConstantIntValue(mixedSize);
+                         return constSize.has_value() &&
+                                *constSize == staticSize;
+                       })) {
+    return true;
   }
-  // All conditions met. The `SubViewOp` is foldable as a no-op.
-  return true;
+  auto sourceOpResult = dyn_cast<OpResult>(subViewOp.getSource());
+  if (!sourceOpResult)
+    return false;
+  ReifiedRankedShapedTypeDims resultDims;
+  if (failed(reifyResultShapes(b, sourceOpResult.getOwner(), resultDims)))
+    return false;
+  return llvm::equal(mixedSizes, resultDims[sourceOpResult.getResultNumber()]);
 }
 
 namespace {
@@ -3206,7 +3215,7 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
 
   LogicalResult matchAndRewrite(SubViewOp subViewOp,
                                 PatternRewriter &rewriter) const override {
-    if (!isTrivialSubViewOp(subViewOp))
+    if (!isTrivialSubViewOp(rewriter, subViewOp))
       return failure();
     if (subViewOp.getSourceType() == subViewOp.getType()) {
       rewriter.replaceOp(subViewOp, subViewOp.getSource());
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e7cee7cd85426..ebad9e3eab345 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -70,6 +70,20 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
 
 // -----
 
+// CHECK-LABEL: func @subview_of_dynamic_full_size
+//  CHECK-SAME:   %[[ARG0:.+]]: memref<?xi8>
+//  CHECK-SAME:   %[[SIZE:.+]]: index
+//       CHECK:   %[[EXPAND_SHAPE:.+]] = memref.expand_shape
+//   CHECK-NOT:   memref.subview
+//       CHECK:   return %[[EXPAND_SHAPE]] : memref<?x?xi8>
+func.func @subview_of_dynamic_full_size(%arg0 : memref<?xi8>, %size : index) -> memref<?x?xi8> {
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [%size, %size] : memref<?xi8> into memref<?x?xi8>
+  %1 = memref.subview %0[0, 0] [%size, %size] [1, 1] : memref<?x?xi8> to memref<?x?xi8>
+  return %1 : memref<?x?xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @negative_subview_of_static_full_size
 //  CHECK-SAME:   %[[ARG0:.+]]: memref<16x4xf32,  strided<[4, 1], offset: ?>>
 //  CHECK-SAME:   %[[IDX:.+]]: index



More information about the Mlir-commits mailing list