[Mlir-commits] [mlir] c325e97 - [mlir][linalg] Swap tensor.extract_slice(linalg.fill)

Lei Zhang llvmlistbot at llvm.org
Tue Sep 20 14:35:07 PDT 2022


Author: Lei Zhang
Date: 2022-09-20T17:31:22-04:00
New Revision: c325e978b529e7e899a31cd2693fbc1a676afcf0

URL: https://github.com/llvm/llvm-project/commit/c325e978b529e7e899a31cd2693fbc1a676afcf0
DIFF: https://github.com/llvm/llvm-project/commit/c325e978b529e7e899a31cd2693fbc1a676afcf0.diff

LOG: [mlir][linalg] Swap tensor.extract_slice(linalg.fill)

This commit adds a pattern to swap

```
tensor.extract_slice(linalg.fill(%cst, %init))
```
into
```
linalg.fill(%cst, tensor.extract_slice(%init))
```
when the linalg.fill op have no other users.
This helps to reduce the fill footprint.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D134102

Added: 
    mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp
    mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 23821887cc645..b6586718ad48a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -110,6 +110,10 @@ void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
 /// Patterns that are used to bubble up extract slice op above linalg op.
 void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
 
+/// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into
+/// linalg.fill(%cst, tensor.extract_slice(%init)).
+void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);
+
 /// Return true if two `linalg.generic` operations with producer/consumer
 /// relationship through `fusedOperand` can be fused using elementwise op
 /// fusion.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 6d97dfc6d84fb..90f31fd679be6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Promotion.cpp
   Split.cpp
   SplitReduction.cpp
+  SwapExtractSliceWithFillPatterns.cpp
   Tiling.cpp
   TilingInterfaceImpl.cpp
   Transforms.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp
new file mode 100644
index 0000000000000..425ef2f3068b2
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp
@@ -0,0 +1,41 @@
+//===- SwapExtractSliceWithFillPatterns.cpp -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst,
+/// tensor.extract_slice(%init)) when the linalg.fill op have no other users.
+/// This helps to reduce the fill footprint.
+struct SwapExtractSliceOfFill final
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto fillOp = extractOp.getSource().getDefiningOp<FillOp>();
+    if (!fillOp || !fillOp->hasOneUse())
+      return failure();
+
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0],
+        extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
+        extractOp.getMixedStrides());
+    rewriter.replaceOpWithNewOp<FillOp>(extractOp, fillOp.getInputs(),
+                                        ValueRange{newExtractOp.getResult()});
+    return success();
+  }
+};
+
+void mlir::linalg::populateSwapExtractSliceWithFillPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<SwapExtractSliceOfFill>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir b/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir
new file mode 100644
index 0000000000000..0309301628714
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir
@@ -0,0 +1,28 @@
+//RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-swap-extract-slice-with-fill-pattern %s | FileCheck %s
+
+// CHECK-LABEL: func.func @swap_fill_insert_slice
+//  CHECK-SAME: (%[[INIT:.+]]: tensor<?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[SIZE1:.+]]: index)
+//       CHECK:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+//       CHECK:   %[[EXT:.+]] = tensor.extract_slice %[[INIT]][%[[OFFSET0]], 8, 4] [1, %[[SIZE1]], 6] [1, 3, 1]
+//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EXT]] : tensor<?x6xf32>) -> tensor<?x6xf32>
+//       CHECK:   return %[[FILL]]
+func.func @swap_fill_insert_slice(%init : tensor<?x?x?xf32>, %offset0: index, %size1: index) -> tensor<?x6xf32> {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %0 = linalg.fill ins(%f0 : f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %1 = tensor.extract_slice %0[%offset0, 8, 4] [1, %size1, 6] [1, 3, 1]
+    : tensor<?x?x?xf32> to tensor<?x6xf32>
+  return %1: tensor<?x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @dont_swap_fill_insert_slice_multi_user
+//       CHECK:   linalg.fill
+//       CHECK:   tensor.extract_slice
+func.func @dont_swap_fill_insert_slice_multi_user(%init : tensor<?x?x?xf32>, %offset0: index, %size1: index) -> (tensor<?x?x?xf32>, tensor<2x?x6xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %0 = linalg.fill ins(%f0 : f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %1 = tensor.extract_slice %0[%offset0, 8, 4] [2, %size1, 6] [1, 3, 1]
+    : tensor<?x?x?xf32> to tensor<2x?x6xf32>
+  return %0, %1: tensor<?x?x?xf32>, tensor<2x?x6xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 50a2abc2e07c3..8e13b9801c5ba 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,6 +123,11 @@ struct TestLinalgTransforms
       llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
                      "extract_slice + linalgOp"),
       llvm::cl::init(false)};
+  Option<bool> testSwapExtractSliceWithFill{
+      *this, "test-swap-extract-slice-with-fill-pattern",
+      llvm::cl::desc(
+          "Test patterns to swap tensor.extract_slice(linalg.fill())"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -508,6 +513,12 @@ static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateSwapExtractSliceWithFillPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   auto lambda = [&](void *) {
@@ -551,6 +562,8 @@ void TestLinalgTransforms::runOnOperation() {
     return applySplitReduction(getOperation());
   if (testBubbleUpExtractSliceOpPattern)
     return applyBubbleUpExtractSliceOpPattern(getOperation());
+  if (testSwapExtractSliceWithFill)
+    return applySwapExtractSliceWithFillPattern(getOperation());
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list