[Mlir-commits] [mlir] 80d7ac3 - [mlir] Allow for patterns to match any root kind.

River Riddle llvmlistbot at llvm.org
Thu Jun 18 14:02:22 PDT 2020


Author: River Riddle
Date: 2020-06-18T13:58:47-07:00
New Revision: 80d7ac3bc7c04975fd444e9f2806e4db224f2416

URL: https://github.com/llvm/llvm-project/commit/80d7ac3bc7c04975fd444e9f2806e4db224f2416
DIFF: https://github.com/llvm/llvm-project/commit/80d7ac3bc7c04975fd444e9f2806e4db224f2416.diff

LOG: [mlir] Allow for patterns to match any root kind.

Traditionally patterns have always had the root operation kind hardcoded to a specific operation name. This has worked well for quite some time, but it has certain limitations that make it undesirable. For example, some lowering have the same implementation for many different operations types with a few lowering entire dialects using the same pattern implementation. This problem has led to several "solutions":
a) Provide a template implementation to the user so that they can instantiate it for each operation combination, generally requiring the inclusion of the auto-generated operation definition file.
b) Use a non-templated pattern that allows for providing the name of the operation to match
  - No one ever does this, because enumerating operation names can be cumbersome and so this quickly devolves into solution a.

This revision removes the restriction that patterns have a hardcoded root type, and allows for a class patterns that could match "any" operation type. The major downside of root-agnostic patterns is that they make certain pattern analyses more difficult, so it is still very highly encouraged that an operation specific pattern be used whenever possible.

Differential Revision: https://reviews.llvm.org/D82066

Added: 
    mlir/test/Transforms/test-legalize-unknown-root.mlir

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8692d5ba0005..0f0228a3dad3 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -74,9 +74,9 @@ class Pattern {
   /// condition predicates.
   PatternBenefit getBenefit() const { return benefit; }
 
-  /// Return the root node that this pattern matches.  Patterns that can
-  /// match multiple root types are instantiated once per root.
-  OperationName getRootKind() const { return rootKind; }
+  /// Return the root node that this pattern matches. Patterns that can match
+  /// multiple root types return None.
+  Optional<OperationName> getRootKind() const { return rootKind; }
 
   //===--------------------------------------------------------------------===//
   // Implementation hooks for patterns to implement.
@@ -89,12 +89,30 @@ class Pattern {
   virtual ~Pattern() {}
 
 protected:
-  /// Patterns must specify the root operation name they match against, and can
-  /// also specify the benefit of the pattern matching.
+  /// This class acts as a special tag that makes the desire to match "any"
+  /// operation type explicit. This helps to avoid unnecessary usages of this
+  /// feature, and ensures that the user is making a conscious decision.
+  struct MatchAnyOpTypeTag {};
+
+  /// This constructor is used for patterns that match against a specific
+  /// operation type. The `benefit` is the expected benefit of matching this
+  /// pattern.
   Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
 
+  /// This contructor is used when a pattern may match against multiple
+  /// 
diff erent types of operations. The `benefit` is the expected benefit of
+  /// matching this pattern. `MatchAnyOpTypeTag` is just a tag to ensure that
+  /// the "match any" behavior is what the user actually desired,
+  /// `MatchAnyOpTypeTag()` should always be supplied here.
+  Pattern(PatternBenefit benefit, MatchAnyOpTypeTag);
+
 private:
-  const OperationName rootKind;
+  /// The root operation of the pattern. If the pattern matches a specific
+  /// operation, this contains the name of that operation. Contains None
+  /// otherwise.
+  Optional<OperationName> rootKind;
+
+  /// The expected benefit of matching this pattern.
   const PatternBenefit benefit;
 
   virtual void anchor();
@@ -151,10 +169,24 @@ class RewritePattern : public Pattern {
                  MLIRContext *context)
       : Pattern(rootName, benefit, context) {}
   /// Patterns must specify the root operation name they match against, and can
+  /// also specify the benefit of the pattern matching. `MatchAnyOpTypeTag`
+  /// is just a tag to ensure that the "match any" behavior is what the user
+  /// actually desired, `MatchAnyOpTypeTag()` should always be supplied here.
+  RewritePattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
+      : Pattern(benefit, tag) {}
+  /// Patterns must specify the root operation name they match against, and can
   /// also specify the benefit of the pattern matching. They can also specify
   /// the names of operations that may be generated during a successful rewrite.
   RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
                  PatternBenefit benefit, MLIRContext *context);
+  /// Patterns must specify the root operation name they match against, and can
+  /// also specify the benefit of the pattern matching. They can also specify
+  /// the names of operations that may be generated during a successful rewrite.
+  /// `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
+  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
+  /// always be supplied here.
+  RewritePattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
+                 MLIRContext *context, MatchAnyOpTypeTag tag);
 
   /// A list of the potential operations that may be generated when rewriting
   /// an op with this pattern.
@@ -431,6 +463,14 @@ class OwningRewritePatternList {
     return *this;
   }
 
+  /// Add an instance of each of the pattern types 'Ts'. Return a reference to
+  /// `this` for chaining insertions.
+  template <typename... Ts> OwningRewritePatternList &insert() {
+    (void)std::initializer_list<int>{
+        0, (patterns.emplace_back(std::make_unique<Ts>()), 0)...};
+    return *this;
+  }
+
   /// Add the given pattern to the pattern list.
   void insert(std::unique_ptr<RewritePattern> pattern) {
     patterns.emplace_back(std::move(pattern));
@@ -485,11 +525,23 @@ class PatternApplicator {
   void walkAllPatterns(function_ref<void(const RewritePattern &)> walk);
 
 private:
+  /// Attempt to match and rewrite the given op with the given pattern, allowing
+  /// a predicate to decide if a pattern can be applied or not, and hooks for if
+  /// the pattern match was a success or failure.
+  LogicalResult matchAndRewrite(
+      Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
+      function_ref<bool(const RewritePattern &)> canApply,
+      function_ref<void(const RewritePattern &)> onFailure,
+      function_ref<LogicalResult(const RewritePattern &)> onSuccess);
+
   /// The list that owns the patterns used within this applicator.
   const OwningRewritePatternList &owningPatternList;
 
   /// The set of patterns to match for each operation, stable sorted by benefit.
   DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
+  /// The set of patterns that may match against any operation type, stable
+  /// sorted by benefit.
+  SmallVector<RewritePattern *, 1> anyOpPatterns;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 3e7d50380fe1..2ce95b10d607 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -274,12 +274,6 @@ class TypeConverter {
 /// below.
 class ConversionPattern : public RewritePattern {
 public:
-  /// Construct an ConversionPattern.  `rootName` must correspond to the
-  /// canonical name of the first operation matched by the pattern.
-  ConversionPattern(StringRef rootName, PatternBenefit benefit,
-                    MLIRContext *ctx)
-      : RewritePattern(rootName, benefit, ctx) {}
-
   /// Hook for derived classes to implement rewriting. `op` is the (first)
   /// operation matched by the pattern, `operands` is a list of rewritten values
   /// that are passed to this operation, `rewriter` can be used to emit the new
@@ -304,6 +298,9 @@ class ConversionPattern : public RewritePattern {
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const final;
 
+protected:
+  using RewritePattern::RewritePattern;
+
 private:
   using RewritePattern::rewrite;
 };

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index b3b152377e36..e05d234c4ad0 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -30,6 +30,8 @@ unsigned short PatternBenefit::getBenefit() const {
 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
                  MLIRContext *context)
     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
+Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag)
+    : benefit(benefit) {}
 
 // Out-of-line vtable anchor.
 void Pattern::anchor() {}
@@ -47,9 +49,6 @@ LogicalResult RewritePattern::match(Operation *op) const {
   llvm_unreachable("need to implement either match or matchAndRewrite!");
 }
 
-/// Patterns must specify the root operation name they match against, and can
-/// also specify the benefit of the pattern matching. They can also specify the
-/// names of operations that may be generated during a successful rewrite.
 RewritePattern::RewritePattern(StringRef rootName,
                                ArrayRef<StringRef> generatedNames,
                                PatternBenefit benefit, MLIRContext *context)
@@ -60,6 +59,16 @@ RewritePattern::RewritePattern(StringRef rootName,
                    return OperationName(name, context);
                  });
 }
+RewritePattern::RewritePattern(ArrayRef<StringRef> generatedNames,
+                               PatternBenefit benefit, MLIRContext *context,
+                               MatchAnyOpTypeTag tag)
+    : Pattern(benefit, tag) {
+  generatedOps.reserve(generatedNames.size());
+  std::transform(generatedNames.begin(), generatedNames.end(),
+                 std::back_inserter(generatedOps), [context](StringRef name) {
+                   return OperationName(name, context);
+                 });
+}
 
 PatternRewriter::~PatternRewriter() {
   // Out of line to provide a vtable anchor for the class.
@@ -173,22 +182,28 @@ void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
 void PatternApplicator::applyCostModel(CostModel model) {
   // Separate patterns by root kind to simplify lookup later on.
   patterns.clear();
-  for (const auto &pat : owningPatternList)
-    patterns[pat->getRootKind()].push_back(pat.get());
+  anyOpPatterns.clear();
+  for (const auto &pat : owningPatternList) {
+    // If the pattern is always impossible to match, just ignore it.
+    if (pat->getBenefit().isImpossibleToMatch())
+      continue;
+    if (Optional<OperationName> opName = pat->getRootKind())
+      patterns[*opName].push_back(pat.get());
+    else
+      anyOpPatterns.push_back(pat.get());
+  }
 
   // Sort the patterns using the provided cost model.
   llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
   auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
     return benefits[lhs] > benefits[rhs];
   };
-  for (auto &it : patterns) {
-    SmallVectorImpl<RewritePattern *> &list = it.second;
-
+  auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
     // Special case for one pattern in the list, which is the most common case.
     if (list.size() == 1) {
       if (model(*list.front()).isImpossibleToMatch())
         list.clear();
-      continue;
+      return;
     }
 
     // Collect the dynamic benefits for the current pattern list.
@@ -201,7 +216,10 @@ void PatternApplicator::applyCostModel(CostModel model) {
     std::stable_sort(list.begin(), list.end(), cmp);
     while (!list.empty() && benefits[list.back()].isImpossibleToMatch())
       list.pop_back();
-  }
+  };
+  for (auto &it : patterns)
+    processPatternList(it.second);
+  processPatternList(anyOpPatterns);
 }
 
 void PatternApplicator::walkAllPatterns(
@@ -210,32 +228,64 @@ void PatternApplicator::walkAllPatterns(
     walk(*it);
 }
 
-/// Try to match the given operation to a pattern and rewrite it.
 LogicalResult PatternApplicator::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter,
     function_ref<bool(const RewritePattern &)> canApply,
     function_ref<void(const RewritePattern &)> onFailure,
     function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
+  // Check to see if there are patterns matching this specific operation type.
+  MutableArrayRef<RewritePattern *> opPatterns;
   auto patternIt = patterns.find(op->getName());
-  if (patternIt == patterns.end())
-    return failure();
+  if (patternIt != patterns.end())
+    opPatterns = patternIt->second;
+
+  // Process the patterns for that match the specific operation type, and any
+  // operation type in an interleaved fashion.
+  // FIXME: It'd be nice to just write an llvm::make_merge_range utility
+  // and pass in a comparison function. That would make this code trivial.
+  auto opIt = opPatterns.begin(), opE = opPatterns.end();
+  auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
+  while (opIt != opE && anyIt != anyE) {
+    // Try to match the pattern providing the most benefit.
+    RewritePattern *pattern;
+    if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
+      pattern = *(opIt++);
+    else
+      pattern = *(anyIt++);
+
+    // Otherwise, try to match the generic pattern.
+    if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
+                                  onSuccess)))
+      return success();
+  }
+  // If we break from the loop, then only one of the ranges can still have
+  // elements. Loop over both without checking given that we don't need to
+  // interleave anymore.
+  for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
+           llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
+    if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
+                                  onSuccess)))
+      return success();
+  }
+  return failure();
+}
 
-  for (auto *pattern : patternIt->second) {
-    // Check that the pattern can be applied.
-    if (canApply && !canApply(*pattern))
-      continue;
+LogicalResult PatternApplicator::matchAndRewrite(
+    Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
+    function_ref<bool(const RewritePattern &)> canApply,
+    function_ref<void(const RewritePattern &)> onFailure,
+    function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
+  // Check that the pattern can be applied.
+  if (canApply && !canApply(pattern))
+    return failure();
 
-    // Try to match and rewrite this pattern. The patterns are sorted by
-    // benefit, so if we match we can immediately rewrite.
-    rewriter.setInsertionPoint(op);
-    if (succeeded(pattern->matchAndRewrite(op, rewriter))) {
-      if (!onSuccess || succeeded(onSuccess(*pattern)))
-        return success();
-      continue;
-    }
+  // Try to match and rewrite this pattern. The patterns are sorted by
+  // benefit, so if we match we can immediately rewrite.
+  rewriter.setInsertionPoint(op);
+  if (succeeded(pattern.matchAndRewrite(op, rewriter)))
+    return success(!onSuccess || succeeded(onSuccess(pattern)));
 
-    if (onFailure)
-      onFailure(*pattern);
-  }
+  if (onFailure)
+    onFailure(pattern);
   return failure();
 }

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index aa12c16a3d8c..b06524719577 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1164,10 +1164,11 @@ class OperationLegalizer {
                                       RewriterState &curState);
 
   /// Build an optimistic legalization graph given the provided patterns. This
-  /// function populates 'legalizerPatterns' with the operations that are not
-  /// directly legal, but may be transitively legal for the current target given
-  /// the provided patterns.
+  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
+  /// patterns for operations that are not directly legal, but may be
+  /// transitively legal for the current target given the provided patterns.
   void buildLegalizationGraph(
+      LegalizationPatterns &anyOpLegalizerPatterns,
       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
 
   /// Compute the benefit of each node within the computed legalization graph.
@@ -1179,6 +1180,21 @@ class OperationLegalizer {
   ///     pattern with the highest PatternBenefit. This allows for users to
   ///     prefer specific legalizations over others.
   void computeLegalizationGraphBenefit(
+      LegalizationPatterns &anyOpLegalizerPatterns,
+      DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
+
+  /// Compute the legalization depth when legalizing an operation of the given
+  /// type.
+  unsigned computeOpLegalizationDepth(
+      OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
+      DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
+
+  /// Apply the conversion cost model to the given set of patterns, and return
+  /// the smallest legalization depth of any of the patterns. See
+  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
+  unsigned applyCostModelToPatterns(
+      LegalizationPatterns &patterns,
+      DenseMap<OperationName, unsigned> &minOpPatternDepth,
       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
 
   /// The current set of patterns that have been applied.
@@ -1195,12 +1211,13 @@ class OperationLegalizer {
 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
                                        const OwningRewritePatternList &patterns)
     : target(targetInfo), applicator(patterns) {
-  // The set of legality information for operations transitively supported by
-  // the target.
+  // The set of patterns that can be applied to illegal operations to transform
+  // them into legal ones.
   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
+  LegalizationPatterns anyOpLegalizerPatterns;
 
-  buildLegalizationGraph(legalizerPatterns);
-  computeLegalizationGraphBenefit(legalizerPatterns);
+  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
+  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
 }
 
 bool OperationLegalizer::isIllegal(Operation *op) const {
@@ -1365,7 +1382,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
   LLVM_DEBUG({
     auto &os = rewriter.getImpl().logger;
     os.getOStream() << "\n";
-    os.startLine() << "* Pattern : '" << pattern.getRootKind() << " -> (";
+    os.startLine() << "* Pattern : '" << op->getName() << " -> (";
     llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs());
     os.getOStream() << ")' {\n";
     os.indent();
@@ -1466,6 +1483,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
 }
 
 void OperationLegalizer::buildLegalizationGraph(
+    LegalizationPatterns &anyOpLegalizerPatterns,
     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
   // A mapping between an operation and a set of operations that can be used to
   // generate it.
@@ -1478,22 +1496,41 @@ void OperationLegalizer::buildLegalizationGraph(
 
   // Build the mapping from operations to the parent ops that may generate them.
   applicator.walkAllPatterns([&](const RewritePattern &pattern) {
-    OperationName root = pattern.getRootKind();
+    Optional<OperationName> root = pattern.getRootKind();
+
+    // If the pattern has no specific root, we can't analyze the relationship
+    // between the root op and generated operations. Given that, add all such
+    // patterns to the legalization set.
+    if (!root) {
+      anyOpLegalizerPatterns.push_back(&pattern);
+      return;
+    }
 
     // Skip operations that are always known to be legal.
-    if (target.getOpAction(root) == LegalizationAction::Legal)
+    if (target.getOpAction(*root) == LegalizationAction::Legal)
       return;
 
     // Add this pattern to the invalid set for the root op and record this root
     // as a parent for any generated operations.
-    invalidPatterns[root].insert(&pattern);
+    invalidPatterns[*root].insert(&pattern);
     for (auto op : pattern.getGeneratedOps())
-      parentOps[op].insert(root);
+      parentOps[op].insert(*root);
 
     // Add this pattern to the worklist.
     patternWorklist.insert(&pattern);
   });
 
+  // If there are any patterns that don't have a specific root kind, we can't
+  // make direct assumptions about what operations will never be legalized.
+  // Note: Technically we could, but it would require an analysis that may
+  // recurse into itself. It would be better to perform this kind of filtering
+  // at a higher level than here anyways.
+  if (!anyOpLegalizerPatterns.empty()) {
+    for (const RewritePattern *pattern : patternWorklist)
+      legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
+    return;
+  }
+
   while (!patternWorklist.empty()) {
     auto *pattern = patternWorklist.pop_back_val();
 
@@ -1507,108 +1544,132 @@ void OperationLegalizer::buildLegalizationGraph(
 
     // Otherwise, if all of the generated operation are valid, this op is now
     // legal so add all of the child patterns to the worklist.
-    legalizerPatterns[pattern->getRootKind()].push_back(pattern);
-    invalidPatterns[pattern->getRootKind()].erase(pattern);
+    legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
+    invalidPatterns[*pattern->getRootKind()].erase(pattern);
 
     // Add any invalid patterns of the parent operations to see if they have now
     // become legal.
-    for (auto op : parentOps[pattern->getRootKind()])
+    for (auto op : parentOps[*pattern->getRootKind()])
       patternWorklist.set_union(invalidPatterns[op]);
   }
 }
 
 void OperationLegalizer::computeLegalizationGraphBenefit(
+    LegalizationPatterns &anyOpLegalizerPatterns,
     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
   // The smallest pattern depth, when legalizing an operation.
-  DenseMap<OperationName, unsigned> minPatternDepth;
-
-  // Compute the minimum legalization depth for a given operation.
-  std::function<unsigned(OperationName)> computeDepth = [&](OperationName op) {
-    // Check for existing depth.
-    auto depthIt = minPatternDepth.find(op);
-    if (depthIt != minPatternDepth.end())
-      return depthIt->second;
-
-    // If a mapping for this operation does not exist, then this operation
-    // is always legal. Return 0 as the depth for a directly legal operation.
-    auto opPatternsIt = legalizerPatterns.find(op);
-    if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
-      return 0u;
-
-    // Initialize the depth to the maximum value.
-    unsigned minDepth = std::numeric_limits<unsigned>::max();
-
-    // Record this initial depth in case we encounter this op again when
-    // recursively computing the depth.
-    minPatternDepth.try_emplace(op, minDepth);
-
-    // Compute the depth for each pattern used to legalize this operation.
-    SmallVector<std::pair<const RewritePattern *, unsigned>, 4> patternsByDepth;
-    patternsByDepth.reserve(opPatternsIt->second.size());
-    for (const RewritePattern *pattern : opPatternsIt->second) {
-      unsigned depth = 0;
-      for (auto generatedOp : pattern->getGeneratedOps())
-        depth = std::max(depth, computeDepth(generatedOp) + 1);
-      patternsByDepth.emplace_back(pattern, depth);
-
-      // Update the min depth for this operation.
-      minDepth = std::min(minDepth, depth);
-    }
-
-    // Update the pattern depth.
-    minPatternDepth[op] = minDepth;
-
-    // If the operation only has one legalization pattern, there is no need to
-    // sort them.
-    if (patternsByDepth.size() == 1)
-      return minDepth;
-
-    // Sort the patterns by those likely to be the most beneficial.
-    llvm::array_pod_sort(
-        patternsByDepth.begin(), patternsByDepth.end(),
-        [](const std::pair<const RewritePattern *, unsigned> *lhs,
-           const std::pair<const RewritePattern *, unsigned> *rhs) {
-          // First sort by the smaller pattern legalization depth.
-          if (lhs->second != rhs->second)
-            return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
-                                                             &rhs->second);
-
-          // Then sort by the larger pattern benefit.
-          auto lhsBenefit = lhs->first->getBenefit();
-          auto rhsBenefit = rhs->first->getBenefit();
-          return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
-                                                                 &lhsBenefit);
-        });
-
-    // Update the legalization pattern to use the new sorted list.
-    opPatternsIt->second.clear();
-    for (auto &patternIt : patternsByDepth)
-      opPatternsIt->second.push_back(patternIt.first);
-
-    return minDepth;
-  };
+  DenseMap<OperationName, unsigned> minOpPatternDepth;
 
   // For each operation that is transitively legal, compute a cost for it.
   for (auto &opIt : legalizerPatterns)
-    if (!minPatternDepth.count(opIt.first))
-      computeDepth(opIt.first);
+    if (!minOpPatternDepth.count(opIt.first))
+      computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
+                                 legalizerPatterns);
+
+  // Apply the cost model to the patterns that can match any operation. Those
+  // with a specific operation type are already resolved when computing the op
+  // legalization depth.
+  if (!anyOpLegalizerPatterns.empty())
+    applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
+                             legalizerPatterns);
 
   // Apply a cost model to the pattern applicator. We order patterns first by
   // depth then benefit. `legalizerPatterns` contains per-op patterns by
   // decreasing benefit.
   applicator.applyCostModel([&](const RewritePattern &p) {
-    auto &list = legalizerPatterns[p.getRootKind()];
+    ArrayRef<const RewritePattern *> orderedPatternList;
+    if (Optional<OperationName> rootName = p.getRootKind())
+      orderedPatternList = legalizerPatterns[*rootName];
+    else
+      orderedPatternList = anyOpLegalizerPatterns;
 
     // If the pattern is not found, then it was removed and cannot be matched.
-    LegalizationPatterns::iterator it = llvm::find(list, &p);
-    if (it == list.end())
+    auto it = llvm::find(orderedPatternList, &p);
+    if (it == orderedPatternList.end())
       return PatternBenefit::impossibleToMatch();
 
     // Patterns found earlier in the list have higher benefit.
-    return PatternBenefit(std::distance(it, list.end()));
+    return PatternBenefit(std::distance(it, orderedPatternList.end()));
   });
 }
 
+unsigned OperationLegalizer::computeOpLegalizationDepth(
+    OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
+    DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
+  // Check for existing depth.
+  auto depthIt = minOpPatternDepth.find(op);
+  if (depthIt != minOpPatternDepth.end())
+    return depthIt->second;
+
+  // If a mapping for this operation does not exist, then this operation
+  // is always legal. Return 0 as the depth for a directly legal operation.
+  auto opPatternsIt = legalizerPatterns.find(op);
+  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
+    return 0u;
+
+  // Record this initial depth in case we encounter this op again when
+  // recursively computing the depth.
+  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
+
+  // Apply the cost model to the operation patterns, and update the minimum
+  // depth.
+  unsigned minDepth = applyCostModelToPatterns(
+      opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
+  minOpPatternDepth[op] = minDepth;
+  return minDepth;
+}
+
+unsigned OperationLegalizer::applyCostModelToPatterns(
+    LegalizationPatterns &patterns,
+    DenseMap<OperationName, unsigned> &minOpPatternDepth,
+    DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
+  unsigned minDepth = std::numeric_limits<unsigned>::max();
+
+  // Compute the depth for each pattern within the set.
+  SmallVector<std::pair<const RewritePattern *, unsigned>, 4> patternsByDepth;
+  patternsByDepth.reserve(patterns.size());
+  for (const RewritePattern *pattern : patterns) {
+    unsigned depth = 0;
+    for (auto generatedOp : pattern->getGeneratedOps()) {
+      unsigned generatedOpDepth = computeOpLegalizationDepth(
+          generatedOp, minOpPatternDepth, legalizerPatterns);
+      depth = std::max(depth, generatedOpDepth + 1);
+    }
+    patternsByDepth.emplace_back(pattern, depth);
+
+    // Update the minimum depth of the pattern list.
+    minDepth = std::min(minDepth, depth);
+  }
+
+  // If the operation only has one legalization pattern, there is no need to
+  // sort them.
+  if (patternsByDepth.size() == 1)
+    return minDepth;
+
+  // Sort the patterns by those likely to be the most beneficial.
+  llvm::array_pod_sort(
+      patternsByDepth.begin(), patternsByDepth.end(),
+      [](const std::pair<const RewritePattern *, unsigned> *lhs,
+         const std::pair<const RewritePattern *, unsigned> *rhs) {
+        // First sort by the smaller pattern legalization depth.
+        if (lhs->second != rhs->second)
+          return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
+                                                           &rhs->second);
+
+        // Then sort by the larger pattern benefit.
+        auto lhsBenefit = lhs->first->getBenefit();
+        auto rhsBenefit = rhs->first->getBenefit();
+        return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
+                                                               &lhsBenefit);
+      });
+
+  // Update the legalization pattern to use the new sorted list.
+  patterns.clear();
+  for (auto &patternIt : patternsByDepth)
+    patterns.push_back(patternIt.first);
+  return minDepth;
+}
+
 //===----------------------------------------------------------------------===//
 // OperationConverter
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/test-legalize-unknown-root.mlir b/mlir/test/Transforms/test-legalize-unknown-root.mlir
new file mode 100644
index 000000000000..c31840d31a93
--- /dev/null
+++ b/mlir/test/Transforms/test-legalize-unknown-root.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -test-legalize-unknown-root-patterns | FileCheck %s
+
+// Test that all `test` dialect operations are removed.
+// CHECK-LABEL: func @remove_all_ops
+func @remove_all_ops(%arg0: i32) {
+  // CHECK-NEXT: return
+  %0 = "test.illegal_op_a"() : () -> i32
+  %1 = "test.illegal_op_b"() : () -> i32
+  %2 = "test.illegal_op_c"() : () -> i32
+  %3 = "test.illegal_op_d"() : () -> i32
+  %4 = "test.illegal_op_e"() : () -> i32
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 2f9577c5b4b4..cbab7d7494da 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -701,18 +701,50 @@ struct TestRemappedValue
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// Test patterns without a specific root operation kind
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This pattern matches and removes any operation in the test dialect.
+struct RemoveTestDialectOps : public RewritePattern {
+  RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (!isa<TestDialect>(op->getDialect()))
+      return failure();
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct TestUnknownRootOpDriver
+    : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
+  void runOnFunction() override {
+    mlir::OwningRewritePatternList patterns;
+    patterns.insert<RemoveTestDialectOps>();
+
+    mlir::ConversionTarget target(getContext());
+    target.addIllegalDialect<TestDialect>();
+    if (failed(applyPartialConversion(getFunction(), target, patterns)))
+      signalPassFailure();
+  }
+};
+} // end anonymous namespace
+
 namespace mlir {
 void registerPatternsTestPass() {
-  mlir::PassRegistration<TestReturnTypeDriver>("test-return-type",
-                                               "Run return type functions");
+  PassRegistration<TestReturnTypeDriver>("test-return-type",
+                                         "Run return type functions");
 
-  mlir::PassRegistration<TestDerivedAttributeDriver>(
-      "test-derived-attr", "Run test derived attributes");
+  PassRegistration<TestDerivedAttributeDriver>("test-derived-attr",
+                                               "Run test derived attributes");
 
-  mlir::PassRegistration<TestPatternDriver>("test-patterns",
-                                            "Run test dialect patterns");
+  PassRegistration<TestPatternDriver>("test-patterns",
+                                      "Run test dialect patterns");
 
-  mlir::PassRegistration<TestLegalizePatternDriver>(
+  PassRegistration<TestLegalizePatternDriver>(
       "test-legalize-patterns", "Run test dialect legalization patterns", [] {
         return std::make_unique<TestLegalizePatternDriver>(
             legalizerConversionMode);
@@ -721,5 +753,9 @@ void registerPatternsTestPass() {
   PassRegistration<TestRemappedValue>(
       "test-remapped-value",
       "Test public remapped value mechanism in ConversionPatternRewriter");
+
+  PassRegistration<TestUnknownRootOpDriver>(
+      "test-legalize-unknown-root-patterns",
+      "Test public remapped value mechanism in ConversionPatternRewriter");
 }
 } // namespace mlir


        


More information about the Mlir-commits mailing list