[Mlir-commits] [mlir] b091701 - [mlir] Add a method on MLIRContext to retrieve the operations for a given dialect (#112344)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 17 03:02:28 PDT 2024
Author: Rajveer Singh Bharadwaj
Date: 2024-10-17T12:02:24+02:00
New Revision: b091701d0190912578ac3fe91ee8fd29e9b6de6e
URL: https://github.com/llvm/llvm-project/commit/b091701d0190912578ac3fe91ee8fd29e9b6de6e
DIFF: https://github.com/llvm/llvm-project/commit/b091701d0190912578ac3fe91ee8fd29e9b6de6e.diff
LOG: [mlir] Add a method on MLIRContext to retrieve the operations for a given dialect (#112344)
Currently we have `MLIRContext::getRegisteredOperations` which returns
all operations for the given context, with the addition of
`MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the
same for a given dialect class.
Closes #111591
Added:
Modified:
mlir/include/mlir/IR/MLIRContext.h
mlir/lib/IR/MLIRContext.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d17bbac81655b5..ef8dab87f131a1 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -197,6 +197,11 @@ class MLIRContext {
/// operations.
ArrayRef<RegisteredOperationName> getRegisteredOperations();
+ /// Return a sorted array containing the information for registered operations
+ /// filtered by dialect name.
+ ArrayRef<RegisteredOperationName>
+ getRegisteredOperationsByDialect(StringRef dialectName);
+
/// Return true if this operation name is registered in this context.
bool isOperationRegistered(StringRef name);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f05666fcde207b..d33340f4aefc85 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -711,6 +711,30 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
return impl->sortedRegisteredOperations;
}
+/// Return information for registered operations by dialect.
+ArrayRef<RegisteredOperationName>
+MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
+ auto lowerBound =
+ std::lower_bound(impl->sortedRegisteredOperations.begin(),
+ impl->sortedRegisteredOperations.end(), dialectName,
+ [](auto &lhs, auto &rhs) {
+ return lhs.getDialect().getNamespace().compare(rhs);
+ });
+
+ if (lowerBound == impl->sortedRegisteredOperations.end() ||
+ lowerBound->getDialect().getNamespace() != dialectName)
+ return ArrayRef<RegisteredOperationName>();
+
+ auto upperBound =
+ std::upper_bound(lowerBound, impl->sortedRegisteredOperations.end(),
+ dialectName, [](auto &lhs, auto &rhs) {
+ return lhs.compare(rhs.getDialect().getNamespace());
+ });
+
+ size_t count = std::distance(lowerBound, upperBound);
+ return ArrayRef(&*lowerBound, count);
+}
+
bool MLIRContext::isOperationRegistered(StringRef name) {
return RegisteredOperationName::lookup(name, this).has_value();
}
More information about the Mlir-commits
mailing list