[Mlir-commits] [mlir] 06c3b9c - [mlir:PDL] Fix bugs in PDLPatternModule merging

River Riddle llvmlistbot at llvm.org
Fri Dec 10 11:50:35 PST 2021


Author: River Riddle
Date: 2021-12-10T19:38:43Z
New Revision: 06c3b9c7be727af0edb0ebe73375776b491c6b0b

URL: https://github.com/llvm/llvm-project/commit/06c3b9c7be727af0edb0ebe73375776b491c6b0b
DIFF: https://github.com/llvm/llvm-project/commit/06c3b9c7be727af0edb0ebe73375776b491c6b0b.diff

LOG: [mlir:PDL] Fix bugs in PDLPatternModule merging

* Constraints/Rewrites registered before a pattern was added were dropped
* Constraints/Rewrites may be registered multiple times (if different pattern sets depend on them)
* ModuleOp no longer has a terminator, so we shouldn't be removing the terminator from it

Differential Revision: https://reviews.llvm.org/D114816

Added: 
    

Modified: 
    mlir/lib/IR/PatternMatch.cpp
    mlir/test/lib/Rewrite/TestPDLByteCode.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 39d8bad2bbd26..56063d05e0e14 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -157,22 +157,21 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
   // Ignore the other module if it has no patterns.
   if (!other.pdlModule)
     return;
+
+  // Steal the functions of the other module.
+  for (auto &it : other.constraintFunctions)
+    registerConstraintFunction(it.first(), std::move(it.second));
+  for (auto &it : other.rewriteFunctions)
+    registerRewriteFunction(it.first(), std::move(it.second));
+
   // Steal the other state if we have no patterns.
   if (!pdlModule) {
-    constraintFunctions = std::move(other.constraintFunctions);
-    rewriteFunctions = std::move(other.rewriteFunctions);
     pdlModule = std::move(other.pdlModule);
     return;
   }
-  // Steal the functions of the other module.
-  for (auto &it : constraintFunctions)
-    registerConstraintFunction(it.first(), std::move(it.second));
-  for (auto &it : rewriteFunctions)
-    registerRewriteFunction(it.first(), std::move(it.second));
 
   // Merge the pattern operations from the other module into this one.
   Block *block = pdlModule->getBody();
-  block->getTerminator()->erase();
   block->getOperations().splice(block->end(),
                                 other.pdlModule->getBody()->getOperations());
 }
@@ -182,18 +181,20 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
 
 void PDLPatternModule::registerConstraintFunction(
     StringRef name, PDLConstraintFunction constraintFn) {
-  auto it = constraintFunctions.try_emplace(name, std::move(constraintFn));
-  (void)it;
-  assert(it.second &&
-         "constraint with the given name has already been registered");
+  // TODO: Is it possible to diagnose when `name` is already registered to
+  // a function that is not equivalent to `constraintFn`?
+  // Allow existing mappings in the case multiple patterns depend on the same
+  // constraint.
+  constraintFunctions.try_emplace(name, std::move(constraintFn));
 }
 
 void PDLPatternModule::registerRewriteFunction(StringRef name,
                                                PDLRewriteFunction rewriteFn) {
-  auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));
-  (void)it;
-  assert(it.second && "native rewrite function with the given name has "
-                      "already been registered");
+  // TODO: Is it possible to diagnose when `name` is already registered to
+  // a function that is not equivalent to `rewriteFn`?
+  // Allow existing mappings in the case multiple patterns depend on the same
+  // rewrite.
+  rewriteFunctions.try_emplace(name, std::move(rewriteFn));
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 1f893e6d82304..ef62d73978d8b 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -87,13 +87,27 @@ struct TestPDLByteCodePass
     if (!patternModule || !irModule)
       return;
 
+    RewritePatternSet patternList(module->getContext());
+
+    // Register ahead of time to test when functions are registered without a
+    // pattern.
+    patternList.getPDLPatterns().registerConstraintFunction(
+        "multi_entity_constraint", customMultiEntityConstraint);
+    patternList.getPDLPatterns().registerConstraintFunction(
+        "single_entity_constraint", customSingleEntityConstraint);
+
     // Process the pattern module.
     patternModule.getOperation()->remove();
     PDLPatternModule pdlPattern(patternModule);
+
+    // Note: This constraint was already registered, but we re-register here to
+    // ensure that duplication registration is allowed (the duplicate mapping
+    // will be ignored). This tests that we support separating the registration
+    // of library functions from the construction of patterns, and also that we
+    // allow multiple patterns to depend on the same library functions (without
+    // asserting/crashing).
     pdlPattern.registerConstraintFunction("multi_entity_constraint",
                                           customMultiEntityConstraint);
-    pdlPattern.registerConstraintFunction("single_entity_constraint",
-                                          customSingleEntityConstraint);
     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
                                           customMultiEntityVariadicConstraint);
     pdlPattern.registerRewriteFunction("creator", customCreate);
@@ -101,8 +115,7 @@ struct TestPDLByteCodePass
                                        customVariadicResultCreate);
     pdlPattern.registerRewriteFunction("type_creator", customCreateType);
     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
-
-    RewritePatternSet patternList(std::move(pdlPattern));
+    patternList.add(std::move(pdlPattern));
 
     // Invoke the pattern driver with the provided patterns.
     (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),


        


More information about the Mlir-commits mailing list