[Mlir-commits] [mlir] 0557c6a - [mlir] Keep sorted vector of registered operation names for efficient lookup
Eugene Zhulenev
llvmlistbot at llvm.org
Thu Feb 3 14:19:39 PST 2022
Author: Eugene Zhulenev
Date: 2022-02-03T14:19:34-08:00
New Revision: 0557c6a7970d174d4c575b667590490d5c9a3539
URL: https://github.com/llvm/llvm-project/commit/0557c6a7970d174d4c575b667590490d5c9a3539
DIFF: https://github.com/llvm/llvm-project/commit/0557c6a7970d174d4c575b667590490d5c9a3539.diff
LOG: [mlir] Keep sorted vector of registered operation names for efficient lookup
I see a lot of array sorting in stack traces of our compiler, canonicalizer traverses this list every time it builds a pattern set, and it gets expensive very quickly.
Reviewed By: rriddle, mehdi_amini
Differential Revision: https://reviews.llvm.org/D118937
Added:
Modified:
mlir/include/mlir/IR/MLIRContext.h
mlir/lib/IR/MLIRContext.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 659cd71265592..a66399dee71d7 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -176,10 +176,9 @@ class MLIRContext {
/// emitting diagnostics.
void printStackTraceOnDiagnostic(bool enable);
- /// Return information about all registered operations. This isn't very
- /// efficient: typically you should ask the operations about their properties
- /// directly.
- std::vector<RegisteredOperationName> getRegisteredOperations();
+ /// Return a sorted array containing the information about all registered
+ /// operations.
+ ArrayRef<RegisteredOperationName> getRegisteredOperations();
/// Return true if this operation name is registered in this context.
bool isOperationRegistered(StringRef name);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index a144769a3fd3b..6c4dcaca311ef 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -184,6 +184,10 @@ class MLIRContextImpl {
/// A vector of operation info specifically for registered operations.
llvm::StringMap<RegisteredOperationName> registeredOperations;
+ /// This is a sorted container of registered operations for a deterministic
+ /// and efficient `getRegisteredOperations` implementation.
+ SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
+
/// A mutex used when accessing operation information.
llvm::sys::SmartRWMutex<true> operationInfoMutex;
@@ -569,24 +573,9 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
impl->printStackTraceOnDiagnostic = enable;
}
-/// Return information about all registered operations. This isn't very
-/// efficient, typically you should ask the operations about their properties
-/// directly.
-std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
- // We just have the operations in a non-deterministic hash table order. Dump
- // into a temporary array, then sort it by operation name to get a stable
- // ordering.
- auto unwrappedNames = llvm::make_second_range(impl->registeredOperations);
- std::vector<RegisteredOperationName> result(unwrappedNames.begin(),
- unwrappedNames.end());
- llvm::array_pod_sort(result.begin(), result.end(),
- [](const RegisteredOperationName *lhs,
- const RegisteredOperationName *rhs) {
- return lhs->getIdentifier().compare(
- rhs->getIdentifier());
- });
-
- return result;
+/// Return information about all registered operations.
+ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
+ return impl->sortedRegisteredOperations;
}
bool MLIRContext::isOperationRegistered(StringRef name) {
@@ -736,8 +725,19 @@ void RegisteredOperationName::insert(
<< "' is already registered.\n";
abort();
}
- ctxImpl.registeredOperations.try_emplace(name,
- RegisteredOperationName(&impl));
+ auto emplaced = ctxImpl.registeredOperations.try_emplace(
+ name, RegisteredOperationName(&impl));
+ assert(emplaced.second && "operation name registration must be successful");
+
+ // Add emplaced operation name to the sorted operations container.
+ RegisteredOperationName &value = emplaced.first->getValue();
+ ctxImpl.sortedRegisteredOperations.insert(
+ llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
+ [](auto &lhs, auto &rhs) {
+ return lhs.getIdentifier().compare(
+ rhs.getIdentifier());
+ }),
+ value);
// Update the registered info for this operation.
impl.dialect = &dialect;
More information about the Mlir-commits
mailing list