[Mlir-commits] [mlir] [MLIR] Add a second map for registered OperationName in MLIRContext (NFC) (PR #87170)

Mehdi Amini llvmlistbot at llvm.org
Sat Mar 30 16:23:03 PDT 2024


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/87170

>From 10a366397c651979a1b458155dff972a62883fea Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 30 Mar 2024 14:05:49 -0700
Subject: [PATCH] [MLIR] Add a second map for registered OperationName in
 MLIRContext (NFC)

This speeds up registered op creation by 10-11% by allowing lookup by
TypeID instead of StringRef.
---
 mlir/include/mlir/Dialect/Arith/IR/Arith.h    |  4 ++-
 .../Dialect/Transform/IR/TransformDialect.h   |  8 ++---
 mlir/include/mlir/IR/Builders.h               |  2 +-
 mlir/include/mlir/IR/OpDefinition.h           |  3 +-
 mlir/include/mlir/IR/OperationSupport.h       |  5 ++++
 mlir/lib/IR/MLIRContext.cpp                   | 29 ++++++++++++++-----
 6 files changed, 36 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 971c78f4a86a75..00cdb13feb29bb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -53,6 +53,7 @@ namespace arith {
 class ConstantIntOp : public arith::ConstantOp {
 public:
   using arith::ConstantOp::ConstantOp;
+  static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
 
   /// Build a constant int op that produces an integer of the specified width.
   static void build(OpBuilder &builder, OperationState &result, int64_t value,
@@ -74,6 +75,7 @@ class ConstantIntOp : public arith::ConstantOp {
 class ConstantFloatOp : public arith::ConstantOp {
 public:
   using arith::ConstantOp::ConstantOp;
+  static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
 
   /// Build a constant float op that produces a float of the specified type.
   static void build(OpBuilder &builder, OperationState &result,
@@ -90,7 +92,7 @@ class ConstantFloatOp : public arith::ConstantOp {
 class ConstantIndexOp : public arith::ConstantOp {
 public:
   using arith::ConstantOp::ConstantOp;
-
+  static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
   /// Build a constant int op that produces an index.
   static void build(OpBuilder &builder, OperationState &result, int64_t value);
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index db27f2c6fc49b7..128eacdbe6ab7a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -252,21 +252,21 @@ class TransformDialectExtension
 
 template <typename OpTy>
 void TransformDialect::addOperationIfNotRegistered() {
-  StringRef name = OpTy::getOperationName();
   std::optional<RegisteredOperationName> opName =
-      RegisteredOperationName::lookup(name, getContext());
+      RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext());
   if (!opName) {
     addOperations<OpTy>();
 #ifndef NDEBUG
+    StringRef name = OpTy::getOperationName();
     detail::checkImplementsTransformOpInterface(name, getContext());
 #endif // NDEBUG
     return;
   }
 
-  if (opName->getTypeID() == TypeID::get<OpTy>())
+  if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
     return;
 
-  reportDuplicateOpRegistration(name);
+  reportDuplicateOpRegistration(OpTy::getOperationName());
 }
 
 template <typename Type>
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 43b6d2b3841690..3beade017d1ab9 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -490,7 +490,7 @@ class OpBuilder : public Builder {
   template <typename OpT>
   RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
     std::optional<RegisteredOperationName> opName =
-        RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
+        RegisteredOperationName::lookup(TypeID::get<OpT>(), ctx);
     if (LLVM_UNLIKELY(!opName)) {
       llvm::report_fatal_error(
           "Building op `" + OpT::getOperationName() +
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index bd68c27445744e..c177ae3594d11f 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1729,8 +1729,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
   template <typename... Models>
   static void attachInterface(MLIRContext &context) {
     std::optional<RegisteredOperationName> info =
-        RegisteredOperationName::lookup(ConcreteType::getOperationName(),
-                                        &context);
+        RegisteredOperationName::lookup(TypeID::get<ConcreteType>(), &context);
     if (!info)
       llvm::report_fatal_error(
           "Attempting to attach an interface to an unregistered operation " +
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index f2aa6cee840308..90e63ff8fcb38f 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -676,6 +676,11 @@ class RegisteredOperationName : public OperationName {
   static std::optional<RegisteredOperationName> lookup(StringRef name,
                                                        MLIRContext *ctx);
 
+  /// Lookup the registered operation information for the given operation.
+  /// Returns std::nullopt if the operation isn't registered.
+  static std::optional<RegisteredOperationName> lookup(TypeID typeID,
+                                                       MLIRContext *ctx);
+
   /// Register a new operation in a Dialect object.
   /// This constructor is used by Dialect objects when they register the list
   /// of operations they contain.
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index e1e6d14231d9f1..214b354c5347e9 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -183,7 +183,8 @@ class MLIRContextImpl {
   llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
 
   /// A vector of operation info specifically for registered operations.
-  llvm::StringMap<RegisteredOperationName> registeredOperations;
+  llvm::DenseMap<TypeID, RegisteredOperationName> registeredOperations;
+  llvm::StringMap<RegisteredOperationName> registeredOperationsByName;
 
   /// This is a sorted container of registered operations for a deterministic
   /// and efficient `getRegisteredOperations` implementation.
@@ -780,8 +781,8 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
     // Check the registered info map first. In the overwhelmingly common case,
     // the entry will be in here and it also removes the need to acquire any
     // locks.
-    auto registeredIt = ctxImpl.registeredOperations.find(name);
-    if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
+    auto registeredIt = ctxImpl.registeredOperationsByName.find(name);
+    if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) {
       impl = registeredIt->second.impl;
       return;
     }
@@ -909,10 +910,19 @@ OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) {
 //===----------------------------------------------------------------------===//
 
 std::optional<RegisteredOperationName>
-RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
+RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) {
   auto &impl = ctx->getImpl();
-  auto it = impl.registeredOperations.find(name);
+  auto it = impl.registeredOperations.find(typeID);
   if (it != impl.registeredOperations.end())
+    return it->second;
+  return std::nullopt;
+}
+
+std::optional<RegisteredOperationName>
+RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
+  auto &impl = ctx->getImpl();
+  auto it = impl.registeredOperationsByName.find(name);
+  if (it != impl.registeredOperationsByName.end())
     return it->getValue();
   return std::nullopt;
 }
@@ -945,11 +955,16 @@ void RegisteredOperationName::insert(
 
   // Update the registered info for this operation.
   auto emplaced = ctxImpl.registeredOperations.try_emplace(
-      name, RegisteredOperationName(impl));
+      impl->getTypeID(), RegisteredOperationName(impl));
   assert(emplaced.second && "operation name registration must be successful");
+  auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace(
+      name, RegisteredOperationName(impl));
+  (void)emplacedByName;
+  assert(emplacedByName.second &&
+         "operation name registration must be successful");
 
   // Add emplaced operation name to the sorted operations container.
-  RegisteredOperationName &value = emplaced.first->getValue();
+  RegisteredOperationName &value = emplaced.first->second;
   ctxImpl.sortedRegisteredOperations.insert(
       llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
                         [](auto &lhs, auto &rhs) {



More information about the Mlir-commits mailing list