[Mlir-commits] [mlir] 27cc31b - [mlir][vector] NFC - Clean up vector patterns and propagate benefit through populate functions

Nicolas Vasilache llvmlistbot at llvm.org
Fri Sep 9 02:45:31 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-09T02:45:22-07:00
New Revision: 27cc31b64c0491725aa88a6822f0f2a2c18914d7

URL: https://github.com/llvm/llvm-project/commit/27cc31b64c0491725aa88a6822f0f2a2c18914d7
DIFF: https://github.com/llvm/llvm-project/commit/27cc31b64c0491725aa88a6822f0f2a2c18914d7.diff

LOG: [mlir][vector] NFC - Clean up vector patterns and propagate benefit through populate functions

Differential Revision: https://reviews.llvm.org/D133559

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index cd92dfd1e2217..f4316d0615609 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -62,11 +62,12 @@ isBroadcastableTo(Type srcType, VectorType dstVectorType,
                   std::pair<int, int> *mismatchingDims = nullptr);
 
 /// Collect a set of vector-to-vector canonicalization patterns.
-void populateVectorToVectorCanonicalizationPatterns(
-    RewritePatternSet &patterns);
+void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit = 1);
 
 /// Collect a set of vector.shape_cast folding patterns.
-void populateShapeCastFoldingPatterns(RewritePatternSet &patterns);
+void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
+                                      PatternBenefit benefit = 1);
 
 /// Collect a set of leading one dimension removal patterns.
 ///
@@ -74,14 +75,16 @@ void populateShapeCastFoldingPatterns(RewritePatternSet &patterns);
 /// to expose more canonical forms of read/write/insert/extract operations.
 /// With them, there are more chances that we can cancel out extract-insert
 /// pairs or forward write-read pairs.
-void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns);
+void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
+                                                 PatternBenefit benefit = 1);
 
 /// Collect a set of one dimension removal patterns.
 ///
 /// These patterns insert rank-reducing memref.subview ops to remove one
 /// dimensions. With them, there are more chances that we can avoid
 /// potentially exensive vector.shape_cast operations.
-void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns);
+void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
+                                                PatternBenefit benefit = 1);
 
 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
 /// memref.
@@ -89,14 +92,16 @@ void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns);
 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
 /// to transform multiple small n-D transfers into a larger 1-D transfer where
 /// the memref contiguity properties allow it.
-void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns);
+void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
+                                           PatternBenefit benefit = 1);
 
 /// Collect a set of patterns that bubble up/down bitcast ops.
 ///
 /// These patterns move vector.bitcast ops to be before insert ops or after
 /// extract ops where suitable. With them, bitcast will happen on smaller
 /// vectors and there are more chances to share extract/insert ops.
-void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns);
+void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns,
+                                           PatternBenefit benefit = 1);
 
 /// Collect a set of transfer read/write lowering patterns.
 ///
@@ -106,28 +111,34 @@ void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns);
 /// VectorToSCF, which reduces the rank of vector transfer ops.
 void populateVectorTransferLoweringPatterns(
     RewritePatternSet &patterns,
-    llvm::Optional<unsigned> maxTransferRank = llvm::None);
+    llvm::Optional<unsigned> maxTransferRank = llvm::None,
+    PatternBenefit benefit = 1);
 
 /// These patterns materialize masks for various vector ops such as transfers.
 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
-                                               bool force32BitVectorIndices);
+                                               bool force32BitVectorIndices,
+                                               PatternBenefit benefit = 1);
 
 /// Collect a set of patterns to propagate insert_map/extract_map in the ssa
 /// chain.
-void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
+void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns,
+                                                 PatternBenefit benefit = 1);
 
 /// Collects patterns to progressively lower vector.broadcast ops on high-D
 /// vectors to low-D vector ops.
-void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
+void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
+                                             PatternBenefit benefit = 1);
 
 /// Collects patterns to progressively lower vector mask ops into elementary
 /// selection and insertion ops.
-void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
+void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
 
 /// Collects patterns to progressively lower vector.shape_cast ops on high-D
 /// vectors into 1-D/2-D vector ops by generating data movement extract/insert
 /// ops.
-void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
+void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
+                                             PatternBenefit benefit = 1);
 
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index 9174d0e2ab53e..204b322e2deae 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -37,7 +37,8 @@ struct WarpExecuteOnLane0LoweringOptions {
 
 void populateWarpExecuteOnLane0OpToScfForPattern(
     RewritePatternSet &patterns,
-    const WarpExecuteOnLane0LoweringOptions &options);
+    const WarpExecuteOnLane0LoweringOptions &options,
+    PatternBenefit benefit = 1);
 
 using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
 
@@ -59,7 +60,8 @@ using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
 /// }
 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
 void populateDistributeTransferWriteOpPatterns(
-    RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn);
+    RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
+    PatternBenefit benefit = 1);
 
 /// Move scalar operations with no dependency on the warp op outside of the
 /// region.
@@ -67,7 +69,7 @@ void moveScalarUniformCode(WarpExecuteOnLane0Op op);
 
 /// Collect patterns to propagate warp distribution.
 void populatePropagateWarpVectorDistributionPatterns(
-    RewritePatternSet &pattern);
+    RewritePatternSet &pattern, PatternBenefit benefit = 1);
 
 /// Lambda signature to compute a reduction of a distributed value for the given
 /// reduction kind and size.
@@ -78,7 +80,8 @@ using DistributedReductionFn =
 /// distribute reduction op.
 void populateDistributeReduction(
     RewritePatternSet &pattern,
-    const DistributedReductionFn &distributedReductionFn);
+    const DistributedReductionFn &distributedReductionFn,
+    PatternBenefit benefit = 1);
 
 } // namespace vector
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ba4f6b3788c32..e7169b60285a9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -150,7 +150,8 @@ struct UnrollVectorOptions {
 /// Insert TransposeLowering patterns into extraction/insertion.
 void populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
-    VectorTransformsOptions options = VectorTransformsOptions());
+    VectorTransformsOptions options = VectorTransformsOptions(),
+    PatternBenefit benefit = 1);
 
 /// Collect a set of patterns to convert vector.multi_reduction op into
 /// a sequence of vector.reduction ops. The patterns comprise:
@@ -175,20 +176,24 @@ void populateVectorTransposeLoweringPatterns(
 /// the other patterns can kick in, thus fully exiting out of the
 /// vector.multi_reduction abstraction.
 void populateVectorMultiReductionLoweringPatterns(
-    RewritePatternSet &patterns, VectorMultiReductionLowering options);
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
 
 /// Collects patterns to progressively lower vector contraction ops on high-D
 /// into low-D reduction and product ops.
 void populateVectorContractLoweringPatterns(
     RewritePatternSet &patterns,
-    VectorTransformsOptions options = VectorTransformsOptions());
+    VectorTransformsOptions options = VectorTransformsOptions(),
+    PatternBenefit benefit = 1);
 
 /// Collect patterns to convert reduction op to vector.contract and fold
 /// transpose/broadcast ops into the contract.
-void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
+void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
+                                               PatternBenefit benefit = 1);
 
 /// Collect patterns to convert scan op
-void populateVectorScanLoweringPatterns(RewritePatternSet &patterns);
+void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
+                                        PatternBenefit benefit = 1);
 
 //===----------------------------------------------------------------------===//
 // Vector.transfer patterns.
@@ -246,14 +251,14 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns);
 ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
 ///     vector.broadcast %v
 void populateVectorTransferPermutationMapLoweringPatterns(
-    RewritePatternSet &patterns);
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
 /// Collect a set of patterns to reduce the rank of the operands of vector
 /// transfer ops to operate on the largest contigious vector.
 /// These patterns are useful when lowering to dialects with 1d vector type
 /// such as llvm and it will result fewer memory reads.
 void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
-    RewritePatternSet &patterns);
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
 /// Populate `patterns` with the following patterns.
 ///
@@ -278,7 +283,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
 void populateVectorInsertExtractStridedSliceDecompositionPatterns(
-    RewritePatternSet &patterns);
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
 /// Populate `patterns` with the following patterns.
 ///
@@ -299,7 +304,7 @@ void populateVectorInsertExtractStridedSliceDecompositionPatterns(
 /// =========================================
 /// For such cases, we can lower it to a ShuffleOp.
 void populateVectorInsertExtractStridedSliceTransforms(
-    RewritePatternSet &patterns);
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
 /// `options` structure controls which operations are unrolled and the target
@@ -332,7 +337,8 @@ void populateVectorInsertExtractStridedSliceTransforms(
 /// Other local patterns then kick in iteratively (including DCE) and compose
 /// to combine the ExtractStridedSlice/InsertStridedSlice.
 void populateVectorUnrollPatterns(RewritePatternSet &patterns,
-                                  const UnrollVectorOptions &options);
+                                  const UnrollVectorOptions &options,
+                                  PatternBenefit benefit = 1);
 
 //===----------------------------------------------------------------------===//
 // Finer-grained patterns exposed for more control over individual lowerings.
@@ -377,7 +383,8 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
 class ContractionOpToMatmulOpLowering
     : public OpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
+
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
 
@@ -387,8 +394,9 @@ class ContractionOpToMatmulOpLowering
 
   ContractionOpToMatmulOpLowering(
       vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context),
+      MLIRContext *context, PatternBenefit benefit = 1,
+      FilterConstraintType constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
@@ -419,7 +427,8 @@ class ContractionOpToMatmulOpLowering
 class ContractionOpToOuterProductOpLowering
     : public OpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
+
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
 
@@ -429,8 +438,9 @@ class ContractionOpToOuterProductOpLowering
 
   ContractionOpToOuterProductOpLowering(
       vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context),
+      MLIRContext *context, PatternBenefit benefit = 1,
+      FilterConstraintType constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
@@ -464,7 +474,8 @@ class ContractionOpToOuterProductOpLowering
 class ContractionOpToDotLowering
     : public OpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
+
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
 
@@ -474,9 +485,9 @@ class ContractionOpToDotLowering
 
   ContractionOpToDotLowering(
       vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context,
+      MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context),
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
   LogicalResult matchAndRewrite(vector::ContractionOp op,
@@ -504,7 +515,7 @@ class ContractionOpToDotLowering
 /// to Dot or when other contraction patterns fail.
 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
 
@@ -513,9 +524,9 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
   }
 
   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
-                        MLIRContext *context,
+                        MLIRContext *context, PatternBenefit benefit = 1,
                         FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context),
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79c3719a8337b..574b4b977961a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1464,7 +1464,7 @@ namespace {
 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
 public:
-  using OpRewritePattern<ExtractOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
@@ -1494,7 +1494,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
 class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
 public:
-  using OpRewritePattern<ExtractOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
@@ -1681,7 +1681,7 @@ namespace {
 
 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
-  using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
                                 PatternRewriter &rewriter) const override {
@@ -1828,7 +1828,7 @@ namespace {
 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
 // to a broadcast.
 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
-  using OpRewritePattern<ShuffleOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
                                 PatternRewriter &rewriter) const override {
@@ -1852,7 +1852,7 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
 public:
-  using OpRewritePattern<ShuffleOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ShuffleOp op,
                                 PatternRewriter &rewriter) const override {
@@ -1979,7 +1979,7 @@ namespace {
 // broadcast.
 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
 public:
-  using OpRewritePattern<InsertOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(InsertOp insertOp,
                                 PatternRewriter &rewriter) const override {
@@ -1996,7 +1996,7 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
 public:
-  using OpRewritePattern<InsertOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(InsertOp op,
                                 PatternRewriter &rewriter) const override {
@@ -2202,7 +2202,7 @@ namespace {
 class FoldInsertStridedSliceSplat final
     : public OpRewritePattern<InsertStridedSliceOp> {
 public:
-  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
                                 PatternRewriter &rewriter) const override {
@@ -2227,7 +2227,7 @@ class FoldInsertStridedSliceSplat final
 class FoldInsertStridedSliceOfExtract final
     : public OpRewritePattern<InsertStridedSliceOp> {
 public:
-  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
                                 PatternRewriter &rewriter) const override {
@@ -2587,7 +2587,7 @@ namespace {
 class StridedSliceConstantMaskFolder final
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
-  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
                                 PatternRewriter &rewriter) const override {
@@ -2640,7 +2640,7 @@ class StridedSliceConstantMaskFolder final
 class StridedSliceConstantFolder final
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
-  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
                                 PatternRewriter &rewriter) const override {
@@ -2666,7 +2666,7 @@ class StridedSliceConstantFolder final
 class StridedSliceBroadcast final
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
-  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
@@ -2709,7 +2709,7 @@ class StridedSliceBroadcast final
 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
-  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
@@ -3182,7 +3182,7 @@ namespace {
 struct FoldExtractSliceIntoTransferRead
     : public OpRewritePattern<TransferReadOp> {
 public:
-  using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(TransferReadOp xferOp,
                                 PatternRewriter &rewriter) const override {
@@ -3279,7 +3279,7 @@ struct FoldExtractSliceIntoTransferRead
 /// ```
 struct TransferReadAfterWriteToBroadcast
     : public OpRewritePattern<TransferReadOp> {
-  using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
@@ -3628,7 +3628,7 @@ namespace {
 /// any other uses.
 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
 public:
-  using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(TransferWriteOp writeOp,
                                 PatternRewriter &rewriter) const override {
     if (!writeOp.getShapedType().isa<RankedTensorType>())
@@ -3674,7 +3674,7 @@ class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
 struct FoldInsertSliceIntoTransferWrite
     : public OpRewritePattern<tensor::InsertSliceOp> {
 public:
-  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
                                 PatternRewriter &rewriter) const override {
@@ -3768,7 +3768,7 @@ struct FoldInsertSliceIntoTransferWrite
 struct SwapExtractSliceOfTransferWrite
     : public OpRewritePattern<tensor::InsertSliceOp> {
 public:
-  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
                                 PatternRewriter &rewriter) const override {
@@ -3947,7 +3947,7 @@ LogicalResult MaskedLoadOp::verify() {
 namespace {
 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
 public:
-  using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(MaskedLoadOp load,
                                 PatternRewriter &rewriter) const override {
     switch (getMaskFormat(load.getMask())) {
@@ -3998,7 +3998,7 @@ LogicalResult MaskedStoreOp::verify() {
 namespace {
 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
 public:
-  using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(MaskedStoreOp store,
                                 PatternRewriter &rewriter) const override {
     switch (getMaskFormat(store.getMask())) {
@@ -4056,7 +4056,7 @@ LogicalResult GatherOp::verify() {
 namespace {
 class GatherFolder final : public OpRewritePattern<GatherOp> {
 public:
-  using OpRewritePattern<GatherOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherOp gather,
                                 PatternRewriter &rewriter) const override {
     switch (getMaskFormat(gather.getMask())) {
@@ -4102,7 +4102,7 @@ LogicalResult ScatterOp::verify() {
 namespace {
 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
 public:
-  using OpRewritePattern<ScatterOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ScatterOp scatter,
                                 PatternRewriter &rewriter) const override {
     switch (getMaskFormat(scatter.getMask())) {
@@ -4148,7 +4148,7 @@ LogicalResult ExpandLoadOp::verify() {
 namespace {
 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
 public:
-  using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ExpandLoadOp expand,
                                 PatternRewriter &rewriter) const override {
     switch (getMaskFormat(expand.getMask())) {
@@ -4193,7 +4193,7 @@ LogicalResult CompressStoreOp::verify() {
 namespace {
 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
 public:
-  using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(CompressStoreOp compress,
                                 PatternRewriter &rewriter) const override {
     switch (getMaskFormat(compress.getMask())) {
@@ -4333,7 +4333,7 @@ namespace {
 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
-  using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
                                 PatternRewriter &rewriter) const override {
@@ -4359,7 +4359,7 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
 /// enough to capture the result in a single op).
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
-  using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
                                 PatternRewriter &rewriter) const override {
@@ -4589,7 +4589,7 @@ namespace {
 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
 public:
-  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
@@ -4651,7 +4651,7 @@ struct FoldTransposedScalarBroadcast final
 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
 public:
-  using OpRewritePattern<TransposeOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
@@ -4751,7 +4751,7 @@ namespace {
 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
 public:
-  using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
                                 PatternRewriter &rewriter) const override {
@@ -4850,12 +4850,12 @@ LogicalResult ScanOp::verify() {
 }
 
 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns
       .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
            ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
            StridedSliceConstantMaskFolder, TransposeFolder>(
-          patterns.getContext());
+          patterns.getContext(), benefit);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index a7667098125ea..f356248f18311 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1056,26 +1056,30 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
     RewritePatternSet &patterns,
-    const WarpExecuteOnLane0LoweringOptions &options) {
-  patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
+    const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
+  patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateDistributeTransferWriteOpPatterns(
-    RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) {
-  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
+    RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
+    PatternBenefit benefit) {
+  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
+                                    benefit);
 }
 
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
                WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpScfForOp, WarpOpConstant>(patterns.getContext());
+               WarpOpScfForOp, WarpOpConstant>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateDistributeReduction(
     RewritePatternSet &patterns,
-    const DistributedReductionFn &distributedReductionFn) {
-  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn);
+    const DistributedReductionFn &distributedReductionFn,
+    PatternBenefit benefit) {
+  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
+                                benefit);
 }
 
 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 07782eb48826e..47f00fa7ec495 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -401,8 +401,9 @@ struct CastAwayContractionLeadingOneDim
 
 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
 public:
-  CastAwayElementwiseLeadingOneDim(MLIRContext *context)
-      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+  CastAwayElementwiseLeadingOneDim(MLIRContext *context,
+                                   PatternBenefit benefit = 1)
+      : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
 
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
@@ -436,12 +437,12 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
 } // namespace
 
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns
       .add<CastAwayExtractStridedSliceLeadingOneDim,
            CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
            CastAwayTransferReadLeadingOneDim,
            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
-           CastAwayContractionLeadingOneDim>(patterns.getContext());
-  populateShapeCastFoldingPatterns(patterns);
+           CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
+  populateShapeCastFoldingPatterns(patterns, benefit);
 }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 224c7fbc01df4..9582a9d525008 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -286,15 +286,17 @@ class DecomposeNDExtractStridedSlice
 };
 
 void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DecomposeDifferentRankInsertStridedSlice,
-               DecomposeNDExtractStridedSlice>(patterns.getContext());
+               DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
 }
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
-    RewritePatternSet &patterns) {
-  populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns);
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
+                                                               benefit);
   patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
-               Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext());
+               Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
+                                                        benefit);
 }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index 2582781aaab08..f8c8f9e74fbc3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -29,11 +29,12 @@ using namespace mlir;
 class InnerOuterDimReductionConversion
     : public OpRewritePattern<vector::MultiDimReductionOp> {
 public:
-  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   explicit InnerOuterDimReductionConversion(
-      MLIRContext *context, vector::VectorMultiReductionLowering options)
-      : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
+      MLIRContext *context, vector::VectorMultiReductionLowering options,
+      PatternBenefit benefit = 1)
+      : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
         useInnerDimsForReduction(
             options == vector::VectorMultiReductionLowering::InnerReduction) {}
 
@@ -101,11 +102,12 @@ class InnerOuterDimReductionConversion
 class ReduceMultiDimReductionRank
     : public OpRewritePattern<vector::MultiDimReductionOp> {
 public:
-  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   explicit ReduceMultiDimReductionRank(
-      MLIRContext *context, vector::VectorMultiReductionLowering options)
-      : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
+      MLIRContext *context, vector::VectorMultiReductionLowering options,
+      PatternBenefit benefit = 1)
+      : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
         useInnerDimsForReduction(
             options == vector::VectorMultiReductionLowering::InnerReduction) {}
 
@@ -224,7 +226,7 @@ class ReduceMultiDimReductionRank
 /// and combines results
 struct TwoDimMultiReductionToElementWise
     : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
@@ -261,7 +263,7 @@ struct TwoDimMultiReductionToElementWise
 /// a sequence of vector.reduction ops.
 struct TwoDimMultiReductionToReduction
     : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
@@ -301,7 +303,7 @@ struct TwoDimMultiReductionToReduction
 /// separately.
 struct OneDimMultiReductionToTwoDim
     : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
@@ -338,12 +340,15 @@ struct OneDimMultiReductionToTwoDim
 };
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
-    RewritePatternSet &patterns, VectorMultiReductionLowering options) {
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
   patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
-      patterns.getContext(), options);
-  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
+      patterns.getContext(), options, benefit);
+  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
   if (options == VectorMultiReductionLowering ::InnerReduction)
-    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
+    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
+                                                  benefit);
   else
-    patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
+    patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
+                                                    benefit);
 }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 9125aae4ccb9b..5fe393b48b10f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -257,7 +257,7 @@ static bool isZero(Value v) {
 /// inserting a memref.subview dropping those unit dims.
 class TransferReadDropUnitDimsPattern
     : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
                                 PatternRewriter &rewriter) const override {
@@ -300,7 +300,7 @@ class TransferReadDropUnitDimsPattern
 /// unit dims, by inserting a memref.subview dropping those unit dims.
 class TransferWriteDropUnitDimsPattern
     : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
                                 PatternRewriter &rewriter) const override {
@@ -412,7 +412,7 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
 /// already reduced i.e. without unit dims.
 class FlattenContiguousRowMajorTransferReadPattern
     : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
                                 PatternRewriter &rewriter) const override {
@@ -470,7 +470,7 @@ class FlattenContiguousRowMajorTransferReadPattern
 /// already reduced i.e. without unit dims.
 class FlattenContiguousRowMajorTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
                                 PatternRewriter &rewriter) const override {
@@ -543,17 +543,17 @@ void mlir::vector::transferOpflowOpt(Operation *rootOp) {
 }
 
 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns
       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
-          patterns.getContext());
+          patterns.getContext(), benefit);
   populateShapeCastFoldingPatterns(patterns);
 }
 
 void mlir::vector::populateFlattenVectorTransferPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
                FlattenContiguousRowMajorTransferWritePattern>(
-      patterns.getContext());
-  populateShapeCastFoldingPatterns(patterns);
+      patterns.getContext(), benefit);
+  populateShapeCastFoldingPatterns(patterns, benefit);
 }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
index e419fc2db87a0..de72c6d2cfeac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
@@ -53,7 +53,7 @@ transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
 /// vector.transfer_read to do the transpose in memory instead.
 struct TransferReadPermutationLowering
     : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferReadOp op,
                                 PatternRewriter &rewriter) const override {
@@ -142,7 +142,7 @@ struct TransferReadPermutationLowering
 ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
 struct TransferWritePermutationLowering
     : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp op,
                                 PatternRewriter &rewriter) const override {
@@ -201,7 +201,7 @@ struct TransferWritePermutationLowering
 ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
 ///     vector.broadcast %v
 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferReadOp op,
                                 PatternRewriter &rewriter) const override {
@@ -271,8 +271,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
 };
 
 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<TransferReadPermutationLowering,
                TransferWritePermutationLowering, TransferOpReduceRank>(
-      patterns.getContext());
+      patterns.getContext(), benefit);
 }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8c3d3e9bc6b15..3ff045031be10 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -216,7 +216,7 @@ namespace {
 //   %1 = user %0 : vector<5x4x2xf32>
 //
 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
-  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
                                 PatternRewriter &rewriter) const override {
@@ -250,7 +250,7 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
 /// Progressive lowering of BroadcastOp.
 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
 public:
-  using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
@@ -381,11 +381,11 @@ void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
 ///   %x = vector.insert .., .. [.., ..]
 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 public:
-  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
-                      MLIRContext *context)
-      : OpRewritePattern<vector::TransposeOp>(context),
+                      MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions) {}
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
@@ -470,12 +470,12 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 class TransposeOp2DToShuffleLowering
     : public OpRewritePattern<vector::TransposeOp> {
 public:
-  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   TransposeOp2DToShuffleLowering(
       vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context)
-      : OpRewritePattern<vector::TransposeOp>(context),
+      MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions) {}
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
@@ -534,7 +534,7 @@ class TransposeOp2DToShuffleLowering
 ///
 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 public:
-  using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::OuterProductOp op,
                                 PatternRewriter &rewriter) const override {
@@ -593,9 +593,9 @@ struct ContractOpToElementwise
   }
   ContractOpToElementwise(
       vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context,
+      MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context),
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
@@ -715,7 +715,7 @@ struct ContractOpToElementwise
 /// will be folded at LLVM IR level.
 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
 public:
-  using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
                                 PatternRewriter &rewriter) const override {
@@ -789,7 +789,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
 /// until a one-dimensional vector is reached.
 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
 public:
-  using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
                                 PatternRewriter &rewriter) const override {
@@ -835,7 +835,7 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
 class ShapeCastOp2DDownCastRewritePattern
     : public OpRewritePattern<vector::ShapeCastOp> {
 public:
-  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                 PatternRewriter &rewriter) const override {
@@ -868,7 +868,7 @@ class ShapeCastOp2DDownCastRewritePattern
 class ShapeCastOp2DUpCastRewritePattern
     : public OpRewritePattern<vector::ShapeCastOp> {
 public:
-  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                 PatternRewriter &rewriter) const override {
@@ -900,7 +900,7 @@ class ShapeCastOp2DUpCastRewritePattern
 // into the right place if we get here.
 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
 public:
-  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                 PatternRewriter &rewriter) const override {
@@ -974,7 +974,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
 ///  ```
 struct MultiReduceToContract
     : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
                                 PatternRewriter &rewriter) const override {
@@ -1030,7 +1030,7 @@ struct MultiReduceToContract
 ///  ```
 struct CombineContractTranspose
     : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
@@ -1087,7 +1087,7 @@ struct CombineContractTranspose
 ///  ```
 struct CombineContractBroadcast
     : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
@@ -2036,8 +2036,9 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
 struct TransferReadToVectorLoadLowering
     : public OpRewritePattern<vector::TransferReadOp> {
   TransferReadToVectorLoadLowering(MLIRContext *context,
-                                   llvm::Optional<unsigned> maxRank)
-      : OpRewritePattern<vector::TransferReadOp>(context),
+                                   llvm::Optional<unsigned> maxRank,
+                                   PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransferReadOp>(context, benefit),
         maxTransferRank(maxRank) {}
 
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
@@ -2124,7 +2125,7 @@ struct TransferReadToVectorLoadLowering
 // trivial case (for architectures for which this matters).
 struct VectorLoadToMemrefLoadLowering
     : public OpRewritePattern<vector::LoadOp> {
-  using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
                                 PatternRewriter &rewriter) const override {
@@ -2142,7 +2143,7 @@ struct VectorLoadToMemrefLoadLowering
 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
 struct VectorStoreToMemrefStoreLowering
     : public OpRewritePattern<vector::StoreOp> {
-  using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
                                 PatternRewriter &rewriter) const override {
@@ -2177,8 +2178,9 @@ struct VectorStoreToMemrefStoreLowering
 struct TransferWriteToVectorStoreLowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   TransferWriteToVectorStoreLowering(MLIRContext *context,
-                                     llvm::Optional<unsigned> maxRank)
-      : OpRewritePattern<vector::TransferWriteOp>(context),
+                                     llvm::Optional<unsigned> maxRank,
+                                     PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
         maxTransferRank(maxRank) {}
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
@@ -2415,6 +2417,7 @@ struct BubbleDownBitCastForStridedSliceExtract
 struct BubbleUpBitCastForStridedSliceInsert
     : public OpRewritePattern<vector::BitCastOp> {
   using OpRewritePattern::OpRewritePattern;
+
   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
                                 PatternRewriter &rewriter) const override {
     VectorType castSrcType = bitcastOp.getSourceVectorType();
@@ -2530,8 +2533,9 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
 template <typename ConcreteOp>
 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
 public:
-  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
-      : mlir::OpRewritePattern<ConcreteOp>(context),
+  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
+                                   PatternBenefit benefit = 1)
+      : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
         force32BitVectorIndices(enableIndexOpt) {}
 
   LogicalResult matchAndRewrite(ConcreteOp xferOp,
@@ -2583,8 +2587,9 @@ class VectorCreateMaskOpConversion
     : public OpRewritePattern<vector::CreateMaskOp> {
 public:
   explicit VectorCreateMaskOpConversion(MLIRContext *context,
-                                        bool enableIndexOpt)
-      : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
+                                        bool enableIndexOpt,
+                                        PatternBenefit benefit = 1)
+      : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
         force32BitVectorIndices(enableIndexOpt) {}
 
   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
@@ -2608,7 +2613,7 @@ class VectorCreateMaskOpConversion
 
 // Drop inner most contiguous unit dimensions from transfer_read operand.
 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
@@ -2815,7 +2820,7 @@ static Value genOperator(Location loc, Value x, Value y,
 ///   vector<2x3xi32>, vector<2xi32>
 /// ```
 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
-  using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
                                 PatternRewriter &rewriter) const override {
@@ -2896,81 +2901,87 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
 } // namespace
 
 void mlir::vector::populateVectorMaskMaterializationPatterns(
-    RewritePatternSet &patterns, bool force32BitVectorIndices) {
+    RewritePatternSet &patterns, bool force32BitVectorIndices,
+    PatternBenefit benefit) {
   patterns.add<VectorCreateMaskOpConversion,
                MaterializeTransferMask<vector::TransferReadOp>,
                MaterializeTransferMask<vector::TransferWriteOp>>(
-      patterns.getContext(), force32BitVectorIndices);
+      patterns.getContext(), force32BitVectorIndices, benefit);
 }
 
-void mlir::vector::populateShapeCastFoldingPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<ShapeCastOpFolder>(patterns.getContext());
+void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<BubbleDownVectorBitCastForExtract,
                BubbleDownBitCastForStridedSliceExtract,
-               BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
+               BubbleUpBitCastForStridedSliceInsert>(patterns.getContext(),
+                                                     benefit);
 }
 
 void mlir::vector::populateVectorBroadcastLoweringPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<BroadcastOpLowering>(patterns.getContext());
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorMaskOpLoweringPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
-      patterns.getContext());
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorShapeCastLoweringPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<ShapeCastOp2DDownCastRewritePattern,
                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
-      patterns.getContext());
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options) {
-  patterns.add<OuterProductOpLowering>(patterns.getContext());
+    RewritePatternSet &patterns, VectorTransformsOptions options,
+    PatternBenefit benefit) {
+  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
-               ContractionOpToOuterProductOpLowering>(options,
-                                                      patterns.getContext());
+               ContractionOpToOuterProductOpLowering>(
+      options, patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorTransposeLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options) {
+    RewritePatternSet &patterns, VectorTransformsOptions options,
+    PatternBenefit benefit) {
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
-      options, patterns.getContext());
+      options, patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorReductionToContractPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
                CombineContractTranspose, ReorderCastOpsOnBroadcast,
-               ReorderElementwiseOpsOnTranspose>(patterns.getContext());
+               ReorderElementwiseOpsOnTranspose>(patterns.getContext(),
+                                                 benefit);
 }
 
 void mlir::vector::
     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
-        RewritePatternSet &patterns) {
-  patterns.add<DropInnerMostUnitDims>(patterns.getContext());
+        RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorTransferLoweringPatterns(
-    RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
+    RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank,
+    PatternBenefit benefit) {
   patterns.add<TransferReadToVectorLoadLowering,
                TransferWriteToVectorStoreLowering>(patterns.getContext(),
-                                                   maxTransferRank);
+                                                   maxTransferRank, benefit);
   patterns
       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
-          patterns.getContext());
+          patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorScanLoweringPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<ScanToArithOps>(patterns.getContext());
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
 }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index a7461456682c3..dc314aa2141f4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -181,9 +181,11 @@ namespace {
 struct UnrollTransferReadPattern
     : public OpRewritePattern<vector::TransferReadOp> {
   UnrollTransferReadPattern(MLIRContext *context,
-                            const vector::UnrollVectorOptions &options)
-      : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
+                            const vector::UnrollVectorOptions &options,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransferReadOp>(context, benefit),
         options(options) {}
+
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
@@ -236,9 +238,11 @@ struct UnrollTransferReadPattern
 struct UnrollTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
   UnrollTransferWritePattern(MLIRContext *context,
-                             const vector::UnrollVectorOptions &options)
-      : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
+                             const vector::UnrollVectorOptions &options,
+                             PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
         options(options) {}
+
   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
                                 PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
@@ -306,8 +310,9 @@ struct OffsetMapInfo {
 struct UnrollContractionPattern
     : public OpRewritePattern<vector::ContractionOp> {
   UnrollContractionPattern(MLIRContext *context,
-                           const vector::UnrollVectorOptions &options)
-      : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
+                           const vector::UnrollVectorOptions &options,
+                           PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
         options(options) {}
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
@@ -408,8 +413,9 @@ struct UnrollContractionPattern
 struct UnrollMultiReductionPattern
     : public OpRewritePattern<vector::MultiDimReductionOp> {
   UnrollMultiReductionPattern(MLIRContext *context,
-                              const vector::UnrollVectorOptions &options)
-      : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
+                              const vector::UnrollVectorOptions &options,
+                              PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
         options(options) {}
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
@@ -481,9 +487,11 @@ struct UnrollMultiReductionPattern
 
 struct UnrollElementwisePattern : public RewritePattern {
   UnrollElementwisePattern(MLIRContext *context,
-                           const vector::UnrollVectorOptions &options)
-      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
+                           const vector::UnrollVectorOptions &options,
+                           PatternBenefit benefit = 1)
+      : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
         options(options) {}
+
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
@@ -539,7 +547,8 @@ struct UnrollElementwisePattern : public RewritePattern {
 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
 /// %dv = arith.addf %da, %db : vector<1xf32>
 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
-  using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
+
   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
                                 PatternRewriter &rewriter) const override {
     Operation *definedOp = extract.getVector().getDefiningOp();
@@ -570,7 +579,8 @@ struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
 /// Canonicalize an extract_map using the result of a contract operation.
 /// This propagate the extract_map to operands.
 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
-  using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
+
   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
                                 PatternRewriter &rewriter) const override {
     Operation *definedOp = extract.getVector().getDefiningOp();
@@ -631,8 +641,8 @@ struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
 /// ```
 struct TransferReadExtractPattern
     : public OpRewritePattern<vector::TransferReadOp> {
-  TransferReadExtractPattern(MLIRContext *context)
-      : OpRewritePattern<vector::TransferReadOp>(context) {}
+  using OpRewritePattern::OpRewritePattern;
+
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
@@ -682,8 +692,8 @@ struct TransferReadExtractPattern
 
 struct TransferWriteInsertPattern
     : public OpRewritePattern<vector::TransferWriteOp> {
-  TransferWriteInsertPattern(MLIRContext *context)
-      : OpRewritePattern<vector::TransferWriteOp>(context) {}
+  using OpRewritePattern::OpRewritePattern;
+
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
@@ -726,8 +736,9 @@ struct TransferWriteInsertPattern
 
 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
   UnrollReductionPattern(MLIRContext *context,
-                         const vector::UnrollVectorOptions &options)
-      : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ReductionOp>(context, benefit),
         options(options) {}
 
   LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
@@ -772,9 +783,11 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
 
 struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
   UnrollTranposePattern(MLIRContext *context,
-                        const vector::UnrollVectorOptions &options)
-      : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
+                        const vector::UnrollVectorOptions &options,
+                        PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit),
         options(options) {}
+
   LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
                                 PatternRewriter &rewriter) const override {
     if (tranposeOp.getResultType().getRank() == 0)
@@ -821,16 +834,17 @@ struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
-    RewritePatternSet &patterns, const UnrollVectorOptions &options) {
+    RewritePatternSet &patterns, const UnrollVectorOptions &options,
+    PatternBenefit benefit) {
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,
-               UnrollTranposePattern>(patterns.getContext(), options);
+               UnrollTranposePattern>(patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populatePropagateVectorDistributionPatterns(
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<PointwiseExtractPattern, ContractExtractPattern,
                TransferReadExtractPattern, TransferWriteInsertPattern>(
-      patterns.getContext());
+      patterns.getContext(), benefit);
 }

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 20ce2471cf17b..eec8de0094477 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -160,7 +160,7 @@ struct TestVectorContractionLowering
       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
       VectorTransformsOptions options{lowering};
       patterns.add<ContractionOpToOuterProductOpLowering>(
-          options, &getContext(), [](vector::ContractionOp op) {
+          options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) {
             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
             // where M is not 4.
             if (op.getRhsType().getShape()[0] == 4)


        


More information about the Mlir-commits mailing list