[Mlir-commits] [mlir] e7cb716 - [mlir][Linalg] Pattern to fuse pad operation with elementwise operations.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 11 13:37:33 PST 2022


Author: MaheshRavishankar
Date: 2022-01-11T13:37:25-08:00
New Revision: e7cb716ef95551000d9de2d0334ab1bf84a120a3

URL: https://github.com/llvm/llvm-project/commit/e7cb716ef95551000d9de2d0334ab1bf84a120a3
DIFF: https://github.com/llvm/llvm-project/commit/e7cb716ef95551000d9de2d0334ab1bf84a120a3.diff

LOG: [mlir][Linalg] Pattern to fuse pad operation with elementwise operations.

Most convolution operations need explicit padding of the input to
ensure all accesses are inbounds. In such cases, having a pad
operation can be a significant overhead. One way to reduce that
overhead is to try to fuse the pad operation with the producer of its
source.

A sequence

```
linalg.generic -> linalg.pad_tensor
```

can be replaced with

```
linalg.fill -> tensor.extract_slice -> linalg.generic ->
tensor.insert_slice.
```

if the `linalg.generic` has all parallel iterator types.

Differential Revision: https://reviews.llvm.org/D116418

Added: 
    mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
    mlir/test/Dialect/Linalg/pad_fusion.mlir
    mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ee6e948dd4698..839064ebd36f4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -87,6 +87,12 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
 void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns);
 
+/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its
+/// source, if the producer is a `linalg` operation with all parallel iterator
+/// types.
+void populateFusePadTensorWithProducerLinalgOpPatterns(
+    RewritePatternSet &patterns);
+
 /// Patterns to convert from one named op to another. These can be seen as
 /// canonicalizations of named ops into another named op.
 void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 5df61c73fcc6f..024499a01209c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Loops.cpp
   LinalgStrategyPasses.cpp
   NamedOpConversions.cpp
+  PadOpInterchange.cpp
   Promotion.cpp
   Tiling.cpp
   Transforms.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
new file mode 100644
index 0000000000000..65bc8bc061b2d
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
@@ -0,0 +1,122 @@
+//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns that intechanges a generic op -> pad_tensor
+// pattern into extract_slice -> generic_op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// A sequence of operations
+///
+/// ```mlir
+/// %0 = linalg. ...
+/// %1 = linalg.pad_tensor %0 ...
+/// ```
+///
+/// can be replaced with
+///
+/// ```mlir
+/// %0 = linalg.fill
+/// %1 = tensor.extract_slice %0 ...
+/// %2 = linalg. .... outs(..., %1, ....) ....
+/// %3 = tensor.insert_slice %2 into %1 ...
+/// ```
+///
+/// if the `linalg.generic` has all parallel iterator types.
+struct FusePadTensorOp : OpRewritePattern<PadTensorOp> {
+  using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(PadTensorOp padOp,
+                                PatternRewriter &rewriter) const override {
+    // Only works on padding op that sets the padded value to a constant.
+    Value padValue = padOp.getConstantPaddingValue();
+    if (!padValue)
+      return rewriter.notifyMatchFailure(padOp, "non constant padding");
+
+    // This pattern could work for any Linalg op. For now restrict it to generic
+    // ops.
+    Value source = padOp.source();
+    auto linalgOp = source.getDefiningOp<GenericOp>();
+    if (!linalgOp) {
+      return rewriter.notifyMatchFailure(
+          padOp, "expected source to be linalg.generic op");
+    }
+    // All iterator types need to be parallel.
+    if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) {
+      return rewriter.notifyMatchFailure(
+          padOp, "only supported for ops with all parallel iterator types");
+    }
+    ReifiedRankedShapedTypeDims resultShape;
+    if (failed(padOp.reifyResultShapes(rewriter, resultShape)) ||
+        resultShape.size() != 1) {
+      return rewriter.notifyMatchFailure(
+          padOp, "failed to get shape of pad op result");
+    }
+
+    Location loc = padOp.getLoc();
+
+    // Create the tensor of same size as output of the pad op.
+    RankedTensorType padResultType = padOp.getResultType();
+    auto resultSizes = getAsOpFoldResult(resultShape[0]);
+    auto initTensor = rewriter.create<InitTensorOp>(
+        loc, resultSizes, padResultType.getElementType());
+
+    // Fill the tensor with the pad value.
+    // TODO: There is an option to fill only the boundaries. For now just
+    // filling the whole tensor.
+    auto fillTensor =
+        rewriter.create<FillOp>(loc, padValue, initTensor.getResult());
+
+    // Construct a slice of the fill result that is to be replaced with the
+    // result of the generic op. The low pad values are the offsets, the size of
+    // the source is the size of the slice.
+    // TODO: This insert/extract could be potentially made a utility method.
+    unsigned resultNumber = source.cast<OpResult>().getResultNumber();
+    SmallVector<OpFoldResult> offsets = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> sizes;
+    sizes.reserve(offsets.size());
+    for (auto shape : llvm::enumerate(
+             source.getType().cast<RankedTensorType>().getShape())) {
+      if (ShapedType::isDynamic(shape.value())) {
+        sizes.push_back(
+            rewriter.create<tensor::DimOp>(loc, source, shape.index())
+                .getResult());
+      } else {
+        sizes.push_back(rewriter.getIndexAttr(shape.value()));
+      }
+    }
+    SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
+    auto slice = rewriter.create<tensor::ExtractSliceOp>(
+        loc, fillTensor.getResult(0), offsets, sizes, strides);
+
+    // Clone the generic op.
+    auto clonedOp = cast<GenericOp>(rewriter.clone(*linalgOp.getOperation()));
+    clonedOp.setOutputOperand(resultNumber, slice.getResult());
+
+    // Insert it back into the result of the fill.
+    rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+        padOp, clonedOp.getResult(resultNumber), fillTensor.getResult(0),
+        offsets, sizes, strides);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FusePadTensorOp>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Linalg/pad_fusion.mlir b/mlir/test/Dialect/Linalg/pad_fusion.mlir
new file mode 100644
index 0000000000000..7f6bd150f3de9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pad_fusion.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-opt -test-linalg-pad-fusion -split-input-file %s | FileCheck %s
+
+func @dynamic_pad_fusion(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index,
+    %arg3 : index, %arg4 : index, %arg5 : f32) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]} 
+    ins(%arg0 : tensor<?x?xf32>) outs(%init : tensor<?x?xf32>) {
+    ^bb0(%arg6 : f32, %arg7 : f32):
+      %1 = arith.mulf %arg6, %arg6 : f32
+      linalg.yield %1 : f32
+    } -> tensor<?x?xf32>
+  %1 = linalg.pad_tensor %0 low [%arg1, %arg2] high [%arg3, %arg4] {
+    ^bb0(%arg6: index, %arg7 : index):
+      linalg.yield %arg5 : f32
+    } : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)>
+//      CHECK: func @dynamic_pad_fusion
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG5:[a-zA-Z0-9]+]]: f32
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[SOURCE:.+]] = linalg.generic
+//  CHECK-DAG:   %[[SOURCE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
+//  CHECK-DAG:   %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[SOURCE_D0]]]
+//  CHECK-DAG:   %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
+//  CHECK-DAG:   %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]]
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%[[TARGET_D0]], %[[TARGET_D1]]] 
+//      CHECK:   %[[FILL:.+]] = linalg.fill(%[[ARG5]], %[[INIT]])
+//  CHECK-DAG:   %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
+//  CHECK-DAG:   %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
+//      CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
+// CHECK-SAME:       [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1]
+//      CHECK:   %[[SOURCE:.+]] = linalg.generic
+// CHECK-SAME:       outs(%[[SLICE]] : tensor<?x?xf32>)
+//      CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
+// CHECK-SAME:       [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func @mixed_pad_fusion(%arg0 : tensor<?x42xf32>, %arg1 : index, %arg2 : index,
+    %arg3 : f32) -> tensor<49x?xf32> {
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x42xf32>
+  %init = linalg.init_tensor [42, %d0] : tensor<42x?xf32>
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
+    iterator_types = ["parallel", "parallel"]} 
+    ins(%arg0 : tensor<?x42xf32>) outs(%init : tensor<42x?xf32>) {
+    ^bb0(%arg4 : f32, %arg5 : f32):
+      %1 = arith.mulf %arg4, %arg4 : f32
+      linalg.yield %1 : f32
+    } -> tensor<42x?xf32>
+  %1 = linalg.pad_tensor %0 low [3, %arg1] high [4, %arg2] {
+    ^bb0(%arg4: index, %arg5 : index):
+      linalg.yield %arg3 : f32
+    } : tensor<42x?xf32> to tensor<49x?xf32>
+  return %1 : tensor<49x?xf32>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)>
+//      CHECK: func @mixed_pad_fusion
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x42xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: f32
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[SOURCE:.+]] = linalg.generic
+//  CHECK-DAG:   %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
+//  CHECK-DAG:   %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]]
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor [49, %[[TARGET_D1]]] 
+//      CHECK:   %[[FILL:.+]] = linalg.fill(%[[ARG3]], %[[INIT]])
+//  CHECK-DAG:   %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
+//      CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
+// CHECK-SAME:       [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]
+//      CHECK:   %[[SOURCE:.+]] = linalg.generic
+// CHECK-SAME:       outs(%[[SLICE]] : tensor<42x?xf32>)
+//      CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
+// CHECK-SAME:       [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]
+//      CHECK:   return %[[RESULT]]

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 440d62ebd81d3..4a6c9d67c0d6f 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRLinalgTestPasses
   TestLinalgFusionTransforms.cpp
   TestLinalgHoisting.cpp
   TestLinalgTransforms.cpp
+  TestPadFusion.cpp
 
   EXCLUDE_FROM_LIBMLIR
 

diff  --git a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
new file mode 100644
index 0000000000000..bfe84e5d63d05
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
@@ -0,0 +1,48 @@
+//===- TestPadFusion.cpp - Test fusion of pad op with Linalg ops ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing fusion of pad ops with its producer
+// Linalg op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+
+namespace {
+struct TestPadFusionPass : public PassWrapper<TestPadFusionPass, FunctionPass> {
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
+  }
+
+  StringRef getArgument() const final { return "test-linalg-pad-fusion"; }
+  StringRef getDescription() const final { return "Test PadOp fusion"; }
+
+  void runOnFunction() override {
+    MLIRContext *context = &getContext();
+    FuncOp funcOp = getFunction();
+    RewritePatternSet patterns(context);
+    linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                            std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+namespace test {
+void registerTestPadFusion() { PassRegistration<TestPadFusionPass>(); }
+} // namespace test
+
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 72333a9044dcd..3a2d83ebbe77b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -103,6 +103,7 @@ void registerTestMemRefStrideCalculation();
 void registerTestNumberOfBlockExecutionsPass();
 void registerTestNumberOfOperationExecutionsPass();
 void registerTestOpaqueLoc();
+void registerTestPadFusion();
 void registerTestPDLByteCodePass();
 void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestRecursiveTypesPass();
@@ -195,6 +196,7 @@ void registerTestPasses() {
   mlir::test::registerTestNumberOfBlockExecutionsPass();
   mlir::test::registerTestNumberOfOperationExecutionsPass();
   mlir::test::registerTestOpaqueLoc();
+  mlir::test::registerTestPadFusion();
   mlir::test::registerTestPDLByteCodePass();
   mlir::test::registerTestRecursiveTypesPass();
   mlir::test::registerTestSCFUtilsPass();


        


More information about the Mlir-commits mailing list