[Mlir-commits] [mlir] 27cc31b - [mlir][vector] NFC - Clean up vector patterns and propagate benefit through populate functions
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Sep 9 02:45:31 PDT 2022
Author: Nicolas Vasilache
Date: 2022-09-09T02:45:22-07:00
New Revision: 27cc31b64c0491725aa88a6822f0f2a2c18914d7
URL: https://github.com/llvm/llvm-project/commit/27cc31b64c0491725aa88a6822f0f2a2c18914d7
DIFF: https://github.com/llvm/llvm-project/commit/27cc31b64c0491725aa88a6822f0f2a2c18914d7.diff
LOG: [mlir][vector] NFC - Clean up vector patterns and propagate benefit through populate functions
Differential Revision: https://reviews.llvm.org/D133559
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index cd92dfd1e2217..f4316d0615609 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -62,11 +62,12 @@ isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims = nullptr);
/// Collect a set of vector-to-vector canonicalization patterns.
-void populateVectorToVectorCanonicalizationPatterns(
- RewritePatternSet &patterns);
+void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collect a set of vector.shape_cast folding patterns.
-void populateShapeCastFoldingPatterns(RewritePatternSet &patterns);
+void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collect a set of leading one dimension removal patterns.
///
@@ -74,14 +75,16 @@ void populateShapeCastFoldingPatterns(RewritePatternSet &patterns);
/// to expose more canonical forms of read/write/insert/extract operations.
/// With them, there are more chances that we can cancel out extract-insert
/// pairs or forward write-read pairs.
-void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns);
+void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collect a set of one dimension removal patterns.
///
/// These patterns insert rank-reducing memref.subview ops to remove one
/// dimensions. With them, there are more chances that we can avoid
/// potentially exensive vector.shape_cast operations.
-void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns);
+void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collect a set of patterns to flatten n-D vector transfers on contiguous
/// memref.
@@ -89,14 +92,16 @@ void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns);
/// These patterns insert memref.collapse_shape + vector.shape_cast patterns
/// to transform multiple small n-D transfers into a larger 1-D transfer where
/// the memref contiguity properties allow it.
-void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns);
+void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collect a set of patterns that bubble up/down bitcast ops.
///
/// These patterns move vector.bitcast ops to be before insert ops or after
/// extract ops where suitable. With them, bitcast will happen on smaller
/// vectors and there are more chances to share extract/insert ops.
-void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns);
+void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collect a set of transfer read/write lowering patterns.
///
@@ -106,28 +111,34 @@ void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns);
/// VectorToSCF, which reduces the rank of vector transfer ops.
void populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns,
- llvm::Optional<unsigned> maxTransferRank = llvm::None);
+ llvm::Optional<unsigned> maxTransferRank = llvm::None,
+ PatternBenefit benefit = 1);
/// These patterns materialize masks for various vector ops such as transfers.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
- bool force32BitVectorIndices);
+ bool force32BitVectorIndices,
+ PatternBenefit benefit = 1);
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
/// chain.
-void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
+void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collects patterns to progressively lower vector.broadcast ops on high-D
/// vectors to low-D vector ops.
-void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
+void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collects patterns to progressively lower vector mask ops into elementary
/// selection and insertion ops.
-void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
+void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collects patterns to progressively lower vector.shape_cast ops on high-D
/// vectors into 1-D/2-D vector ops by generating data movement extract/insert
/// ops.
-void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
+void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index 9174d0e2ab53e..204b322e2deae 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -37,7 +37,8 @@ struct WarpExecuteOnLane0LoweringOptions {
void populateWarpExecuteOnLane0OpToScfForPattern(
RewritePatternSet &patterns,
- const WarpExecuteOnLane0LoweringOptions &options);
+ const WarpExecuteOnLane0LoweringOptions &options,
+ PatternBenefit benefit = 1);
using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
@@ -59,7 +60,8 @@ using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
/// }
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
void populateDistributeTransferWriteOpPatterns(
- RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn);
+ RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
+ PatternBenefit benefit = 1);
/// Move scalar operations with no dependency on the warp op outside of the
/// region.
@@ -67,7 +69,7 @@ void moveScalarUniformCode(WarpExecuteOnLane0Op op);
/// Collect patterns to propagate warp distribution.
void populatePropagateWarpVectorDistributionPatterns(
- RewritePatternSet &pattern);
+ RewritePatternSet &pattern, PatternBenefit benefit = 1);
/// Lambda signature to compute a reduction of a distributed value for the given
/// reduction kind and size.
@@ -78,7 +80,8 @@ using DistributedReductionFn =
/// distribute reduction op.
void populateDistributeReduction(
RewritePatternSet &pattern,
- const DistributedReductionFn &distributedReductionFn);
+ const DistributedReductionFn &distributedReductionFn,
+ PatternBenefit benefit = 1);
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ba4f6b3788c32..e7169b60285a9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -150,7 +150,8 @@ struct UnrollVectorOptions {
/// Insert TransposeLowering patterns into extraction/insertion.
void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
- VectorTransformsOptions options = VectorTransformsOptions());
+ VectorTransformsOptions options = VectorTransformsOptions(),
+ PatternBenefit benefit = 1);
/// Collect a set of patterns to convert vector.multi_reduction op into
/// a sequence of vector.reduction ops. The patterns comprise:
@@ -175,20 +176,24 @@ void populateVectorTransposeLoweringPatterns(
/// the other patterns can kick in, thus fully exiting out of the
/// vector.multi_reduction abstraction.
void populateVectorMultiReductionLoweringPatterns(
- RewritePatternSet &patterns, VectorMultiReductionLowering options);
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
/// Collects patterns to progressively lower vector contraction ops on high-D
/// into low-D reduction and product ops.
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns,
- VectorTransformsOptions options = VectorTransformsOptions());
+ VectorTransformsOptions options = VectorTransformsOptions(),
+ PatternBenefit benefit = 1);
/// Collect patterns to convert reduction op to vector.contract and fold
/// transpose/broadcast ops into the contract.
-void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
+void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Collect patterns to convert scan op
-void populateVectorScanLoweringPatterns(RewritePatternSet &patterns);
+void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
//===----------------------------------------------------------------------===//
// Vector.transfer patterns.
@@ -246,14 +251,14 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns);
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
void populateVectorTransferPermutationMapLoweringPatterns(
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Collect a set of patterns to reduce the rank of the operands of vector
/// transfer ops to operate on the largest contigious vector.
/// These patterns are useful when lowering to dialects with 1d vector type
/// such as llvm and it will result fewer memory reads.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Populate `patterns` with the following patterns.
///
@@ -278,7 +283,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
void populateVectorInsertExtractStridedSliceDecompositionPatterns(
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Populate `patterns` with the following patterns.
///
@@ -299,7 +304,7 @@ void populateVectorInsertExtractStridedSliceDecompositionPatterns(
/// =========================================
/// For such cases, we can lower it to a ShuffleOp.
void populateVectorInsertExtractStridedSliceTransforms(
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Collect a set of pattern to unroll vector operations to a smaller shapes.
/// `options` structure controls which operations are unrolled and the target
@@ -332,7 +337,8 @@ void populateVectorInsertExtractStridedSliceTransforms(
/// Other local patterns then kick in iteratively (including DCE) and compose
/// to combine the ExtractStridedSlice/InsertStridedSlice.
void populateVectorUnrollPatterns(RewritePatternSet &patterns,
- const UnrollVectorOptions &options);
+ const UnrollVectorOptions &options,
+ PatternBenefit benefit = 1);
//===----------------------------------------------------------------------===//
// Finer-grained patterns exposed for more control over individual lowerings.
@@ -377,7 +383,8 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
class ContractionOpToMatmulOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
+
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -387,8 +394,9 @@ class ContractionOpToMatmulOpLowering
ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context),
+ MLIRContext *context, PatternBenefit benefit = 1,
+ FilterConstraintType constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
@@ -419,7 +427,8 @@ class ContractionOpToMatmulOpLowering
class ContractionOpToOuterProductOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
+
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -429,8 +438,9 @@ class ContractionOpToOuterProductOpLowering
ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context),
+ MLIRContext *context, PatternBenefit benefit = 1,
+ FilterConstraintType constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
@@ -464,7 +474,8 @@ class ContractionOpToOuterProductOpLowering
class ContractionOpToDotLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
+
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -474,9 +485,9 @@ class ContractionOpToDotLowering
ContractionOpToDotLowering(
vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context,
+ MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context),
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
@@ -504,7 +515,7 @@ class ContractionOpToDotLowering
/// to Dot or when other contraction patterns fail.
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -513,9 +524,9 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
}
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context,
+ MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context),
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79c3719a8337b..574b4b977961a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1464,7 +1464,7 @@ namespace {
// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
public:
- using OpRewritePattern<ExtractOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
@@ -1494,7 +1494,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
public:
- using OpRewritePattern<ExtractOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
@@ -1681,7 +1681,7 @@ namespace {
// Fold broadcast1(broadcast2(x)) into broadcast1(x).
struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
- using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
@@ -1828,7 +1828,7 @@ namespace {
// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
// to a broadcast.
struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
- using OpRewritePattern<ShuffleOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
PatternRewriter &rewriter) const override {
@@ -1852,7 +1852,7 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
public:
- using OpRewritePattern<ShuffleOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ShuffleOp op,
PatternRewriter &rewriter) const override {
@@ -1979,7 +1979,7 @@ namespace {
// broadcast.
class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
public:
- using OpRewritePattern<InsertOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOp insertOp,
PatternRewriter &rewriter) const override {
@@ -1996,7 +1996,7 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
public:
- using OpRewritePattern<InsertOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
@@ -2202,7 +2202,7 @@ namespace {
class FoldInsertStridedSliceSplat final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
@@ -2227,7 +2227,7 @@ class FoldInsertStridedSliceSplat final
class FoldInsertStridedSliceOfExtract final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
@@ -2587,7 +2587,7 @@ namespace {
class StridedSliceConstantMaskFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
PatternRewriter &rewriter) const override {
@@ -2640,7 +2640,7 @@ class StridedSliceConstantMaskFolder final
class StridedSliceConstantFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
PatternRewriter &rewriter) const override {
@@ -2666,7 +2666,7 @@ class StridedSliceConstantFolder final
class StridedSliceBroadcast final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -2709,7 +2709,7 @@ class StridedSliceBroadcast final
/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -3182,7 +3182,7 @@ namespace {
struct FoldExtractSliceIntoTransferRead
: public OpRewritePattern<TransferReadOp> {
public:
- using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
@@ -3279,7 +3279,7 @@ struct FoldExtractSliceIntoTransferRead
/// ```
struct TransferReadAfterWriteToBroadcast
: public OpRewritePattern<TransferReadOp> {
- using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
@@ -3628,7 +3628,7 @@ namespace {
/// any other uses.
class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
public:
- using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
if (!writeOp.getShapedType().isa<RankedTensorType>())
@@ -3674,7 +3674,7 @@ class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
struct FoldInsertSliceIntoTransferWrite
: public OpRewritePattern<tensor::InsertSliceOp> {
public:
- using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
PatternRewriter &rewriter) const override {
@@ -3768,7 +3768,7 @@ struct FoldInsertSliceIntoTransferWrite
struct SwapExtractSliceOfTransferWrite
: public OpRewritePattern<tensor::InsertSliceOp> {
public:
- using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
PatternRewriter &rewriter) const override {
@@ -3947,7 +3947,7 @@ LogicalResult MaskedLoadOp::verify() {
namespace {
class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
public:
- using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedLoadOp load,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(load.getMask())) {
@@ -3998,7 +3998,7 @@ LogicalResult MaskedStoreOp::verify() {
namespace {
class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
public:
- using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedStoreOp store,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(store.getMask())) {
@@ -4056,7 +4056,7 @@ LogicalResult GatherOp::verify() {
namespace {
class GatherFolder final : public OpRewritePattern<GatherOp> {
public:
- using OpRewritePattern<GatherOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(gather.getMask())) {
@@ -4102,7 +4102,7 @@ LogicalResult ScatterOp::verify() {
namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
public:
- using OpRewritePattern<ScatterOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp scatter,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(scatter.getMask())) {
@@ -4148,7 +4148,7 @@ LogicalResult ExpandLoadOp::verify() {
namespace {
class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
public:
- using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandLoadOp expand,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(expand.getMask())) {
@@ -4193,7 +4193,7 @@ LogicalResult CompressStoreOp::verify() {
namespace {
class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
public:
- using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CompressStoreOp compress,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(compress.getMask())) {
@@ -4333,7 +4333,7 @@ namespace {
// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
public:
- using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
@@ -4359,7 +4359,7 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
/// enough to capture the result in a single op).
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
- using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
@@ -4589,7 +4589,7 @@ namespace {
// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
@@ -4651,7 +4651,7 @@ struct FoldTransposedScalarBroadcast final
// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
public:
- using OpRewritePattern<TransposeOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
@@ -4751,7 +4751,7 @@ namespace {
// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
public:
- using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
@@ -4850,12 +4850,12 @@ LogicalResult ScanOp::verify() {
}
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
StridedSliceConstantMaskFolder, TransposeFolder>(
- patterns.getContext());
+ patterns.getContext(), benefit);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index a7667098125ea..f356248f18311 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1056,26 +1056,30 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
RewritePatternSet &patterns,
- const WarpExecuteOnLane0LoweringOptions &options) {
- patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
+ const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
+ patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options, benefit);
}
void mlir::vector::populateDistributeTransferWriteOpPatterns(
- RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) {
- patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
+ RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
+ PatternBenefit benefit) {
+ patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
+ benefit);
}
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpScfForOp, WarpOpConstant>(patterns.getContext());
+ WarpOpScfForOp, WarpOpConstant>(patterns.getContext(), benefit);
}
void mlir::vector::populateDistributeReduction(
RewritePatternSet &patterns,
- const DistributedReductionFn &distributedReductionFn) {
- patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn);
+ const DistributedReductionFn &distributedReductionFn,
+ PatternBenefit benefit) {
+ patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
+ benefit);
}
void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 07782eb48826e..47f00fa7ec495 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -401,8 +401,9 @@ struct CastAwayContractionLeadingOneDim
class CastAwayElementwiseLeadingOneDim : public RewritePattern {
public:
- CastAwayElementwiseLeadingOneDim(MLIRContext *context)
- : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+ CastAwayElementwiseLeadingOneDim(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
@@ -436,12 +437,12 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
} // namespace
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<CastAwayExtractStridedSliceLeadingOneDim,
CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
- CastAwayContractionLeadingOneDim>(patterns.getContext());
- populateShapeCastFoldingPatterns(patterns);
+ CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
+ populateShapeCastFoldingPatterns(patterns, benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 224c7fbc01df4..9582a9d525008 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -286,15 +286,17 @@ class DecomposeNDExtractStridedSlice
};
void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DecomposeDifferentRankInsertStridedSlice,
- DecomposeNDExtractStridedSlice>(patterns.getContext());
+ DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
}
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
- RewritePatternSet &patterns) {
- populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns);
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
+ benefit);
patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
- Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext());
+ Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
+ benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index 2582781aaab08..f8c8f9e74fbc3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -29,11 +29,12 @@ using namespace mlir;
class InnerOuterDimReductionConversion
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
- using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
explicit InnerOuterDimReductionConversion(
- MLIRContext *context, vector::VectorMultiReductionLowering options)
- : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
+ MLIRContext *context, vector::VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1)
+ : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
useInnerDimsForReduction(
options == vector::VectorMultiReductionLowering::InnerReduction) {}
@@ -101,11 +102,12 @@ class InnerOuterDimReductionConversion
class ReduceMultiDimReductionRank
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
- using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
explicit ReduceMultiDimReductionRank(
- MLIRContext *context, vector::VectorMultiReductionLowering options)
- : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
+ MLIRContext *context, vector::VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1)
+ : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
useInnerDimsForReduction(
options == vector::VectorMultiReductionLowering::InnerReduction) {}
@@ -224,7 +226,7 @@ class ReduceMultiDimReductionRank
/// and combines results
struct TwoDimMultiReductionToElementWise
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
@@ -261,7 +263,7 @@ struct TwoDimMultiReductionToElementWise
/// a sequence of vector.reduction ops.
struct TwoDimMultiReductionToReduction
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
@@ -301,7 +303,7 @@ struct TwoDimMultiReductionToReduction
/// separately.
struct OneDimMultiReductionToTwoDim
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
@@ -338,12 +340,15 @@ struct OneDimMultiReductionToTwoDim
};
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
- RewritePatternSet &patterns, VectorMultiReductionLowering options) {
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
- patterns.getContext(), options);
- patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
+ patterns.getContext(), options, benefit);
+ patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
if (options == VectorMultiReductionLowering ::InnerReduction)
- patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
+ patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
+ benefit);
else
- patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
+ patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
+ benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 9125aae4ccb9b..5fe393b48b10f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -257,7 +257,7 @@ static bool isZero(Value v) {
/// inserting a memref.subview dropping those unit dims.
class TransferReadDropUnitDimsPattern
: public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
@@ -300,7 +300,7 @@ class TransferReadDropUnitDimsPattern
/// unit dims, by inserting a memref.subview dropping those unit dims.
class TransferWriteDropUnitDimsPattern
: public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
@@ -412,7 +412,7 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
/// already reduced i.e. without unit dims.
class FlattenContiguousRowMajorTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
@@ -470,7 +470,7 @@ class FlattenContiguousRowMajorTransferReadPattern
/// already reduced i.e. without unit dims.
class FlattenContiguousRowMajorTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
@@ -543,17 +543,17 @@ void mlir::vector::transferOpflowOpt(Operation *rootOp) {
}
void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
- patterns.getContext());
+ patterns.getContext(), benefit);
populateShapeCastFoldingPatterns(patterns);
}
void mlir::vector::populateFlattenVectorTransferPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<FlattenContiguousRowMajorTransferReadPattern,
FlattenContiguousRowMajorTransferWritePattern>(
- patterns.getContext());
- populateShapeCastFoldingPatterns(patterns);
+ patterns.getContext(), benefit);
+ populateShapeCastFoldingPatterns(patterns, benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
index e419fc2db87a0..de72c6d2cfeac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
@@ -53,7 +53,7 @@ transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
/// vector.transfer_read to do the transpose in memory instead.
struct TransferReadPermutationLowering
: public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
@@ -142,7 +142,7 @@ struct TransferReadPermutationLowering
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
struct TransferWritePermutationLowering
: public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
PatternRewriter &rewriter) const override {
@@ -201,7 +201,7 @@ struct TransferWritePermutationLowering
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
@@ -271,8 +271,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
};
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<TransferReadPermutationLowering,
TransferWritePermutationLowering, TransferOpReduceRank>(
- patterns.getContext());
+ patterns.getContext(), benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8c3d3e9bc6b15..3ff045031be10 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -216,7 +216,7 @@ namespace {
// %1 = user %0 : vector<5x4x2xf32>
//
struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
@@ -250,7 +250,7 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
/// Progressive lowering of BroadcastOp.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
- using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
@@ -381,11 +381,11 @@ void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
/// %x = vector.insert .., .. [.., ..]
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context)
- : OpRewritePattern<vector::TransposeOp>(context),
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
@@ -470,12 +470,12 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
class TransposeOp2DToShuffleLowering
: public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
TransposeOp2DToShuffleLowering(
vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context)
- : OpRewritePattern<vector::TransposeOp>(context),
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
@@ -534,7 +534,7 @@ class TransposeOp2DToShuffleLowering
///
class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
public:
- using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
@@ -593,9 +593,9 @@ struct ContractOpToElementwise
}
ContractOpToElementwise(
vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context,
+ MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context),
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
@@ -715,7 +715,7 @@ struct ContractOpToElementwise
/// will be folded at LLVM IR level.
class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
public:
- using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
PatternRewriter &rewriter) const override {
@@ -789,7 +789,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
/// until a one-dimensional vector is reached.
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
public:
- using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
@@ -835,7 +835,7 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
class ShapeCastOp2DDownCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
@@ -868,7 +868,7 @@ class ShapeCastOp2DDownCastRewritePattern
class ShapeCastOp2DUpCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
@@ -900,7 +900,7 @@ class ShapeCastOp2DUpCastRewritePattern
// into the right place if we get here.
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
public:
- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
@@ -974,7 +974,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
/// ```
struct MultiReduceToContract
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
PatternRewriter &rewriter) const override {
@@ -1030,7 +1030,7 @@ struct MultiReduceToContract
/// ```
struct CombineContractTranspose
: public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
@@ -1087,7 +1087,7 @@ struct CombineContractTranspose
/// ```
struct CombineContractBroadcast
: public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
@@ -2036,8 +2036,9 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
struct TransferReadToVectorLoadLowering
: public OpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLowering(MLIRContext *context,
- llvm::Optional<unsigned> maxRank)
- : OpRewritePattern<vector::TransferReadOp>(context),
+ llvm::Optional<unsigned> maxRank,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransferReadOp>(context, benefit),
maxTransferRank(maxRank) {}
LogicalResult matchAndRewrite(vector::TransferReadOp read,
@@ -2124,7 +2125,7 @@ struct TransferReadToVectorLoadLowering
// trivial case (for architectures for which this matters).
struct VectorLoadToMemrefLoadLowering
: public OpRewritePattern<vector::LoadOp> {
- using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const override {
@@ -2142,7 +2143,7 @@ struct VectorLoadToMemrefLoadLowering
/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
struct VectorStoreToMemrefStoreLowering
: public OpRewritePattern<vector::StoreOp> {
- using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const override {
@@ -2177,8 +2178,9 @@ struct VectorStoreToMemrefStoreLowering
struct TransferWriteToVectorStoreLowering
: public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLowering(MLIRContext *context,
- llvm::Optional<unsigned> maxRank)
- : OpRewritePattern<vector::TransferWriteOp>(context),
+ llvm::Optional<unsigned> maxRank,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
maxTransferRank(maxRank) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
@@ -2415,6 +2417,7 @@ struct BubbleDownBitCastForStridedSliceExtract
struct BubbleUpBitCastForStridedSliceInsert
: public OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
PatternRewriter &rewriter) const override {
VectorType castSrcType = bitcastOp.getSourceVectorType();
@@ -2530,8 +2533,9 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
template <typename ConcreteOp>
struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
public:
- explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
- : mlir::OpRewritePattern<ConcreteOp>(context),
+ explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
+ PatternBenefit benefit = 1)
+ : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
force32BitVectorIndices(enableIndexOpt) {}
LogicalResult matchAndRewrite(ConcreteOp xferOp,
@@ -2583,8 +2587,9 @@ class VectorCreateMaskOpConversion
: public OpRewritePattern<vector::CreateMaskOp> {
public:
explicit VectorCreateMaskOpConversion(MLIRContext *context,
- bool enableIndexOpt)
- : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
+ bool enableIndexOpt,
+ PatternBenefit benefit = 1)
+ : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
force32BitVectorIndices(enableIndexOpt) {}
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
@@ -2608,7 +2613,7 @@ class VectorCreateMaskOpConversion
// Drop inner most contiguous unit dimensions from transfer_read operand.
class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
@@ -2815,7 +2820,7 @@ static Value genOperator(Location loc, Value x, Value y,
/// vector<2x3xi32>, vector<2xi32>
/// ```
struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
- using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ScanOp scanOp,
PatternRewriter &rewriter) const override {
@@ -2896,81 +2901,87 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
} // namespace
void mlir::vector::populateVectorMaskMaterializationPatterns(
- RewritePatternSet &patterns, bool force32BitVectorIndices) {
+ RewritePatternSet &patterns, bool force32BitVectorIndices,
+ PatternBenefit benefit) {
patterns.add<VectorCreateMaskOpConversion,
MaterializeTransferMask<vector::TransferReadOp>,
MaterializeTransferMask<vector::TransferWriteOp>>(
- patterns.getContext(), force32BitVectorIndices);
+ patterns.getContext(), force32BitVectorIndices, benefit);
}
-void mlir::vector::populateShapeCastFoldingPatterns(
- RewritePatternSet &patterns) {
- patterns.add<ShapeCastOpFolder>(patterns.getContext());
+void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<BubbleDownVectorBitCastForExtract,
BubbleDownBitCastForStridedSliceExtract,
- BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
+ BubbleUpBitCastForStridedSliceInsert>(patterns.getContext(),
+ benefit);
}
void mlir::vector::populateVectorBroadcastLoweringPatterns(
- RewritePatternSet &patterns) {
- patterns.add<BroadcastOpLowering>(patterns.getContext());
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
}
void mlir::vector::populateVectorMaskOpLoweringPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
- patterns.getContext());
+ patterns.getContext(), benefit);
}
void mlir::vector::populateVectorShapeCastLoweringPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
- patterns.getContext());
+ patterns.getContext(), benefit);
}
void mlir::vector::populateVectorContractLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions options) {
- patterns.add<OuterProductOpLowering>(patterns.getContext());
+ RewritePatternSet &patterns, VectorTransformsOptions options,
+ PatternBenefit benefit) {
+ patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
- ContractionOpToOuterProductOpLowering>(options,
- patterns.getContext());
+ ContractionOpToOuterProductOpLowering>(
+ options, patterns.getContext(), benefit);
}
void mlir::vector::populateVectorTransposeLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions options) {
+ RewritePatternSet &patterns, VectorTransformsOptions options,
+ PatternBenefit benefit) {
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
- options, patterns.getContext());
+ options, patterns.getContext(), benefit);
}
void mlir::vector::populateVectorReductionToContractPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
CombineContractTranspose, ReorderCastOpsOnBroadcast,
- ReorderElementwiseOpsOnTranspose>(patterns.getContext());
+ ReorderElementwiseOpsOnTranspose>(patterns.getContext(),
+ benefit);
}
void mlir::vector::
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
- RewritePatternSet &patterns) {
- patterns.add<DropInnerMostUnitDims>(patterns.getContext());
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
}
void mlir::vector::populateVectorTransferLoweringPatterns(
- RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
+ RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank,
+ PatternBenefit benefit) {
patterns.add<TransferReadToVectorLoadLowering,
TransferWriteToVectorStoreLowering>(patterns.getContext(),
- maxTransferRank);
+ maxTransferRank, benefit);
patterns
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
- patterns.getContext());
+ patterns.getContext(), benefit);
}
void mlir::vector::populateVectorScanLoweringPatterns(
- RewritePatternSet &patterns) {
- patterns.add<ScanToArithOps>(patterns.getContext());
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index a7461456682c3..dc314aa2141f4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -181,9 +181,11 @@ namespace {
struct UnrollTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
UnrollTransferReadPattern(MLIRContext *context,
- const vector::UnrollVectorOptions &options)
- : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransferReadOp>(context, benefit),
options(options) {}
+
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
@@ -236,9 +238,11 @@ struct UnrollTransferReadPattern
struct UnrollTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
UnrollTransferWritePattern(MLIRContext *context,
- const vector::UnrollVectorOptions &options)
- : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
options(options) {}
+
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
@@ -306,8 +310,9 @@ struct OffsetMapInfo {
struct UnrollContractionPattern
: public OpRewritePattern<vector::ContractionOp> {
UnrollContractionPattern(MLIRContext *context,
- const vector::UnrollVectorOptions &options)
- : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
@@ -408,8 +413,9 @@ struct UnrollContractionPattern
struct UnrollMultiReductionPattern
: public OpRewritePattern<vector::MultiDimReductionOp> {
UnrollMultiReductionPattern(MLIRContext *context,
- const vector::UnrollVectorOptions &options)
- : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
@@ -481,9 +487,11 @@ struct UnrollMultiReductionPattern
struct UnrollElementwisePattern : public RewritePattern {
UnrollElementwisePattern(MLIRContext *context,
- const vector::UnrollVectorOptions &options)
- : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
options(options) {}
+
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
@@ -539,7 +547,8 @@ struct UnrollElementwisePattern : public RewritePattern {
/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
/// %dv = arith.addf %da, %db : vector<1xf32>
struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
- using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
PatternRewriter &rewriter) const override {
Operation *definedOp = extract.getVector().getDefiningOp();
@@ -570,7 +579,8 @@ struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
/// Canonicalize an extract_map using the result of a contract operation.
/// This propagate the extract_map to operands.
struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
- using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
+ using OpRewritePattern::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
PatternRewriter &rewriter) const override {
Operation *definedOp = extract.getVector().getDefiningOp();
@@ -631,8 +641,8 @@ struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
/// ```
struct TransferReadExtractPattern
: public OpRewritePattern<vector::TransferReadOp> {
- TransferReadExtractPattern(MLIRContext *context)
- : OpRewritePattern<vector::TransferReadOp>(context) {}
+ using OpRewritePattern::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
@@ -682,8 +692,8 @@ struct TransferReadExtractPattern
struct TransferWriteInsertPattern
: public OpRewritePattern<vector::TransferWriteOp> {
- TransferWriteInsertPattern(MLIRContext *context)
- : OpRewritePattern<vector::TransferWriteOp>(context) {}
+ using OpRewritePattern::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
@@ -726,8 +736,9 @@ struct TransferWriteInsertPattern
struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
UnrollReductionPattern(MLIRContext *context,
- const vector::UnrollVectorOptions &options)
- : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ReductionOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
@@ -772,9 +783,11 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
UnrollTranposePattern(MLIRContext *context,
- const vector::UnrollVectorOptions &options)
- : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit),
options(options) {}
+
LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
PatternRewriter &rewriter) const override {
if (tranposeOp.getResultType().getRank() == 0)
@@ -821,16 +834,17 @@ struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
- RewritePatternSet &patterns, const UnrollVectorOptions &options) {
+ RewritePatternSet &patterns, const UnrollVectorOptions &options,
+ PatternBenefit benefit) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
- UnrollTranposePattern>(patterns.getContext(), options);
+ UnrollTranposePattern>(patterns.getContext(), options, benefit);
}
void mlir::vector::populatePropagateVectorDistributionPatterns(
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<PointwiseExtractPattern, ContractExtractPattern,
TransferReadExtractPattern, TransferWriteInsertPattern>(
- patterns.getContext());
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 20ce2471cf17b..eec8de0094477 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -160,7 +160,7 @@ struct TestVectorContractionLowering
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.add<ContractionOpToOuterProductOpLowering>(
- options, &getContext(), [](vector::ContractionOp op) {
+ options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) {
// Only lowers vector.contract where the lhs as a type vector<MxNx?>
// where M is not 4.
if (op.getRhsType().getShape()[0] == 4)
More information about the Mlir-commits
mailing list