[Mlir-commits] [mlir] ccb8a4e - [mlir][memref] Fold subview(subview(x))
Matthias Springer
llvmlistbot at llvm.org
Thu Dec 15 08:55:14 PST 2022
Author: Matthias Springer
Date: 2022-12-15T17:50:12+01:00
New Revision: ccb8a4e3f36937a71a526ca4a3fa29895253ee9d
URL: https://github.com/llvm/llvm-project/commit/ccb8a4e3f36937a71a526ca4a3fa29895253ee9d
DIFF: https://github.com/llvm/llvm-project/commit/ccb8a4e3f36937a71a526ca4a3fa29895253ee9d.diff
LOG: [mlir][memref] Fold subview(subview(x))
Folding of rank-reduced subviews is also supported.
Differential Revision: https://reviews.llvm.org/D140110
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 8357ec49e79fd..92f02c068d2b9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -294,6 +294,58 @@ class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
PatternRewriter &rewriter) const override;
};
+/// Folds subview(subview(x)) to a single subview(x).
+class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
+public:
+ using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::SubViewOp subView,
+ PatternRewriter &rewriter) const override {
+ Location loc = subView.getLoc();
+ auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
+ if (!srcSubView)
+ return failure();
+ int64_t srcRank = srcSubView.getSourceType().getRank();
+
+ // TODO: Only stride 1 is supported.
+ for (auto s : {subView.getMixedStrides(), srcSubView.getMixedStrides()})
+ if (!llvm::all_of(
+ s, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }))
+ return failure();
+
+ // Get original offsets and sizes.
+ SmallVector<OpFoldResult> offsets = subView.getMixedOffsets();
+ SmallVector<OpFoldResult> srcOffsets = srcSubView.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = subView.getMixedSizes();
+ SmallVector<OpFoldResult> srcSizes = srcSubView.getMixedSizes();
+
+ // Compute new offsets and sizes.
+ llvm::SmallBitVector srcReducedDims = srcSubView.getDroppedDims();
+ SmallVector<OpFoldResult> newOffsets, newSizes;
+ int64_t dim = 0;
+ for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) {
+ if (srcReducedDims[srcDim]) {
+ // Dim is reduced in srcSubView.
+ assert(isConstantIntValue(srcSizes[srcDim], 1) && "expected size 1");
+ newOffsets.push_back(srcOffsets[srcDim]);
+ newSizes.push_back(srcSizes[srcDim]);
+ continue;
+ }
+ AffineExpr sym0, sym1;
+ bindSymbols(subView.getContext(), sym0, sym1);
+ newOffsets.push_back(makeComposedFoldedAffineApply(
+ rewriter, loc, sym0 + sym1, {srcOffsets[srcDim], offsets[dim]}));
+ newSizes.push_back(sizes[dim]);
+ ++dim;
+ }
+
+ // Replace original op.
+ rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+ subView, subView.getType(), srcSubView.getSource(), newOffsets,
+ newSizes, srcSubView.getMixedStrides());
+ return success();
+ }
+};
} // namespace
static SmallVector<Value>
@@ -533,8 +585,8 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
LoadOpOfCollapseShapeOpFolder<AffineLoadOp>,
LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
StoreOpOfCollapseShapeOpFolder<AffineStoreOp>,
- StoreOpOfCollapseShapeOpFolder<memref::StoreOp>>(
- patterns.getContext());
+ StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
+ SubViewOfSubViewFolder>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 8ed0cf06fc769..c2ecc90be8ddf 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s -o - | FileCheck %s
+// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s | FileCheck %s
func.func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
@@ -465,3 +465,40 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK-NEXT: affine.for %{{.*}} = 0 to 3 {
// CHECK-NEXT: affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32>
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)>
+// CHECK-LABEL: func @subview_of_subview(
+// CHECK-SAME: %[[m:.*]]: memref<1x1024xf32, 3>, %[[pos:.*]]: index
+// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%arg1]
+// CHECK: memref.subview %arg0[4, %[[add]]] [1, 1] [1, 1] : memref<1x1024xf32, 3> to memref<f32, strided<[], offset: ?>, 3>
+func.func @subview_of_subview(%m: memref<1x1024xf32, 3>, %pos: index)
+ -> memref<f32, strided<[], offset: ?>, 3>
+{
+ %0 = memref.subview %m[3, %pos] [1, 2] [1, 1]
+ : memref<1x1024xf32, 3>
+ to memref<1x2xf32, strided<[1024, 2], offset: ?>, 3>
+ %1 = memref.subview %0[1, 2] [1, 1] [1, 1]
+ : memref<1x2xf32, strided<[1024, 2], offset: ?>, 3>
+ to memref<f32, strided<[], offset: ?>, 3>
+ return %1 : memref<f32, strided<[], offset: ?>, 3>
+}
+
+// -----
+
+// CHECK-LABEL: func @subview_of_subview_rank_reducing(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>
+// CHECK: memref.subview %arg0[3, 7, 8] [1, 1, 1] [1, 1, 1] : memref<?x?x?xf32> to memref<f32, strided<[], offset: ?>>
+func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
+ %sz: index, %pos: index)
+ -> memref<f32, strided<[], offset: ?>>
+{
+ %0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1]
+ : memref<?x?x?xf32>
+ to memref<?xf32, strided<[1], offset: ?>>
+ %1 = memref.subview %0[6] [1] [1]
+ : memref<?xf32, strided<[1], offset: ?>>
+ to memref<f32, strided<[], offset: ?>>
+ return %1 : memref<f32, strided<[], offset: ?>>
+}
More information about the Mlir-commits
mailing list