[Mlir-commits] [mlir] ea069ae - [mlir][Linalg] NFC: Move populatePatterns* method into linalg namespace.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 5 11:16:19 PDT 2021
Author: MaheshRavishankar
Date: 2021-04-05T11:16:02-07:00
New Revision: ea069aebccd317f350be3cabdcd848476616d4da
URL: https://github.com/llvm/llvm-project/commit/ea069aebccd317f350be3cabdcd848476616d4da
DIFF: https://github.com/llvm/llvm-project/commit/ea069aebccd317f350be3cabdcd848476616d4da.diff
LOG: [mlir][Linalg] NFC: Move populatePatterns* method into linalg namespace.
The moved `populate` methods are only relevant to Linalg
operations. So they are better of in `linalg` namespace. Also rename
`populateLinalgTensorOpsFusionPatterns` to
`populateElementwiseOpsFusionPatterns`. This makes the scope of these
patterns explicit and disambiguates it with fusion on tensors using
tile + fuse.
Differential Revision: https://reviews.llvm.org/D99819
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 18820d4316b91..17fcf08dba96b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -50,10 +50,6 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
/// buffers instead.
std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
-/// Populate patterns that convert `ElementwiseMappable` ops to linalg
-/// parallel loops.
-void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
-
/// Create a pass to conver named Linalg operations to Linalg generic
/// operations.
std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
@@ -62,35 +58,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
/// work on primitive types, if possible.
std::unique_ptr<Pass> createLinalgDetensorizePass();
-/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
-/// producer (consumer) generic operation by expanding the dimensionality of the
-/// loop in the generic op.
-void populateFoldReshapeOpsByExpansionPatterns(
- RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
-
-/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic/indexed_generic operation by linearizing the
-/// indexing map used to access the source (target) of the reshape operation in
-/// the generic/indexed_generic operation.
-void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
-
-/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic/indexed_generic operation by linearizing the
-/// indexing map used to access the source (target) of the reshape operation in
-/// the generic/indexed_generic operation. The patterns are applied only when
-/// the tensor reshape involved is collapsing (introducing) unit-extent
-/// dimensions.
-void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
- RewritePatternSet &patterns);
-
-/// Patterns for fusing linalg operation on tensors.
-void populateLinalgTensorOpsFusionPatterns(
- RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
-
-/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
-/// tensors.
-void populateLinalgFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
-
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 21e6cba9dc3ce..8ce5677762695 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -36,10 +36,43 @@ void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes);
+/// Populate patterns that convert `ElementwiseMappable` ops to linalg
+/// parallel loops.
+void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
+
+/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
+/// producer (consumer) generic operation by expanding the dimensionality of the
+/// loop in the generic op.
+void populateFoldReshapeOpsByExpansionPatterns(
+ RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
+
+/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
+/// producer (consumer) generic/indexed_generic operation by linearizing the
+/// indexing map used to access the source (target) of the reshape operation in
+/// the generic/indexed_generic operation.
+void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
+
+/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
+/// producer (consumer) generic/indexed_generic operation by linearizing the
+/// indexing map used to access the source (target) of the reshape operation in
+/// the generic/indexed_generic operation. The patterns are applied only when
+/// the tensor reshape involved is collapsing (introducing) unit-extent
+/// dimensions.
+void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
+ RewritePatternSet &patterns);
+
/// Populates the given list with patterns to bufferize linalg ops.
void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
RewritePatternSet &patterns);
+/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
+/// tensors.
+void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
+
+/// Patterns for fusing linalg operation on tensors.
+void populateElementwiseOpsFusionPatterns(
+ RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
+
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 33efeddadc9e5..4ee63a92875f3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -136,11 +136,6 @@ Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
OpResult producerOpResult,
OpOperand &consumerOpOperand);
-/// Fuse linalg operation on tensors, with the producer of the operand at
-/// position `consumerIdx` of the consumer.
-Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
- OpOperand &consumerOpOperand);
-
//===----------------------------------------------------------------------===//
// Distribution utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index d5f08056d551f..7aefefd642ead 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
@@ -556,7 +557,7 @@ struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
-void mlir::populateLinalgFoldUnitExtentDimsPatterns(
+void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
@@ -580,7 +581,7 @@ struct LinalgFoldUnitExtentDimsPass
.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
context);
else
- populateLinalgFoldUnitExtentDimsPatterns(patterns);
+ populateFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 86b7eafa4ecce..1daaafa8a3643 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -10,6 +10,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
@@ -115,7 +116,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
};
} // namespace
-void mlir::populateElementwiseToLinalgConversionPatterns(
+void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
patterns.getContext());
@@ -131,7 +132,7 @@ class ConvertElementwiseToLinalgPass
ConversionTarget target(*context);
RewritePatternSet patterns(context);
- populateElementwiseToLinalgConversionPatterns(patterns);
+ mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return !isElementwiseMappableOpOnRankedTensors(op);
});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 71f3bef56969e..91848edee6f01 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -26,8 +26,8 @@ using namespace mlir;
using namespace mlir::linalg;
/// Implementation of fusion of generic ops and indexed_generic ops.
-static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
- unsigned consumerIdx) {
+static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
+ unsigned consumerIdx) {
// Producer and consumer must have tensor semantics.
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
return false;
@@ -91,11 +91,11 @@ static void getIndexingMapOfProducerOperandsInFusedOp(
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
-static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
- Operation *fusedOp, LinalgOp producer,
- LinalgOp consumer,
- AffineMap consumerToProducerLoopsMap,
- unsigned consumerIdx, unsigned nloops) {
+static void
+generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
+ LinalgOp producer, LinalgOp consumer,
+ AffineMap consumerToProducerLoopsMap,
+ unsigned consumerIdx, unsigned nloops) {
// Build the region of the fused op.
Block &producerBlock = producer->getRegion(0).front();
Block &consumerBlock = consumer->getRegion(0).front();
@@ -208,11 +208,11 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
}
static Optional<SmallVector<Value, 1>>
-fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
- PatternRewriter &rewriter) {
+fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
+ PatternRewriter &rewriter) {
LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
unsigned consumerIdx = consumerOpOperand.getOperandNumber();
- if (!areTensorOpsFusable(producer, consumer, consumerIdx))
+ if (!areElementwiseOpsFusable(producer, consumer, consumerIdx))
return llvm::None;
unsigned numFusedOperands =
@@ -291,9 +291,9 @@ fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
AffineMap consumerToProducerLoopsMap =
invProducerResultIndexMap.compose(consumerResultIndexMap);
- generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer,
- consumer, consumerToProducerLoopsMap, consumerIdx,
- consumer.getNumLoops());
+ generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer,
+ consumer, consumerToProducerLoopsMap,
+ consumerIdx, consumer.getNumLoops());
return SmallVector<Value, 1>(fusedOp->getResults());
}
@@ -1102,9 +1102,8 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
};
} // namespace
-Optional<SmallVector<Value, 1>>
-mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
- OpOperand &consumerOpOperand) {
+static Optional<SmallVector<Value, 1>>
+fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
Operation *producer = consumerOpOperand.get().getDefiningOp();
if (!producer || producer->getNumResults() != 1)
return llvm::None;
@@ -1114,14 +1113,14 @@ mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
!isa<GenericOp, IndexedGenericOp>(producer))
return llvm::None;
- return fuseTensorOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
- rewriter);
+ return fuseElementwiseOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
+ rewriter);
}
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
template <typename LinalgOpTy>
-struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
+struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(LinalgOpTy op,
@@ -1133,7 +1132,7 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
if (!producerOp || !producerOp.hasTensorSemantics())
continue;
Optional<SmallVector<Value, 1>> fusedOpResults =
- fuseTensorOps(rewriter, opOperand);
+ fuseElementwiseOps(rewriter, opOperand);
if (fusedOpResults) {
rewriter.replaceOp(op, *fusedOpResults);
return success();
@@ -1149,8 +1148,7 @@ struct FusionOfTensorOpsPass
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
- populateLinalgTensorOpsFusionPatterns(patterns,
- allowFoldingUnitDimReshapes);
+ populateElementwiseOpsFusionPatterns(patterns, allowFoldingUnitDimReshapes);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
@@ -1170,7 +1168,7 @@ struct FoldReshapeOpsByLinearizationPass
} // namespace
-void mlir::populateFoldReshapeOpsByLinearizationPatterns(
+void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>,
FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
@@ -1178,7 +1176,7 @@ void mlir::populateFoldReshapeOpsByLinearizationPatterns(
patterns.getContext());
}
-void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
+void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>,
FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
@@ -1186,7 +1184,7 @@ void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
patterns.getContext());
}
-void mlir::populateFoldReshapeOpsByExpansionPatterns(
+void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
@@ -1194,11 +1192,11 @@ void mlir::populateFoldReshapeOpsByExpansionPatterns(
patterns.getContext(), allowFoldingUnitDimReshapes);
}
-void mlir::populateLinalgTensorOpsFusionPatterns(
+void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
auto *context = patterns.getContext();
patterns
- .add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
+ .add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
context);
populateFoldReshapeOpsByExpansionPatterns(patterns,
More information about the Mlir-commits
mailing list