[Mlir-commits] [mlir] e7cb716 - [mlir][Linalg] Pattern to fuse pad operation with elementwise operations.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 11 13:37:33 PST 2022
Author: MaheshRavishankar
Date: 2022-01-11T13:37:25-08:00
New Revision: e7cb716ef95551000d9de2d0334ab1bf84a120a3
URL: https://github.com/llvm/llvm-project/commit/e7cb716ef95551000d9de2d0334ab1bf84a120a3
DIFF: https://github.com/llvm/llvm-project/commit/e7cb716ef95551000d9de2d0334ab1bf84a120a3.diff
LOG: [mlir][Linalg] Pattern to fuse pad operation with elementwise operations.
Most convolution operations need explicit padding of the input to
ensure all accesses are inbounds. In such cases, having a pad
operation can be a significant overhead. One way to reduce that
overhead is to try to fuse the pad operation with the producer of its
source.
A sequence
```
linalg.generic -> linalg.pad_tensor
```
can be replaced with
```
linalg.fill -> tensor.extract_slice -> linalg.generic ->
tensor.insert_slice.
```
if the `linalg.generic` has all parallel iterator types.
Differential Revision: https://reviews.llvm.org/D116418
Added:
mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
mlir/test/Dialect/Linalg/pad_fusion.mlir
mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Linalg/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ee6e948dd4698..839064ebd36f4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -87,6 +87,12 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
+/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its
+/// source, if the producer is a `linalg` operation with all parallel iterator
+/// types.
+void populateFusePadTensorWithProducerLinalgOpPatterns(
+ RewritePatternSet &patterns);
+
/// Patterns to convert from one named op to another. These can be seen as
/// canonicalizations of named ops into another named op.
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 5df61c73fcc6f..024499a01209c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Loops.cpp
LinalgStrategyPasses.cpp
NamedOpConversions.cpp
+ PadOpInterchange.cpp
Promotion.cpp
Tiling.cpp
Transforms.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
new file mode 100644
index 0000000000000..65bc8bc061b2d
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
@@ -0,0 +1,122 @@
+//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns that intechanges a generic op -> pad_tensor
+// pattern into extract_slice -> generic_op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// A sequence of operations
+///
+/// ```mlir
+/// %0 = linalg. ...
+/// %1 = linalg.pad_tensor %0 ...
+/// ```
+///
+/// can be replaced with
+///
+/// ```mlir
+/// %0 = linalg.fill
+/// %1 = tensor.extract_slice %0 ...
+/// %2 = linalg. .... outs(..., %1, ....) ....
+/// %3 = tensor.insert_slice %2 into %1 ...
+/// ```
+///
+/// if the `linalg.generic` has all parallel iterator types.
+struct FusePadTensorOp : OpRewritePattern<PadTensorOp> {
+ using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(PadTensorOp padOp,
+ PatternRewriter &rewriter) const override {
+ // Only works on padding op that sets the padded value to a constant.
+ Value padValue = padOp.getConstantPaddingValue();
+ if (!padValue)
+ return rewriter.notifyMatchFailure(padOp, "non constant padding");
+
+ // This pattern could work for any Linalg op. For now restrict it to generic
+ // ops.
+ Value source = padOp.source();
+ auto linalgOp = source.getDefiningOp<GenericOp>();
+ if (!linalgOp) {
+ return rewriter.notifyMatchFailure(
+ padOp, "expected source to be linalg.generic op");
+ }
+ // All iterator types need to be parallel.
+ if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) {
+ return rewriter.notifyMatchFailure(
+ padOp, "only supported for ops with all parallel iterator types");
+ }
+ ReifiedRankedShapedTypeDims resultShape;
+ if (failed(padOp.reifyResultShapes(rewriter, resultShape)) ||
+ resultShape.size() != 1) {
+ return rewriter.notifyMatchFailure(
+ padOp, "failed to get shape of pad op result");
+ }
+
+ Location loc = padOp.getLoc();
+
+ // Create the tensor of same size as output of the pad op.
+ RankedTensorType padResultType = padOp.getResultType();
+ auto resultSizes = getAsOpFoldResult(resultShape[0]);
+ auto initTensor = rewriter.create<InitTensorOp>(
+ loc, resultSizes, padResultType.getElementType());
+
+ // Fill the tensor with the pad value.
+ // TODO: There is an option to fill only the boundaries. For now just
+ // filling the whole tensor.
+ auto fillTensor =
+ rewriter.create<FillOp>(loc, padValue, initTensor.getResult());
+
+ // Construct a slice of the fill result that is to be replaced with the
+ // result of the generic op. The low pad values are the offsets, the size of
+ // the source is the size of the slice.
+ // TODO: This insert/extract could be potentially made a utility method.
+ unsigned resultNumber = source.cast<OpResult>().getResultNumber();
+ SmallVector<OpFoldResult> offsets = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> sizes;
+ sizes.reserve(offsets.size());
+ for (auto shape : llvm::enumerate(
+ source.getType().cast<RankedTensorType>().getShape())) {
+ if (ShapedType::isDynamic(shape.value())) {
+ sizes.push_back(
+ rewriter.create<tensor::DimOp>(loc, source, shape.index())
+ .getResult());
+ } else {
+ sizes.push_back(rewriter.getIndexAttr(shape.value()));
+ }
+ }
+ SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
+ auto slice = rewriter.create<tensor::ExtractSliceOp>(
+ loc, fillTensor.getResult(0), offsets, sizes, strides);
+
+ // Clone the generic op.
+ auto clonedOp = cast<GenericOp>(rewriter.clone(*linalgOp.getOperation()));
+ clonedOp.setOutputOperand(resultNumber, slice.getResult());
+
+ // Insert it back into the result of the fill.
+ rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+ padOp, clonedOp.getResult(resultNumber), fillTensor.getResult(0),
+ offsets, sizes, strides);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FusePadTensorOp>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/pad_fusion.mlir b/mlir/test/Dialect/Linalg/pad_fusion.mlir
new file mode 100644
index 0000000000000..7f6bd150f3de9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pad_fusion.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-opt -test-linalg-pad-fusion -split-input-file %s | FileCheck %s
+
+func @dynamic_pad_fusion(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index, %arg5 : f32) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?x?xf32>) outs(%init : tensor<?x?xf32>) {
+ ^bb0(%arg6 : f32, %arg7 : f32):
+ %1 = arith.mulf %arg6, %arg6 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ %1 = linalg.pad_tensor %0 low [%arg1, %arg2] high [%arg3, %arg4] {
+ ^bb0(%arg6: index, %arg7 : index):
+ linalg.yield %arg5 : f32
+ } : tensor<?x?xf32> to tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)>
+// CHECK: func @dynamic_pad_fusion
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: f32
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[SOURCE:.+]] = linalg.generic
+// CHECK-DAG: %[[SOURCE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
+// CHECK-DAG: %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[SOURCE_D0]]]
+// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
+// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]]
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[TARGET_D0]], %[[TARGET_D1]]]
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG5]], %[[INIT]])
+// CHECK-DAG: %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
+// CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
+// CHECK-SAME: [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1]
+// CHECK: %[[SOURCE:.+]] = linalg.generic
+// CHECK-SAME: outs(%[[SLICE]] : tensor<?x?xf32>)
+// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
+// CHECK-SAME: [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @mixed_pad_fusion(%arg0 : tensor<?x42xf32>, %arg1 : index, %arg2 : index,
+ %arg3 : f32) -> tensor<49x?xf32> {
+ %c0 = arith.constant 0 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x42xf32>
+ %init = linalg.init_tensor [42, %d0] : tensor<42x?xf32>
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?x42xf32>) outs(%init : tensor<42x?xf32>) {
+ ^bb0(%arg4 : f32, %arg5 : f32):
+ %1 = arith.mulf %arg4, %arg4 : f32
+ linalg.yield %1 : f32
+ } -> tensor<42x?xf32>
+ %1 = linalg.pad_tensor %0 low [3, %arg1] high [4, %arg2] {
+ ^bb0(%arg4: index, %arg5 : index):
+ linalg.yield %arg3 : f32
+ } : tensor<42x?xf32> to tensor<49x?xf32>
+ return %1 : tensor<49x?xf32>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)>
+// CHECK: func @mixed_pad_fusion
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x42xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: f32
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[SOURCE:.+]] = linalg.generic
+// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
+// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]]
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [49, %[[TARGET_D1]]]
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG3]], %[[INIT]])
+// CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
+// CHECK-SAME: [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]
+// CHECK: %[[SOURCE:.+]] = linalg.generic
+// CHECK-SAME: outs(%[[SLICE]] : tensor<42x?xf32>)
+// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
+// CHECK-SAME: [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]
+// CHECK: return %[[RESULT]]
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 440d62ebd81d3..4a6c9d67c0d6f 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgFusionTransforms.cpp
TestLinalgHoisting.cpp
TestLinalgTransforms.cpp
+ TestPadFusion.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
new file mode 100644
index 0000000000000..bfe84e5d63d05
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
@@ -0,0 +1,48 @@
+//===- TestPadFusion.cpp - Test fusion of pad op with Linalg ops ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing fusion of pad ops with its producer
+// Linalg op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+
+namespace {
+struct TestPadFusionPass : public PassWrapper<TestPadFusionPass, FunctionPass> {
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
+ }
+
+ StringRef getArgument() const final { return "test-linalg-pad-fusion"; }
+ StringRef getDescription() const final { return "Test PadOp fusion"; }
+
+ void runOnFunction() override {
+ MLIRContext *context = &getContext();
+ FuncOp funcOp = getFunction();
+ RewritePatternSet patterns(context);
+ linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+namespace test {
+void registerTestPadFusion() { PassRegistration<TestPadFusionPass>(); }
+} // namespace test
+
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 72333a9044dcd..3a2d83ebbe77b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -103,6 +103,7 @@ void registerTestMemRefStrideCalculation();
void registerTestNumberOfBlockExecutionsPass();
void registerTestNumberOfOperationExecutionsPass();
void registerTestOpaqueLoc();
+void registerTestPadFusion();
void registerTestPDLByteCodePass();
void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
@@ -195,6 +196,7 @@ void registerTestPasses() {
mlir::test::registerTestNumberOfBlockExecutionsPass();
mlir::test::registerTestNumberOfOperationExecutionsPass();
mlir::test::registerTestOpaqueLoc();
+ mlir::test::registerTestPadFusion();
mlir::test::registerTestPDLByteCodePass();
mlir::test::registerTestRecursiveTypesPass();
mlir::test::registerTestSCFUtilsPass();
More information about the Mlir-commits
mailing list