[Mlir-commits] [mlir] 7d249df - [mlir][linalg] NFC: minor cleanups after moving pad to tensor dialect
Lei Zhang
llvmlistbot at llvm.org
Thu Mar 3 06:48:03 PST 2022
Author: Lei Zhang
Date: 2022-03-03T09:44:54-05:00
New Revision: 7d249dfd7da55eed6992aba26b3544b69025bcc2
URL: https://github.com/llvm/llvm-project/commit/7d249dfd7da55eed6992aba26b3544b69025bcc2
DIFF: https://github.com/llvm/llvm-project/commit/7d249dfd7da55eed6992aba26b3544b69025bcc2.diff
LOG: [mlir][linalg] NFC: minor cleanups after moving pad to tensor dialect
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D120627
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 87f5cc17b4221..d4cd9a550b3a9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -103,10 +103,9 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
-/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its
-/// source, if the producer is a `linalg` operation with all parallel iterator
-/// types.
-void populateFusePadTensorWithProducerLinalgOpPatterns(
+/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
+/// if the producer is a `linalg` operation with all parallel iterator types.
+void populateFuseTensorPadWithProducerLinalgOpPatterns(
RewritePatternSet &patterns);
/// Patterns to convert from one named op to another. These can be seen as
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
index 78b5305c8a1ec..1777201e6037d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
@@ -1,4 +1,4 @@
-//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===//
+//===- PadOpInterchange.cpp - Interchange tensor.pad with linalg producer -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,9 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements patterns that intechanges a generic op -> pad_tensor
-// pattern into extract_slice -> generic_op.
+// This file implements patterns that intechanges a linalg.generic -> tensor.pad
+// op chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
+// op chain.
//
//===----------------------------------------------------------------------===//
@@ -17,7 +18,6 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
-using namespace mlir::linalg;
namespace {
@@ -25,7 +25,7 @@ namespace {
///
/// ```mlir
/// %0 = linalg. ...
-/// %1 = linalg.pad_tensor %0 ...
+/// %1 = tensor.pad %0 ...
/// ```
///
/// can be replaced with
@@ -40,6 +40,7 @@ namespace {
/// if the `linalg.generic` has all parallel iterator types.
struct FusePadOp : OpRewritePattern<tensor::PadOp> {
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
// Only works on padding op that sets the padded value to a constant.
@@ -50,7 +51,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
// This pattern could work for any Linalg op. For now restrict it to generic
// ops.
Value source = padOp.source();
- auto linalgOp = source.getDefiningOp<GenericOp>();
+ auto linalgOp = source.getDefiningOp<linalg::GenericOp>();
if (!linalgOp) {
return rewriter.notifyMatchFailure(
padOp, "expected source to be linalg.generic op");
@@ -75,14 +76,14 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
// Create the tensor of same size as output of the pad op.
RankedTensorType padResultType = padOp.getResultType();
auto resultSizes = getAsOpFoldResult(resultShape[0]);
- auto initTensor = rewriter.create<InitTensorOp>(
+ auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultSizes, padResultType.getElementType());
// Fill the tensor with the pad value.
// TODO: There is an option to fill only the boundaries. For now just
// filling the whole tensor.
auto fillTensor =
- rewriter.create<FillOp>(loc, padValue, initTensor.getResult());
+ rewriter.create<linalg::FillOp>(loc, padValue, initTensor.getResult());
// Construct a slice of the fill result that is to be replaced with the
// result of the generic op. The low pad values are the offsets, the size of
@@ -107,7 +108,8 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
loc, fillTensor.getResult(0), offsets, sizes, strides);
// Clone the generic op.
- auto clonedOp = cast<GenericOp>(rewriter.clone(*linalgOp.getOperation()));
+ auto clonedOp =
+ cast<linalg::GenericOp>(rewriter.clone(*linalgOp.getOperation()));
clonedOp.setOutputOperand(resultNumber, slice.getResult());
// Insert it back into the result of the fill.
@@ -119,7 +121,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
};
} // namespace
-void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns(
+void mlir::linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(
RewritePatternSet &patterns) {
patterns.add<FusePadOp>(patterns.getContext());
}
diff --git a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
index 0c3f51d309635..b811e3c386c73 100644
--- a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
@@ -34,7 +34,7 @@ struct TestPadFusionPass
MLIRContext *context = &getContext();
FuncOp funcOp = getOperation();
RewritePatternSet patterns(context);
- linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns);
+ linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
return signalPassFailure();
More information about the Mlir-commits
mailing list