[Mlir-commits] [mlir] bb4c53b - [mlir][tensor] Merge consecutive insert_slice/extract_slice ops
Lei Zhang
llvmlistbot at llvm.org
Tue Sep 20 17:00:34 PDT 2022
Author: Lei Zhang
Date: 2022-09-20T19:52:56-04:00
New Revision: bb4c53b7ba113b274ad0fd8d881313509947c896
URL: https://github.com/llvm/llvm-project/commit/bb4c53b7ba113b274ad0fd8d881313509947c896
DIFF: https://github.com/llvm/llvm-project/commit/bb4c53b7ba113b274ad0fd8d881313509947c896.diff
LOG: [mlir][tensor] Merge consecutive insert_slice/extract_slice ops
Consecutive tensor.insert_slice/tensor.extract_slice can be
created for the case like tiling convolution and then downsizing
2-D convolutions into 1-D ones. It hinders further transformations.
So adding these patterns to clean it up.
Given that bufferization is sensitive and have requirements over
the IR structure (see https://reviews.llvm.org/D132666),
these patterns are put in Transforms/ with separate entry points
for explicit collection.
Reviewed By: ThomasRaoux, mravishankar
Differential Revision: https://reviews.llvm.org/D133871
Added:
mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 28c22aecdf318..13ff67e3af433 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -29,6 +29,13 @@ void populateSplitPaddingPatterns(RewritePatternSet &patterns,
FailureOr<Value> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
+/// Collects patterns to merge consecutive tensor.insert_slice/extract_slice
+/// into one. These patterns are in in this separate entry point because the
+/// bufferization is sensitive over IR structure, particularly those
+/// tensor.extract_slice and tensor.insert_slice ops for creating the slices.
+void populateMergeConsecutiveInsertExtractSlicePatterns(
+ RewritePatternSet &patterns);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 0b200e03226da..73bab5685ab6a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ExtractSliceFromReshape.cpp
+ MergeConsecutiveInsertExtractSlicePatterns.cpp
SplitPadding.cpp
SwapExtractSliceWithProducer.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
new file mode 100644
index 0000000000000..48977a90ffb3b
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
@@ -0,0 +1,117 @@
+//===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+/// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and
+/// returns the results.
+static SmallVector<OpFoldResult> mergeOffsets(Location loc,
+ ArrayRef<OpFoldResult> offsets1,
+ ArrayRef<OpFoldResult> offsets2,
+ OpBuilder &builder) {
+ SmallVector<OpFoldResult> foldedOffsets;
+ assert(offsets1.size() == offsets2.size());
+ foldedOffsets.reserve(offsets1.size());
+
+ AffineExpr dim1, dim2;
+ bindDims(builder.getContext(), dim1, dim2);
+
+ for (const auto &pair : llvm::zip(offsets1, offsets2)) {
+ auto offset0 =
+ getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair));
+ auto offset1 =
+ getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair));
+ auto foldedOffset =
+ makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1});
+ foldedOffsets.push_back(foldedOffset.getResult());
+ }
+ return foldedOffsets;
+}
+
+namespace {
+/// Merges consecutive tensor.extract_slice ops into one.
+struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
+ PatternRewriter &rewriter) const override {
+ auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
+ if (!prevOp)
+ return failure();
+
+ if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
+ return failure();
+
+ auto prevResultType = prevOp.getType().cast<ShapedType>();
+ if (prevOp.getSourceType().getRank() != prevResultType.getRank())
+ return rewriter.notifyMatchFailure(
+ prevOp, "rank-reducing producder case unimplemented");
+
+ Location loc = nextOp.getLoc();
+
+ SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets();
+ SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets();
+ SmallVector<OpFoldResult> foldedOffsets =
+ mergeOffsets(loc, prevOffsets, nextOffsets, rewriter);
+
+ rewriter.replaceOpWithNewOp<ExtractSliceOp>(
+ nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets,
+ nextOp.getMixedSizes(), nextOp.getMixedStrides());
+ return success();
+ }
+};
+
+/// Merges consecutive tensor.insert_slice ops into one.
+struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertSliceOp nextOp,
+ PatternRewriter &rewriter) const override {
+ auto prevOp = nextOp.getSource().getDefiningOp<InsertSliceOp>();
+ if (!prevOp)
+ return failure();
+
+ if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
+ return failure();
+
+ // The first insert_slice op should be rank reducing to make sure we cover
+ // the full source tensor to be inserted in the second insert_slice op.
+ SliceVerificationResult result =
+ isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
+ if (result != SliceVerificationResult::Success)
+ return failure();
+
+ // Dynamic dimensions can pass rank reducing check in the above, e.g,
+ // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
+ // the dynamic size covers the full tensor.
+ if (!prevOp.getSourceType().hasStaticShape() ||
+ !prevOp.getDestType().hasStaticShape())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<InsertSliceOp>(
+ nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
+ nextOp.getMixedSizes(), nextOp.getMixedStrides());
+ return success();
+ }
+};
+} // namespace
+
+void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<MergeConsecutiveExtractSlice, MergeConsecutiveInsertSlice>(
+ patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
new file mode 100644
index 0000000000000..45a3f37ea0679
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-consecutive-insert-extract-slice -canonicalize -mlir-print-local-scope %s | FileCheck %s
+
+func.func @extract_slice_same_rank(
+ %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x16x32x?xf32> {
+ %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
+ %1 = tensor.extract_slice %0[7, 8, 9, %offset1] [8, 16, 32, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<8x16x32x?xf32>
+ return %1: tensor<8x16x32x?xf32>
+}
+
+// CHECK-LABEL: func.func @extract_slice_same_rank
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
+// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET0]], %[[OFFSET1]]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1]
+// CHECK: return %[[EXTRACT]] : tensor<8x16x32x?xf32>
+
+func.func @extract_slice_rank_reducing_consumer(
+ %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> {
+ %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
+ %1 = tensor.extract_slice %0[7, 8, 9, %offset1] [1, 16, 1, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<16x?xf32>
+ return %1: tensor<16x?xf32>
+}
+
+// CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer
+// CHECK: tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32>
+
+func.func @extract_slice_rank_reducing_producer(
+ %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> {
+ %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32>
+ %1 = tensor.extract_slice %0[7, %offset1] [8, %size1] [1, 1] : tensor<128x?xf32> to tensor<8x?xf32>
+ return %1: tensor<8x?xf32>
+}
+
+// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
+// CHECK-COUNT-2: tensor.extract_slice
+
+// -----
+
+func.func @insert_slice_rank_reducing(
+ %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x16x1xf32>, %src: tensor<16xf32>, %offset: index) -> tensor<128x128x128x128xf32> {
+ %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, 16, 1] [1, 1, 1] : tensor<16xf32> into tensor<1x16x1xf32>
+ %1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, 16, 1] [1, 1, 1, 1] : tensor<1x16x1xf32> into tensor<128x128x128x128xf32>
+ return %1: tensor<128x128x128x128xf32>
+}
+
+// CHECK-LABEL: func.func @insert_slice_rank_reducing
+// CHECK-SAME: (%[[DST:.+]]: tensor<128x128x128x128xf32>, %{{.+}}: tensor<1x16x1xf32>, %[[SRC:.+]]: tensor<16xf32>, %[[IDX:.+]]: index)
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1]
+// CHECK: return %[[INSERT]]
+
+func.func @insert_slice_rank_reducing_dynamic_shape(
+ %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor<?xf32>, %offset: index, %size: index) -> tensor<128x128x128x128xf32> {
+ %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor<?xf32> into tensor<1x?x1xf32>
+ %1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, %size, 1] [1, 1, 1, 1] : tensor<1x?x1xf32> into tensor<128x128x128x128xf32>
+ return %1: tensor<128x128x128x128xf32>
+}
+
+// CHECK-LABEL: func.func @insert_slice_rank_reducing_dynamic_shape
+// CHECK-COUNT-2: tensor.insert_slice
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 5dd5d763388a9..e06607cb30ed8 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -53,6 +53,12 @@ struct TestTensorTransforms
llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
llvm::cl::init(false)};
+ Option<bool> testFoldConsecutiveInsertExtractSlice{
+ *this, "test-fold-consecutive-insert-extract-slice",
+ llvm::cl::desc(
+ "Test folding consecutive tensor.insert_slice/tensor.extract_slice"),
+ llvm::cl::init(false)};
+
Option<bool> testRewriteExtractSliceWithTiledCollapseShape{
*this, "test-rewrite-extract-slice-from-collapse-shape",
llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
@@ -90,6 +96,12 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
+static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -233,6 +245,8 @@ void TestTensorTransforms::runOnOperation() {
applySplitPaddingPatterns(rootOp);
if (testFoldConstantExtractSlice)
applyFoldConstantExtractSlicePatterns(rootOp);
+ if (testFoldConsecutiveInsertExtractSlice)
+ applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
if (testRewriteExtractSliceWithTiledCollapseShape) {
if (failed(
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
More information about the Mlir-commits
mailing list