[Mlir-commits] [mlir] 3e98fbf - [mlir] Refactor RewritePatternMatcher into a new PatternApplicator class.

River Riddle llvmlistbot at llvm.org
Thu Jun 18 14:02:20 PDT 2020


Author: River Riddle
Date: 2020-06-18T13:58:47-07:00
New Revision: 3e98fbf4f522bf475ae2199a29f5411c86973ce5

URL: https://github.com/llvm/llvm-project/commit/3e98fbf4f522bf475ae2199a29f5411c86973ce5
DIFF: https://github.com/llvm/llvm-project/commit/3e98fbf4f522bf475ae2199a29f5411c86973ce5.diff

LOG: [mlir] Refactor RewritePatternMatcher into a new PatternApplicator class.

This class enables for abstracting more of the details for the rewrite process, and will allow for clients to apply specific cost models to the pattern list. This allows for DialectConversion and the GreedyPatternRewriter to share the same underlying matcher implementation. This also simplifies the plumbing necessary to support dynamic patterns.

Differential Revision: https://reviews.llvm.org/D81985

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8178f71ec43d..8692d5ba0005 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -30,7 +30,8 @@ class PatternBenefit {
   enum { ImpossibleToMatchSentinel = 65535 };
 
 public:
-  /*implicit*/ PatternBenefit(unsigned benefit);
+  PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
+  PatternBenefit(unsigned benefit);
   PatternBenefit(const PatternBenefit &) = default;
   PatternBenefit &operator=(const PatternBenefit &) = default;
 
@@ -48,9 +49,11 @@ class PatternBenefit {
   bool operator<(const PatternBenefit &rhs) const {
     return representation < rhs.representation;
   }
+  bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
+  bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
+  bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
 
 private:
-  PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
   unsigned short representation;
 };
 
@@ -384,6 +387,9 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
 // Pattern-driven rewriters
 //===----------------------------------------------------------------------===//
 
+//===----------------------------------------------------------------------===//
+// OwningRewritePatternList
+
 class OwningRewritePatternList {
   using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
 
@@ -401,6 +407,7 @@ class OwningRewritePatternList {
   PatternListT::iterator end() { return patterns.end(); }
   PatternListT::const_iterator begin() const { return patterns.begin(); }
   PatternListT::const_iterator end() const { return patterns.end(); }
+  PatternListT::size_type size() const { return patterns.size(); }
   void clear() { patterns.clear(); }
 
   //===--------------------------------------------------------------------===//
@@ -419,60 +426,100 @@ class OwningRewritePatternList {
     // types 'Ts'. This magic is necessary due to a limitation in the places
     // that a parameter pack can be expanded in c++11.
     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
-    using dummy = int[];
-    (void)dummy{
+    (void)std::initializer_list<int>{
         0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
     return *this;
   }
 
+  /// Add the given pattern to the pattern list.
+  void insert(std::unique_ptr<RewritePattern> pattern) {
+    patterns.emplace_back(std::move(pattern));
+  }
+
 private:
   PatternListT patterns;
 };
 
-/// This class manages optimization and execution of a group of rewrite
-/// patterns, providing an API for finding and applying, the best match against
-/// a given node.
-///
-class RewritePatternMatcher {
+//===----------------------------------------------------------------------===//
+// PatternApplicator
+
+/// This class manages the application of a group of rewrite patterns, with a
+/// user-provided cost model.
+class PatternApplicator {
 public:
-  /// Create a RewritePatternMatcher with the specified set of patterns.
-  explicit RewritePatternMatcher(const OwningRewritePatternList &patterns);
+  /// The cost model dynamically assigns a PatternBenefit to a particular
+  /// pattern. Users can query contained patterns and pass analysis results to
+  /// applyCostModel. Patterns to be discarded should have a benefit of
+  /// `impossibleToMatch`.
+  using CostModel = function_ref<PatternBenefit(const RewritePattern &)>;
+
+  explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
+      : owningPatternList(owningPatternList) {}
+
+  /// Attempt to match and rewrite the given op with any pattern, allowing a
+  /// predicate to decide if a pattern can be applied or not, and hooks for if
+  /// the pattern match was a success or failure.
+  ///
+  /// canApply:  called before each match and rewrite attempt; return false to
+  ///            skip pattern.
+  /// onFailure: called when a pattern fails to match to perform cleanup.
+  /// onSuccess: called when a pattern match succeeds; return failure() to
+  ///            invalidate the match and try another pattern.
+  LogicalResult matchAndRewrite(
+      Operation *op, PatternRewriter &rewriter,
+      function_ref<bool(const RewritePattern &)> canApply = {},
+      function_ref<void(const RewritePattern &)> onFailure = {},
+      function_ref<LogicalResult(const RewritePattern &)> onSuccess = {});
+
+  /// Apply a cost model to the patterns within this applicator.
+  void applyCostModel(CostModel model);
+
+  /// Apply the default cost model that solely uses the pattern's static
+  /// benefit.
+  void applyDefaultCostModel() {
+    applyCostModel(
+        [](const RewritePattern &pattern) { return pattern.getBenefit(); });
+  }
 
-  /// Try to match the given operation to a pattern and rewrite it. Return
-  /// true if any pattern matches.
-  bool matchAndRewrite(Operation *op, PatternRewriter &rewriter);
+  /// Walk all of the rewrite patterns within the applicator.
+  void walkAllPatterns(function_ref<void(const RewritePattern &)> walk);
 
 private:
-  RewritePatternMatcher(const RewritePatternMatcher &) = delete;
-  void operator=(const RewritePatternMatcher &) = delete;
+  /// The list that owns the patterns used within this applicator.
+  const OwningRewritePatternList &owningPatternList;
 
-  /// The group of patterns that are matched for optimization through this
-  /// matcher.
-  std::vector<RewritePattern *> patterns;
+  /// The set of patterns to match for each operation, stable sorted by benefit.
+  DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
 };
 
+//===----------------------------------------------------------------------===//
+// applyPatternsGreedily
+//===----------------------------------------------------------------------===//
+
 /// Rewrite the regions of the specified operation, which must be isolated from
 /// above, by repeatedly applying the highest benefit patterns in a greedy
-/// work-list driven manner. Return true if no more patterns can be matched in
-/// the result operation regions.
-/// Note: This does not apply patterns to the top-level operation itself.
-/// Note: These methods also perform folding and simple dead-code elimination
+/// work-list driven manner. Return success if no more patterns can be matched
+/// in the result operation regions.
+/// Note: This does not apply patterns to the top-level operation itself. Note:
+///       These methods also perform folding and simple dead-code elimination
 ///       before attempting to match any of the provided patterns.
 ///
-bool applyPatternsAndFoldGreedily(Operation *op,
-                                  const OwningRewritePatternList &patterns);
+LogicalResult
+applyPatternsAndFoldGreedily(Operation *op,
+                             const OwningRewritePatternList &patterns);
 /// Rewrite the given regions, which must be isolated from above.
-bool applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                                  const OwningRewritePatternList &patterns);
+LogicalResult
+applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
+                             const OwningRewritePatternList &patterns);
 
 /// Applies the specified patterns on `op` alone while also trying to fold it,
-/// by selecting the highest benefits patterns in a greedy manner. Returns true
-/// if no more patterns can be matched. `erased` is set to true if `op` was
-/// folded away or erased as a result of becoming dead. Note: This does not
+/// by selecting the highest benefits patterns in a greedy manner. Returns
+/// success if no more patterns can be matched. `erased` is set to true if `op`
+/// was folded away or erased as a result of becoming dead. Note: This does not
 /// apply any patterns recursively to the regions of `op`.
-bool applyOpPatternsAndFold(Operation *op,
-                            const OwningRewritePatternList &patterns,
-                            bool *erased = nullptr);
+LogicalResult applyOpPatternsAndFold(Operation *op,
+                                     const OwningRewritePatternList &patterns,
+                                     bool *erased = nullptr);
 } // end namespace mlir
 
 #endif // MLIR_PATTERN_MATCH_H

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 402200f6b1ca..6af4de7e9d83 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -214,13 +214,13 @@ LogicalResult mlir::linalg::applyStagedPatterns(
   for (const auto &patterns : stage1Patterns) {
     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
                       << *op);
-    if (!applyPatternsAndFoldGreedily(op, patterns)) {
+    if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
       return failure();
     }
     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
                       << *op);
-    if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
+    if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
       return failure();
     }

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index faf30dca4cfe..b3b152377e36 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -19,8 +19,7 @@ PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
 }
 
 unsigned short PatternBenefit::getBenefit() const {
-  assert(representation != ImpossibleToMatchSentinel &&
-         "Pattern doesn't match");
+  assert(!isImpossibleToMatch() && "Pattern doesn't match");
   return representation;
 }
 
@@ -171,31 +170,72 @@ void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
 // PatternMatcher implementation
 //===----------------------------------------------------------------------===//
 
-RewritePatternMatcher::RewritePatternMatcher(
-    const OwningRewritePatternList &patterns) {
-  for (auto &pattern : patterns)
-    this->patterns.push_back(pattern.get());
+void PatternApplicator::applyCostModel(CostModel model) {
+  // Separate patterns by root kind to simplify lookup later on.
+  patterns.clear();
+  for (const auto &pat : owningPatternList)
+    patterns[pat->getRootKind()].push_back(pat.get());
+
+  // Sort the patterns using the provided cost model.
+  llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
+  auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
+    return benefits[lhs] > benefits[rhs];
+  };
+  for (auto &it : patterns) {
+    SmallVectorImpl<RewritePattern *> &list = it.second;
+
+    // Special case for one pattern in the list, which is the most common case.
+    if (list.size() == 1) {
+      if (model(*list.front()).isImpossibleToMatch())
+        list.clear();
+      continue;
+    }
+
+    // Collect the dynamic benefits for the current pattern list.
+    benefits.clear();
+    for (RewritePattern *pat : list)
+      benefits.try_emplace(pat, model(*pat));
+
+    // Sort patterns with highest benefit first, and remove those that are
+    // impossible to match.
+    std::stable_sort(list.begin(), list.end(), cmp);
+    while (!list.empty() && benefits[list.back()].isImpossibleToMatch())
+      list.pop_back();
+  }
+}
 
-  // Sort the patterns by benefit to simplify the matching logic.
-  std::stable_sort(this->patterns.begin(), this->patterns.end(),
-                   [](RewritePattern *l, RewritePattern *r) {
-                     return r->getBenefit() < l->getBenefit();
-                   });
+void PatternApplicator::walkAllPatterns(
+    function_ref<void(const RewritePattern &)> walk) {
+  for (auto &it : owningPatternList)
+    walk(*it);
 }
 
 /// Try to match the given operation to a pattern and rewrite it.
-bool RewritePatternMatcher::matchAndRewrite(Operation *op,
-                                            PatternRewriter &rewriter) {
-  for (auto *pattern : patterns) {
-    // Ignore patterns that are for the wrong root or are impossible to match.
-    if (pattern->getRootKind() != op->getName() ||
-        pattern->getBenefit().isImpossibleToMatch())
+LogicalResult PatternApplicator::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter,
+    function_ref<bool(const RewritePattern &)> canApply,
+    function_ref<void(const RewritePattern &)> onFailure,
+    function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
+  auto patternIt = patterns.find(op->getName());
+  if (patternIt == patterns.end())
+    return failure();
+
+  for (auto *pattern : patternIt->second) {
+    // Check that the pattern can be applied.
+    if (canApply && !canApply(*pattern))
       continue;
 
     // Try to match and rewrite this pattern. The patterns are sorted by
-    // benefit, so if we match we can immediately rewrite and return.
-    if (succeeded(pattern->matchAndRewrite(op, rewriter)))
-      return true;
+    // benefit, so if we match we can immediately rewrite.
+    rewriter.setInsertionPoint(op);
+    if (succeeded(pattern->matchAndRewrite(op, rewriter))) {
+      if (!onSuccess || succeeded(onSuccess(*pattern)))
+        return success();
+      continue;
+    }
+
+    if (onFailure)
+      onFailure(*pattern);
   }
-  return false;
+  return failure();
 }

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index fdac287d5d69..aa12c16a3d8c 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1122,7 +1122,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
 
 namespace {
 /// A set of rewrite patterns that can be used to legalize a given operation.
-using LegalizationPatterns = SmallVector<RewritePattern *, 1>;
+using LegalizationPatterns = SmallVector<const RewritePattern *, 1>;
 
 /// This class defines a recursive operation legalizer.
 class OperationLegalizer {
@@ -1130,11 +1130,7 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(ConversionTarget &targetInfo,
-                     const OwningRewritePatternList &patterns)
-      : target(targetInfo) {
-    buildLegalizationGraph(patterns);
-    computeLegalizationGraphBenefit();
-  }
+                     const OwningRewritePatternList &patterns);
 
   /// Returns if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -1151,16 +1147,28 @@ class OperationLegalizer {
   LogicalResult legalizeWithFold(Operation *op,
                                  ConversionPatternRewriter &rewriter);
 
-  /// Attempt to legalize the given operation by applying the provided pattern.
-  /// Returns success if the operation was legalized, failure otherwise.
-  LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
-                                ConversionPatternRewriter &rewriter);
+  /// Attempt to legalize the given operation by applying a pattern. Returns
+  /// success if the operation was legalized, failure otherwise.
+  LogicalResult legalizeWithPattern(Operation *op,
+                                    ConversionPatternRewriter &rewriter);
+
+  /// Return true if the given pattern may be applied to the given operation,
+  /// false otherwise.
+  bool canApplyPattern(Operation *op, const RewritePattern &pattern,
+                       ConversionPatternRewriter &rewriter);
+
+  /// Legalize the resultant IR after successfully applying the given pattern.
+  LogicalResult legalizePatternResult(Operation *op,
+                                      const RewritePattern &pattern,
+                                      ConversionPatternRewriter &rewriter,
+                                      RewriterState &curState);
 
   /// Build an optimistic legalization graph given the provided patterns. This
   /// function populates 'legalizerPatterns' with the operations that are not
   /// directly legal, but may be transitively legal for the current target given
   /// the provided patterns.
-  void buildLegalizationGraph(const OwningRewritePatternList &patterns);
+  void buildLegalizationGraph(
+      DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
 
   /// Compute the benefit of each node within the computed legalization graph.
   /// This orders the patterns within 'legalizerPatterns' based upon two
@@ -1170,20 +1178,31 @@ class OperationLegalizer {
   ///  2) When comparing patterns with the same legalization depth, prefer the
   ///     pattern with the highest PatternBenefit. This allows for users to
   ///     prefer specific legalizations over others.
-  void computeLegalizationGraphBenefit();
+  void computeLegalizationGraphBenefit(
+      DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
 
   /// The current set of patterns that have been applied.
-  SmallPtrSet<RewritePattern *, 8> appliedPatterns;
-
-  /// The set of legality information for operations transitively supported by
-  /// the target.
-  DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
+  SmallPtrSet<const RewritePattern *, 8> appliedPatterns;
 
   /// The legalization information provided by the target.
   ConversionTarget ⌖
+
+  /// The pattern applicator to use for conversions.
+  PatternApplicator applicator;
 };
 } // namespace
 
+OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
+                                       const OwningRewritePatternList &patterns)
+    : target(targetInfo), applicator(patterns) {
+  // The set of legality information for operations transitively supported by
+  // the target.
+  DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
+
+  buildLegalizationGraph(legalizerPatterns);
+  computeLegalizationGraphBenefit(legalizerPatterns);
+}
+
 bool OperationLegalizer::isIllegal(Operation *op) const {
   // Check if the target explicitly marked this operation as illegal.
   return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
@@ -1253,24 +1272,12 @@ OperationLegalizer::legalize(Operation *op,
   }
 
   // Otherwise, we need to apply a legalization pattern to this operation.
-  auto it = legalizerPatterns.find(op->getName());
-  if (it == legalizerPatterns.end()) {
+  if (succeeded(legalizeWithPattern(op, rewriter))) {
     LLVM_DEBUG({
-      logFailure(rewriterImpl.logger, "no known legalization path");
+      logSuccess(rewriterImpl.logger, "");
       rewriterImpl.logger.startLine() << logLineComment;
     });
-    return failure();
-  }
-
-  // The patterns are sorted by expected benefit, so try to apply each in-order.
-  for (auto *pattern : it->second) {
-    if (succeeded(legalizePattern(op, pattern, rewriter))) {
-      LLVM_DEBUG({
-        logSuccess(rewriterImpl.logger, "");
-        rewriterImpl.logger.startLine() << logLineComment;
-      });
-      return success();
-    }
+    return success();
   }
 
   LLVM_DEBUG({
@@ -1320,46 +1327,70 @@ OperationLegalizer::legalizeWithFold(Operation *op,
 }
 
 LogicalResult
-OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
-                                    ConversionPatternRewriter &rewriter) {
+OperationLegalizer::legalizeWithPattern(Operation *op,
+                                        ConversionPatternRewriter &rewriter) {
   auto &rewriterImpl = rewriter.getImpl();
+
+  // Functor that returns if the given pattern may be applied.
+  auto canApply = [&](const RewritePattern &pattern) {
+    return canApplyPattern(op, pattern, rewriter);
+  };
+
+  // Functor that cleans up the rewriter state after a pattern failed to match.
+  RewriterState curState = rewriterImpl.getCurrentState();
+  auto onFailure = [&](const RewritePattern &pattern) {
+    LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
+    rewriterImpl.resetState(curState);
+    appliedPatterns.erase(&pattern);
+  };
+
+  // Functor that performs additional legalization when a pattern is
+  // successfully applied.
+  auto onSuccess = [&](const RewritePattern &pattern) {
+    auto result = legalizePatternResult(op, pattern, rewriter, curState);
+    appliedPatterns.erase(&pattern);
+    if (failed(result))
+      rewriterImpl.resetState(curState);
+    return result;
+  };
+
+  // Try to match and rewrite a pattern on this operation.
+  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
+                                    onSuccess);
+}
+
+bool OperationLegalizer::canApplyPattern(Operation *op,
+                                         const RewritePattern &pattern,
+                                         ConversionPatternRewriter &rewriter) {
   LLVM_DEBUG({
-    auto &os = rewriterImpl.logger;
+    auto &os = rewriter.getImpl().logger;
     os.getOStream() << "\n";
-    os.startLine() << "* Pattern : '" << pattern->getRootKind() << " -> (";
-    llvm::interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
+    os.startLine() << "* Pattern : '" << pattern.getRootKind() << " -> (";
+    llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs());
     os.getOStream() << ")' {\n";
     os.indent();
   });
 
   // Ensure that we don't cycle by not allowing the same pattern to be
   // applied twice in the same recursion stack if it is not known to be safe.
-  if (!pattern->hasBoundedRewriteRecursion() &&
-      !appliedPatterns.insert(pattern).second) {
-    LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern was already applied"));
-    return failure();
+  if (!pattern.hasBoundedRewriteRecursion() &&
+      !appliedPatterns.insert(&pattern).second) {
+    LLVM_DEBUG(
+        logFailure(rewriter.getImpl().logger, "pattern was already applied"));
+    return false;
   }
+  return true;
+}
 
-  RewriterState curState = rewriterImpl.getCurrentState();
-  auto cleanupFailure = [&] {
-    // Reset the rewriter state and pop this pattern.
-    rewriterImpl.resetState(curState);
-    appliedPatterns.erase(pattern);
-    return failure();
-  };
+LogicalResult OperationLegalizer::legalizePatternResult(
+    Operation *op, const RewritePattern &pattern,
+    ConversionPatternRewriter &rewriter, RewriterState &curState) {
+  auto &rewriterImpl = rewriter.getImpl();
 
-  // Try to rewrite with the given pattern.
-  rewriter.setInsertionPoint(op);
-  LogicalResult matchedPattern = pattern->matchAndRewrite(op, rewriter);
 #ifndef NDEBUG
   assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
 #endif
 
-  if (failed(matchedPattern)) {
-    LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
-    return cleanupFailure();
-  }
-
   // If the pattern moved or created any blocks, try to legalize their types.
   // This ensures that the types of the block arguments are legal for the region
   // they were moved into.
@@ -1376,7 +1407,7 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
     if (failed(rewriterImpl.convertBlockSignature(action.block))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to convert types of moved block"));
-      return cleanupFailure();
+      return failure();
     }
   }
 
@@ -1414,7 +1445,7 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "operation updated in-place '{0}' was illegal",
                             op->getName()));
-      return cleanupFailure();
+      return failure();
     }
   }
 
@@ -1426,42 +1457,42 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "generated operation '{0}'({1}) was illegal",
                             op->getName(), op));
-      return cleanupFailure();
+      return failure();
     }
   }
 
   LLVM_DEBUG(logSuccess(rewriterImpl.logger, "pattern applied successfully"));
-  appliedPatterns.erase(pattern);
   return success();
 }
 
 void OperationLegalizer::buildLegalizationGraph(
-    const OwningRewritePatternList &patterns) {
+    DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
   // A mapping between an operation and a set of operations that can be used to
   // generate it.
   DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
   // A mapping between an operation and any currently invalid patterns it has.
-  DenseMap<OperationName, SmallPtrSet<RewritePattern *, 2>> invalidPatterns;
+  DenseMap<OperationName, SmallPtrSet<const RewritePattern *, 2>>
+      invalidPatterns;
   // A worklist of patterns to consider for legality.
-  llvm::SetVector<RewritePattern *> patternWorklist;
+  llvm::SetVector<const RewritePattern *> patternWorklist;
 
   // Build the mapping from operations to the parent ops that may generate them.
-  for (auto &pattern : patterns) {
-    auto root = pattern->getRootKind();
+  applicator.walkAllPatterns([&](const RewritePattern &pattern) {
+    OperationName root = pattern.getRootKind();
 
     // Skip operations that are always known to be legal.
     if (target.getOpAction(root) == LegalizationAction::Legal)
-      continue;
+      return;
 
     // Add this pattern to the invalid set for the root op and record this root
     // as a parent for any generated operations.
-    invalidPatterns[root].insert(pattern.get());
-    for (auto op : pattern->getGeneratedOps())
+    invalidPatterns[root].insert(&pattern);
+    for (auto op : pattern.getGeneratedOps())
       parentOps[op].insert(root);
 
     // Add this pattern to the worklist.
-    patternWorklist.insert(pattern.get());
-  }
+    patternWorklist.insert(&pattern);
+  });
 
   while (!patternWorklist.empty()) {
     auto *pattern = patternWorklist.pop_back_val();
@@ -1486,7 +1517,8 @@ void OperationLegalizer::buildLegalizationGraph(
   }
 }
 
-void OperationLegalizer::computeLegalizationGraphBenefit() {
+void OperationLegalizer::computeLegalizationGraphBenefit(
+    DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
   // The smallest pattern depth, when legalizing an operation.
   DenseMap<OperationName, unsigned> minPatternDepth;
 
@@ -1511,9 +1543,9 @@ void OperationLegalizer::computeLegalizationGraphBenefit() {
     minPatternDepth.try_emplace(op, minDepth);
 
     // Compute the depth for each pattern used to legalize this operation.
-    SmallVector<std::pair<RewritePattern *, unsigned>, 4> patternsByDepth;
+    SmallVector<std::pair<const RewritePattern *, unsigned>, 4> patternsByDepth;
     patternsByDepth.reserve(opPatternsIt->second.size());
-    for (RewritePattern *pattern : opPatternsIt->second) {
+    for (const RewritePattern *pattern : opPatternsIt->second) {
       unsigned depth = 0;
       for (auto generatedOp : pattern->getGeneratedOps())
         depth = std::max(depth, computeDepth(generatedOp) + 1);
@@ -1534,8 +1566,8 @@ void OperationLegalizer::computeLegalizationGraphBenefit() {
     // Sort the patterns by those likely to be the most beneficial.
     llvm::array_pod_sort(
         patternsByDepth.begin(), patternsByDepth.end(),
-        [](const std::pair<RewritePattern *, unsigned> *lhs,
-           const std::pair<RewritePattern *, unsigned> *rhs) {
+        [](const std::pair<const RewritePattern *, unsigned> *lhs,
+           const std::pair<const RewritePattern *, unsigned> *rhs) {
           // First sort by the smaller pattern legalization depth.
           if (lhs->second != rhs->second)
             return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
@@ -1560,6 +1592,21 @@ void OperationLegalizer::computeLegalizationGraphBenefit() {
   for (auto &opIt : legalizerPatterns)
     if (!minPatternDepth.count(opIt.first))
       computeDepth(opIt.first);
+
+  // Apply a cost model to the pattern applicator. We order patterns first by
+  // depth then benefit. `legalizerPatterns` contains per-op patterns by
+  // decreasing benefit.
+  applicator.applyCostModel([&](const RewritePattern &p) {
+    auto &list = legalizerPatterns[p.getRootKind()];
+
+    // If the pattern is not found, then it was removed and cannot be matched.
+    LegalizationPatterns::iterator it = llvm::find(list, &p);
+    if (it == list.end())
+      return PatternBenefit::impossibleToMatch();
+
+    // Patterns found earlier in the list have higher benefit.
+    return PatternBenefit(std::distance(it, list.end()));
+  });
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 0974992e079c..ea420733e5ff 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -39,6 +39,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
                                       const OwningRewritePatternList &patterns)
       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
     worklist.reserve(64);
+
+    // Apply a simple cost model based solely on pattern benefit.
+    matcher.applyDefaultCostModel();
   }
 
   bool simplify(MutableArrayRef<Region> regions, int maxIterations);
@@ -103,8 +106,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   // be re-added to the worklist. This function should be called when an
   // operation is modified or removed, as it may trigger further
   // simplifications.
-  template <typename Operands>
-  void addToWorklist(Operands &&operands) {
+  template <typename Operands> void addToWorklist(Operands &&operands) {
     for (Value operand : operands) {
       // If the use count of this operand is now < 2, we re-add the defining
       // operation to the worklist.
@@ -118,8 +120,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
     }
   }
 
-  /// The low-level pattern matcher.
-  RewritePatternMatcher matcher;
+  /// The low-level pattern applicator.
+  PatternApplicator matcher;
 
   /// The worklist for this transformation keeps track of the operations that
   /// need to be revisited, plus their index in the worklist.  This allows us to
@@ -192,12 +194,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
           continue;
       }
 
-      // Make sure that any new operations are inserted at this point.
-      setInsertionPoint(op);
-
       // Try to match one of the patterns. The rewriter is automatically
       // notified of any necessary changes, so there is nothing else to do here.
-      changed |= matcher.matchAndRewrite(op, *this);
+      changed |= succeeded(matcher.matchAndRewrite(op, *this));
     }
 
     // After applying patterns, make sure that the CFG of each of the regions is
@@ -213,20 +212,21 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
 
 /// Rewrite the regions of the specified operation, which must be isolated from
 /// above, by repeatedly applying the highest benefit patterns in a greedy
-/// work-list driven manner. Return true if no more patterns can be matched in
-/// the result operation regions.
-/// Note: This does not apply patterns to the top-level operation itself.
+/// work-list driven manner. Return success if no more patterns can be matched
+/// in the result operation regions. Note: This does not apply patterns to the
+/// top-level operation itself.
 ///
-bool mlir::applyPatternsAndFoldGreedily(
-    Operation *op, const OwningRewritePatternList &patterns) {
+LogicalResult
+mlir::applyPatternsAndFoldGreedily(Operation *op,
+                                   const OwningRewritePatternList &patterns) {
   return applyPatternsAndFoldGreedily(op->getRegions(), patterns);
 }
-
 /// Rewrite the given regions, which must be isolated from above.
-bool mlir::applyPatternsAndFoldGreedily(
-    MutableArrayRef<Region> regions, const OwningRewritePatternList &patterns) {
+LogicalResult
+mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
+                                   const OwningRewritePatternList &patterns) {
   if (regions.empty())
-    return true;
+    return success();
 
   // The top-level operation must be known to be isolated from above to
   // prevent performing canonicalizations on operations defined at or above
@@ -245,7 +245,7 @@ bool mlir::applyPatternsAndFoldGreedily(
     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
                  << maxPatternMatchIterations << " times";
   });
-  return converged;
+  return success(converged);
 }
 
 //===----------------------------------------------------------------------===//
@@ -259,9 +259,17 @@ class OpPatternRewriteDriver : public PatternRewriter {
 public:
   explicit OpPatternRewriteDriver(MLIRContext *ctx,
                                   const OwningRewritePatternList &patterns)
-      : PatternRewriter(ctx), matcher(patterns), folder(ctx) {}
+      : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
+    // Apply a simple cost model based solely on pattern benefit.
+    matcher.applyDefaultCostModel();
+  }
 
-  bool simplifyLocally(Operation *op, int maxIterations, bool &erased);
+  /// Performs the rewrites and folding only on `op`. The simplification
+  /// converges if the op is erased as a result of being folded, replaced, or
+  /// dead, or no more changes happen in an iteration. Returns success if the
+  /// rewrite converges in `maxIterations`. `erased` is set to true if `op` gets
+  /// erased.
+  LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
 
   // These are hooks implemented for PatternRewriter.
 protected:
@@ -276,8 +284,8 @@ class OpPatternRewriteDriver : public PatternRewriter {
   void notifyRootReplaced(Operation *op) override {}
 
 private:
-  /// The low-level pattern matcher.
-  RewritePatternMatcher matcher;
+  /// The low-level pattern applicator.
+  PatternApplicator matcher;
 
   /// Non-pattern based folder for operations.
   OperationFolder folder;
@@ -288,12 +296,9 @@ class OpPatternRewriteDriver : public PatternRewriter {
 
 } // anonymous namespace
 
-/// Performs the rewrites and folding only on `op`. The simplification converges
-/// if the op is erased as a result of being folded, replaced, or dead, or no
-/// more changes happen in an iteration. Returns true if the rewrite converges
-/// in `maxIterations`. `erased` is set to true if `op` gets erased.
-bool OpPatternRewriteDriver::simplifyLocally(Operation *op, int maxIterations,
-                                             bool &erased) {
+LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
+                                                      int maxIterations,
+                                                      bool &erased) {
   bool changed = false;
   erased = false;
   opErasedViaPatternRewrites = false;
@@ -305,7 +310,7 @@ bool OpPatternRewriteDriver::simplifyLocally(Operation *op, int maxIterations,
     if (isOpTriviallyDead(op)) {
       op->erase();
       erased = true;
-      return true;
+      return success();
     }
 
     // Try to fold this op.
@@ -316,38 +321,34 @@ bool OpPatternRewriteDriver::simplifyLocally(Operation *op, int maxIterations,
       changed = true;
       if (!inPlaceUpdate) {
         erased = true;
-        return true;
+        return success();
       }
     }
 
-    // Make sure that any new operations are inserted at this point.
-    setInsertionPoint(op);
-
     // Try to match one of the patterns. The rewriter is automatically
     // notified of any necessary changes, so there is nothing else to do here.
-    changed |= matcher.matchAndRewrite(op, *this);
+    changed |= succeeded(matcher.matchAndRewrite(op, *this));
     if ((erased = opErasedViaPatternRewrites))
-      return true;
+      return success();
   } while (changed && ++i < maxIterations);
 
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
-  return !changed;
+  return failure(changed);
 }
 
 /// Rewrites only `op` using the supplied canonicalization patterns and
 /// folding. `erased` is set to true if the op is erased as a result of being
 /// folded, replaced, or dead.
-bool mlir::applyOpPatternsAndFold(Operation *op,
-                                  const OwningRewritePatternList &patterns,
-                                  bool *erased) {
+LogicalResult mlir::applyOpPatternsAndFold(
+    Operation *op, const OwningRewritePatternList &patterns, bool *erased) {
   // Start the pattern driver.
   OpPatternRewriteDriver driver(op->getContext(), patterns);
   bool opErased;
-  bool converged =
+  LogicalResult converged =
       driver.simplifyLocally(op, maxPatternMatchIterations, opErased);
   if (erased)
     *erased = opErased;
-  LLVM_DEBUG(if (!converged) {
+  LLVM_DEBUG(if (failed(converged)) {
     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
                  << maxPatternMatchIterations << " times";
   });


        


More information about the Mlir-commits mailing list