[Mlir-commits] [mlir] fa51c5a - [mlir] Resolve TODO and use the pass argument instead of the TypeID for registration

River Riddle llvmlistbot at llvm.org
Wed Jun 2 12:17:55 PDT 2021


Author: River Riddle
Date: 2021-06-02T12:17:36-07:00
New Revision: fa51c5af5d5de25a7824a939e90734ae5ca5448d

URL: https://github.com/llvm/llvm-project/commit/fa51c5af5d5de25a7824a939e90734ae5ca5448d
DIFF: https://github.com/llvm/llvm-project/commit/fa51c5af5d5de25a7824a939e90734ae5ca5448d.diff

LOG: [mlir] Resolve TODO and use the pass argument instead of the TypeID for registration

This simplifies various pieces of code that interact with the pass registry, e.g. this removes the need to register passes to get accurate pass pipelines descriptions when generating crash reproducers.

Differential Revision: https://reviews.llvm.org/D101880

Added: 
    

Modified: 
    mlir/include/mlir/Pass/Pass.h
    mlir/include/mlir/Pass/PassRegistry.h
    mlir/lib/Pass/PassRegistry.cpp
    mlir/test/lib/Pass/TestPassManager.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 42df2dc8bd08..67c695467706 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -56,13 +56,12 @@ class Pass {
   TypeID getTypeID() const { return passID; }
 
   /// Returns the pass info for the specified pass class or null if unknown.
-  static const PassInfo *lookupPassInfo(TypeID passID);
-  template <typename PassT> static const PassInfo *lookupPassInfo() {
-    return lookupPassInfo(TypeID::get<PassT>());
-  }
+  static const PassInfo *lookupPassInfo(StringRef passArg);
 
-  /// Returns the pass info for this pass.
-  const PassInfo *lookupPassInfo() const { return lookupPassInfo(getTypeID()); }
+  /// Returns the pass info for this pass, or null if unknown.
+  const PassInfo *lookupPassInfo() const {
+    return lookupPassInfo(getArgument());
+  }
 
   /// Returns the derived pass name.
   virtual StringRef getName() const = 0;
@@ -76,11 +75,7 @@ class Pass {
 
   /// Returns the command line argument used when registering this pass. Return
   /// an empty string if one does not exist.
-  virtual StringRef getArgument() const {
-    if (const PassInfo *passInfo = lookupPassInfo())
-      return passInfo->getPassArgument();
-    return "";
-  }
+  virtual StringRef getArgument() const { return ""; }
 
   /// Returns the name of the operation that this pass operates on, or None if
   /// this is a generic OperationPass.

diff  --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h
index 8def0f31ad15..d03aaf8dfd25 100644
--- a/mlir/include/mlir/Pass/PassRegistry.h
+++ b/mlir/include/mlir/Pass/PassRegistry.h
@@ -108,7 +108,7 @@ class PassInfo : public PassRegistryEntry {
 public:
   /// PassInfo constructor should not be invoked directly, instead use
   /// PassRegistration or registerPass.
-  PassInfo(StringRef arg, StringRef description, TypeID passID,
+  PassInfo(StringRef arg, StringRef description,
            const PassAllocatorFunction &allocator);
 };
 

diff  --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index e53113eb968e..2c690a2659ac 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -19,7 +19,11 @@ using namespace mlir;
 using namespace detail;
 
 /// Static mapping of all of the registered passes.
-static llvm::ManagedStatic<DenseMap<TypeID, PassInfo>> passRegistry;
+static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
+
+/// A mapping of the above pass registry entries to the corresponding TypeID
+/// of the pass that they generate.
+static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
 
 /// Static mapping of all of the registered pass pipelines.
 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
@@ -94,7 +98,7 @@ void mlir::registerPassPipeline(
 // PassInfo
 //===----------------------------------------------------------------------===//
 
-PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
+PassInfo::PassInfo(StringRef arg, StringRef description,
                    const PassAllocatorFunction &allocator)
     : PassRegistryEntry(
           arg, description, buildDefaultRegistryFn(allocator),
@@ -105,18 +109,23 @@ PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
 
 void mlir::registerPass(StringRef arg, StringRef description,
                         const PassAllocatorFunction &function) {
-  // TODO: We should use the 'arg' as the lookup key instead of the pass id.
-  TypeID passID = function()->getTypeID();
-  PassInfo passInfo(arg, description, passID, function);
-  passRegistry->try_emplace(passID, passInfo);
+  PassInfo passInfo(arg, description, function);
+  passRegistry->try_emplace(arg, passInfo);
+
+  // Verify that the registered pass has the same ID as any registered to this
+  // arg before it.
+  TypeID entryTypeID = function()->getTypeID();
+  auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
+  if (it->second != entryTypeID) {
+    llvm_unreachable("pass allocator creates a 
diff erent pass than previously "
+                     "registered");
+  }
 }
 
-/// Returns the pass info for the specified pass class or null if unknown.
-const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) {
-  auto it = passRegistry->find(passID);
-  if (it == passRegistry->end())
-    return nullptr;
-  return &it->getSecond();
+/// Returns the pass info for the specified pass argument or null if unknown.
+const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
+  auto it = passRegistry->find(passArg);
+  return it == passRegistry->end() ? nullptr : &it->second;
 }
 
 //===----------------------------------------------------------------------===//
@@ -433,12 +442,8 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element,
   }
 
   // If not, then this must be a specific pass name.
-  for (auto &passIt : *passRegistry) {
-    if (passIt.second.getPassArgument() == element.name) {
-      element.registryEntry = &passIt.second;
-      return success();
-    }
-  }
+  if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
+    return success();
 
   // Emit an error for the unknown pass.
   auto *rawLoc = element.name.data();

diff  --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 937a5c2317c2..6e5a5b9de8ba 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -16,9 +16,11 @@ namespace {
 struct TestModulePass
     : public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
   void runOnOperation() final {}
+  StringRef getArgument() const final { return "test-module-pass"; }
 };
 struct TestFunctionPass : public PassWrapper<TestFunctionPass, FunctionPass> {
   void runOnFunction() final {}
+  StringRef getArgument() const final { return "test-function-pass"; }
 };
 class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> {
 public:
@@ -41,6 +43,7 @@ class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> {
   }
 
   void runOnFunction() final {}
+  StringRef getArgument() const final { return "test-options-pass"; }
 
   ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
                              llvm::cl::desc("Example list option")};
@@ -56,6 +59,7 @@ class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> {
 class TestCrashRecoveryPass
     : public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
   void runOnOperation() final { abort(); }
+  StringRef getArgument() const final { return "test-pass-crash"; }
 };
 
 /// A test pass that always fails to enable testing the failure recovery


        


More information about the Mlir-commits mailing list