[Mlir-commits] [mlir] b99bd77 - [mlir][Pattern] Refactor the Pattern class into a "metadata only" class

River Riddle llvmlistbot at llvm.org
Mon Oct 26 18:05:23 PDT 2020


Author: River Riddle
Date: 2020-10-26T18:01:06-07:00
New Revision: b99bd771626fbbf8b9b29ce312d4151968796826

URL: https://github.com/llvm/llvm-project/commit/b99bd771626fbbf8b9b29ce312d4151968796826
DIFF: https://github.com/llvm/llvm-project/commit/b99bd771626fbbf8b9b29ce312d4151968796826.diff

LOG: [mlir][Pattern] Refactor the Pattern class into a "metadata only" class

The Pattern class was originally intended to be used for solely matching operations, but that use never materialized. All of the pattern infrastructure uses RewritePattern, and the infrastructure for pure matching(Matchers.h) is implemented inline. This means that this class isn't a useful abstraction at the moment, so this revision refactors it to solely encapsulate the "metadata" of a pattern. The metadata includes the various state describing a pattern; benefit, root operation, etc. The API on PatternApplicator is updated to now operate on `Pattern`s as nothing special from `RewritePattern` is necessary.

This refactoring is also necessary for the upcoming use of PDL patterns alongside C++ rewrite patterns.

Differential Revision: https://reviews.llvm.org/D86258

Added: 
    

Modified: 
    mlir/docs/PatternRewriter.md
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 2a2c30d98e04..ab93245395a3 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -174,10 +174,10 @@ Each driver is responsible for defining its own operation visitation order as
 well as pattern cost model, but the final application is performed via a
 `PatternApplicator` class. This class takes as input the
 `OwningRewritePatternList` and transforms the patterns based upon a provided
-cost model. This cost model computes a final benefit for a given rewrite
-pattern, using whatever driver specific information necessary. After a cost
-model has been computed, the driver may begin to match patterns against
-operations using `PatternApplicator::matchAndRewrite`.
+cost model. This cost model computes a final benefit for a given pattern, using
+whatever driver specific information necessary. After a cost model has been
+computed, the driver may begin to match patterns against operations using
+`PatternApplicator::matchAndRewrite`.
 
 An example is shown below:
 
@@ -209,7 +209,7 @@ void applyMyPatternDriver(Operation *op,
 
   // Create the applicator and apply our cost model.
   PatternApplicator applicator(patterns);
-  applicator.applyCostModel([](const RewritePattern &pattern) {
+  applicator.applyCostModel([](const Pattern &pattern) {
     // Apply a default cost model.
     // Note: This is just for demonstration, if the default cost model is truly
     //       desired `applicator.applyDefaultCostModel()` should be used

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index ea8f410460c5..ef6e3bd86258 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -58,15 +58,23 @@ class PatternBenefit {
 };
 
 //===----------------------------------------------------------------------===//
-// Pattern class
+// Pattern
 //===----------------------------------------------------------------------===//
 
-/// Instances of Pattern can be matched against SSA IR.  These matches get used
-/// in ways dependent on their subclasses and the driver doing the matching.
-/// For example, RewritePatterns implement a rewrite from one matched pattern
-/// to a replacement DAG tile.
+/// This class contains all of the data related to a pattern, but does not
+/// contain any methods or logic for the actual matching. This class is solely
+/// used to interface with the metadata of a pattern, such as the benefit or
+/// root operation.
 class Pattern {
 public:
+  /// Return a list of operations that may be generated when rewriting an
+  /// operation instance with this pattern.
+  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
+
+  /// Return the root node that this pattern matches. Patterns that can match
+  /// multiple root types return None.
+  Optional<OperationName> getRootKind() const { return rootKind; }
+
   /// Return the benefit (the inverse of "cost") of matching this pattern.  The
   /// benefit of a Pattern is always static - rewrites that may have dynamic
   /// benefit can be instantiated multiple times (
diff erent Pattern instances)
@@ -74,19 +82,11 @@ class Pattern {
   /// condition predicates.
   PatternBenefit getBenefit() const { return benefit; }
 
-  /// Return the root node that this pattern matches. Patterns that can match
-  /// multiple root types return None.
-  Optional<OperationName> getRootKind() const { return rootKind; }
-
-  //===--------------------------------------------------------------------===//
-  // Implementation hooks for patterns to implement.
-  //===--------------------------------------------------------------------===//
-
-  /// Attempt to match against code rooted at the specified operation,
-  /// which is the same operation code as getRootKind().
-  virtual LogicalResult match(Operation *op) const = 0;
-
-  virtual ~Pattern() {}
+  /// Returns true if this pattern is known to result in recursive application,
+  /// i.e. this pattern may generate IR that also matches this pattern, but is
+  /// known to bound the recursion. This signals to a rewrite driver that it is
+  /// safe to apply this pattern recursively to generated IR.
+  bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
 
 protected:
   /// This class acts as a special tag that makes the desire to match "any"
@@ -94,19 +94,38 @@ class Pattern {
   /// feature, and ensures that the user is making a conscious decision.
   struct MatchAnyOpTypeTag {};
 
-  /// This constructor is used for patterns that match against a specific
-  /// operation type. The `benefit` is the expected benefit of matching this
-  /// pattern.
+  /// Construct a pattern with a certain benefit that matches the operation
+  /// with the given root name.
   Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
-
-  /// This constructor is used when a pattern may match against multiple
-  /// 
diff erent types of operations. The `benefit` is the expected benefit of
-  /// matching this pattern. `MatchAnyOpTypeTag` is just a tag to ensure that
-  /// the "match any" behavior is what the user actually desired,
-  /// `MatchAnyOpTypeTag()` should always be supplied here.
-  Pattern(PatternBenefit benefit, MatchAnyOpTypeTag);
+  /// Construct a pattern with a certain benefit that matches any operation
+  /// type. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
+  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
+  /// always be supplied here.
+  Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag);
+  /// Construct a pattern with a certain benefit that matches the operation with
+  /// the given root name. `generatedNames` contains the names of operations
+  /// that may be generated during a successful rewrite.
+  Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
+          PatternBenefit benefit, MLIRContext *context);
+  /// Construct a pattern that may match any operation type. `generatedNames`
+  /// contains the names of operations that may be generated during a successful
+  /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
+  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
+  /// always be supplied here.
+  Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
+          MLIRContext *context, MatchAnyOpTypeTag tag);
+
+  /// Set the flag detailing if this pattern has bounded rewrite recursion or
+  /// not.
+  void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
+    hasBoundedRecursion = hasBoundedRecursionArg;
+  }
 
 private:
+  /// A list of the potential operations that may be generated when rewriting
+  /// an op with this pattern.
+  SmallVector<OperationName, 2> generatedOps;
+
   /// The root operation of the pattern. If the pattern matches a specific
   /// operation, this contains the name of that operation. Contains None
   /// otherwise.
@@ -115,9 +134,14 @@ class Pattern {
   /// The expected benefit of matching this pattern.
   const PatternBenefit benefit;
 
-  virtual void anchor();
+  /// A boolean flag of whether this pattern has bounded recursion or not.
+  bool hasBoundedRecursion = false;
 };
 
+//===----------------------------------------------------------------------===//
+// RewritePattern
+//===----------------------------------------------------------------------===//
+
 /// RewritePattern is the common base class for all DAG to DAG replacements.
 /// There are two possible usages of this class:
 ///   * Multi-step RewritePattern with "match" and "rewrite"
@@ -129,6 +153,8 @@ class Pattern {
 ///
 class RewritePattern : public Pattern {
 public:
+  virtual ~RewritePattern() {}
+
   /// Rewrite the IR rooted at the specified operation with the result of
   /// this pattern, generating any new operations with the specified
   /// builder.  If an unexpected error is encountered (an internal
@@ -138,7 +164,7 @@ class RewritePattern : public Pattern {
 
   /// Attempt to match against code rooted at the specified operation,
   /// which is the same operation code as getRootKind().
-  LogicalResult match(Operation *op) const override;
+  virtual LogicalResult match(Operation *op) const;
 
   /// Attempt to match against code rooted at the specified operation,
   /// which is the same operation code as getRootKind(). If successful, this
@@ -152,44 +178,12 @@ class RewritePattern : public Pattern {
     return failure();
   }
 
-  /// Returns true if this pattern is known to result in recursive application,
-  /// i.e. this pattern may generate IR that also matches this pattern, but is
-  /// known to bound the recursion. This signals to a rewriter that it is safe
-  /// to apply this pattern recursively to generated IR.
-  virtual bool hasBoundedRewriteRecursion() const { return false; }
-
-  /// Return a list of operations that may be generated when rewriting an
-  /// operation instance with this pattern.
-  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
-
 protected:
-  /// Construct a rewrite pattern with a certain benefit that matches the
-  /// operation with the given root name.
-  RewritePattern(StringRef rootName, PatternBenefit benefit,
-                 MLIRContext *context)
-      : Pattern(rootName, benefit, context) {}
-  /// Construct a rewrite pattern with a certain benefit that matches any
-  /// operation type. `MatchAnyOpTypeTag` is just a tag to ensure that the
-  /// "match any" behavior is what the user actually desired,
-  /// `MatchAnyOpTypeTag()` should always be supplied here.
-  RewritePattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
-      : Pattern(benefit, tag) {}
-  /// Construct a rewrite pattern with a certain benefit that matches the
-  /// operation with the given root name. `generatedNames` contains the names of
-  /// operations that may be generated during a successful rewrite.
-  RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
-                 PatternBenefit benefit, MLIRContext *context);
-  /// Construct a rewrite pattern that may match any operation type.
-  /// `generatedNames` contains the names of operations that may be generated
-  /// during a successful rewrite. `MatchAnyOpTypeTag` is just a tag to ensure
-  /// that the "match any" behavior is what the user actually desired,
-  /// `MatchAnyOpTypeTag()` should always be supplied here.
-  RewritePattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
-                 MLIRContext *context, MatchAnyOpTypeTag tag);
+  /// Inherit the base constructors from `Pattern`.
+  using Pattern::Pattern;
 
-  /// A list of the potential operations that may be generated when rewriting
-  /// an op with this pattern.
-  SmallVector<OperationName, 2> generatedOps;
+  /// An anchor for the virtual table.
+  virtual void anchor();
 };
 
 /// OpRewritePattern is a wrapper around RewritePattern that allows for
@@ -232,7 +226,7 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
 };
 
 //===----------------------------------------------------------------------===//
-// PatternRewriter class
+// PatternRewriter
 //===----------------------------------------------------------------------===//
 
 /// This class coordinates the application of a pattern to the current function,
@@ -498,7 +492,7 @@ class PatternApplicator {
   /// pattern. Users can query contained patterns and pass analysis results to
   /// applyCostModel. Patterns to be discarded should have a benefit of
   /// `impossibleToMatch`.
-  using CostModel = function_ref<PatternBenefit(const RewritePattern &)>;
+  using CostModel = function_ref<PatternBenefit(const Pattern &)>;
 
   explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
       : owningPatternList(owningPatternList) {}
@@ -512,11 +506,11 @@ class PatternApplicator {
   /// onFailure: called when a pattern fails to match to perform cleanup.
   /// onSuccess: called when a pattern match succeeds; return failure() to
   ///            invalidate the match and try another pattern.
-  LogicalResult matchAndRewrite(
-      Operation *op, PatternRewriter &rewriter,
-      function_ref<bool(const RewritePattern &)> canApply = {},
-      function_ref<void(const RewritePattern &)> onFailure = {},
-      function_ref<LogicalResult(const RewritePattern &)> onSuccess = {});
+  LogicalResult
+  matchAndRewrite(Operation *op, PatternRewriter &rewriter,
+                  function_ref<bool(const Pattern &)> canApply = {},
+                  function_ref<void(const Pattern &)> onFailure = {},
+                  function_ref<LogicalResult(const Pattern &)> onSuccess = {});
 
   /// Apply a cost model to the patterns within this applicator.
   void applyCostModel(CostModel model);
@@ -524,22 +518,22 @@ class PatternApplicator {
   /// Apply the default cost model that solely uses the pattern's static
   /// benefit.
   void applyDefaultCostModel() {
-    applyCostModel(
-        [](const RewritePattern &pattern) { return pattern.getBenefit(); });
+    applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
   }
 
-  /// Walk all of the rewrite patterns within the applicator.
-  void walkAllPatterns(function_ref<void(const RewritePattern &)> walk);
+  /// Walk all of the patterns within the applicator.
+  void walkAllPatterns(function_ref<void(const Pattern &)> walk);
 
 private:
   /// Attempt to match and rewrite the given op with the given pattern, allowing
   /// a predicate to decide if a pattern can be applied or not, and hooks for if
   /// the pattern match was a success or failure.
-  LogicalResult matchAndRewrite(
-      Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
-      function_ref<bool(const RewritePattern &)> canApply,
-      function_ref<void(const RewritePattern &)> onFailure,
-      function_ref<LogicalResult(const RewritePattern &)> onSuccess);
+  LogicalResult
+  matchAndRewrite(Operation *op, const RewritePattern &pattern,
+                  PatternRewriter &rewriter,
+                  function_ref<bool(const Pattern &)> canApply,
+                  function_ref<void(const Pattern &)> onFailure,
+                  function_ref<LogicalResult(const Pattern &)> onSuccess);
 
   /// The list that owns the patterns used within this applicator.
   const OwningRewritePatternList &owningPatternList;

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a95e1006e381..71eaf0d59e68 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1042,7 +1042,12 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
 class VectorInsertStridedSliceOpSameRankRewritePattern
     : public OpRewritePattern<InsertStridedSliceOp> {
 public:
-  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+  VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
+      : OpRewritePattern<InsertStridedSliceOp>(ctx) {
+    // This pattern creates recursive InsertStridedSliceOp, but the recursion is
+    // bounded as the rank is strictly decreasing.
+    setHasBoundedRewriteRecursion();
+  }
 
   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
@@ -1093,9 +1098,6 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
     rewriter.replaceOp(op, res);
     return success();
   }
-  /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
-  /// bounded as the rank is strictly decreasing.
-  bool hasBoundedRewriteRecursion() const final { return true; }
 };
 
 /// Returns the strides if the memory underlying `memRefType` has a contiguous
@@ -1505,7 +1507,12 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
 class VectorExtractStridedSliceOpConversion
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
-  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+  VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
+      : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
+    // This pattern creates recursive ExtractStridedSliceOp, but the recursion
+    // is bounded as the rank is strictly decreasing.
+    setHasBoundedRewriteRecursion();
+  }
 
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
@@ -1552,9 +1559,6 @@ class VectorExtractStridedSliceOpConversion
     rewriter.replaceOp(op, res);
     return success();
   }
-  /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
-  /// bounded as the rank is strictly decreasing.
-  bool hasBoundedRewriteRecursion() const final { return true; }
 };
 
 } // namespace

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index d1da8d1d8f26..136d01966688 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -16,6 +16,10 @@ using namespace mlir;
 
 #define DEBUG_TYPE "pattern-match"
 
+//===----------------------------------------------------------------------===//
+// PatternBenefit
+//===----------------------------------------------------------------------===//
+
 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
          "This pattern match benefit is too large to represent");
@@ -27,34 +31,16 @@ unsigned short PatternBenefit::getBenefit() const {
 }
 
 //===----------------------------------------------------------------------===//
-// Pattern implementation
+// Pattern
 //===----------------------------------------------------------------------===//
 
 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
                  MLIRContext *context)
     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
-Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag)
+Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
     : benefit(benefit) {}
-
-// Out-of-line vtable anchor.
-void Pattern::anchor() {}
-
-//===----------------------------------------------------------------------===//
-// RewritePattern and PatternRewriter implementation
-//===----------------------------------------------------------------------===//
-
-void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
-  llvm_unreachable("need to implement either matchAndRewrite or one of the "
-                   "rewrite functions!");
-}
-
-LogicalResult RewritePattern::match(Operation *op) const {
-  llvm_unreachable("need to implement either match or matchAndRewrite!");
-}
-
-RewritePattern::RewritePattern(StringRef rootName,
-                               ArrayRef<StringRef> generatedNames,
-                               PatternBenefit benefit, MLIRContext *context)
+Pattern::Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
+                 PatternBenefit benefit, MLIRContext *context)
     : Pattern(rootName, benefit, context) {
   generatedOps.reserve(generatedNames.size());
   std::transform(generatedNames.begin(), generatedNames.end(),
@@ -62,9 +48,8 @@ RewritePattern::RewritePattern(StringRef rootName,
                    return OperationName(name, context);
                  });
 }
-RewritePattern::RewritePattern(ArrayRef<StringRef> generatedNames,
-                               PatternBenefit benefit, MLIRContext *context,
-                               MatchAnyOpTypeTag tag)
+Pattern::Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
+                 MLIRContext *context, MatchAnyOpTypeTag tag)
     : Pattern(benefit, tag) {
   generatedOps.reserve(generatedNames.size());
   std::transform(generatedNames.begin(), generatedNames.end(),
@@ -73,6 +58,26 @@ RewritePattern::RewritePattern(ArrayRef<StringRef> generatedNames,
                  });
 }
 
+//===----------------------------------------------------------------------===//
+// RewritePattern
+//===----------------------------------------------------------------------===//
+
+void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
+  llvm_unreachable("need to implement either matchAndRewrite or one of the "
+                   "rewrite functions!");
+}
+
+LogicalResult RewritePattern::match(Operation *op) const {
+  llvm_unreachable("need to implement either match or matchAndRewrite!");
+}
+
+/// Out-of-line vtable anchor.
+void RewritePattern::anchor() {}
+
+//===----------------------------------------------------------------------===//
+// PatternRewriter
+//===----------------------------------------------------------------------===//
+
 PatternRewriter::~PatternRewriter() {
   // Out of line to provide a vtable anchor for the class.
 }
@@ -201,7 +206,7 @@ void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
 }
 
 //===----------------------------------------------------------------------===//
-// PatternMatcher implementation
+// PatternApplicator
 //===----------------------------------------------------------------------===//
 
 void PatternApplicator::applyCostModel(CostModel model) {
@@ -266,16 +271,16 @@ void PatternApplicator::applyCostModel(CostModel model) {
 }
 
 void PatternApplicator::walkAllPatterns(
-    function_ref<void(const RewritePattern &)> walk) {
+    function_ref<void(const Pattern &)> walk) {
   for (auto &it : owningPatternList)
     walk(*it);
 }
 
 LogicalResult PatternApplicator::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter,
-    function_ref<bool(const RewritePattern &)> canApply,
-    function_ref<void(const RewritePattern &)> onFailure,
-    function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
+    function_ref<bool(const Pattern &)> canApply,
+    function_ref<void(const Pattern &)> onFailure,
+    function_ref<LogicalResult(const Pattern &)> onSuccess) {
   // Check to see if there are patterns matching this specific operation type.
   MutableArrayRef<RewritePattern *> opPatterns;
   auto patternIt = patterns.find(op->getName());
@@ -315,9 +320,9 @@ LogicalResult PatternApplicator::matchAndRewrite(
 
 LogicalResult PatternApplicator::matchAndRewrite(
     Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
-    function_ref<bool(const RewritePattern &)> canApply,
-    function_ref<void(const RewritePattern &)> onFailure,
-    function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
+    function_ref<bool(const Pattern &)> canApply,
+    function_ref<void(const Pattern &)> onFailure,
+    function_ref<LogicalResult(const Pattern &)> onSuccess) {
   // Check that the pattern can be applied.
   if (canApply && !canApply(pattern))
     return failure();

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 5f6c97272922..692cd494324e 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1452,7 +1452,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
 
 namespace {
 /// A set of rewrite patterns that can be used to legalize a given operation.
-using LegalizationPatterns = SmallVector<const RewritePattern *, 1>;
+using LegalizationPatterns = SmallVector<const Pattern *, 1>;
 
 /// This class defines a recursive operation legalizer.
 class OperationLegalizer {
@@ -1484,12 +1484,11 @@ class OperationLegalizer {
 
   /// Return true if the given pattern may be applied to the given operation,
   /// false otherwise.
-  bool canApplyPattern(Operation *op, const RewritePattern &pattern,
+  bool canApplyPattern(Operation *op, const Pattern &pattern,
                        ConversionPatternRewriter &rewriter);
 
   /// Legalize the resultant IR after successfully applying the given pattern.
-  LogicalResult legalizePatternResult(Operation *op,
-                                      const RewritePattern &pattern,
+  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
                                       ConversionPatternRewriter &rewriter,
                                       RewriterState &curState);
 
@@ -1546,7 +1545,7 @@ class OperationLegalizer {
       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
 
   /// The current set of patterns that have been applied.
-  SmallPtrSet<const RewritePattern *, 8> appliedPatterns;
+  SmallPtrSet<const Pattern *, 8> appliedPatterns;
 
   /// The legalization information provided by the target.
   ConversionTarget ⌖
@@ -1697,13 +1696,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
   auto &rewriterImpl = rewriter.getImpl();
 
   // Functor that returns if the given pattern may be applied.
-  auto canApply = [&](const RewritePattern &pattern) {
+  auto canApply = [&](const Pattern &pattern) {
     return canApplyPattern(op, pattern, rewriter);
   };
 
   // Functor that cleans up the rewriter state after a pattern failed to match.
   RewriterState curState = rewriterImpl.getCurrentState();
-  auto onFailure = [&](const RewritePattern &pattern) {
+  auto onFailure = [&](const Pattern &pattern) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
     rewriterImpl.resetState(curState);
     appliedPatterns.erase(&pattern);
@@ -1711,7 +1710,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
 
   // Functor that performs additional legalization when a pattern is
   // successfully applied.
-  auto onSuccess = [&](const RewritePattern &pattern) {
+  auto onSuccess = [&](const Pattern &pattern) {
     auto result = legalizePatternResult(op, pattern, rewriter, curState);
     appliedPatterns.erase(&pattern);
     if (failed(result))
@@ -1724,8 +1723,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
                                     onSuccess);
 }
 
-bool OperationLegalizer::canApplyPattern(Operation *op,
-                                         const RewritePattern &pattern,
+bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
                                          ConversionPatternRewriter &rewriter) {
   LLVM_DEBUG({
     auto &os = rewriter.getImpl().logger;
@@ -1747,9 +1745,10 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
   return true;
 }
 
-LogicalResult OperationLegalizer::legalizePatternResult(
-    Operation *op, const RewritePattern &pattern,
-    ConversionPatternRewriter &rewriter, RewriterState &curState) {
+LogicalResult
+OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
+                                          ConversionPatternRewriter &rewriter,
+                                          RewriterState &curState) {
   auto &impl = rewriter.getImpl();
 
 #ifndef NDEBUG
@@ -1877,13 +1876,12 @@ void OperationLegalizer::buildLegalizationGraph(
   // generate it.
   DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
   // A mapping between an operation and any currently invalid patterns it has.
-  DenseMap<OperationName, SmallPtrSet<const RewritePattern *, 2>>
-      invalidPatterns;
+  DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
   // A worklist of patterns to consider for legality.
-  llvm::SetVector<const RewritePattern *> patternWorklist;
+  llvm::SetVector<const Pattern *> patternWorklist;
 
   // Build the mapping from operations to the parent ops that may generate them.
-  applicator.walkAllPatterns([&](const RewritePattern &pattern) {
+  applicator.walkAllPatterns([&](const Pattern &pattern) {
     Optional<OperationName> root = pattern.getRootKind();
 
     // If the pattern has no specific root, we can't analyze the relationship
@@ -1914,7 +1912,7 @@ void OperationLegalizer::buildLegalizationGraph(
   // recurse into itself. It would be better to perform this kind of filtering
   // at a higher level than here anyways.
   if (!anyOpLegalizerPatterns.empty()) {
-    for (const RewritePattern *pattern : patternWorklist)
+    for (const Pattern *pattern : patternWorklist)
       legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
     return;
   }
@@ -1964,15 +1962,15 @@ void OperationLegalizer::computeLegalizationGraphBenefit(
   // Apply a cost model to the pattern applicator. We order patterns first by
   // depth then benefit. `legalizerPatterns` contains per-op patterns by
   // decreasing benefit.
-  applicator.applyCostModel([&](const RewritePattern &p) {
-    ArrayRef<const RewritePattern *> orderedPatternList;
-    if (Optional<OperationName> rootName = p.getRootKind())
+  applicator.applyCostModel([&](const Pattern &pattern) {
+    ArrayRef<const Pattern *> orderedPatternList;
+    if (Optional<OperationName> rootName = pattern.getRootKind())
       orderedPatternList = legalizerPatterns[*rootName];
     else
       orderedPatternList = anyOpLegalizerPatterns;
 
     // If the pattern is not found, then it was removed and cannot be matched.
-    auto it = llvm::find(orderedPatternList, &p);
+    auto it = llvm::find(orderedPatternList, &pattern);
     if (it == orderedPatternList.end())
       return PatternBenefit::impossibleToMatch();
 
@@ -2014,9 +2012,9 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
   unsigned minDepth = std::numeric_limits<unsigned>::max();
 
   // Compute the depth for each pattern within the set.
-  SmallVector<std::pair<const RewritePattern *, unsigned>, 4> patternsByDepth;
+  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
   patternsByDepth.reserve(patterns.size());
-  for (const RewritePattern *pattern : patterns) {
+  for (const Pattern *pattern : patterns) {
     unsigned depth = 0;
     for (auto generatedOp : pattern->getGeneratedOps()) {
       unsigned generatedOpDepth = computeOpLegalizationDepth(
@@ -2037,8 +2035,8 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
   // Sort the patterns by those likely to be the most beneficial.
   llvm::array_pod_sort(
       patternsByDepth.begin(), patternsByDepth.end(),
-      [](const std::pair<const RewritePattern *, unsigned> *lhs,
-         const std::pair<const RewritePattern *, unsigned> *rhs) {
+      [](const std::pair<const Pattern *, unsigned> *lhs,
+         const std::pair<const Pattern *, unsigned> *rhs) {
         // First sort by the smaller pattern legalization depth.
         if (lhs->second != rhs->second)
           return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 282d31065549..04a21ec01fe6 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -452,7 +452,11 @@ struct TestNonRootReplacement : public RewritePattern {
 /// bounded recursion.
 struct TestBoundedRecursiveRewrite
     : public OpRewritePattern<TestRecursiveRewriteOp> {
-  using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
+  TestBoundedRecursiveRewrite(MLIRContext *ctx)
+      : OpRewritePattern<TestRecursiveRewriteOp>(ctx) {
+    // The conversion target handles bounding the recursion of this pattern.
+    setHasBoundedRewriteRecursion();
+  }
 
   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
                                 PatternRewriter &rewriter) const final {
@@ -462,9 +466,6 @@ struct TestBoundedRecursiveRewrite
     });
     return success();
   }
-
-  /// The conversion target handles bounding the recursion of this pattern.
-  bool hasBoundedRewriteRecursion() const final { return true; }
 };
 
 struct TestNestedOpCreationUndoRewrite


        


More information about the Mlir-commits mailing list