[Mlir-commits] [mlir] ccb8a4e - [mlir][memref] Fold subview(subview(x))

Matthias Springer llvmlistbot at llvm.org
Thu Dec 15 08:55:14 PST 2022


Author: Matthias Springer
Date: 2022-12-15T17:50:12+01:00
New Revision: ccb8a4e3f36937a71a526ca4a3fa29895253ee9d

URL: https://github.com/llvm/llvm-project/commit/ccb8a4e3f36937a71a526ca4a3fa29895253ee9d
DIFF: https://github.com/llvm/llvm-project/commit/ccb8a4e3f36937a71a526ca4a3fa29895253ee9d.diff

LOG: [mlir][memref] Fold subview(subview(x))

Folding of rank-reduced subviews is also supported.

Differential Revision: https://reviews.llvm.org/D140110

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 8357ec49e79fd..92f02c068d2b9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -294,6 +294,58 @@ class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
                                 PatternRewriter &rewriter) const override;
 };
 
+/// Folds subview(subview(x)) to a single subview(x).
+class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
+public:
+  using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::SubViewOp subView,
+                                PatternRewriter &rewriter) const override {
+    Location loc = subView.getLoc();
+    auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
+    if (!srcSubView)
+      return failure();
+    int64_t srcRank = srcSubView.getSourceType().getRank();
+
+    // TODO: Only stride 1 is supported.
+    for (auto s : {subView.getMixedStrides(), srcSubView.getMixedStrides()})
+      if (!llvm::all_of(
+              s, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }))
+        return failure();
+
+    // Get original offsets and sizes.
+    SmallVector<OpFoldResult> offsets = subView.getMixedOffsets();
+    SmallVector<OpFoldResult> srcOffsets = srcSubView.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = subView.getMixedSizes();
+    SmallVector<OpFoldResult> srcSizes = srcSubView.getMixedSizes();
+
+    // Compute new offsets and sizes.
+    llvm::SmallBitVector srcReducedDims = srcSubView.getDroppedDims();
+    SmallVector<OpFoldResult> newOffsets, newSizes;
+    int64_t dim = 0;
+    for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) {
+      if (srcReducedDims[srcDim]) {
+        // Dim is reduced in srcSubView.
+        assert(isConstantIntValue(srcSizes[srcDim], 1) && "expected size 1");
+        newOffsets.push_back(srcOffsets[srcDim]);
+        newSizes.push_back(srcSizes[srcDim]);
+        continue;
+      }
+      AffineExpr sym0, sym1;
+      bindSymbols(subView.getContext(), sym0, sym1);
+      newOffsets.push_back(makeComposedFoldedAffineApply(
+          rewriter, loc, sym0 + sym1, {srcOffsets[srcDim], offsets[dim]}));
+      newSizes.push_back(sizes[dim]);
+      ++dim;
+    }
+
+    // Replace original op.
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+        subView, subView.getType(), srcSubView.getSource(), newOffsets,
+        newSizes, srcSubView.getMixedStrides());
+    return success();
+  }
+};
 } // namespace
 
 static SmallVector<Value>
@@ -533,8 +585,8 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
                LoadOpOfCollapseShapeOpFolder<AffineLoadOp>,
                LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
                StoreOpOfCollapseShapeOpFolder<AffineStoreOp>,
-               StoreOpOfCollapseShapeOpFolder<memref::StoreOp>>(
-      patterns.getContext());
+               StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
+               SubViewOfSubViewFolder>(patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 8ed0cf06fc769..c2ecc90be8ddf 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s -o - | FileCheck %s
+// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s | FileCheck %s
 
 func.func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
   %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
@@ -465,3 +465,40 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with
 // CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
 // CHECK-NEXT: affine.for %{{.*}} = 0 to 3 {
 // CHECK-NEXT:   affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32>
+
+// -----
+
+//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)>
+// CHECK-LABEL: func @subview_of_subview(
+//  CHECK-SAME:     %[[m:.*]]: memref<1x1024xf32, 3>, %[[pos:.*]]: index
+//       CHECK:   %[[add:.*]] = affine.apply #[[$map]]()[%arg1]
+//       CHECK:   memref.subview %arg0[4, %[[add]]] [1, 1] [1, 1] : memref<1x1024xf32, 3> to memref<f32, strided<[], offset: ?>, 3>
+func.func @subview_of_subview(%m: memref<1x1024xf32, 3>, %pos: index)
+    -> memref<f32, strided<[], offset: ?>, 3>
+{
+  %0 = memref.subview %m[3, %pos] [1, 2] [1, 1]
+      : memref<1x1024xf32, 3>
+        to memref<1x2xf32, strided<[1024, 2], offset: ?>, 3>
+  %1 = memref.subview %0[1, 2] [1, 1] [1, 1]
+      : memref<1x2xf32, strided<[1024, 2], offset: ?>, 3>
+        to memref<f32, strided<[], offset: ?>, 3>
+  return %1 : memref<f32, strided<[], offset: ?>, 3>
+}
+
+// -----
+
+// CHECK-LABEL: func @subview_of_subview_rank_reducing(
+//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>
+//       CHECK:   memref.subview %arg0[3, 7, 8] [1, 1, 1] [1, 1, 1] : memref<?x?x?xf32> to memref<f32, strided<[], offset: ?>>
+func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
+                                            %sz: index, %pos: index)
+    -> memref<f32, strided<[], offset: ?>>
+{
+  %0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1]
+      : memref<?x?x?xf32>
+        to memref<?xf32, strided<[1], offset: ?>>
+  %1 = memref.subview %0[6] [1] [1]
+      : memref<?xf32, strided<[1], offset: ?>>
+        to memref<f32, strided<[], offset: ?>>
+  return %1 : memref<f32, strided<[], offset: ?>>
+}


        


More information about the Mlir-commits mailing list