[Mlir-commits] [mlir] 138803e - [MLIR][PDL] Make predicate order deterministic.

Uday Bondhugula llvmlistbot at llvm.org
Mon Jan 3 18:39:53 PST 2022


Author: Stanislav Funiak
Date: 2022-01-04T08:03:44+05:30
New Revision: 138803e017739c81b43b73631c7096bfc4d097d8

URL: https://github.com/llvm/llvm-project/commit/138803e017739c81b43b73631c7096bfc4d097d8
DIFF: https://github.com/llvm/llvm-project/commit/138803e017739c81b43b73631c7096bfc4d097d8.diff

LOG: [MLIR][PDL] Make predicate order deterministic.

The tree merging of pattern predicates places the predicates in an unordered set. When the predicates are sorted, they are taken in the set order, not the insertion order. This results in nondeterministic behavior.

One solution to this problem would be to use `SetVector`. However, the value `SetVector` does not provide a `find` function for fast O(1) lookups and stores the predicates twice -- once in the set and once in the vector, which is undesirable, because we store patternToAnswer in each predicate. A simpler solution is to store the tie breaking ID (which follows the insertion order), and use this ID to break any ties when comparing predicates.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116081

Added: 
    

Modified: 
    mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 24b2f19e58c2..9fd5de11a83d 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -721,6 +721,11 @@ struct OrderedPredicate {
   /// opposed to those shared across patterns.
   unsigned secondary = 0;
 
+  /// The tie breaking ID, used to preserve a deterministic (insertion) order
+  /// among all the predicates with the same priority, depth, and position /
+  /// predicate dependency.
+  unsigned id = 0;
+
   /// A map between a pattern operation and the answer to the predicate question
   /// within that pattern.
   DenseMap<Operation *, Qualifier *> patternToAnswer;
@@ -733,12 +738,13 @@ struct OrderedPredicate {
     // * lower depth
     // * lower position dependency
     // * lower predicate dependency
+    // * lower tie breaking ID
     auto *rhsPos = rhs.position;
     return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
-                           rhsPos->getKind(), rhs.question->getKind()) >
+                           rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
            std::make_tuple(rhs.primary, rhs.secondary,
                            position->getOperationDepth(), position->getKind(),
-                           question->getKind());
+                           question->getKind(), id);
   }
 };
 
@@ -903,6 +909,9 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
       auto it = uniqued.insert(predicate);
       it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
                                             predicate.answer);
+      // Mark the insertion order (0-based indexing).
+      if (it.second)
+        it.first->id = uniqued.size() - 1;
     }
   }
 
@@ -939,9 +948,9 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
   ordered.reserve(uniqued.size());
   for (auto &ip : uniqued)
     ordered.push_back(&ip);
-  std::stable_sort(
-      ordered.begin(), ordered.end(),
-      [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
+  llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
+    return *lhs < *rhs;
+  });
 
   // Build the matchers for each of the pattern predicate lists.
   std::unique_ptr<MatcherNode> root;


        


More information about the Mlir-commits mailing list