[Mlir-commits] [mlir] 76f3c2f - [mlir][Pattern] Add better support for using interfaces/traits to match root operations in rewrite patterns

River Riddle llvmlistbot at llvm.org
Tue Mar 23 14:12:26 PDT 2021


Author: River Riddle
Date: 2021-03-23T14:05:33-07:00
New Revision: 76f3c2f3f34a2be05e5970f7413c6929541fd219

URL: https://github.com/llvm/llvm-project/commit/76f3c2f3f34a2be05e5970f7413c6929541fd219
DIFF: https://github.com/llvm/llvm-project/commit/76f3c2f3f34a2be05e5970f7413c6929541fd219.diff

LOG: [mlir][Pattern] Add better support for using interfaces/traits to match root operations in rewrite patterns

To match an interface or trait, users currently have to use the `MatchAny` tag. This tag can be quite problematic for compile time for things like the canonicalizer, as the `MatchAny` patterns may get applied to  *every* operation. This revision adds better support by bucketing interface/trait patterns based on which registered operations have them registered. This means that moving forward we will only attempt to match these patterns to operations that have this interface registered. Two simplify defining patterns that match traits and interfaces, two new utility classes have been added: OpTraitRewritePattern and OpInterfaceRewritePattern.

Differential Revision: https://reviews.llvm.org/D98986

Added: 
    

Modified: 
    flang/lib/Optimizer/Dialect/FIROps.cpp
    mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
    mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h
    mlir/include/mlir/Support/InterfaceSupport.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
    mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
    mlir/lib/Rewrite/PatternApplicator.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
    mlir/tools/mlir-tblgen/RewriterGen.cpp
    mlir/unittests/Rewrite/PatternBenefit.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 6d2d78d5825f..38390d801134 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -697,7 +697,7 @@ static bool isOne(mlir::Value v) { return checkIsIntegerConstant(v, 1); }
 template <typename FltOp, typename CpxOp>
 struct UndoComplexPattern : public mlir::RewritePattern {
   UndoComplexPattern(mlir::MLIRContext *ctx)
-      : mlir::RewritePattern("fir.insert_value", {}, 2, ctx) {}
+      : mlir::RewritePattern("fir.insert_value", 2, ctx) {}
 
   mlir::LogicalResult
   matchAndRewrite(mlir::Operation *op,

diff  --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index f66a29250aa2..eeb20c4806b9 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -30,12 +30,12 @@ namespace linalg {
 // or in an externally linked library.
 // This is a generic entry point for all LinalgOp, except for CopyOp and
 // IndexedGenericOp, for which omre specialized patterns are provided.
-class LinalgOpToLibraryCallRewrite : public RewritePattern {
+class LinalgOpToLibraryCallRewrite
+    : public OpInterfaceRewritePattern<LinalgOp> {
 public:
-  LinalgOpToLibraryCallRewrite()
-      : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(Operation *op,
+  LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override;
 };
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index d005cc310abe..bee0d5a12800 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -60,7 +60,8 @@ void enqueue(RewritePatternSet &patternList, OptionsType options,
   if (!opName.empty())
     patternList.add<PatternType>(opName, patternList.getContext(), options, m);
   else
-    patternList.add<PatternType>(m.addOpFilter<OpType>(), options);
+    patternList.add<PatternType>(patternList.getContext(),
+                                 m.addOpFilter<OpType>(), options);
 }
 
 /// Promotion transformation enqueues a particular stage-1 pattern for

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 71dbe9fb24cd..21e6cba9dc3c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -452,7 +452,7 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
 struct LinalgBaseTilingPattern : public RewritePattern {
   // Entry point to match any LinalgOp OpInterface.
   LinalgBaseTilingPattern(
-      LinalgTilingOptions options,
+      MLIRContext *context, LinalgTilingOptions options,
       LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
   // Entry point to match a specific Linalg op.
@@ -644,7 +644,8 @@ struct LinalgVectorizationOptions {};
 
 struct LinalgBaseVectorizationPattern : public RewritePattern {
   /// MatchAnyOpTag-based constructor with a mandatory `filter`.
-  LinalgBaseVectorizationPattern(LinalgTransformationFilter filter,
+  LinalgBaseVectorizationPattern(MLIRContext *context,
+                                 LinalgTransformationFilter filter,
                                  PatternBenefit benefit = 1);
   /// Name-based constructor with an optional `filter`.
   LinalgBaseVectorizationPattern(
@@ -663,10 +664,10 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
   /// These constructors are available to anyone.
   /// MatchAnyOpTag-based constructor with a mandatory `filter`.
   LinalgVectorizationPattern(
-      LinalgTransformationFilter filter,
+      MLIRContext *context, LinalgTransformationFilter filter,
       LinalgVectorizationOptions options = LinalgVectorizationOptions(),
       PatternBenefit benefit = 1)
-      : LinalgBaseVectorizationPattern(filter, benefit) {}
+      : LinalgBaseVectorizationPattern(context, filter, benefit) {}
   /// Name-based constructor with an optional `filter`.
   LinalgVectorizationPattern(
       StringRef opName, MLIRContext *context,
@@ -702,8 +703,8 @@ template <typename OpType, typename = std::enable_if_t<
 void insertVectorizationPatternImpl(RewritePatternSet &patternList,
                                     linalg::LinalgVectorizationOptions options,
                                     linalg::LinalgTransformationFilter f) {
-  patternList.add<linalg::LinalgVectorizationPattern>(f.addOpFilter<OpType>(),
-                                                      options);
+  patternList.add<linalg::LinalgVectorizationPattern>(
+      patternList.getContext(), f.addOpFilter<OpType>(), options);
 }
 
 /// Variadic helper function to insert vectorization patterns for C++ ops.
@@ -737,7 +738,7 @@ struct LinalgLoweringPattern : public RewritePattern {
       MLIRContext *context, LinalgLoweringType loweringType,
       LinalgTransformationFilter filter = LinalgTransformationFilter(),
       ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1)
-      : RewritePattern(OpTy::getOperationName(), {}, benefit, context),
+      : RewritePattern(OpTy::getOperationName(), benefit, context),
         filter(filter), loweringType(loweringType),
         interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
 

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 35eb83d8f03a..b765dafbce46 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -123,7 +123,8 @@ struct UnrollVectorOptions {
 struct UnrollVectorPattern : public RewritePattern {
   using FilterConstraintType = std::function<LogicalResult(Operation *op)>;
   UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options)
-      : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {}
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
+        options(options) {}
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     if (options.filterConstraint && failed(options.filterConstraint(op)))
@@ -216,7 +217,7 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
       FilterConstraintType filter =
           [](VectorTransferOpInterface op) { return success(); },
       PatternBenefit benefit = 1)
-      : RewritePattern(benefit, MatchAnyOpTypeTag()), options(options),
+      : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
         filter(filter) {}
 
   /// Performs the rewrite.

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index b27e1e0e4a78..ec3884e58fc3 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1516,6 +1516,13 @@ class Op : public OpState, public Traits<ConcreteType>... {
 #endif
     return false;
   }
+  /// Provide `classof` support for other OpBase derived classes, such as
+  /// Interfaces.
+  template <typename T>
+  static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
+  classof(const T *op) {
+    return classof(const_cast<T *>(op)->getOperation());
+  }
 
   /// Expose the type we are instantiated on to template machinery that may want
   /// to introspect traits on this operation.

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 8cc97d9c02ee..cb82ec9c4714 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -142,12 +142,20 @@ class AbstractOperation {
     return interfaceMap.lookup<T>();
   }
 
+  /// Returns true if this operation has the given interface registered to it.
+  bool hasInterface(TypeID interfaceID) const {
+    return interfaceMap.contains(interfaceID);
+  }
+
   /// Returns true if the operation has a particular trait.
   template <template <typename T> class Trait>
   bool hasTrait() const {
     return hasTraitFn(TypeID::get<Trait>());
   }
 
+  /// Returns true if the operation has a particular trait.
+  bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
+
   /// Look up the specified operation in the specified MLIRContext and return a
   /// pointer to it if present.  Otherwise, return a null pointer.
   static const AbstractOperation *lookup(StringRef opName,

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 115ad5f039bc..5ee9418efa38 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -68,6 +68,19 @@ class PatternBenefit {
 /// used to interface with the metadata of a pattern, such as the benefit or
 /// root operation.
 class Pattern {
+  /// This enum represents the kind of value used to select the root operations
+  /// that match this pattern.
+  enum class RootKind {
+    /// The pattern root matches "any" operation.
+    Any,
+    /// The pattern root is matched using a concrete operation name.
+    OperationName,
+    /// The pattern root is matched using an interface ID.
+    InterfaceID,
+    /// The patter root is matched using a trait ID.
+    TraitID
+  };
+
 public:
   /// Return a list of operations that may be generated when rewriting an
   /// operation instance with this pattern.
@@ -75,7 +88,29 @@ class Pattern {
 
   /// Return the root node that this pattern matches. Patterns that can match
   /// multiple root types return None.
-  Optional<OperationName> getRootKind() const { return rootKind; }
+  Optional<OperationName> getRootKind() const {
+    if (rootKind == RootKind::OperationName)
+      return OperationName::getFromOpaquePointer(rootValue);
+    return llvm::None;
+  }
+
+  /// Return the interface ID used to match the root operation of this pattern.
+  /// If the pattern does not use an interface ID for deciding the root match,
+  /// this returns None.
+  Optional<TypeID> getRootInterfaceID() const {
+    if (rootKind == RootKind::InterfaceID)
+      return TypeID::getFromOpaquePointer(rootValue);
+    return llvm::None;
+  }
+
+  /// Return the trait ID used to match the root operation of this pattern.
+  /// If the pattern does not use a trait ID for deciding the root match, this
+  /// returns None.
+  Optional<TypeID> getRootTraitID() const {
+    if (rootKind == RootKind::TraitID)
+      return TypeID::getFromOpaquePointer(rootValue);
+    return llvm::None;
+  }
 
   /// Return the benefit (the inverse of "cost") of matching this pattern.  The
   /// benefit of a Pattern is always static - rewrites that may have dynamic
@@ -88,56 +123,85 @@ class Pattern {
   /// i.e. this pattern may generate IR that also matches this pattern, but is
   /// known to bound the recursion. This signals to a rewrite driver that it is
   /// safe to apply this pattern recursively to generated IR.
-  bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
+  bool hasBoundedRewriteRecursion() const {
+    return contextAndHasBoundedRecursion.getInt();
+  }
+
+  /// Return the MLIRContext used to create this pattern.
+  MLIRContext *getContext() const {
+    return contextAndHasBoundedRecursion.getPointer();
+  }
 
 protected:
   /// This class acts as a special tag that makes the desire to match "any"
   /// operation type explicit. This helps to avoid unnecessary usages of this
   /// feature, and ensures that the user is making a conscious decision.
   struct MatchAnyOpTypeTag {};
+  /// This class acts as a special tag that makes the desire to match any
+  /// operation that implements a given interface explicit. This helps to avoid
+  /// unnecessary usages of this feature, and ensures that the user is making a
+  /// conscious decision.
+  struct MatchInterfaceOpTypeTag {};
+  /// This class acts as a special tag that makes the desire to match any
+  /// operation that implements a given trait explicit. This helps to avoid
+  /// unnecessary usages of this feature, and ensures that the user is making a
+  /// conscious decision.
+  struct MatchTraitOpTypeTag {};
 
   /// Construct a pattern with a certain benefit that matches the operation
   /// with the given root name.
-  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
-  /// Construct a pattern with a certain benefit that matches any operation
-  /// type. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
-  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
-  /// always be supplied here.
-  Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag);
-  /// Construct a pattern with a certain benefit that matches the operation with
-  /// the given root name. `generatedNames` contains the names of operations
-  /// that may be generated during a successful rewrite.
-  Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
-          PatternBenefit benefit, MLIRContext *context);
+  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
+          ArrayRef<StringRef> generatedNames = {});
   /// Construct a pattern that may match any operation type. `generatedNames`
   /// contains the names of operations that may be generated during a successful
   /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
   /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
   /// always be supplied here.
-  Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
-          MLIRContext *context, MatchAnyOpTypeTag tag);
+  Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
+          ArrayRef<StringRef> generatedNames = {});
+  /// Construct a pattern that may match any operation that implements the
+  /// interface defined by the provided `interfaceID`. `generatedNames` contains
+  /// the names of operations that may be generated during a successful rewrite.
+  /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
+  /// interface" behavior is what the user actually desired,
+  /// `MatchInterfaceOpTypeTag()` should always be supplied here.
+  Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
+          PatternBenefit benefit, MLIRContext *context,
+          ArrayRef<StringRef> generatedNames = {});
+  /// Construct a pattern that may match any operation that implements the
+  /// trait defined by the provided `traitID`. `generatedNames` contains the
+  /// names of operations that may be generated during a successful rewrite.
+  /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
+  /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
+  /// always be supplied here.
+  Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
+          MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
 
   /// Set the flag detailing if this pattern has bounded rewrite recursion or
   /// not.
   void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
-    hasBoundedRecursion = hasBoundedRecursionArg;
+    contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
   }
 
 private:
-  /// A list of the potential operations that may be generated when rewriting
-  /// an op with this pattern.
-  SmallVector<OperationName, 2> generatedOps;
+  Pattern(const void *rootValue, RootKind rootKind,
+          ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
+          MLIRContext *context);
 
-  /// The root operation of the pattern. If the pattern matches a specific
-  /// operation, this contains the name of that operation. Contains None
-  /// otherwise.
-  Optional<OperationName> rootKind;
+  /// The value used to match the root operation of the pattern.
+  const void *rootValue;
+  RootKind rootKind;
 
   /// The expected benefit of matching this pattern.
   const PatternBenefit benefit;
 
-  /// A boolean flag of whether this pattern has bounded recursion or not.
-  bool hasBoundedRecursion = false;
+  /// The context this pattern was created from, and a boolean flag indicating
+  /// whether this pattern has bounded recursion or not.
+  llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
+
+  /// A list of the potential operations that may be generated when rewriting
+  /// an op with this pattern.
+  SmallVector<OperationName, 2> generatedOps;
 };
 
 //===----------------------------------------------------------------------===//
@@ -188,15 +252,13 @@ class RewritePattern : public Pattern {
   virtual void anchor();
 };
 
-/// OpRewritePattern is a wrapper around RewritePattern that allows for
-/// matching and rewriting against an instance of a derived operation class as
-/// opposed to a raw Operation.
+namespace detail {
+/// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
+/// allows for matching and rewriting against an instance of a derived operation
+/// class or Interface.
 template <typename SourceOp>
-struct OpRewritePattern : public RewritePattern {
-  /// Patterns must specify the root operation name they match against, and can
-  /// also specify the benefit of the pattern matching.
-  OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
-      : RewritePattern(SourceOp::getOperationName(), benefit, context) {}
+struct OpOrInterfaceRewritePatternBase : public RewritePattern {
+  using RewritePattern::RewritePattern;
 
   /// Wrappers around the RewritePattern methods that pass the derived op type.
   void rewrite(Operation *op, PatternRewriter &rewriter) const final {
@@ -227,6 +289,43 @@ struct OpRewritePattern : public RewritePattern {
     return failure();
   }
 };
+} // namespace detail
+
+/// OpRewritePattern is a wrapper around RewritePattern that allows for
+/// matching and rewriting against an instance of a derived operation class as
+/// opposed to a raw Operation.
+template <typename SourceOp>
+struct OpRewritePattern
+    : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  /// Patterns must specify the root operation name they match against, and can
+  /// also specify the benefit of the pattern matching.
+  OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
+            SourceOp::getOperationName(), benefit, context) {}
+};
+
+/// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
+/// matching and rewriting against an instance of an operation interface instead
+/// of a raw Operation.
+template <typename SourceOp>
+struct OpInterfaceRewritePattern
+    : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
+            Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
+            benefit, context) {}
+};
+
+/// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
+/// matching and rewriting against instances of an operation that possess a
+/// given trait.
+template <template <typename> class TraitType>
+class OpTraitRewritePattern : public RewritePattern {
+public:
+  OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
+                       benefit, context) {}
+};
 
 //===----------------------------------------------------------------------===//
 // PDLPatternModule

diff  --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h
index 554bfd217534..6791fbd7e3c0 100644
--- a/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h
+++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h
@@ -25,6 +25,10 @@ class FrozenRewritePatternSet {
   using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
 
 public:
+  /// A map of operation specific native patterns.
+  using OpSpecificNativePatternListT =
+      DenseMap<OperationName, std::vector<RewritePattern *>>;
+
   /// Freeze the patterns held in `patterns`, and take ownership.
   FrozenRewritePatternSet();
   FrozenRewritePatternSet(RewritePatternSet &&patterns);
@@ -36,11 +40,16 @@ class FrozenRewritePatternSet {
   operator=(FrozenRewritePatternSet &&patterns) = default;
   ~FrozenRewritePatternSet();
 
-  /// Return the native patterns held by this list.
+  /// Return the op specific native patterns held by this list.
+  const OpSpecificNativePatternListT &getOpSpecificNativePatterns() const {
+    return impl->nativeOpSpecificPatternMap;
+  }
+
+  /// Return the "match any" native patterns held by this list.
   iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>>
-  getNativePatterns() const {
-    const NativePatternListT &nativePatterns = impl->nativePatterns;
-    return llvm::make_pointee_range(nativePatterns);
+  getMatchAnyOpNativePatterns() const {
+    const NativePatternListT &nativeList = impl->nativeAnyOpPatterns;
+    return llvm::make_pointee_range(nativeList);
   }
 
   /// Return the compiled PDL bytecode held by this list. Returns null if
@@ -52,8 +61,17 @@ class FrozenRewritePatternSet {
 private:
   /// The internal implementation of the frozen pattern list.
   struct Impl {
-    /// The set of native C++ rewrite patterns.
-    NativePatternListT nativePatterns;
+    /// The set of native C++ rewrite patterns that are matched to specific
+    /// operation kinds.
+    OpSpecificNativePatternListT nativeOpSpecificPatternMap;
+
+    /// The full op-specific native rewrite list. This allows for the map above
+    /// to contain duplicate patterns, e.g. for interfaces and traits.
+    NativePatternListT nativeOpSpecificPatternList;
+
+    /// The set of native C++ rewrite patterns that are matched to "any"
+    /// operation.
+    NativePatternListT nativeAnyOpPatterns;
 
     /// The bytecode containing the compiled PDL patterns.
     std::unique_ptr<detail::PDLByteCode> pdlByteCode;

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index b618e8effd4a..d7a455722c77 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -183,6 +183,9 @@ class InterfaceMap {
     return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID()));
   }
 
+  /// Returns true if the interface map contains an interface for the given id.
+  bool contains(TypeID interfaceID) const { return lookup(interfaceID); }
+
 private:
   /// Compare two TypeID instances by comparing the underlying pointer.
   static bool compare(TypeID lhs, TypeID rhs) {

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 7ebd07d8cb42..d1bb6bc1033d 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -351,20 +351,12 @@ class ConversionPattern : public RewritePattern {
   /// See `RewritePattern::RewritePattern` for information on the other
   /// available constructors.
   using RewritePattern::RewritePattern;
-  /// Construct a conversion pattern that matches an operation with the given
-  /// root name. This constructor allows for providing a type converter to use
-  /// within the pattern.
-  ConversionPattern(StringRef rootName, PatternBenefit benefit,
-                    TypeConverter &typeConverter, MLIRContext *ctx)
-      : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
-  /// Construct a conversion pattern that matches any operation type. This
-  /// constructor allows for providing a type converter to use within the
-  /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
-  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
-  /// always be supplied here.
-  ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
-                    MatchAnyOpTypeTag tag)
-      : RewritePattern(benefit, tag), typeConverter(&typeConverter) {}
+  /// Construct a conversion pattern with the given converter, and forward the
+  /// remaining arguments to RewritePattern.
+  template <typename... Args>
+  ConversionPattern(TypeConverter &typeConverter, Args &&... args)
+      : RewritePattern(std::forward<Args>(args)...),
+        typeConverter(&typeConverter) {}
 
 protected:
   /// An optional type converter for use by this pattern.
@@ -374,17 +366,13 @@ class ConversionPattern : public RewritePattern {
   using RewritePattern::rewrite;
 };
 
-/// OpConversionPattern is a wrapper around ConversionPattern that allows for
-/// matching and rewriting against an instance of a derived operation class as
-/// opposed to a raw Operation.
+namespace detail {
+/// OpOrInterfaceConversionPatternBase is a wrapper around ConversionPattern
+/// that allows for matching and rewriting against an instance of a derived
+/// operation class or an Interface as opposed to a raw Operation.
 template <typename SourceOp>
-struct OpConversionPattern : public ConversionPattern {
-  OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
-      : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
-  OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
-                      PatternBenefit benefit = 1)
-      : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter,
-                          context) {}
+struct OpOrInterfaceConversionPatternBase : public ConversionPattern {
+  using ConversionPattern::ConversionPattern;
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
@@ -419,6 +407,39 @@ struct OpConversionPattern : public ConversionPattern {
 private:
   using ConversionPattern::matchAndRewrite;
 };
+} // namespace detail
+
+/// OpConversionPattern is a wrapper around ConversionPattern that allows for
+/// matching and rewriting against an instance of a derived operation class as
+/// opposed to a raw Operation.
+template <typename SourceOp>
+struct OpConversionPattern
+    : public detail::OpOrInterfaceConversionPatternBase<SourceOp> {
+  OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : detail::OpOrInterfaceConversionPatternBase<SourceOp>(
+            SourceOp::getOperationName(), benefit, context) {}
+  OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
+                      PatternBenefit benefit = 1)
+      : detail::OpOrInterfaceConversionPatternBase<SourceOp>(
+            typeConverter, SourceOp::getOperationName(), benefit, context) {}
+};
+
+/// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
+/// allows for matching and rewriting against an instance of an OpInterface
+/// class as opposed to a raw Operation.
+template <typename SourceOp>
+struct OpInterfaceConversionPattern
+    : public detail::OpOrInterfaceConversionPatternBase<SourceOp> {
+  OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : detail::OpOrInterfaceConversionPatternBase<SourceOp>(
+            Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
+            benefit, context) {}
+  OpInterfaceConversionPattern(TypeConverter &typeConverter,
+                               MLIRContext *context, PatternBenefit benefit = 1)
+      : detail::OpOrInterfaceConversionPatternBase<SourceOp>(
+            typeConverter, Pattern::MatchInterfaceOpTypeTag(),
+            SourceOp::getInterfaceID(), benefit, context) {}
+};
 
 /// Add a pattern to the given pattern list to convert the signature of a
 /// FunctionLike op with the given type converter. This only supports

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 72237fdafada..36d484fafe66 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -101,9 +101,9 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
 }
 
 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+    LinalgOp op, PatternRewriter &rewriter) const {
   // Only LinalgOp for which there is no specialized pattern go through this.
-  if (!isa<LinalgOp>(op) || isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
+  if (isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
     return failure();
 
   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
@@ -199,8 +199,8 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns(
   patterns.add<
       CopyOpToLibraryCallRewrite,
       CopyTransposeRewrite,
-      IndexedGenericOpToLibraryCallRewrite>(patterns.getContext());
-  patterns.add<LinalgOpToLibraryCallRewrite>();
+      IndexedGenericOpToLibraryCallRewrite,
+      LinalgOpToLibraryCallRewrite>(patterns.getContext());
   // clang-format on
 }
 

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 5ac7fdd6f5ef..03251098d5c9 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -450,7 +450,7 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
                                            MLIRContext *context,
                                            LLVMTypeConverter &typeConverter,
                                            PatternBenefit benefit)
-    : ConversionPattern(rootOpName, benefit, typeConverter, context) {}
+    : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
 
 //===----------------------------------------------------------------------===//
 // StructBuilder implementation

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5d9bb1f5cf03..9f8ade5b5fbc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2366,16 +2366,12 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct EraseDeadLinalgOp : public RewritePattern {
-  EraseDeadLinalgOp(PatternBenefit benefit = 1)
-      : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(Operation *op,
+  LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override {
-    auto linalgOp = dyn_cast<LinalgOp>(op);
-    if (!linalgOp)
-      return failure();
-    for (Value v : linalgOp.getShapedOperands()) {
+    for (Value v : op.getShapedOperands()) {
       // Linalg "inputs" may be either tensor or memref type.
       // tensor<0xelt_type> is a convention that may not always mean
       // "0 iterations". Only erase in cases we see memref<...x0x...>.
@@ -2383,7 +2379,7 @@ struct EraseDeadLinalgOp : public RewritePattern {
       if (!mt)
         continue;
       if (llvm::is_contained(mt.getShape(), 0)) {
-        rewriter.eraseOp(linalgOp);
+        rewriter.eraseOp(op);
         return success();
       }
     }
@@ -2391,19 +2387,14 @@ struct EraseDeadLinalgOp : public RewritePattern {
   }
 };
 
-struct FoldTensorCastOp : public RewritePattern {
-  FoldTensorCastOp(PatternBenefit benefit = 1)
-      : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(Operation *op,
+  LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override {
-    auto linalgOp = dyn_cast<LinalgOp>(op);
-    if (!linalgOp)
-      return failure();
-
     // If no operand comes from a tensor::CastOp and can be folded then fail.
     bool hasTensorCastOperand =
-        llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
+        llvm::any_of(op.getShapedOperands(), [&](Value v) {
           if (v.isa<BlockArgument>())
             return false;
           auto castOp = v.getDefiningOp<tensor::CastOp>();
@@ -2417,23 +2408,23 @@ struct FoldTensorCastOp : public RewritePattern {
     SmallVector<Value, 4> newOperands;
     newOperands.reserve(op->getNumOperands());
     // Inputs may fold.
-    for (Value v : linalgOp.getInputs()) {
+    for (Value v : op.getInputs()) {
       auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
       newOperands.push_back(
           canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
     }
     // Init tensors may fold, in which case the resultType must also change.
-    for (Value v : linalgOp.getOutputs()) {
+    for (Value v : op.getOutputs()) {
       auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
       bool fold = canFoldIntoConsumerOp(tensorCastOp);
       newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
       newResultTypes.push_back(newOperands.back().getType());
     }
-    auto extraOperands = linalgOp.getAssumedNonShapedOperands();
+    auto extraOperands = op.getAssumedNonShapedOperands();
     newOperands.append(extraOperands.begin(), extraOperands.end());
     // Clone op.
     Operation *newOp =
-        linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
+        op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
     SmallVector<Value, 4> replacements;
     replacements.reserve(newOp->getNumResults());
     for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
@@ -2500,17 +2491,15 @@ struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<memref::DimOp> {
 namespace {
 // Deduplicate redundant args of a linalg op.
 // An arg is redundant if it has the same Value and indexing map as another.
-struct DeduplicateInputs : public RewritePattern {
-  DeduplicateInputs(PatternBenefit benefit = 1)
-      : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(Operation *op,
+  LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override {
     // This pattern reduces the number of arguments of an op, which breaks
     // the invariants of semantically charged named ops.
     if (!isa<GenericOp, IndexedGenericOp>(op))
       return failure();
-    auto linalgOp = cast<LinalgOp>(op);
 
     // Associate each input to an equivalent "canonical" input that has the same
     // Value and indexing map.
@@ -2524,9 +2513,9 @@ struct DeduplicateInputs : public RewritePattern {
     // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
     // convenient.
     SmallVector<int, 6> canonicalInputIndices;
-    for (int i = 0, e = linalgOp.getNumInputs(); i != e; i++) {
-      Value input = linalgOp.getInput(i);
-      AffineMap indexingMap = linalgOp.getInputIndexingMap(i);
+    for (int i = 0, e = op.getNumInputs(); i != e; i++) {
+      Value input = op.getInput(i);
+      AffineMap indexingMap = op.getInputIndexingMap(i);
       // STL-like maps have a convenient behavior for our use case here. In the
       // case of duplicate keys, the insertion is rejected, and the returned
       // iterator gives access to the value already in the map.
@@ -2535,20 +2524,20 @@ struct DeduplicateInputs : public RewritePattern {
     }
 
     // If there are no duplicate args, then bail out.
-    if (canonicalInput.size() == linalgOp.getNumInputs())
+    if (canonicalInput.size() == op.getNumInputs())
       return failure();
 
     // The operands for the newly canonicalized op.
     SmallVector<Value, 6> newOperands;
-    for (auto v : llvm::enumerate(linalgOp.getInputs()))
+    for (auto v : llvm::enumerate(op.getInputs()))
       if (canonicalInputIndices[v.index()] == static_cast<int>(v.index()))
         newOperands.push_back(v.value());
-    llvm::append_range(newOperands, linalgOp.getOutputs());
-    llvm::append_range(newOperands, linalgOp.getAssumedNonShapedOperands());
+    llvm::append_range(newOperands, op.getOutputs());
+    llvm::append_range(newOperands, op.getAssumedNonShapedOperands());
 
     // Clone the old op with new operands.
-    Operation *newOp = linalgOp.clone(rewriter, op->getLoc(),
-                                      op->getResultTypes(), newOperands);
+    Operation *newOp =
+        op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
     auto newLinalgOp = cast<LinalgOp>(newOp);
 
     // Repair the indexing maps by filtering out the ones that have been
@@ -2573,7 +2562,7 @@ struct DeduplicateInputs : public RewritePattern {
     // Repair the payload entry block by RAUW'ing redundant arguments and
     // erasing them.
     Block &payload = newOp->getRegion(0).front();
-    for (int i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
+    for (int i = 0, e = op.getNumInputs(); i < e; i++) {
       // Iterate in reverse, so that we erase later args first, preventing the
       // argument list from shifting unexpectedly and invalidating all our
       // indices.
@@ -2597,13 +2586,12 @@ struct DeduplicateInputs : public RewritePattern {
 /// 1) All iterator types are parallel
 /// 2) The body contains just a yield operation with the yielded values being
 ///    the arguments corresponding to the operands.
-struct RemoveIdentityLinalgOps : public RewritePattern {
-  RemoveIdentityLinalgOps(PatternBenefit benefit = 1)
-      : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(Operation *op,
+  LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override {
-    if (auto copyOp = dyn_cast<CopyOp>(op)) {
+    if (auto copyOp = dyn_cast<CopyOp>(*op)) {
       assert(copyOp.hasBufferSemantics());
       if (copyOp.input() == copyOp.output() &&
           copyOp.inputPermutation() == copyOp.outputPermutation()) {
@@ -2614,11 +2602,10 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
 
     if (!isa<GenericOp, IndexedGenericOp>(op))
       return failure();
-    LinalgOp genericOp = cast<LinalgOp>(op);
-    if (!genericOp.hasTensorSemantics())
+    if (!op.hasTensorSemantics())
       return failure();
     // Check all indexing maps are identity.
-    if (llvm::any_of(genericOp.getIndexingMaps(),
+    if (llvm::any_of(op.getIndexingMaps(),
                      [](AffineMap map) { return !map.isIdentity(); }))
       return failure();
 
@@ -2633,7 +2620,7 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
 
     // Get the argument number of the returned values. That is the operand
     // number to use for replacing uses of this operation.
-    unsigned numIndexArgs = genericOp.getNumPayloadInductionVariables();
+    unsigned numIndexArgs = op.getNumPayloadInductionVariables();
     SmallVector<Value, 4> returnedArgs;
     for (Value yieldVal : yieldOp.values()) {
       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
@@ -2644,9 +2631,9 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
         return failure();
       returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs));
     }
-    if (returnedArgs.size() != genericOp.getOperation()->getNumResults())
+    if (returnedArgs.size() != op.getOperation()->getNumResults())
       return failure();
-    rewriter.replaceOp(genericOp, returnedArgs);
+    rewriter.replaceOp(op, returnedArgs);
     return success();
   }
 };
@@ -2656,8 +2643,7 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
   void XXX::getCanonicalizationPatterns(RewritePatternSet &results,            \
                                         MLIRContext *context) {                \
     results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,        \
-                RemoveIdentityLinalgOps>();                                    \
-    results.add<ReplaceDimOfLinalgOpResult>(context);                          \
+                RemoveIdentityLinalgOps, ReplaceDimOfLinalgOpResult>(context); \
   }                                                                            \
                                                                                \
   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 0f50e13b0acd..3ab86beb9367 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -175,17 +175,15 @@ class BufferizeFillOp : public OpConversionPattern<FillOp> {
 
 /// Generic conversion pattern that matches any LinalgOp. This avoids template
 /// instantiating one pattern for each LinalgOp.
-class BufferizeAnyLinalgOp : public ConversionPattern {
+class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
 public:
-  BufferizeAnyLinalgOp(TypeConverter &typeConverter)
-      : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
+  using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-
-    LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
-    if (!linalgOp)
+    // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
+    if (!op->hasAttr("operand_segment_sizes"))
       return failure();
 
     // We abuse the GenericOpAdaptor here.
@@ -193,32 +191,30 @@ class BufferizeAnyLinalgOp : public ConversionPattern {
     // linalg::LinalgOp interface ops.
     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
 
-    Location loc = linalgOp.getLoc();
+    Location loc = op.getLoc();
     SmallVector<Value, 2> newOutputBuffers;
 
-    if (failed(allocateBuffersForResults(loc, linalgOp, adaptor.outputs(),
+    if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
                                          newOutputBuffers, rewriter))) {
-      linalgOp.emitOpError()
-          << "Failed to allocate buffers for tensor results.";
-      return failure();
+      return op.emitOpError()
+             << "Failed to allocate buffers for tensor results.";
     }
 
     // Delegate to the linalg generic pattern.
-    if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
+    if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) {
       finalizeBufferAllocationForGenericOp<GenericOp>(
           rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
       return success();
     }
 
     // Delegate to the linalg indexed generic pattern.
-    if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) {
+    if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(*op)) {
       finalizeBufferAllocationForGenericOp<IndexedGenericOp>(
           rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
       return success();
     }
 
-    finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
-                             newOutputBuffers);
+    finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers);
     return success();
   }
 };
@@ -338,10 +334,10 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
 
 void mlir::linalg::populateLinalgBufferizePatterns(
     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
-  patterns.add<BufferizeAnyLinalgOp>(typeConverter);
   // TODO: Drop this once tensor constants work in standard.
   // clang-format off
   patterns.add<
+      BufferizeAnyLinalgOp,
       BufferizeFillOp,
       BufferizeInitTensorOp,
       SubTensorOpConverter,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index aece769721ca..85b9836d5d36 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -83,7 +83,7 @@ class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
 struct FunctionNonEntryBlockConversion : public ConversionPattern {
   FunctionNonEntryBlockConversion(StringRef functionLikeOpName,
                                   MLIRContext *ctx, TypeConverter &converter)
-      : ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {}
+      : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 321961d2deac..86b7eafa4ecc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -75,8 +75,8 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
 
 namespace {
 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
-  ConvertAnyElementwiseMappableOpOnRankedTensors()
-      : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+  ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const final {
     if (!isElementwiseMappableOpOnRankedTensors(op))
@@ -117,7 +117,8 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
 
 void mlir::populateElementwiseToLinalgConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>();
+  patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
+      patterns.getContext());
 }
 
 namespace {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index cb959a866935..af3f393997e7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -104,7 +104,7 @@ struct LinalgNamedOpGeneralizationPattern : RewritePattern {
   LinalgNamedOpGeneralizationPattern(MLIRContext *context,
                                      linalg::LinalgTransformationFilter marker,
                                      PatternBenefit benefit = 1)
-      : RewritePattern(benefit, MatchAnyOpTypeTag()),
+      : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
         marker(std::move(marker)) {}
 
   LogicalResult matchAndRewrite(Operation *rootOp,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 5bc6cefe489a..a6e296b0ea11 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -520,8 +520,9 @@ namespace {
 template <typename LoopType>
 class LinalgRewritePattern : public RewritePattern {
 public:
-  LinalgRewritePattern(ArrayRef<unsigned> interchangeVector)
-      : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()),
+  LinalgRewritePattern(MLIRContext *context,
+                       ArrayRef<unsigned> interchangeVector)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
         interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
 
   LogicalResult matchAndRewrite(Operation *op,
@@ -546,7 +547,7 @@ static void lowerLinalgToLoopsImpl(FuncOp funcOp,
                                    ArrayRef<unsigned> interchangeVector) {
   MLIRContext *context = funcOp.getContext();
   RewritePatternSet patterns(context);
-  patterns.add<LinalgRewritePattern<LoopType>>(interchangeVector);
+  patterns.add<LinalgRewritePattern<LoopType>>(context, interchangeVector);
   memref::DimOp::getCanonicalizationPatterns(patterns, context);
   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldAffineOp>(context);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4202cb268576..e7095a9f0b34 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -234,13 +234,13 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
     LinalgTransformationFilter filter, PatternBenefit benefit)
-    : RewritePattern(opName, {}, benefit, context), filter(filter),
+    : RewritePattern(opName, benefit, context), filter(filter),
       options(options) {}
 
 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
-    LinalgTilingOptions options, LinalgTransformationFilter filter,
-    PatternBenefit benefit)
-    : RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter),
+    MLIRContext *context, LinalgTilingOptions options,
+    LinalgTransformationFilter filter, PatternBenefit benefit)
+    : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
       options(options) {}
 
 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
@@ -306,7 +306,7 @@ mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
-    : RewritePattern(opName, {}, benefit, context),
+    : RewritePattern(opName, benefit, context, {}),
       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
       fusionOptions(fusionOptions), filter(filter),
       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
@@ -401,7 +401,7 @@ mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
     StringRef opName, MLIRContext *context,
     ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter filter,
     PatternBenefit benefit)
-    : RewritePattern(opName, {}, benefit, context), filter(filter),
+    : RewritePattern(opName, benefit, context, {}), filter(filter),
       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
 
 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
@@ -427,7 +427,7 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
     LinalgTransformationFilter filter, PatternBenefit benefit)
-    : RewritePattern(opName, {}, benefit, context), filter(filter),
+    : RewritePattern(opName, benefit, context, {}), filter(filter),
       options(options) {}
 
 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
@@ -453,13 +453,14 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
 }
 
 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
-    LinalgTransformationFilter filter, PatternBenefit benefit)
-    : RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {}
+    MLIRContext *context, LinalgTransformationFilter filter,
+    PatternBenefit benefit)
+    : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
 
 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
     PatternBenefit benefit)
-    : RewritePattern(opName, {}, benefit, context), filter(filter) {}
+    : RewritePattern(opName, benefit, context, {}), filter(filter) {}
 
 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
index b0283fe2601b..bf2dcd69e9ca 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
@@ -46,25 +46,21 @@ namespace {
 /// Only needed to support partial conversion of functions where this pattern
 /// ensures that the branch operation arguments matches up with the succesor
 /// block arguments.
-class BranchOpInterfaceTypeConversion : public ConversionPattern {
+class BranchOpInterfaceTypeConversion
+    : public OpInterfaceConversionPattern<BranchOpInterface> {
 public:
-  BranchOpInterfaceTypeConversion(TypeConverter &typeConverter,
-                                  MLIRContext *ctx)
-      : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
+  using OpInterfaceConversionPattern<
+      BranchOpInterface>::OpInterfaceConversionPattern;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    auto branchOp = dyn_cast<BranchOpInterface>(op);
-    if (!branchOp)
-      return failure();
-
     // For a branch operation, only some operands go to the target blocks, so
     // only rewrite those.
     SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
     for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
          succIdx < succEnd; ++succIdx) {
-      auto successorOperands = branchOp.getSuccessorOperands(succIdx);
+      auto successorOperands = op.getSuccessorOperands(succIdx);
       if (!successorOperands)
         continue;
       for (int idx = successorOperands->getBeginOperandIndex(),

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 354d5f31bf74..4482b5cb219a 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -29,23 +29,49 @@ unsigned short PatternBenefit::getBenefit() const {
 // Pattern
 //===----------------------------------------------------------------------===//
 
+//===----------------------------------------------------------------------===//
+// OperationName Root Constructors
+
 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
+                 MLIRContext *context, ArrayRef<StringRef> generatedNames)
+    : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
+              RootKind::OperationName, generatedNames, benefit, context) {}
+
+//===----------------------------------------------------------------------===//
+// MatchAnyOpTypeTag Root Constructors
+
+Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit,
+                 MLIRContext *context, ArrayRef<StringRef> generatedNames)
+    : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
+
+//===----------------------------------------------------------------------===//
+// MatchInterfaceOpTypeTag Root Constructors
+
+Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
+                 PatternBenefit benefit, MLIRContext *context,
+                 ArrayRef<StringRef> generatedNames)
+    : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
+              generatedNames, benefit, context) {}
+
+//===----------------------------------------------------------------------===//
+// MatchTraitOpTypeTag Root Constructors
+
+Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID,
+                 PatternBenefit benefit, MLIRContext *context,
+                 ArrayRef<StringRef> generatedNames)
+    : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
+              benefit, context) {}
+
+//===----------------------------------------------------------------------===//
+// General Constructors
+
+Pattern::Pattern(const void *rootValue, RootKind rootKind,
+                 ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
                  MLIRContext *context)
-    : rootKind(OperationName(rootName, context)), benefit(benefit) {}
-Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
-    : benefit(benefit) {}
-Pattern::Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
-                 PatternBenefit benefit, MLIRContext *context)
-    : Pattern(rootName, benefit, context) {
-  generatedOps.reserve(generatedNames.size());
-  std::transform(generatedNames.begin(), generatedNames.end(),
-                 std::back_inserter(generatedOps), [context](StringRef name) {
-                   return OperationName(name, context);
-                 });
-}
-Pattern::Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
-                 MLIRContext *context, MatchAnyOpTypeTag tag)
-    : Pattern(benefit, tag) {
+    : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
+      contextAndHasBoundedRecursion(context, false) {
+  if (generatedNames.empty())
+    return;
   generatedOps.reserve(generatedNames.size());
   std::transform(generatedNames.begin(), generatedNames.end(),
                  std::back_inserter(generatedOps), [context](StringRef name) {

diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index ea17f99deb9c..a81387f3f58e 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -45,10 +45,10 @@ PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
 
   // Check to see if this is pattern matches a specific operation type.
   if (Optional<StringRef> rootKind = matchOp.rootKind())
-    return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
-                              ctx);
-  return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
-                            MatchAnyOpTypeTag());
+    return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
+                              generatedOps);
+  return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
+                            generatedOps);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
index 9c81363f13f2..0b6a1cf2cdf3 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
@@ -55,7 +55,43 @@ FrozenRewritePatternSet::FrozenRewritePatternSet()
 
 FrozenRewritePatternSet::FrozenRewritePatternSet(RewritePatternSet &&patterns)
     : impl(std::make_shared<Impl>()) {
-  impl->nativePatterns = std::move(patterns.getNativePatterns());
+  // Functor used to walk all of the operations registered in the context. This
+  // is useful for patterns that get applied to multiple operations, such as
+  // interface and trait based patterns.
+  std::vector<AbstractOperation *> abstractOps;
+  auto addToOpsWhen = [&](std::unique_ptr<RewritePattern> &pattern,
+                          function_ref<bool(AbstractOperation *)> callbackFn) {
+    if (abstractOps.empty())
+      abstractOps = pattern->getContext()->getRegisteredOperations();
+    for (AbstractOperation *absOp : abstractOps) {
+      if (callbackFn(absOp)) {
+        OperationName opName(absOp);
+        impl->nativeOpSpecificPatternMap[opName].push_back(pattern.get());
+      }
+    }
+    impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
+  };
+
+  for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
+    if (Optional<OperationName> rootName = pat->getRootKind()) {
+      impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
+      impl->nativeOpSpecificPatternList.push_back(std::move(pat));
+      continue;
+    }
+    if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
+      addToOpsWhen(pat, [&](AbstractOperation *absOp) {
+        return absOp->hasInterface(*interfaceID);
+      });
+      continue;
+    }
+    if (Optional<TypeID> traitID = pat->getRootTraitID()) {
+      addToOpsWhen(pat, [&](AbstractOperation *absOp) {
+        return absOp->hasTrait(*traitID);
+      });
+      continue;
+    }
+    impl->nativeAnyOpPatterns.push_back(std::move(pat));
+  }
 
   // Generate the bytecode for the PDL patterns if any were provided.
   PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();

diff  --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index 3db598883360..0ece814bca47 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -15,6 +15,8 @@
 #include "ByteCode.h"
 #include "llvm/Support/Debug.h"
 
+#define DEBUG_TYPE "pattern-match"
+
 using namespace mlir;
 using namespace mlir::detail;
 
@@ -28,7 +30,14 @@ PatternApplicator::PatternApplicator(
 }
 PatternApplicator::~PatternApplicator() {}
 
-#define DEBUG_TYPE "pattern-match"
+/// Log a message for a pattern that is impossible to match.
+static void logImpossibleToMatch(const Pattern &pattern) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
+                 << "' because it is impossible to match or cannot lead "
+                    "to legal IR (by cost model)\n";
+  });
+}
 
 void PatternApplicator::applyCostModel(CostModel model) {
   // Apply the cost model to the bytecode patterns first, and then the native
@@ -38,23 +47,24 @@ void PatternApplicator::applyCostModel(CostModel model) {
       mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
   }
 
-  // Separate patterns by root kind to simplify lookup later on.
+  // Copy over the patterns so that we can sort by benefit based on the cost
+  // model. Patterns that are already impossible to match are ignored.
   patterns.clear();
-  anyOpPatterns.clear();
-  for (const auto &pat : frozenPatternList.getNativePatterns()) {
-    // If the pattern is always impossible to match, just ignore it.
-    if (pat.getBenefit().isImpossibleToMatch()) {
-      LLVM_DEBUG({
-        llvm::dbgs()
-            << "Ignoring pattern '" << pat.getRootKind()
-            << "' because it is impossible to match (by pattern benefit)\n";
-      });
-      continue;
+  for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
+    for (const RewritePattern *pattern : it.second) {
+      if (pattern->getBenefit().isImpossibleToMatch())
+        logImpossibleToMatch(*pattern);
+      else
+        patterns[it.first].push_back(pattern);
     }
-    if (Optional<OperationName> opName = pat.getRootKind())
-      patterns[*opName].push_back(&pat);
+  }
+  anyOpPatterns.clear();
+  for (const RewritePattern &pattern :
+       frozenPatternList.getMatchAnyOpNativePatterns()) {
+    if (pattern.getBenefit().isImpossibleToMatch())
+      logImpossibleToMatch(pattern);
     else
-      anyOpPatterns.push_back(&pat);
+      anyOpPatterns.push_back(&pattern);
   }
 
   // Sort the patterns using the provided cost model.
@@ -66,11 +76,7 @@ void PatternApplicator::applyCostModel(CostModel model) {
     // Special case for one pattern in the list, which is the most common case.
     if (list.size() == 1) {
       if (model(*list.front()).isImpossibleToMatch()) {
-        LLVM_DEBUG({
-          llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
-                       << "' because it is impossible to match or cannot lead "
-                          "to legal IR (by cost model)\n";
-        });
+        logImpossibleToMatch(*list.front());
         list.clear();
       }
       return;
@@ -84,14 +90,8 @@ void PatternApplicator::applyCostModel(CostModel model) {
     // Sort patterns with highest benefit first, and remove those that are
     // impossible to match.
     std::stable_sort(list.begin(), list.end(), cmp);
-    while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
-      LLVM_DEBUG({
-        llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
-                     << "' because it is impossible to match or cannot lead to "
-                        "legal IR (by cost model)\n";
-      });
-      list.pop_back();
-    }
+    while (!list.empty() && benefits[list.back()].isImpossibleToMatch())
+      logImpossibleToMatch(*list.pop_back_val());
   };
   for (auto &it : patterns)
     processPatternList(it.second);
@@ -100,7 +100,10 @@ void PatternApplicator::applyCostModel(CostModel model) {
 
 void PatternApplicator::walkAllPatterns(
     function_ref<void(const Pattern &)> walk) {
-  for (const Pattern &it : frozenPatternList.getNativePatterns())
+  for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
+    for (const auto &pattern : it.second)
+      walk(*pattern);
+  for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
     walk(it);
   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
     for (const Pattern &it : bytecode->getPatterns())

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 41d3eabb07ea..821d867b7d1d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2582,7 +2582,7 @@ namespace {
 struct FunctionLikeSignatureConversion : public ConversionPattern {
   FunctionLikeSignatureConversion(StringRef functionLikeOpName,
                                   MLIRContext *ctx, TypeConverter &converter)
-      : ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {}
+      : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
 
   /// Hook to implement combined matching and rewriting for FunctionLike ops.
   LogicalResult

diff  --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 530318cbb53f..11cd05aa9bec 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -149,8 +149,8 @@ void ConvertToTargetEnv::runOnFunction() {
 }
 
 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
-    : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op",
-                     {"spv.AtomicCompareExchangeWeak"}, 1, context) {}
+    : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
+                     context, {"spv.AtomicCompareExchangeWeak"}) {}
 
 LogicalResult
 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
@@ -170,8 +170,8 @@ ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
 }
 
 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
-    : RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1,
-                     context) {}
+    : RewritePattern("test.convert_to_bit_reverse_op", 1, context,
+                     {"spv.BitReverse"}) {}
 
 LogicalResult
 ConvertToBitReverse::matchAndRewrite(Operation *op,
@@ -185,8 +185,8 @@ ConvertToBitReverse::matchAndRewrite(Operation *op,
 
 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
     MLIRContext *context)
-    : RewritePattern("test.convert_to_group_non_uniform_ballot_op",
-                     {"spv.GroupNonUniformBallot"}, 1, context) {}
+    : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, context,
+                     {"spv.GroupNonUniformBallot"}) {}
 
 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
@@ -198,7 +198,7 @@ LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
 }
 
 ConvertToModule::ConvertToModule(MLIRContext *context)
-    : RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {}
+    : RewritePattern("test.convert_to_module_op", 1, context, {"spv.module"}) {}
 
 LogicalResult
 ConvertToModule::matchAndRewrite(Operation *op,
@@ -210,8 +210,8 @@ ConvertToModule::matchAndRewrite(Operation *op,
 }
 
 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
-    : RewritePattern("test.convert_to_subgroup_ballot_op",
-                     {"spv.SubgroupBallotKHR"}, 1, context) {}
+    : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
+                     {"spv.SubgroupBallotKHR"}) {}
 
 LogicalResult
 ConvertToSubgroupBallot::matchAndRewrite(Operation *op,

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e34e52a9ef4c..ec85b7e38c43 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -325,7 +325,7 @@ struct TestUndoBlockErase : public ConversionPattern {
 /// This patterns erases a region operation that has had a type conversion.
 struct TestDropOpSignatureConversion : public ConversionPattern {
   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
-      : ConversionPattern("test.drop_region_op", 1, converter, ctx) {}
+      : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
@@ -726,7 +726,8 @@ struct TestRemappedValue
 namespace {
 /// This pattern matches and removes any operation in the test dialect.
 struct RemoveTestDialectOps : public RewritePattern {
-  RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+  RemoveTestDialectOps(MLIRContext *context)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
 
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
@@ -741,7 +742,7 @@ struct TestUnknownRootOpDriver
     : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
   void runOnFunction() override {
     mlir::RewritePatternSet patterns(&getContext());
-    patterns.add<RemoveTestDialectOps>();
+    patterns.add<RemoveTestDialectOps>(&getContext());
 
     mlir::ConversionTarget target(getContext());
     target.addIllegalDialect<TestDialect>();

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 0bb46455b9ca..276a9f7c7fc3 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -183,8 +183,8 @@ static void applyPatterns(FuncOp funcOp) {
   // Linalg to vector contraction patterns.
   //===--------------------------------------------------------------------===//
   patterns.add<LinalgVectorizationPattern>(
-      LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
-          .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
+      ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
+               .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
 
   //===--------------------------------------------------------------------===//
   // Linalg generic permutation patterns.
@@ -258,8 +258,8 @@ static void fillL1TilingAndMatmulToVectorPatterns(
                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
                LinalgTransformationFilter(Identifier::get("VEC", ctx))));
   patternsVector.back().add<LinalgVectorizationPattern>(
-      LinalgTransformationFilter().addFilter(
-          [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
+      ctx, LinalgTransformationFilter().addFilter(
+               [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
 }
 
 //===----------------------------------------------------------------------===//
@@ -496,6 +496,7 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
   RewritePatternSet patterns(funcOp.getContext());
   patterns.add<LinalgVectorizationPattern>(
+      funcOp.getContext(),
       LinalgTransformationFilter()
           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
   patterns.add<PadTensorOpVectorizationPattern>(funcOp.getContext());

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index c1fa63c00eb7..548a2e3bbe7a 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -2075,8 +2075,8 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
     void {0}::getCanonicalizationPatterns(
         RewritePatternSet &results,
         MLIRContext *context) {{
-      results.add<EraseDeadLinalgOp>();
-      results.add<FoldTensorCastOp>();
+      results.add<EraseDeadLinalgOp>(context);
+      results.add<FoldTensorCastOp>(context);
     }
     LogicalResult {0}::fold(ArrayRef<Attribute>,
                             SmallVectorImpl<OpFoldResult> &) {{

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index e38e71bbd926..53a5807bd179 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -521,8 +521,8 @@ const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT(
 void {0}::getCanonicalizationPatterns(
     RewritePatternSet &results,
     MLIRContext *context) {{
-  results.add<EraseDeadLinalgOp>();
-  results.add<FoldTensorCastOp>();
+  results.add<EraseDeadLinalgOp>(context);
+  results.add<FoldTensorCastOp>(context);
 }
 LogicalResult {0}::fold(ArrayRef<Attribute>,
                         SmallVectorImpl<OpFoldResult> &) {{

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 68dddc285f26..28889de1ea60 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -626,8 +626,8 @@ void PatternEmitter::emit(StringRef rewriteName) {
                 make_range(locs.rbegin(), locs.rend()));
   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
   {0}(::mlir::MLIRContext *context)
-      : ::mlir::RewritePattern("{1}", {{)",
-                rewriteName, rootName);
+      : ::mlir::RewritePattern("{1}", {2}, context, {{)",
+                rewriteName, rootName, pattern.getBenefit());
   // Sort result operators by name.
   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
                                                          resultOps.end());
@@ -637,7 +637,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
     os << '"' << op->getOperationName() << '"';
   });
-  os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
+  os << "}) {}\n";
 
   // Emit matchAndRewrite() function.
   {

diff  --git a/mlir/unittests/Rewrite/PatternBenefit.cpp b/mlir/unittests/Rewrite/PatternBenefit.cpp
index 0d2f74ae4890..86b1aa13a8ca 100644
--- a/mlir/unittests/Rewrite/PatternBenefit.cpp
+++ b/mlir/unittests/Rewrite/PatternBenefit.cpp
@@ -38,8 +38,9 @@ TEST(PatternBenefitTest, BenefitOrder) {
   };
 
   struct Pattern2 : public RewritePattern {
-    Pattern2(bool *called)
-        : RewritePattern(/*benefit*/ 2, MatchAnyOpTypeTag{}), called(called) {}
+    Pattern2(MLIRContext *context, bool *called)
+        : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context),
+          called(called) {}
 
     mlir::LogicalResult
     matchAndRewrite(Operation * /*op*/,
@@ -58,7 +59,7 @@ TEST(PatternBenefitTest, BenefitOrder) {
   bool called2 = false;
 
   patterns.add<Pattern1>(&context, &called1);
-  patterns.add<Pattern2>(&called2);
+  patterns.add<Pattern2>(&context, &called2);
 
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
   PatternApplicator pa(frozenPatterns);


        


More information about the Mlir-commits mailing list