[Mlir-commits] [mlir] 11067d7 - [mlir] Optimize OperationName construction and usage

River Riddle llvmlistbot at llvm.org
Thu Jan 13 21:21:53 PST 2022


Author: River Riddle
Date: 2022-01-13T21:14:36-08:00
New Revision: 11067d711bca10ce740d0673073576bb81f50e06

URL: https://github.com/llvm/llvm-project/commit/11067d711bca10ce740d0673073576bb81f50e06
DIFF: https://github.com/llvm/llvm-project/commit/11067d711bca10ce740d0673073576bb81f50e06.diff

LOG: [mlir] Optimize OperationName construction and usage

When constructing an OperationName, the overwhelming majority of
cases are from registered operations. This revision adds a non-locked
lookup into the currently registered operations, which prevents locking
in the common case. This revision also optimizes several uses of
RegisteredOperationName that expect the operation to be registered,
e.g. such as in OpBuilder.

These changes provides a reasonable speedup (5-10%) in some
compilations, especially on platforms where locking is expensive.

Differential Revision: https://reviews.llvm.org/D117187

Added: 
    

Modified: 
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
    mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/OperationSupport.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 84ec1a55d2383..cf4cdb9203e53 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -406,22 +406,27 @@ class OpBuilder : public Builder {
 
 private:
   /// Helper for sanity checking preconditions for create* methods below.
-  void checkHasRegisteredInfo(const OperationName &name) {
-    if (LLVM_UNLIKELY(!name.isRegistered()))
+  template <typename OpT>
+  RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
+    Optional<RegisteredOperationName> opName =
+        RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
+    if (LLVM_UNLIKELY(!opName)) {
       llvm::report_fatal_error(
-          "Building op `" + name.getStringRef() +
+          "Building op `" + OpT::getOperationName() +
           "` but it isn't registered in this MLIRContext: the dialect may not "
           "be loaded or this operation isn't registered by the dialect. See "
           "also https://mlir.llvm.org/getting_started/Faq/"
           "#registered-loaded-dependent-whats-up-with-dialects-management");
+    }
+    return *opName;
   }
 
 public:
   /// Create an operation of specific op type at the current insertion point.
   template <typename OpTy, typename... Args>
   OpTy create(Location location, Args &&...args) {
-    OperationState state(location, OpTy::getOperationName());
-    checkHasRegisteredInfo(state.name);
+    OperationState state(location,
+                         getCheckRegisteredInfo<OpTy>(location.getContext()));
     OpTy::build(*this, state, std::forward<Args>(args)...);
     auto *op = createOperation(state);
     auto result = dyn_cast<OpTy>(op);
@@ -437,8 +442,8 @@ class OpBuilder : public Builder {
                     Args &&...args) {
     // Create the operation without using 'createOperation' as we don't want to
     // insert it yet.
-    OperationState state(location, OpTy::getOperationName());
-    checkHasRegisteredInfo(state.name);
+    OperationState state(location,
+                         getCheckRegisteredInfo<OpTy>(location.getContext()));
     OpTy::build(*this, state, std::forward<Args>(args)...);
     Operation *op = Operation::create(state);
 

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 27c4a620f53b5..138bbf7967adc 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -231,9 +231,7 @@ class RegisteredOperationName : public OperationName {
   /// Lookup the registered operation information for the given operation.
   /// Returns None if the operation isn't registered.
   static Optional<RegisteredOperationName> lookup(StringRef name,
-                                                  MLIRContext *ctx) {
-    return OperationName(name, ctx).getRegisteredInfo();
-  }
+                                                  MLIRContext *ctx);
 
   /// Register a new operation in a Dialect object.
   /// This constructor is used by Dialect objects when they register the list of
@@ -582,9 +580,12 @@ struct OperationState {
 
 public:
   OperationState(Location location, StringRef name);
-
   OperationState(Location location, OperationName name);
 
+  OperationState(Location location, OperationName name, ValueRange operands,
+                 TypeRange types, ArrayRef<NamedAttribute> attributes,
+                 BlockRange successors = {},
+                 MutableArrayRef<std::unique_ptr<Region>> regions = {});
   OperationState(Location location, StringRef name, ValueRange operands,
                  TypeRange types, ArrayRef<NamedAttribute> attributes,
                  BlockRange successors = {},

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index dc698a1724a8d..974d7d1106c20 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -1406,9 +1406,9 @@ static Operation *widenOp(Operation *op, VectorizationState &state) {
   // name that works both in scalar mode and vector mode.
   // TODO: Is it worth considering an Operation.clone operation which
   // changes the type so we can promote an Operation with less boilerplate?
-  OperationState vecOpState(op->getLoc(), op->getName().getStringRef(),
-                            vectorOperands, vectorTypes, op->getAttrs(),
-                            /*successors=*/{}, /*regions=*/{});
+  OperationState vecOpState(op->getLoc(), op->getName(), vectorOperands,
+                            vectorTypes, op->getAttrs(), /*successors=*/{},
+                            /*regions=*/{});
   Operation *vecOp = state.builder.createOperation(vecOpState);
   state.registerOpVectorReplacement(op, vecOp);
   return vecOp;

diff  --git a/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp
index de5b2fdcfcebb..319a6ab6b710e 100644
--- a/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp
@@ -70,8 +70,7 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
                                               Operation *op,
                                               ArrayRef<Value> operands,
                                               ArrayRef<Type> resultTypes) {
-  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
-                     op->getAttrs());
+  OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs());
   return builder.createOperation(res);
 }
 

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f4ffc297de994..a144769a3fd3b 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -182,7 +182,7 @@ class MLIRContextImpl {
   llvm::StringMap<OperationName::Impl> operations;
 
   /// A vector of operation info specifically for registered operations.
-  SmallVector<RegisteredOperationName> registeredOperations;
+  llvm::StringMap<RegisteredOperationName> registeredOperations;
 
   /// A mutex used when accessing operation information.
   llvm::sys::SmartRWMutex<true> operationInfoMutex;
@@ -576,8 +576,9 @@ std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
   // We just have the operations in a non-deterministic hash table order. Dump
   // into a temporary array, then sort it by operation name to get a stable
   // ordering.
-  std::vector<RegisteredOperationName> result(
-      impl->registeredOperations.begin(), impl->registeredOperations.end());
+  auto unwrappedNames = llvm::make_second_range(impl->registeredOperations);
+  std::vector<RegisteredOperationName> result(unwrappedNames.begin(),
+                                              unwrappedNames.end());
   llvm::array_pod_sort(result.begin(), result.end(),
                        [](const RegisteredOperationName *lhs,
                           const RegisteredOperationName *rhs) {
@@ -589,7 +590,7 @@ std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
 }
 
 bool MLIRContext::isOperationRegistered(StringRef name) {
-  return OperationName(name, this).isRegistered();
+  return RegisteredOperationName::lookup(name, this).hasValue();
 }
 
 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
@@ -649,6 +650,15 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
   // Check for an existing name in read-only mode.
   bool isMultithreadingEnabled = context->isMultithreadingEnabled();
   if (isMultithreadingEnabled) {
+    // Check the registered info map first. In the overwhelmingly common case,
+    // the entry will be in here and it also removes the need to acquire any
+    // locks.
+    auto registeredIt = ctxImpl.registeredOperations.find(name);
+    if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
+      impl = registeredIt->second.impl;
+      return;
+    }
+
     llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
     auto it = ctxImpl.operations.find(name);
     if (it != ctxImpl.operations.end()) {
@@ -676,6 +686,15 @@ StringRef OperationName::getDialectNamespace() const {
 // RegisteredOperationName
 //===----------------------------------------------------------------------===//
 
+Optional<RegisteredOperationName>
+RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
+  auto &impl = ctx->getImpl();
+  auto it = impl.registeredOperations.find(name);
+  if (it != impl.registeredOperations.end())
+    return it->getValue();
+  return llvm::None;
+}
+
 ParseResult
 RegisteredOperationName::parseAssembly(OpAsmParser &parser,
                                        OperationState &result) const {
@@ -717,7 +736,8 @@ void RegisteredOperationName::insert(
                  << "' is already registered.\n";
     abort();
   }
-  ctxImpl.registeredOperations.push_back(RegisteredOperationName(&impl));
+  ctxImpl.registeredOperations.try_emplace(name,
+                                           RegisteredOperationName(&impl));
 
   // Update the registered info for this operation.
   impl.dialect = &dialect;

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 747d27ad1b8d0..0cd63013d83f6 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -170,12 +170,12 @@ OperationState::OperationState(Location location, StringRef name)
 OperationState::OperationState(Location location, OperationName name)
     : location(location), name(name) {}
 
-OperationState::OperationState(Location location, StringRef name,
+OperationState::OperationState(Location location, OperationName name,
                                ValueRange operands, TypeRange types,
                                ArrayRef<NamedAttribute> attributes,
                                BlockRange successors,
                                MutableArrayRef<std::unique_ptr<Region>> regions)
-    : location(location), name(name, location->getContext()),
+    : location(location), name(name),
       operands(operands.begin(), operands.end()),
       types(types.begin(), types.end()),
       attributes(attributes.begin(), attributes.end()),
@@ -183,6 +183,13 @@ OperationState::OperationState(Location location, StringRef name,
   for (std::unique_ptr<Region> &r : regions)
     this->regions.push_back(std::move(r));
 }
+OperationState::OperationState(Location location, StringRef name,
+                               ValueRange operands, TypeRange types,
+                               ArrayRef<NamedAttribute> attributes,
+                               BlockRange successors,
+                               MutableArrayRef<std::unique_ptr<Region>> regions)
+    : OperationState(location, OperationName(name, location.getContext()),
+                     operands, types, attributes, successors, regions) {}
 
 void OperationState::addOperands(ValueRange newOperands) {
   operands.append(newOperands.begin(), newOperands.end());


        


More information about the Mlir-commits mailing list