[Mlir-commits] [mlir] 82c6eee - [MLIR] Add a second map for registered OperationName in MLIRContext (NFC) (#87170)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 31 12:28:09 PDT 2024
Author: Mehdi Amini
Date: 2024-03-31T21:28:05+02:00
New Revision: 82c6eeed08b1c8267f6e92d594c910fe57a9775e
URL: https://github.com/llvm/llvm-project/commit/82c6eeed08b1c8267f6e92d594c910fe57a9775e
DIFF: https://github.com/llvm/llvm-project/commit/82c6eeed08b1c8267f6e92d594c910fe57a9775e.diff
LOG: [MLIR] Add a second map for registered OperationName in MLIRContext (NFC) (#87170)
This speeds up registered op creation by 10-11% by allowing lookup by
TypeID instead of StringRef.
This can break your build/tests at runtime with an error that you're creating
an unregistered operation that you have registered. If so you are likely using
a class inheriting from the "real" operation. See for example in this patch the
case of:
class ConstantIndexOp : public arith::ConstantOp {
If one is using `builder.create<ConstantIndexOp>()` they actually create an
`arith.constant` operation, but the builder will fetch the TypeID for
the `ConstantIndexOp` class which does not correspond to any registered
operation. To fix it the `ConstantIndexOp` class got this addition:
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/Arith.h
mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/MLIRContext.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 971c78f4a86a75..00cdb13feb29bb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -53,6 +53,7 @@ namespace arith {
class ConstantIntOp : public arith::ConstantOp {
public:
using arith::ConstantOp::ConstantOp;
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant int op that produces an integer of the specified width.
static void build(OpBuilder &builder, OperationState &result, int64_t value,
@@ -74,6 +75,7 @@ class ConstantIntOp : public arith::ConstantOp {
class ConstantFloatOp : public arith::ConstantOp {
public:
using arith::ConstantOp::ConstantOp;
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant float op that produces a float of the specified type.
static void build(OpBuilder &builder, OperationState &result,
@@ -90,7 +92,7 @@ class ConstantFloatOp : public arith::ConstantOp {
class ConstantIndexOp : public arith::ConstantOp {
public:
using arith::ConstantOp::ConstantOp;
-
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant int op that produces an index.
static void build(OpBuilder &builder, OperationState &result, int64_t value);
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index db27f2c6fc49b7..128eacdbe6ab7a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -252,21 +252,21 @@ class TransformDialectExtension
template <typename OpTy>
void TransformDialect::addOperationIfNotRegistered() {
- StringRef name = OpTy::getOperationName();
std::optional<RegisteredOperationName> opName =
- RegisteredOperationName::lookup(name, getContext());
+ RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext());
if (!opName) {
addOperations<OpTy>();
#ifndef NDEBUG
+ StringRef name = OpTy::getOperationName();
detail::checkImplementsTransformOpInterface(name, getContext());
#endif // NDEBUG
return;
}
- if (opName->getTypeID() == TypeID::get<OpTy>())
+ if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
return;
- reportDuplicateOpRegistration(name);
+ reportDuplicateOpRegistration(OpTy::getOperationName());
}
template <typename Type>
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 43b6d2b3841690..3beade017d1ab9 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -490,7 +490,7 @@ class OpBuilder : public Builder {
template <typename OpT>
RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
std::optional<RegisteredOperationName> opName =
- RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
+ RegisteredOperationName::lookup(TypeID::get<OpT>(), ctx);
if (LLVM_UNLIKELY(!opName)) {
llvm::report_fatal_error(
"Building op `" + OpT::getOperationName() +
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index bd68c27445744e..c177ae3594d11f 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1729,8 +1729,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
template <typename... Models>
static void attachInterface(MLIRContext &context) {
std::optional<RegisteredOperationName> info =
- RegisteredOperationName::lookup(ConcreteType::getOperationName(),
- &context);
+ RegisteredOperationName::lookup(TypeID::get<ConcreteType>(), &context);
if (!info)
llvm::report_fatal_error(
"Attempting to attach an interface to an unregistered operation " +
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index f2aa6cee840308..90e63ff8fcb38f 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -676,6 +676,11 @@ class RegisteredOperationName : public OperationName {
static std::optional<RegisteredOperationName> lookup(StringRef name,
MLIRContext *ctx);
+ /// Lookup the registered operation information for the given operation.
+ /// Returns std::nullopt if the operation isn't registered.
+ static std::optional<RegisteredOperationName> lookup(TypeID typeID,
+ MLIRContext *ctx);
+
/// Register a new operation in a Dialect object.
/// This constructor is used by Dialect objects when they register the list
/// of operations they contain.
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index e1e6d14231d9f1..214b354c5347e9 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -183,7 +183,8 @@ class MLIRContextImpl {
llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
/// A vector of operation info specifically for registered operations.
- llvm::StringMap<RegisteredOperationName> registeredOperations;
+ llvm::DenseMap<TypeID, RegisteredOperationName> registeredOperations;
+ llvm::StringMap<RegisteredOperationName> registeredOperationsByName;
/// This is a sorted container of registered operations for a deterministic
/// and efficient `getRegisteredOperations` implementation.
@@ -780,8 +781,8 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
// Check the registered info map first. In the overwhelmingly common case,
// the entry will be in here and it also removes the need to acquire any
// locks.
- auto registeredIt = ctxImpl.registeredOperations.find(name);
- if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
+ auto registeredIt = ctxImpl.registeredOperationsByName.find(name);
+ if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) {
impl = registeredIt->second.impl;
return;
}
@@ -909,10 +910,19 @@ OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) {
//===----------------------------------------------------------------------===//
std::optional<RegisteredOperationName>
-RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
+RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) {
auto &impl = ctx->getImpl();
- auto it = impl.registeredOperations.find(name);
+ auto it = impl.registeredOperations.find(typeID);
if (it != impl.registeredOperations.end())
+ return it->second;
+ return std::nullopt;
+}
+
+std::optional<RegisteredOperationName>
+RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
+ auto &impl = ctx->getImpl();
+ auto it = impl.registeredOperationsByName.find(name);
+ if (it != impl.registeredOperationsByName.end())
return it->getValue();
return std::nullopt;
}
@@ -945,11 +955,16 @@ void RegisteredOperationName::insert(
// Update the registered info for this operation.
auto emplaced = ctxImpl.registeredOperations.try_emplace(
- name, RegisteredOperationName(impl));
+ impl->getTypeID(), RegisteredOperationName(impl));
assert(emplaced.second && "operation name registration must be successful");
+ auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace(
+ name, RegisteredOperationName(impl));
+ (void)emplacedByName;
+ assert(emplacedByName.second &&
+ "operation name registration must be successful");
// Add emplaced operation name to the sorted operations container.
- RegisteredOperationName &value = emplaced.first->getValue();
+ RegisteredOperationName &value = emplaced.first->second;
ctxImpl.sortedRegisteredOperations.insert(
llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
[](auto &lhs, auto &rhs) {
More information about the Mlir-commits
mailing list