[Mlir-commits] [mlir] 0557c6a - [mlir] Keep sorted vector of registered operation names for efficient lookup

Eugene Zhulenev llvmlistbot at llvm.org
Thu Feb 3 14:19:39 PST 2022


Author: Eugene Zhulenev
Date: 2022-02-03T14:19:34-08:00
New Revision: 0557c6a7970d174d4c575b667590490d5c9a3539

URL: https://github.com/llvm/llvm-project/commit/0557c6a7970d174d4c575b667590490d5c9a3539
DIFF: https://github.com/llvm/llvm-project/commit/0557c6a7970d174d4c575b667590490d5c9a3539.diff

LOG: [mlir] Keep sorted vector of registered operation names for efficient lookup

I see a lot of array sorting in stack traces of our compiler, canonicalizer traverses this list every time it builds a pattern set, and it gets expensive very quickly.

Reviewed By: rriddle, mehdi_amini

Differential Revision: https://reviews.llvm.org/D118937

Added: 
    

Modified: 
    mlir/include/mlir/IR/MLIRContext.h
    mlir/lib/IR/MLIRContext.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 659cd71265592..a66399dee71d7 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -176,10 +176,9 @@ class MLIRContext {
   /// emitting diagnostics.
   void printStackTraceOnDiagnostic(bool enable);
 
-  /// Return information about all registered operations.  This isn't very
-  /// efficient: typically you should ask the operations about their properties
-  /// directly.
-  std::vector<RegisteredOperationName> getRegisteredOperations();
+  /// Return a sorted array containing the information about all registered
+  /// operations.
+  ArrayRef<RegisteredOperationName> getRegisteredOperations();
 
   /// Return true if this operation name is registered in this context.
   bool isOperationRegistered(StringRef name);

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index a144769a3fd3b..6c4dcaca311ef 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -184,6 +184,10 @@ class MLIRContextImpl {
   /// A vector of operation info specifically for registered operations.
   llvm::StringMap<RegisteredOperationName> registeredOperations;
 
+  /// This is a sorted container of registered operations for a deterministic
+  /// and efficient `getRegisteredOperations` implementation.
+  SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
+
   /// A mutex used when accessing operation information.
   llvm::sys::SmartRWMutex<true> operationInfoMutex;
 
@@ -569,24 +573,9 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
   impl->printStackTraceOnDiagnostic = enable;
 }
 
-/// Return information about all registered operations.  This isn't very
-/// efficient, typically you should ask the operations about their properties
-/// directly.
-std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
-  // We just have the operations in a non-deterministic hash table order. Dump
-  // into a temporary array, then sort it by operation name to get a stable
-  // ordering.
-  auto unwrappedNames = llvm::make_second_range(impl->registeredOperations);
-  std::vector<RegisteredOperationName> result(unwrappedNames.begin(),
-                                              unwrappedNames.end());
-  llvm::array_pod_sort(result.begin(), result.end(),
-                       [](const RegisteredOperationName *lhs,
-                          const RegisteredOperationName *rhs) {
-                         return lhs->getIdentifier().compare(
-                             rhs->getIdentifier());
-                       });
-
-  return result;
+/// Return information about all registered operations.
+ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
+  return impl->sortedRegisteredOperations;
 }
 
 bool MLIRContext::isOperationRegistered(StringRef name) {
@@ -736,8 +725,19 @@ void RegisteredOperationName::insert(
                  << "' is already registered.\n";
     abort();
   }
-  ctxImpl.registeredOperations.try_emplace(name,
-                                           RegisteredOperationName(&impl));
+  auto emplaced = ctxImpl.registeredOperations.try_emplace(
+      name, RegisteredOperationName(&impl));
+  assert(emplaced.second && "operation name registration must be successful");
+
+  // Add emplaced operation name to the sorted operations container.
+  RegisteredOperationName &value = emplaced.first->getValue();
+  ctxImpl.sortedRegisteredOperations.insert(
+      llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
+                        [](auto &lhs, auto &rhs) {
+                          return lhs.getIdentifier().compare(
+                              rhs.getIdentifier());
+                        }),
+      value);
 
   // Update the registered info for this operation.
   impl.dialect = &dialect;


        


More information about the Mlir-commits mailing list