[Mlir-commits] [mlir] ea069ae - [mlir][Linalg] NFC: Move populatePatterns* method into linalg namespace.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 5 11:16:19 PDT 2021


Author: MaheshRavishankar
Date: 2021-04-05T11:16:02-07:00
New Revision: ea069aebccd317f350be3cabdcd848476616d4da

URL: https://github.com/llvm/llvm-project/commit/ea069aebccd317f350be3cabdcd848476616d4da
DIFF: https://github.com/llvm/llvm-project/commit/ea069aebccd317f350be3cabdcd848476616d4da.diff

LOG: [mlir][Linalg] NFC: Move populatePatterns* method into linalg namespace.

The moved `populate` methods are only relevant to Linalg
operations. So they are better of in `linalg` namespace.  Also rename
`populateLinalgTensorOpsFusionPatterns` to
`populateElementwiseOpsFusionPatterns`. This makes the scope of these
patterns explicit and disambiguates it with fusion on tensors using
tile + fuse.

Differential Revision: https://reviews.llvm.org/D99819

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 18820d4316b91..17fcf08dba96b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -50,10 +50,6 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
 /// buffers instead.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
 
-/// Populate patterns that convert `ElementwiseMappable` ops to linalg
-/// parallel loops.
-void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
-
 /// Create a pass to conver named Linalg operations to Linalg generic
 /// operations.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
@@ -62,35 +58,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
 /// work on primitive types, if possible.
 std::unique_ptr<Pass> createLinalgDetensorizePass();
 
-/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
-/// producer (consumer) generic operation by expanding the dimensionality of the
-/// loop in the generic op.
-void populateFoldReshapeOpsByExpansionPatterns(
-    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
-
-/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic/indexed_generic operation by linearizing the
-/// indexing map used to access the source (target) of the reshape operation in
-/// the generic/indexed_generic operation.
-void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
-
-/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic/indexed_generic operation by linearizing the
-/// indexing map used to access the source (target) of the reshape operation in
-/// the generic/indexed_generic operation. The patterns are applied only when
-/// the tensor reshape involved is collapsing (introducing) unit-extent
-/// dimensions.
-void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
-    RewritePatternSet &patterns);
-
-/// Patterns for fusing linalg operation on tensors.
-void populateLinalgTensorOpsFusionPatterns(
-    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
-
-/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
-/// tensors.
-void populateLinalgFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 21e6cba9dc3ce..8ce5677762695 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -36,10 +36,43 @@ void populateConvVectorizationPatterns(
     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
     ArrayRef<int64_t> tileSizes);
 
+/// Populate patterns that convert `ElementwiseMappable` ops to linalg
+/// parallel loops.
+void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
+
+/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
+/// producer (consumer) generic operation by expanding the dimensionality of the
+/// loop in the generic op.
+void populateFoldReshapeOpsByExpansionPatterns(
+    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
+
+/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
+/// producer (consumer) generic/indexed_generic operation by linearizing the
+/// indexing map used to access the source (target) of the reshape operation in
+/// the generic/indexed_generic operation.
+void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
+
+/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
+/// producer (consumer) generic/indexed_generic operation by linearizing the
+/// indexing map used to access the source (target) of the reshape operation in
+/// the generic/indexed_generic operation. The patterns are applied only when
+/// the tensor reshape involved is collapsing (introducing) unit-extent
+/// dimensions.
+void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
+    RewritePatternSet &patterns);
+
 /// Populates the given list with patterns to bufferize linalg ops.
 void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
                                      RewritePatternSet &patterns);
 
+/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
+/// tensors.
+void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
+
+/// Patterns for fusing linalg operation on tensors.
+void populateElementwiseOpsFusionPatterns(
+    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
+
 /// Performs standalone tiling of a single LinalgOp by `tileSizes`.
 /// and permute the loop nest according to `interchangeVector`
 /// The permutation is expressed as a list of integers that specify

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 33efeddadc9e5..4ee63a92875f3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -136,11 +136,6 @@ Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
                                           OpResult producerOpResult,
                                           OpOperand &consumerOpOperand);
 
-/// Fuse linalg operation on tensors, with the producer of the operand at
-/// position `consumerIdx` of the consumer.
-Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
-                                              OpOperand &consumerOpOperand);
-
 //===----------------------------------------------------------------------===//
 // Distribution utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index d5f08056d551f..7aefefd642ead 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
@@ -556,7 +557,7 @@ struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
 
 /// Patterns that are used to canonicalize the use of unit-extent dims for
 /// broadcasting.
-void mlir::populateLinalgFoldUnitExtentDimsPatterns(
+void mlir::linalg::populateFoldUnitExtentDimsPatterns(
     RewritePatternSet &patterns) {
   auto *context = patterns.getContext();
   patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
@@ -580,7 +581,7 @@ struct LinalgFoldUnitExtentDimsPass
           .add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
               context);
     else
-      populateLinalgFoldUnitExtentDimsPatterns(patterns);
+      populateFoldUnitExtentDimsPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
   }
 };

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 86b7eafa4ecce..1daaafa8a3643 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -10,6 +10,7 @@
 
 #include "PassDetail.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
@@ -115,7 +116,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
 };
 } // namespace
 
-void mlir::populateElementwiseToLinalgConversionPatterns(
+void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
       patterns.getContext());
@@ -131,7 +132,7 @@ class ConvertElementwiseToLinalgPass
     ConversionTarget target(*context);
     RewritePatternSet patterns(context);
 
-    populateElementwiseToLinalgConversionPatterns(patterns);
+    mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
     target.markUnknownOpDynamicallyLegal([](Operation *op) {
       return !isElementwiseMappableOpOnRankedTensors(op);
     });

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 71f3bef56969e..91848edee6f01 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -26,8 +26,8 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 /// Implementation of fusion of generic ops and indexed_generic ops.
-static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
-                                unsigned consumerIdx) {
+static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
+                                     unsigned consumerIdx) {
   // Producer and consumer must have tensor semantics.
   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
     return false;
@@ -91,11 +91,11 @@ static void getIndexingMapOfProducerOperandsInFusedOp(
 
 /// Generate the region of the fused tensor operation. The region of the fused
 /// op must be empty.
-static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
-                                        Operation *fusedOp, LinalgOp producer,
-                                        LinalgOp consumer,
-                                        AffineMap consumerToProducerLoopsMap,
-                                        unsigned consumerIdx, unsigned nloops) {
+static void
+generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
+                                 LinalgOp producer, LinalgOp consumer,
+                                 AffineMap consumerToProducerLoopsMap,
+                                 unsigned consumerIdx, unsigned nloops) {
   // Build the region of the fused op.
   Block &producerBlock = producer->getRegion(0).front();
   Block &consumerBlock = consumer->getRegion(0).front();
@@ -208,11 +208,11 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
 }
 
 static Optional<SmallVector<Value, 1>>
-fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
-                  PatternRewriter &rewriter) {
+fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
+                       PatternRewriter &rewriter) {
   LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
   unsigned consumerIdx = consumerOpOperand.getOperandNumber();
-  if (!areTensorOpsFusable(producer, consumer, consumerIdx))
+  if (!areElementwiseOpsFusable(producer, consumer, consumerIdx))
     return llvm::None;
 
   unsigned numFusedOperands =
@@ -291,9 +291,9 @@ fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
   AffineMap consumerToProducerLoopsMap =
       invProducerResultIndexMap.compose(consumerResultIndexMap);
 
-  generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer,
-                              consumer, consumerToProducerLoopsMap, consumerIdx,
-                              consumer.getNumLoops());
+  generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer,
+                                   consumer, consumerToProducerLoopsMap,
+                                   consumerIdx, consumer.getNumLoops());
   return SmallVector<Value, 1>(fusedOp->getResults());
 }
 
@@ -1102,9 +1102,8 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
 };
 } // namespace
 
-Optional<SmallVector<Value, 1>>
-mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
-                            OpOperand &consumerOpOperand) {
+static Optional<SmallVector<Value, 1>>
+fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
   Operation *producer = consumerOpOperand.get().getDefiningOp();
   if (!producer || producer->getNumResults() != 1)
     return llvm::None;
@@ -1114,14 +1113,14 @@ mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
       !isa<GenericOp, IndexedGenericOp>(producer))
     return llvm::None;
 
-  return fuseTensorOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
-                           rewriter);
+  return fuseElementwiseOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
+                                rewriter);
 }
 
 namespace {
 /// Patterns to fuse a generic op, with the producer of its operands.
 template <typename LinalgOpTy>
-struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
+struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(LinalgOpTy op,
@@ -1133,7 +1132,7 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
       if (!producerOp || !producerOp.hasTensorSemantics())
         continue;
       Optional<SmallVector<Value, 1>> fusedOpResults =
-          fuseTensorOps(rewriter, opOperand);
+          fuseElementwiseOps(rewriter, opOperand);
       if (fusedOpResults) {
         rewriter.replaceOp(op, *fusedOpResults);
         return success();
@@ -1149,8 +1148,7 @@ struct FusionOfTensorOpsPass
   void runOnOperation() override {
     Operation *op = getOperation();
     RewritePatternSet patterns(op->getContext());
-    populateLinalgTensorOpsFusionPatterns(patterns,
-                                          allowFoldingUnitDimReshapes);
+    populateElementwiseOpsFusionPatterns(patterns, allowFoldingUnitDimReshapes);
     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
   }
 };
@@ -1170,7 +1168,7 @@ struct FoldReshapeOpsByLinearizationPass
 
 } // namespace
 
-void mlir::populateFoldReshapeOpsByLinearizationPatterns(
+void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns) {
   patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>,
                FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
@@ -1178,7 +1176,7 @@ void mlir::populateFoldReshapeOpsByLinearizationPatterns(
       patterns.getContext());
 }
 
-void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
+void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns) {
   patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>,
                FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
@@ -1186,7 +1184,7 @@ void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
       patterns.getContext());
 }
 
-void mlir::populateFoldReshapeOpsByExpansionPatterns(
+void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
     RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
   patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
@@ -1194,11 +1192,11 @@ void mlir::populateFoldReshapeOpsByExpansionPatterns(
       patterns.getContext(), allowFoldingUnitDimReshapes);
 }
 
-void mlir::populateLinalgTensorOpsFusionPatterns(
+void mlir::linalg::populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
   auto *context = patterns.getContext();
   patterns
-      .add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
+      .add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
            FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
           context);
   populateFoldReshapeOpsByExpansionPatterns(patterns,


        


More information about the Mlir-commits mailing list