[Mlir-commits] [mlir] [mlir] Allow folding dynamic full size subviews (PR #140619)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 19 13:58:10 PDT 2025
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/140619
Supports folding subviews with dynamic sizes in `TrivialSubViewOpFolder` using the `ReifyRankedShapedTypeOpInterface`.
>From dc0b38f99cb68ea07dcf55e2910098d785193387 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Mon, 19 May 2025 20:54:07 +0000
Subject: [PATCH] [mlir] Allow folding dynamic full size subviews
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 27 ++++++++++++++--------
mlir/test/Dialect/MemRef/canonicalize.mlir | 14 +++++++++++
2 files changed, 32 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 82702789c2913..3ccdcbf8c3be5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3105,7 +3105,7 @@ FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
/// is the case if the all offsets are zero, all strides are 1, and the source
/// shape is same as the size of the subview. In such cases, the subview can
/// be folded into its source.
-static bool isTrivialSubViewOp(SubViewOp subViewOp) {
+static bool isTrivialSubViewOp(OpBuilder &b, SubViewOp subViewOp) {
if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
return false;
@@ -3127,15 +3127,24 @@ static bool isTrivialSubViewOp(SubViewOp subViewOp) {
}))
return false;
- // Check all size values are static and matches the (static) source shape.
+ // Check all size values match the source shape.
ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
- for (const auto &size : llvm::enumerate(mixedSizes)) {
- std::optional<int64_t> intValue = getConstantIntValue(size.value());
- if (!intValue || *intValue != sourceShape[size.index()])
- return false;
+ if (llvm::all_of_zip(mixedSizes, sourceShape,
+ [](OpFoldResult mixedSize, int64_t staticSize) {
+ std::optional<int64_t> constSize =
+ getConstantIntValue(mixedSize);
+ return constSize.has_value() &&
+ *constSize == staticSize;
+ })) {
+ return true;
}
- // All conditions met. The `SubViewOp` is foldable as a no-op.
- return true;
+ auto sourceOpResult = dyn_cast<OpResult>(subViewOp.getSource());
+ if (!sourceOpResult)
+ return false;
+ ReifiedRankedShapedTypeDims resultDims;
+ if (failed(reifyResultShapes(b, sourceOpResult.getOwner(), resultDims)))
+ return false;
+ return llvm::equal(mixedSizes, resultDims[sourceOpResult.getResultNumber()]);
}
namespace {
@@ -3206,7 +3215,7 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
LogicalResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
- if (!isTrivialSubViewOp(subViewOp))
+ if (!isTrivialSubViewOp(rewriter, subViewOp))
return failure();
if (subViewOp.getSourceType() == subViewOp.getType()) {
rewriter.replaceOp(subViewOp, subViewOp.getSource());
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e7cee7cd85426..ebad9e3eab345 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -70,6 +70,20 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
// -----
+// CHECK-LABEL: func @subview_of_dynamic_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<?xi8>
+// CHECK-SAME: %[[SIZE:.+]]: index
+// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape
+// CHECK-NOT: memref.subview
+// CHECK: return %[[EXPAND_SHAPE]] : memref<?x?xi8>
+func.func @subview_of_dynamic_full_size(%arg0 : memref<?xi8>, %size : index) -> memref<?x?xi8> {
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [%size, %size] : memref<?xi8> into memref<?x?xi8>
+ %1 = memref.subview %0[0, 0] [%size, %size] [1, 1] : memref<?x?xi8> to memref<?x?xi8>
+ return %1 : memref<?x?xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @negative_subview_of_static_full_size
// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>>
// CHECK-SAME: %[[IDX:.+]]: index
More information about the Mlir-commits
mailing list