[Mlir-commits] [mlir] c325e97 - [mlir][linalg] Swap tensor.extract_slice(linalg.fill)
Lei Zhang
llvmlistbot at llvm.org
Tue Sep 20 14:35:07 PDT 2022
Author: Lei Zhang
Date: 2022-09-20T17:31:22-04:00
New Revision: c325e978b529e7e899a31cd2693fbc1a676afcf0
URL: https://github.com/llvm/llvm-project/commit/c325e978b529e7e899a31cd2693fbc1a676afcf0
DIFF: https://github.com/llvm/llvm-project/commit/c325e978b529e7e899a31cd2693fbc1a676afcf0.diff
LOG: [mlir][linalg] Swap tensor.extract_slice(linalg.fill)
This commit adds a pattern to swap
```
tensor.extract_slice(linalg.fill(%cst, %init))
```
into
```
linalg.fill(%cst, tensor.extract_slice(%init))
```
when the linalg.fill op have no other users.
This helps to reduce the fill footprint.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D134102
Added:
mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp
mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 23821887cc645..b6586718ad48a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -110,6 +110,10 @@ void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
/// Patterns that are used to bubble up extract slice op above linalg op.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
+/// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into
+/// linalg.fill(%cst, tensor.extract_slice(%init)).
+void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);
+
/// Return true if two `linalg.generic` operations with producer/consumer
/// relationship through `fusedOperand` can be fused using elementwise op
/// fusion.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 6d97dfc6d84fb..90f31fd679be6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Promotion.cpp
Split.cpp
SplitReduction.cpp
+ SwapExtractSliceWithFillPatterns.cpp
Tiling.cpp
TilingInterfaceImpl.cpp
Transforms.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp
new file mode 100644
index 0000000000000..425ef2f3068b2
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp
@@ -0,0 +1,41 @@
+//===- SwapExtractSliceWithFillPatterns.cpp -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst,
+/// tensor.extract_slice(%init)) when the linalg.fill op have no other users.
+/// This helps to reduce the fill footprint.
+struct SwapExtractSliceOfFill final
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto fillOp = extractOp.getSource().getDefiningOp<FillOp>();
+ if (!fillOp || !fillOp->hasOneUse())
+ return failure();
+
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0],
+ extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
+ extractOp.getMixedStrides());
+ rewriter.replaceOpWithNewOp<FillOp>(extractOp, fillOp.getInputs(),
+ ValueRange{newExtractOp.getResult()});
+ return success();
+ }
+};
+
+void mlir::linalg::populateSwapExtractSliceWithFillPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<SwapExtractSliceOfFill>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir b/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir
new file mode 100644
index 0000000000000..0309301628714
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir
@@ -0,0 +1,28 @@
+//RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-swap-extract-slice-with-fill-pattern %s | FileCheck %s
+
+// CHECK-LABEL: func.func @swap_fill_insert_slice
+// CHECK-SAME: (%[[INIT:.+]]: tensor<?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[SIZE1:.+]]: index)
+// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EXT:.+]] = tensor.extract_slice %[[INIT]][%[[OFFSET0]], 8, 4] [1, %[[SIZE1]], 6] [1, 3, 1]
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EXT]] : tensor<?x6xf32>) -> tensor<?x6xf32>
+// CHECK: return %[[FILL]]
+func.func @swap_fill_insert_slice(%init : tensor<?x?x?xf32>, %offset0: index, %size1: index) -> tensor<?x6xf32> {
+ %f0 = arith.constant 0.000000e+00 : f32
+ %0 = linalg.fill ins(%f0 : f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %1 = tensor.extract_slice %0[%offset0, 8, 4] [1, %size1, 6] [1, 3, 1]
+ : tensor<?x?x?xf32> to tensor<?x6xf32>
+ return %1: tensor<?x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @dont_swap_fill_insert_slice_multi_user
+// CHECK: linalg.fill
+// CHECK: tensor.extract_slice
+func.func @dont_swap_fill_insert_slice_multi_user(%init : tensor<?x?x?xf32>, %offset0: index, %size1: index) -> (tensor<?x?x?xf32>, tensor<2x?x6xf32>) {
+ %f0 = arith.constant 0.000000e+00 : f32
+ %0 = linalg.fill ins(%f0 : f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %1 = tensor.extract_slice %0[%offset0, 8, 4] [2, %size1, 6] [1, 3, 1]
+ : tensor<?x?x?xf32> to tensor<2x?x6xf32>
+ return %0, %1: tensor<?x?x?xf32>, tensor<2x?x6xf32>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 50a2abc2e07c3..8e13b9801c5ba 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,6 +123,11 @@ struct TestLinalgTransforms
llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
"extract_slice + linalgOp"),
llvm::cl::init(false)};
+ Option<bool> testSwapExtractSliceWithFill{
+ *this, "test-swap-extract-slice-with-fill-pattern",
+ llvm::cl::desc(
+ "Test patterns to swap tensor.extract_slice(linalg.fill())"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -508,6 +513,12 @@ static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ populateSwapExtractSliceWithFillPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
auto lambda = [&](void *) {
@@ -551,6 +562,8 @@ void TestLinalgTransforms::runOnOperation() {
return applySplitReduction(getOperation());
if (testBubbleUpExtractSliceOpPattern)
return applyBubbleUpExtractSliceOpPattern(getOperation());
+ if (testSwapExtractSliceWithFill)
+ return applySwapExtractSliceWithFillPattern(getOperation());
}
namespace mlir {
More information about the Mlir-commits
mailing list