[Mlir-commits] [mlir] 7d249df - [mlir][linalg] NFC: minor cleanups after moving pad to tensor dialect

Lei Zhang llvmlistbot at llvm.org
Thu Mar 3 06:48:03 PST 2022


Author: Lei Zhang
Date: 2022-03-03T09:44:54-05:00
New Revision: 7d249dfd7da55eed6992aba26b3544b69025bcc2

URL: https://github.com/llvm/llvm-project/commit/7d249dfd7da55eed6992aba26b3544b69025bcc2
DIFF: https://github.com/llvm/llvm-project/commit/7d249dfd7da55eed6992aba26b3544b69025bcc2.diff

LOG: [mlir][linalg] NFC: minor cleanups after moving pad to tensor dialect

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D120627

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
    mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 87f5cc17b4221..d4cd9a550b3a9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -103,10 +103,9 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
 void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns);
 
-/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its
-/// source, if the producer is a `linalg` operation with all parallel iterator
-/// types.
-void populateFusePadTensorWithProducerLinalgOpPatterns(
+/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
+/// if the producer is a `linalg` operation with all parallel iterator types.
+void populateFuseTensorPadWithProducerLinalgOpPatterns(
     RewritePatternSet &patterns);
 
 /// Patterns to convert from one named op to another. These can be seen as

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
index 78b5305c8a1ec..1777201e6037d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
@@ -1,4 +1,4 @@
-//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===//
+//===- PadOpInterchange.cpp - Interchange tensor.pad with linalg producer -===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements patterns that intechanges a generic op -> pad_tensor
-// pattern into extract_slice -> generic_op.
+// This file implements patterns that intechanges a linalg.generic -> tensor.pad
+// op chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
+// op chain.
 //
 //===----------------------------------------------------------------------===//
 
@@ -17,7 +18,6 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
-using namespace mlir::linalg;
 
 namespace {
 
@@ -25,7 +25,7 @@ namespace {
 ///
 /// ```mlir
 /// %0 = linalg. ...
-/// %1 = linalg.pad_tensor %0 ...
+/// %1 = tensor.pad %0 ...
 /// ```
 ///
 /// can be replaced with
@@ -40,6 +40,7 @@ namespace {
 /// if the `linalg.generic` has all parallel iterator types.
 struct FusePadOp : OpRewritePattern<tensor::PadOp> {
   using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
   LogicalResult matchAndRewrite(tensor::PadOp padOp,
                                 PatternRewriter &rewriter) const override {
     // Only works on padding op that sets the padded value to a constant.
@@ -50,7 +51,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
     // This pattern could work for any Linalg op. For now restrict it to generic
     // ops.
     Value source = padOp.source();
-    auto linalgOp = source.getDefiningOp<GenericOp>();
+    auto linalgOp = source.getDefiningOp<linalg::GenericOp>();
     if (!linalgOp) {
       return rewriter.notifyMatchFailure(
           padOp, "expected source to be linalg.generic op");
@@ -75,14 +76,14 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
     // Create the tensor of same size as output of the pad op.
     RankedTensorType padResultType = padOp.getResultType();
     auto resultSizes = getAsOpFoldResult(resultShape[0]);
-    auto initTensor = rewriter.create<InitTensorOp>(
+    auto initTensor = rewriter.create<linalg::InitTensorOp>(
         loc, resultSizes, padResultType.getElementType());
 
     // Fill the tensor with the pad value.
     // TODO: There is an option to fill only the boundaries. For now just
     // filling the whole tensor.
     auto fillTensor =
-        rewriter.create<FillOp>(loc, padValue, initTensor.getResult());
+        rewriter.create<linalg::FillOp>(loc, padValue, initTensor.getResult());
 
     // Construct a slice of the fill result that is to be replaced with the
     // result of the generic op. The low pad values are the offsets, the size of
@@ -107,7 +108,8 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
         loc, fillTensor.getResult(0), offsets, sizes, strides);
 
     // Clone the generic op.
-    auto clonedOp = cast<GenericOp>(rewriter.clone(*linalgOp.getOperation()));
+    auto clonedOp =
+        cast<linalg::GenericOp>(rewriter.clone(*linalgOp.getOperation()));
     clonedOp.setOutputOperand(resultNumber, slice.getResult());
 
     // Insert it back into the result of the fill.
@@ -119,7 +121,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
 };
 } // namespace
 
-void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns(
+void mlir::linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(
     RewritePatternSet &patterns) {
   patterns.add<FusePadOp>(patterns.getContext());
 }

diff  --git a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
index 0c3f51d309635..b811e3c386c73 100644
--- a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
@@ -34,7 +34,7 @@ struct TestPadFusionPass
     MLIRContext *context = &getContext();
     FuncOp funcOp = getOperation();
     RewritePatternSet patterns(context);
-    linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns);
+    linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns);
     if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
                                             std::move(patterns))))
       return signalPassFailure();


        


More information about the Mlir-commits mailing list