[Mlir-commits] [mlir] [mlir] Method to iterate over registered operations for a given dialect class. (PR #112344)

Rajveer Singh Bharadwaj llvmlistbot at llvm.org
Wed Oct 16 03:32:30 PDT 2024


https://github.com/Rajveer100 updated https://github.com/llvm/llvm-project/pull/112344

>From 41ec7af3e1d842bbd003e7776940a1b2fe9426f1 Mon Sep 17 00:00:00 2001
From: Rajveer <rajveer.developer at icloud.com>
Date: Tue, 15 Oct 2024 16:26:08 +0530
Subject: [PATCH] [mlir] Method to iterate over registered operations for a
 given dialect class.

Part of #111591

Currently we have `MLIRContext::getRegisteredOperations` which returns all operations
for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect`
we can now retrieve the same for a given dialect class.
---
 mlir/include/mlir/IR/MLIRContext.h |  5 +++
 mlir/lib/IR/MLIRContext.cpp        | 60 +++++++++++++++++++++++++-----
 2 files changed, 55 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d17bbac81655b5..ef8dab87f131a1 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -197,6 +197,11 @@ class MLIRContext {
   /// operations.
   ArrayRef<RegisteredOperationName> getRegisteredOperations();
 
+  /// Return a sorted array containing the information for registered operations
+  /// filtered by dialect name.
+  ArrayRef<RegisteredOperationName>
+  getRegisteredOperationsByDialect(StringRef dialectName);
+
   /// Return true if this operation name is registered in this context.
   bool isOperationRegistered(StringRef name);
 
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f05666fcde207b..e02a323ac2387f 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -188,7 +188,15 @@ class MLIRContextImpl {
 
   /// This is a sorted container of registered operations for a deterministic
   /// and efficient `getRegisteredOperations` implementation.
-  SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
+  SmallVector<std::pair<StringRef, RegisteredOperationName>, 0>
+      sortedRegisteredOperations;
+
+  /// This stores the transformed operations when calling
+  /// `getRegisteredOperations`.
+  SmallVector<RegisteredOperationName, 0> transformedOperations;
+
+  /// This returns the number of registered operations for a given dialect.
+  llvm::DenseMap<StringRef, size_t> getCountByDialectName;
 
   /// This is a list of dialects that are created referring to this context.
   /// The MLIRContext owns the objects. These need to be declared after the
@@ -708,7 +716,33 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
 
 /// Return information about all registered operations.
 ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
-  return impl->sortedRegisteredOperations;
+  impl->transformedOperations.clear();
+
+  SmallVector<RegisteredOperationName, 0> operations;
+  std::transform(impl->sortedRegisteredOperations.begin(),
+                 impl->sortedRegisteredOperations.end(),
+                 std::back_inserter(impl->transformedOperations),
+                 [](const auto &t) { return t.second; });
+
+  return impl->transformedOperations;
+}
+
+/// Return information for registered operations by dialect.
+ArrayRef<RegisteredOperationName>
+MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
+  impl->transformedOperations.clear();
+
+  auto lowerBound = std::lower_bound(
+      impl->sortedRegisteredOperations.begin(),
+      impl->sortedRegisteredOperations.end(), std::make_pair(dialectName, ""),
+      [](auto &lhs, auto &rhs) { return lhs.first.compare(rhs.first); });
+  auto count = impl->getCountByDialectName[dialectName];
+
+  std::transform(lowerBound, lowerBound + count,
+                 std::back_inserter(impl->transformedOperations),
+                 [](const auto &t) { return t.second; });
+
+  return impl->transformedOperations;
 }
 
 bool MLIRContext::isOperationRegistered(StringRef name) {
@@ -976,14 +1010,20 @@ void RegisteredOperationName::insert(
          "operation name registration must be successful");
 
   // Add emplaced operation name to the sorted operations container.
-  RegisteredOperationName &value = emplaced.first->second;
-  ctxImpl.sortedRegisteredOperations.insert(
-      llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
-                        [](auto &lhs, auto &rhs) {
-                          return lhs.getIdentifier().compare(
-                              rhs.getIdentifier());
-                        }),
-      value);
+  StringRef dialectClass = impl->getDialect()->getNamespace();
+  ctxImpl.getCountByDialectName[dialectClass] += 1;
+
+  std::pair<StringRef, RegisteredOperationName> value = {
+      dialectClass, emplaced.first->second};
+
+  auto upperBound = llvm::upper_bound(
+      ctxImpl.sortedRegisteredOperations, value, [](auto &lhs, auto &rhs) {
+        if (lhs.first == rhs.first)
+          return lhs.second.getIdentifier().compare(rhs.second.getIdentifier());
+        return lhs.first.compare(rhs.first);
+      });
+
+  ctxImpl.sortedRegisteredOperations.insert(upperBound, value);
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list