[Mlir-commits] [mlir] [mlir] Method to iterate over registered operations for a given dialect class. (PR #112344)
Rajveer Singh Bharadwaj
llvmlistbot at llvm.org
Wed Oct 16 11:20:46 PDT 2024
https://github.com/Rajveer100 updated https://github.com/llvm/llvm-project/pull/112344
>From 859a8ac73791615d74514582e27178ebdde7068a Mon Sep 17 00:00:00 2001
From: Rajveer <rajveer.developer at icloud.com>
Date: Tue, 15 Oct 2024 16:26:08 +0530
Subject: [PATCH] [mlir] Method to iterate over registered operations for a
given dialect class.
Part of #111591
Currently we have `MLIRContext::getRegisteredOperations` which returns all operations
for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect`
we can now retrieve the same for a given dialect class.
---
mlir/include/mlir/IR/MLIRContext.h | 5 +++++
mlir/lib/IR/MLIRContext.cpp | 29 +++++++++++++++++++++++++++--
2 files changed, 32 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d17bbac81655b5..ef8dab87f131a1 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -197,6 +197,11 @@ class MLIRContext {
/// operations.
ArrayRef<RegisteredOperationName> getRegisteredOperations();
+ /// Return a sorted array containing the information for registered operations
+ /// filtered by dialect name.
+ ArrayRef<RegisteredOperationName>
+ getRegisteredOperationsByDialect(StringRef dialectName);
+
/// Return true if this operation name is registered in this context.
bool isOperationRegistered(StringRef name);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f05666fcde207b..85cab594185f3f 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -190,6 +190,9 @@ class MLIRContextImpl {
/// and efficient `getRegisteredOperations` implementation.
SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
+ /// This returns the number of registered operations for a given dialect.
+ DenseMap<StringRef, size_t> getCountByDialectName;
+
/// This is a list of dialects that are created referring to this context.
/// The MLIRContext owns the objects. These need to be declared after the
/// registered operations to ensure correct destruction order.
@@ -711,6 +714,21 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
return impl->sortedRegisteredOperations;
}
+/// Return information for registered operations by dialect.
+ArrayRef<RegisteredOperationName>
+MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
+ auto lowerBound = std::lower_bound(
+ impl->sortedRegisteredOperations.begin(),
+ impl->sortedRegisteredOperations.end(), std::make_pair(dialectName, ""),
+ [](auto &lhs, auto &rhs) {
+ return lhs.getDialect().getNamespace().compare(
+ rhs.getDialect.getNamespace());
+ });
+ auto count = impl->getCountByDialectName[dialectName];
+
+ return ArrayRef(impl->sortedRegisteredOperations.data(), count);
+}
+
bool MLIRContext::isOperationRegistered(StringRef name) {
return RegisteredOperationName::lookup(name, this).has_value();
}
@@ -976,12 +994,19 @@ void RegisteredOperationName::insert(
"operation name registration must be successful");
// Add emplaced operation name to the sorted operations container.
+ StringRef dialectClass = impl->getDialect()->getNamespace();
+ ctxImpl.getCountByDialectName[dialectClass] += 1;
+
RegisteredOperationName &value = emplaced.first->second;
ctxImpl.sortedRegisteredOperations.insert(
llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
[](auto &lhs, auto &rhs) {
- return lhs.getIdentifier().compare(
- rhs.getIdentifier());
+ if (lhs.getDialect().getNamespace() ==
+ rhs.getDialect().getNamespace())
+ return lhs.getIdentifier().compare(
+ rhs.getIdentifier());
+ return lhs.getDialect().getNamespace().compare(
+ rhs.getDialect().getNamespace());
}),
value);
}
More information about the Mlir-commits
mailing list