[Mlir-commits] [mlir] [mlir] Method to iterate over registered operations for a given dialect class. (PR #112344)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 15 04:04:03 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rajveer Singh Bharadwaj (Rajveer100)
<details>
<summary>Changes</summary>
Part of #<!-- -->111591
Currently we have `MLIRContext::getRegisteredOperations` which returns all operations for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the same for a given dialect class.
---
Full diff: https://github.com/llvm/llvm-project/pull/112344.diff
3 Files Affected:
- (modified) mlir/include/mlir/IR/MLIRContext.h (+6-1)
- (modified) mlir/lib/IR/MLIRContext.cpp (+44-11)
- (modified) mlir/lib/Rewrite/FrozenRewritePatternSet.cpp (+1-1)
``````````diff
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d17bbac81655b5..cfad6874b8f4a9 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -195,7 +195,12 @@ class MLIRContext {
/// Return a sorted array containing the information about all registered
/// operations.
- ArrayRef<RegisteredOperationName> getRegisteredOperations();
+ SmallVector<RegisteredOperationName, 0> getRegisteredOperations();
+
+ /// Return a sorted array containing the information for registered operations
+ /// filtered by dialect name.
+ SmallVector<RegisteredOperationName, 0>
+ getRegisteredOperationsByDialect(StringRef dialectName);
/// Return true if this operation name is registered in this context.
bool isOperationRegistered(StringRef name);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f05666fcde207b..bb0da94e985517 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -188,7 +188,11 @@ class MLIRContextImpl {
/// This is a sorted container of registered operations for a deterministic
/// and efficient `getRegisteredOperations` implementation.
- SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
+ SmallVector<std::pair<StringRef, RegisteredOperationName>, 0>
+ sortedRegisteredOperations;
+
+ /// This returns the number of registered operations for a given dialect.
+ llvm::DenseMap<StringRef, size_t> getCountByDialectName;
/// This is a list of dialects that are created referring to this context.
/// The MLIRContext owns the objects. These need to be declared after the
@@ -707,8 +711,31 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
}
/// Return information about all registered operations.
-ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
- return impl->sortedRegisteredOperations;
+SmallVector<RegisteredOperationName, 0> MLIRContext::getRegisteredOperations() {
+ SmallVector<RegisteredOperationName, 0> operations;
+ std::transform(impl->sortedRegisteredOperations.begin(),
+ impl->sortedRegisteredOperations.end(),
+ std::back_inserter(operations),
+ [](const auto &t) { return t.second; });
+
+ return operations;
+}
+
+/// Return information for registered operations by dialect.
+SmallVector<RegisteredOperationName, 0>
+MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
+ SmallVector<RegisteredOperationName, 0> operations;
+
+ auto lowerBound = std::lower_bound(
+ impl->sortedRegisteredOperations.begin(),
+ impl->sortedRegisteredOperations.end(), std::make_pair(dialectName, ""),
+ [](auto &lhs, auto &rhs) { return lhs.first.compare(rhs.first); });
+ auto count = impl->getCountByDialectName[dialectName];
+
+ std::transform(lowerBound, lowerBound + count, std::back_inserter(operations),
+ [](const auto &t) { return t.second; });
+
+ return operations;
}
bool MLIRContext::isOperationRegistered(StringRef name) {
@@ -976,14 +1003,20 @@ void RegisteredOperationName::insert(
"operation name registration must be successful");
// Add emplaced operation name to the sorted operations container.
- RegisteredOperationName &value = emplaced.first->second;
- ctxImpl.sortedRegisteredOperations.insert(
- llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
- [](auto &lhs, auto &rhs) {
- return lhs.getIdentifier().compare(
- rhs.getIdentifier());
- }),
- value);
+ StringRef dialectClass = impl->getDialect()->getNamespace();
+ ctxImpl.getCountByDialectName[dialectClass] += 1;
+
+ std::pair<StringRef, RegisteredOperationName> value = {
+ dialectClass, emplaced.first->second};
+
+ auto upperBound = llvm::upper_bound(
+ ctxImpl.sortedRegisteredOperations, value, [](auto &lhs, auto &rhs) {
+ if (lhs.first == rhs.first)
+ return lhs.second.getIdentifier().compare(rhs.second.getIdentifier());
+ return lhs.first.compare(rhs.first);
+ });
+
+ ctxImpl.sortedRegisteredOperations.insert(upperBound, value);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
index 17fe02df9f66cd..d3317fc6d4fe30 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
@@ -73,7 +73,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
// Functor used to walk all of the operations registered in the context. This
// is useful for patterns that get applied to multiple operations, such as
// interface and trait based patterns.
- std::vector<RegisteredOperationName> opInfos;
+ SmallVector<RegisteredOperationName> opInfos;
auto addToOpsWhen =
[&](std::unique_ptr<RewritePattern> &pattern,
function_ref<bool(RegisteredOperationName)> callbackFn) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/112344
More information about the Mlir-commits
mailing list