[Mlir-commits] [mlir] [mlir] Method to iterate over registered operations for a given dialect class. (PR #112344)

Rajveer Singh Bharadwaj llvmlistbot at llvm.org
Wed Oct 16 12:20:18 PDT 2024


https://github.com/Rajveer100 updated https://github.com/llvm/llvm-project/pull/112344

>From 469021b7630d607e049d8e359c9cf9e8ec5ef7b1 Mon Sep 17 00:00:00 2001
From: Rajveer <rajveer.developer at icloud.com>
Date: Tue, 15 Oct 2024 16:26:08 +0530
Subject: [PATCH] [mlir] Method to iterate over registered operations for a
 given dialect class.

Part of #111591

Currently we have `MLIRContext::getRegisteredOperations` which returns all operations
for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect`
we can now retrieve the same for a given dialect class.
---
 mlir/include/mlir/IR/MLIRContext.h |  5 +++++
 mlir/lib/IR/MLIRContext.cpp        | 19 +++++++++++++++++++
 2 files changed, 24 insertions(+)

diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d17bbac81655b5..ef8dab87f131a1 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -197,6 +197,11 @@ class MLIRContext {
   /// operations.
   ArrayRef<RegisteredOperationName> getRegisteredOperations();
 
+  /// Return a sorted array containing the information for registered operations
+  /// filtered by dialect name.
+  ArrayRef<RegisteredOperationName>
+  getRegisteredOperationsByDialect(StringRef dialectName);
+
   /// Return true if this operation name is registered in this context.
   bool isOperationRegistered(StringRef name);
 
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f05666fcde207b..320ffc7140a4c5 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -190,6 +190,9 @@ class MLIRContextImpl {
   /// and efficient `getRegisteredOperations` implementation.
   SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
 
+  /// This returns the number of registered operations for the given dialect.
+  size_t operationCount = 0;
+
   /// This is a list of dialects that are created referring to this context.
   /// The MLIRContext owns the objects. These need to be declared after the
   /// registered operations to ensure correct destruction order.
@@ -711,6 +714,20 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
   return impl->sortedRegisteredOperations;
 }
 
+/// Return information for registered operations by dialect.
+ArrayRef<RegisteredOperationName>
+MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
+  auto lowerBound =
+      std::lower_bound(impl->sortedRegisteredOperations.begin(),
+                       impl->sortedRegisteredOperations.end(), dialectName,
+                       [](auto &lhs, auto &rhs) {
+                         return lhs.getDialect().getNamespace().compare(
+                             rhs.getDialect().getNamespace());
+                       });
+  return ArrayRef(impl->sortedRegisteredOperations.data(),
+                  impl->operationCount);
+}
+
 bool MLIRContext::isOperationRegistered(StringRef name) {
   return RegisteredOperationName::lookup(name, this).has_value();
 }
@@ -976,6 +993,8 @@ void RegisteredOperationName::insert(
          "operation name registration must be successful");
 
   // Add emplaced operation name to the sorted operations container.
+  ctxImpl.operationCount += 1;
+
   RegisteredOperationName &value = emplaced.first->second;
   ctxImpl.sortedRegisteredOperations.insert(
       llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,



More information about the Mlir-commits mailing list