[Mlir-commits] [mlir] edc6c0e - [mlir] Refactor AbstractOperation and OperationName

River Riddle llvmlistbot at llvm.org
Wed Nov 17 14:30:07 PST 2021


Author: River Riddle
Date: 2021-11-17T22:29:57Z
New Revision: edc6c0ecb9627c7c57fdb8e0ca8267295dd77bcd

URL: https://github.com/llvm/llvm-project/commit/edc6c0ecb9627c7c57fdb8e0ca8267295dd77bcd
DIFF: https://github.com/llvm/llvm-project/commit/edc6c0ecb9627c7c57fdb8e0ca8267295dd77bcd.diff

LOG: [mlir] Refactor AbstractOperation and OperationName

The current implementation is quite clunky; OperationName stores either an Identifier
or an AbstractOperation that corresponds to an operation. This has several problems:

* OperationNames created before and after an operation are registered are different
* Accessing the identifier name/dialect/etc. from an OperationName are overly branchy
  - they need to dyn_cast a PointerUnion to check the state

This commit refactors this such that we create a single information struct for every
operation name, even operations that aren't registered yet. When an OperationName is
created for an unregistered operation, we only populate the name field. When the
operation is registered, we populate the remaining fields. With this we now have two
new classes: OperationName and RegisteredOperationName. These both point to the
same underlying operation information struct, but only RegisteredOperationName can
assume that the operation is actually registered. This leads to a much cleaner API, and
we can also move some AbstractOperation functionality directly to OperationName.

Differential Revision: https://reviews.llvm.org/D114049

Added: 
    

Modified: 
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/Dialect.h
    mlir/include/mlir/IR/MLIRContext.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/Support/InterfaceSupport.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/CAPI/Interfaces/Interfaces.cpp
    mlir/lib/Dialect/PDL/IR/PDL.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/IR/Verifier.cpp
    mlir/lib/Parser/AsmParserState.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
    mlir/lib/Transforms/Canonicalizer.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index d0e12b7a1457d..31f223972accd 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -408,8 +408,8 @@ class OpBuilder : public Builder {
 
 private:
   /// Helper for sanity checking preconditions for create* methods below.
-  void checkHasAbstractOperation(const OperationName &name) {
-    if (LLVM_UNLIKELY(!name.getAbstractOperation()))
+  void checkHasRegisteredInfo(const OperationName &name) {
+    if (LLVM_UNLIKELY(!name.isRegistered()))
       llvm::report_fatal_error(
           "Building op `" + name.getStringRef() +
           "` but it isn't registered in this MLIRContext: the dialect may not "
@@ -423,7 +423,7 @@ class OpBuilder : public Builder {
   template <typename OpTy, typename... Args>
   OpTy create(Location location, Args &&...args) {
     OperationState state(location, OpTy::getOperationName());
-    checkHasAbstractOperation(state.name);
+    checkHasRegisteredInfo(state.name);
     OpTy::build(*this, state, std::forward<Args>(args)...);
     auto *op = createOperation(state);
     auto result = dyn_cast<OpTy>(op);
@@ -440,7 +440,7 @@ class OpBuilder : public Builder {
     // Create the operation without using 'createOperation' as we don't want to
     // insert it yet.
     OperationState state(location, OpTy::getOperationName());
-    checkHasAbstractOperation(state.name);
+    checkHasRegisteredInfo(state.name);
     OpTy::build(*this, state, std::forward<Args>(args)...);
     Operation *op = Operation::create(state);
 

diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 14114e379f78b..c668df03c4745 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -114,7 +114,7 @@ class Dialect {
 
   /// Return the hook to parse an operation registered to this dialect, if any.
   /// By default this will lookup for registered operations and return the
-  /// `parse()` method registered on the AbstractOperation. Dialects can
+  /// `parse()` method registered on the RegisteredOperationName. Dialects can
   /// override this behavior and handle unregistered operations as well.
   virtual Optional<ParseOpHook> getParseOperationHook(StringRef opName) const;
 
@@ -194,7 +194,7 @@ class Dialect {
   ///
   template <typename... Args> void addOperations() {
     (void)std::initializer_list<int>{
-        0, (AbstractOperation::insert<Args>(*this), 0)...};
+        0, (RegisteredOperationName::insert<Args>(*this), 0)...};
   }
 
   /// Register a set of type classes with this dialect.

diff  --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 6d45fbce7b48f..bc9a8a95e4a09 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -20,7 +20,6 @@ class ThreadPool;
 } // end namespace llvm
 
 namespace mlir {
-class AbstractOperation;
 class DebugActionManager;
 class DiagnosticEngine;
 class Dialect;
@@ -28,6 +27,7 @@ class DialectRegistry;
 class InFlightDiagnostic;
 class Location;
 class MLIRContextImpl;
+class RegisteredOperationName;
 class StorageUniquer;
 
 /// MLIRContext is the top-level object for a collection of MLIR operations. It
@@ -172,7 +172,7 @@ class MLIRContext {
   /// Return information about all registered operations.  This isn't very
   /// efficient: typically you should ask the operations about their properties
   /// directly.
-  std::vector<AbstractOperation *> getRegisteredOperations();
+  std::vector<RegisteredOperationName> getRegisteredOperations();
 
   /// Return true if this operation name is registered in this context.
   bool isOperationRegistered(StringRef name);

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index f2b16779da3e6..5b4a936611895 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -191,7 +191,7 @@ class OpState {
   Operation *state;
 
   /// Allow access to internal hook implementation methods.
-  friend AbstractOperation;
+  friend RegisteredOperationName;
 };
 
 // Allow comparing operators.
@@ -1585,8 +1585,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
 
   /// Return true if this "op class" can match against the specified operation.
   static bool classof(Operation *op) {
-    if (auto *abstractOp = op->getAbstractOperation())
-      return TypeID::get<ConcreteType>() == abstractOp->typeID;
+    if (auto info = op->getRegisteredInfo())
+      return TypeID::get<ConcreteType>() == info->getTypeID();
 #ifndef NDEBUG
     if (op->getName().getStringRef() == ConcreteType::getOperationName())
       llvm::report_fatal_error(
@@ -1628,13 +1628,13 @@ class Op : public OpState, public Traits<ConcreteType>... {
   /// for the concrete operation.
   template <typename... Models>
   static void attachInterface(MLIRContext &context) {
-    AbstractOperation *abstract = AbstractOperation::lookupMutable(
+    Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
         ConcreteType::getOperationName(), &context);
-    if (!abstract)
+    if (!info)
       llvm::report_fatal_error(
           "Attempting to attach an interface to an unregistered operation " +
           ConcreteType::getOperationName() + ".");
-    abstract->interfaceMap.insert<Models...>();
+    info->attachInterface<Models...>();
   }
 
 private:
@@ -1673,10 +1673,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
     return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
   }
 
-  /// Return the internal implementations of each of the AbstractOperation
+  /// Return the internal implementations of each of the OperationName
   /// hooks.
-  /// Implementation of `FoldHookFn` AbstractOperation hook.
-  static AbstractOperation::FoldHookFn getFoldHookFn() {
+  /// Implementation of `FoldHookFn` OperationName hook.
+  static OperationName::FoldHookFn getFoldHookFn() {
     return getFoldHookFnImpl<ConcreteType>();
   }
   /// The internal implementation of `getFoldHookFn` above that is invoked if
@@ -1685,7 +1685,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
   static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
                                           Traits<ConcreteOpT>...>::value &&
                               detect_has_single_result_fold<ConcreteOpT>::value,
-                          AbstractOperation::FoldHookFn>
+                          OperationName::FoldHookFn>
   getFoldHookFnImpl() {
     return [](Operation *op, ArrayRef<Attribute> operands,
               SmallVectorImpl<OpFoldResult> &results) {
@@ -1698,7 +1698,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
   static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
                                            Traits<ConcreteOpT>...>::value &&
                               detect_has_fold<ConcreteOpT>::value,
-                          AbstractOperation::FoldHookFn>
+                          OperationName::FoldHookFn>
   getFoldHookFnImpl() {
     return [](Operation *op, ArrayRef<Attribute> operands,
               SmallVectorImpl<OpFoldResult> &results) {
@@ -1710,7 +1710,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
   template <typename ConcreteOpT>
   static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
                               !detect_has_fold<ConcreteOpT>::value,
-                          AbstractOperation::FoldHookFn>
+                          OperationName::FoldHookFn>
   getFoldHookFnImpl() {
     return [](Operation *op, ArrayRef<Attribute> operands,
               SmallVectorImpl<OpFoldResult> &results) {
@@ -1754,29 +1754,29 @@ class Op : public OpState, public Traits<ConcreteType>... {
     return result;
   }
 
-  /// Implementation of `GetCanonicalizationPatternsFn` AbstractOperation hook.
-  static AbstractOperation::GetCanonicalizationPatternsFn
+  /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
+  static OperationName::GetCanonicalizationPatternsFn
   getGetCanonicalizationPatternsFn() {
     return &ConcreteType::getCanonicalizationPatterns;
   }
   /// Implementation of `GetHasTraitFn`
-  static AbstractOperation::HasTraitFn getHasTraitFn() {
+  static OperationName::HasTraitFn getHasTraitFn() {
     return
         [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
   }
-  /// Implementation of `ParseAssemblyFn` AbstractOperation hook.
-  static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
+  /// Implementation of `ParseAssemblyFn` OperationName hook.
+  static OperationName::ParseAssemblyFn getParseAssemblyFn() {
     return &ConcreteType::parse;
   }
-  /// Implementation of `PrintAssemblyFn` AbstractOperation hook.
-  static AbstractOperation::PrintAssemblyFn getPrintAssemblyFn() {
+  /// Implementation of `PrintAssemblyFn` OperationName hook.
+  static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
     return getPrintAssemblyFnImpl<ConcreteType>();
   }
   /// The internal implementation of `getPrintAssemblyFn` that is invoked when
   /// the concrete operation does not define a `print` method.
   template <typename ConcreteOpT>
   static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
-                          AbstractOperation::PrintAssemblyFn>
+                          OperationName::PrintAssemblyFn>
   getPrintAssemblyFnImpl() {
     return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
       return OpState::print(op, printer);
@@ -1786,7 +1786,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
   /// the concrete operation defines a `print` method.
   template <typename ConcreteOpT>
   static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
-                          AbstractOperation::PrintAssemblyFn>
+                          OperationName::PrintAssemblyFn>
   getPrintAssemblyFnImpl() {
     return &printAssembly;
   }
@@ -1795,8 +1795,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
     OpState::printOpName(op, p, defaultDialect);
     return cast<ConcreteType>(op).print(p);
   }
-  /// Implementation of `VerifyInvariantsFn` AbstractOperation hook.
-  static AbstractOperation::VerifyInvariantsFn getVerifyInvariantsFn() {
+  /// Implementation of `VerifyInvariantsFn` OperationName hook.
+  static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
     return &verifyInvariants;
   }
 
@@ -1816,7 +1816,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
   }
 
   /// Allow access to internal implementation methods.
-  friend AbstractOperation;
+  friend RegisteredOperationName;
 };
 
 /// This class represents the base of an operation interface. See the definition
@@ -1836,22 +1836,22 @@ class OpInterface
 protected:
   /// Returns the impl interface instance for the given operation.
   static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
-    // Access the raw interface from the abstract operation.
-    auto *abstractOp = op->getAbstractOperation();
-    if (abstractOp) {
-      if (auto *opIface = abstractOp->getInterface<ConcreteType>())
+    OperationName name = op->getName();
+
+    // Access the raw interface from the operation info.
+    if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) {
+      if (auto *opIface = rInfo->getInterface<ConcreteType>())
         return opIface;
       // Fallback to the dialect to provide it with a chance to implement this
       // interface for this operation.
-      return abstractOp->dialect.getRegisteredInterfaceForOp<ConcreteType>(
+      return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
           op->getName());
     }
     // Fallback to the dialect to provide it with a chance to implement this
     // interface for this operation.
-    Dialect *dialect = op->getName().getDialect();
-    return dialect ? dialect->getRegisteredInterfaceForOp<ConcreteType>(
-                         op->getName())
-                   : nullptr;
+    if (Dialect *dialect = name.getDialect())
+      return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
+    return nullptr;
   }
 
   /// Allow access to `getInterfaceFor`.

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 112284a8f10dc..243870a7b030d 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -57,14 +57,14 @@ class alignas(8) Operation final
   OperationName getName() { return name; }
 
   /// If this operation has a registered operation description, return it.
-  /// Otherwise return null.
-  const AbstractOperation *getAbstractOperation() {
-    return getName().getAbstractOperation();
+  /// Otherwise return None.
+  Optional<RegisteredOperationName> getRegisteredInfo() {
+    return getName().getRegisteredInfo();
   }
 
   /// Returns true if this operation has a registered operation description,
   /// otherwise false.
-  bool isRegistered() { return getAbstractOperation(); }
+  bool isRegistered() { return getName().isRegistered(); }
 
   /// Remove this operation from its parent block and delete it.
   void erase();
@@ -468,16 +468,14 @@ class alignas(8) Operation final
   /// Returns true if the operation was registered with a particular trait, e.g.
   /// hasTrait<OperandsAreSignlessIntegerLike>().
   template <template <typename T> class Trait> bool hasTrait() {
-    const AbstractOperation *abstractOp = getAbstractOperation();
-    return abstractOp ? abstractOp->hasTrait<Trait>() : false;
+    return name.hasTrait<Trait>();
   }
 
-  /// Returns true if the operation is *might* have the provided trait. This
+  /// Returns true if the operation *might* have the provided trait. This
   /// means that either the operation is unregistered, or it was registered with
   /// the provide trait.
   template <template <typename T> class Trait> bool mightHaveTrait() {
-    const AbstractOperation *abstractOp = getAbstractOperation();
-    return abstractOp ? abstractOp->hasTrait<Trait>() : true;
+    return name.mightHaveTrait<Trait>();
   }
 
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index a18d2afbd7b4a..ba81329867b6e 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -59,14 +59,10 @@ class RewritePatternSet;
 using OwningRewritePatternList = RewritePatternSet;
 
 //===----------------------------------------------------------------------===//
-// AbstractOperation
+// OperationName
 //===----------------------------------------------------------------------===//
 
-/// This is a "type erased" representation of a registered operation.  This
-/// should only be used by things like the AsmPrinter and other things that need
-/// to be parameterized by generic operation hooks.  Most user code should use
-/// the concrete operation types.
-class AbstractOperation {
+class OperationName {
 public:
   using GetCanonicalizationPatternsFn =
       llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
@@ -80,32 +76,211 @@ class AbstractOperation {
   using VerifyInvariantsFn =
       llvm::unique_function<LogicalResult(Operation *) const>;
 
-  /// This is the name of the operation.
-  const StringAttr name;
+protected:
+  /// This class represents a type erased version of an operation. It contains
+  /// all of the components necessary for opaquely interacting with an
+  /// operation. If the operation is not registered, some of these components
+  /// may not be populated.
+  struct Impl {
+    Impl(StringAttr name)
+        : name(name), dialect(nullptr), interfaceMap(llvm::None) {}
+
+    /// The name of the operation.
+    StringAttr name;
+
+    //===------------------------------------------------------------------===//
+    // Registered Operation Info
+
+    /// The following fields are only populated when the operation is
+    /// registered.
+
+    /// Returns true if the operation has been registered, i.e. if the
+    /// registration info has been populated.
+    bool isRegistered() const { return dialect; }
+
+    /// This is the dialect that this operation belongs to.
+    Dialect *dialect;
+
+    /// The unique identifier of the derived Op class.
+    TypeID typeID;
+
+    /// A map of interfaces that were registered to this operation.
+    detail::InterfaceMap interfaceMap;
+
+    /// Internal callback hooks provided by the op implementation.
+    FoldHookFn foldHookFn;
+    GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+    HasTraitFn hasTraitFn;
+    ParseAssemblyFn parseAssemblyFn;
+    PrintAssemblyFn printAssemblyFn;
+    VerifyInvariantsFn verifyInvariantsFn;
+
+    /// A list of attribute names registered to this operation in StringAttr
+    /// form. This allows for operation classes to use StringAttr for attribute
+    /// lookup/creation/etc., as opposed to raw strings.
+    ArrayRef<StringAttr> attributeNames;
+  };
+
+public:
+  OperationName(StringRef name, MLIRContext *context);
+
+  /// Return if this operation is registered.
+  bool isRegistered() const { return impl->isRegistered(); }
+
+  /// If this operation is registered, returns the registered information, None
+  /// otherwise.
+  Optional<RegisteredOperationName> getRegisteredInfo() const;
+
+  /// Returns true if the operation was registered with a particular trait, e.g.
+  /// hasTrait<OperandsAreSignlessIntegerLike>(). Returns false if the operation
+  /// is unregistered.
+  template <template <typename T> class Trait> bool hasTrait() const {
+    return hasTrait(TypeID::get<Trait>());
+  }
+  bool hasTrait(TypeID traitID) const {
+    return isRegistered() && impl->hasTraitFn(traitID);
+  }
+
+  /// Returns true if the operation *might* have the provided trait. This
+  /// means that either the operation is unregistered, or it was registered with
+  /// the provide trait.
+  template <template <typename T> class Trait> bool mightHaveTrait() const {
+    return mightHaveTrait(TypeID::get<Trait>());
+  }
+  bool mightHaveTrait(TypeID traitID) const {
+    return !isRegistered() || impl->hasTraitFn(traitID);
+  }
+
+  /// Returns an instance of the concept object for the given interface if it
+  /// was registered to this operation, null otherwise. This should not be used
+  /// directly.
+  template <typename T> typename T::Concept *getInterface() const {
+    return impl->interfaceMap.lookup<T>();
+  }
+
+  /// Returns true if this operation has the given interface registered to it.
+  template <typename T> bool hasInterface() const {
+    return hasInterface(TypeID::get<T>());
+  }
+  bool hasInterface(TypeID interfaceID) const {
+    return impl->interfaceMap.contains(interfaceID);
+  }
+
+  /// Return the dialect this operation is registered to if the dialect is
+  /// loaded in the context, or nullptr if the dialect isn't loaded.
+  Dialect *getDialect() const {
+    return isRegistered() ? impl->dialect : impl->name.getReferencedDialect();
+  }
+
+  /// Return the name of the dialect this operation is registered to.
+  StringRef getDialectNamespace() const;
+
+  /// Return the operation name with dialect name stripped, if it has one.
+  StringRef stripDialect() const { return getStringRef().split('.').second; }
+
+  /// Return the name of this operation. This always succeeds.
+  StringRef getStringRef() const { return getIdentifier(); }
+
+  /// Return the name of this operation as a StringAttr.
+  StringAttr getIdentifier() const { return impl->name; }
+
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  /// Represent the operation name as an opaque pointer. (Used to support
+  /// PointerLikeTypeTraits).
+  void *getAsOpaquePointer() const { return const_cast<Impl *>(impl); }
+  static OperationName getFromOpaquePointer(const void *pointer) {
+    return OperationName(
+        const_cast<Impl *>(reinterpret_cast<const Impl *>(pointer)));
+  }
+
+  bool operator==(const OperationName &rhs) const { return impl == rhs.impl; }
+  bool operator!=(const OperationName &rhs) const { return !(*this == rhs); }
 
-  /// This is the dialect that this operation belongs to.
-  Dialect &dialect;
+protected:
+  OperationName(Impl *impl) : impl(impl) {}
 
-  /// The unique identifier of the derived Op class.
-  TypeID typeID;
+  /// The internal implementation of the operation name.
+  Impl *impl;
+
+  /// Allow access to the Impl struct.
+  friend MLIRContextImpl;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, OperationName info) {
+  info.print(os);
+  return os;
+}
+
+// Make operation names hashable.
+inline llvm::hash_code hash_value(OperationName arg) {
+  return llvm::hash_value(arg.getAsOpaquePointer());
+}
+
+//===----------------------------------------------------------------------===//
+// RegisteredOperationName
+//===----------------------------------------------------------------------===//
+
+/// This is a "type erased" representation of a registered operation. This
+/// should only be used by things like the AsmPrinter and other things that need
+/// to be parameterized by generic operation hooks. Most user code should use
+/// the concrete operation types.
+class RegisteredOperationName : public OperationName {
+public:
+  /// Lookup the registered operation information for the given operation.
+  /// Returns None if the operation isn't registered.
+  static Optional<RegisteredOperationName> lookup(StringRef name,
+                                                  MLIRContext *ctx) {
+    return OperationName(name, ctx).getRegisteredInfo();
+  }
+
+  /// Register a new operation in a Dialect object.
+  /// This constructor is used by Dialect objects when they register the list of
+  /// operations they contain.
+  template <typename T>
+  static void insert(Dialect &dialect) {
+    insert(T::getOperationName(), dialect, TypeID::get<T>(),
+           T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
+           T::getVerifyInvariantsFn(), T::getFoldHookFn(),
+           T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
+           T::getHasTraitFn(), T::getAttributeNames());
+  }
+  /// The use of this method is in general discouraged in favor of
+  /// 'insert<CustomOp>(dialect)'.
+  static void
+  insert(StringRef name, Dialect &dialect, TypeID typeID,
+         ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+         VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+         GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+         detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
+         ArrayRef<StringRef> attrNames);
+
+  /// Return the dialect this operation is registered to.
+  Dialect &getDialect() const { return *impl->dialect; }
+
+  /// Return the unique identifier of the derived Op class.
+  TypeID getTypeID() const { return impl->typeID; }
 
   /// Use the specified object to parse this ops custom assembly format.
   ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
 
   /// Return the static hook for parsing this operation assembly.
-  const ParseAssemblyFn &getParseAssemblyFn() const { return parseAssemblyFn; }
+  const ParseAssemblyFn &getParseAssemblyFn() const {
+    return impl->parseAssemblyFn;
+  }
 
   /// This hook implements the AsmPrinter for this operation.
   void printAssembly(Operation *op, OpAsmPrinter &p,
                      StringRef defaultDialect) const {
-    return printAssemblyFn(op, p, defaultDialect);
+    return impl->printAssemblyFn(op, p, defaultDialect);
   }
 
   /// This hook implements the verifier for this operation.  It should emits an
   /// error message and returns failure if a problem is detected, or returns
   /// success if everything is ok.
   LogicalResult verifyInvariants(Operation *op) const {
-    return verifyInvariantsFn(op);
+    return impl->verifyInvariantsFn(op);
   }
 
   /// This hook implements a generalized folder for this operation.  Operations
@@ -129,66 +304,30 @@ class AbstractOperation {
   /// generalized constant folding.
   LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
                          SmallVectorImpl<OpFoldResult> &results) const {
-    return foldHookFn(op, operands, results);
+    return impl->foldHookFn(op, operands, results);
   }
 
   /// This hook returns any canonicalization pattern rewrites that the operation
   /// supports, for use by the canonicalization pass.
   void getCanonicalizationPatterns(RewritePatternSet &results,
                                    MLIRContext *context) const {
-    return getCanonicalizationPatternsFn(results, context);
+    return impl->getCanonicalizationPatternsFn(results, context);
   }
 
-  /// Returns an instance of the concept object for the given interface if it
-  /// was registered to this operation, null otherwise. This should not be used
-  /// directly.
-  template <typename T>
-  typename T::Concept *getInterface() const {
-    return interfaceMap.lookup<T>();
-  }
-
-  /// Returns true if this operation has the given interface registered to it.
-  bool hasInterface(TypeID interfaceID) const {
-    return interfaceMap.contains(interfaceID);
+  /// Attach the given models as implementations of the corresponding interfaces
+  /// for the concrete operation.
+  template <typename... Models>
+  void attachInterface() {
+    impl->interfaceMap.insert<Models...>();
   }
 
   /// Returns true if the operation has a particular trait.
-  template <template <typename T> class Trait>
-  bool hasTrait() const {
-    return hasTraitFn(TypeID::get<Trait>());
+  template <template <typename T> class Trait> bool hasTrait() const {
+    return hasTrait(TypeID::get<Trait>());
   }
 
   /// Returns true if the operation has a particular trait.
-  bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
-
-  /// Look up the specified operation in the specified MLIRContext and return a
-  /// pointer to it if present.  Otherwise, return a null pointer.
-  static const AbstractOperation *lookup(StringRef opName,
-                                         MLIRContext *context) {
-    return lookupMutable(opName, context);
-  }
-
-  /// This constructor is used by Dialect objects when they register the list of
-  /// operations they contain.
-  template <typename T>
-  static void insert(Dialect &dialect) {
-    insert(T::getOperationName(), dialect, TypeID::get<T>(),
-           T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
-           T::getVerifyInvariantsFn(), T::getFoldHookFn(),
-           T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
-           T::getHasTraitFn(), T::getAttributeNames());
-  }
-
-  /// Register a new operation in a Dialect object.
-  /// The use of this method is in general discouraged in favor of
-  /// 'insert<CustomOp>(dialect)'.
-  static void
-  insert(StringRef name, Dialect &dialect, TypeID typeID,
-         ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
-         VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
-         GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
-         detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
-         ArrayRef<StringRef> attrNames);
+  bool hasTrait(TypeID traitID) const { return impl->hasTraitFn(traitID); }
 
   /// Return the list of cached attribute names registered to this operation.
   /// The order of attributes cached here is unique to each type of operation,
@@ -206,44 +345,30 @@ class AbstractOperation {
   /// greatly simplifying the cost and complexity of attribute usage produced by
   /// the generator.
   ///
-  ArrayRef<StringAttr> getAttributeNames() const { return attributeNames; }
+  ArrayRef<StringAttr> getAttributeNames() const {
+    return impl->attributeNames;
+  }
+
+  /// Represent the operation name as an opaque pointer. (Used to support
+  /// PointerLikeTypeTraits).
+  static RegisteredOperationName getFromOpaquePointer(const void *pointer) {
+    return RegisteredOperationName(
+        const_cast<Impl *>(reinterpret_cast<const Impl *>(pointer)));
+  }
 
 private:
-  AbstractOperation(StringRef name, Dialect &dialect, TypeID typeID,
-                    ParseAssemblyFn &&parseAssembly,
-                    PrintAssemblyFn &&printAssembly,
-                    VerifyInvariantsFn &&verifyInvariants,
-                    FoldHookFn &&foldHook,
-                    GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
-                    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
-                    ArrayRef<StringAttr> attrNames);
-
-  /// Give Op access to lookupMutable.
-  template <typename ConcreteType, template <typename T> class... Traits>
-  friend class Op;
-
-  /// Look up the specified operation in the specified MLIRContext and return a
-  /// pointer to it if present.  Otherwise, return a null pointer.
-  static AbstractOperation *lookupMutable(StringRef opName,
-                                          MLIRContext *context);
-
-  /// A map of interfaces that were registered to this operation.
-  detail::InterfaceMap interfaceMap;
-
-  /// Internal callback hooks provided by the op implementation.
-  FoldHookFn foldHookFn;
-  GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
-  HasTraitFn hasTraitFn;
-  ParseAssemblyFn parseAssemblyFn;
-  PrintAssemblyFn printAssemblyFn;
-  VerifyInvariantsFn verifyInvariantsFn;
-
-  /// A list of attribute names registered to this operation in identifier form.
-  /// This allows for operation classes to use identifiers for attribute
-  /// lookup/creation/etc., as opposed to strings.
-  ArrayRef<StringAttr> attributeNames;
+  RegisteredOperationName(Impl *impl) : OperationName(impl) {}
+
+  /// Allow access to the constructor.
+  friend OperationName;
 };
 
+inline Optional<RegisteredOperationName>
+OperationName::getRegisteredInfo() const {
+  return isRegistered() ? RegisteredOperationName(impl)
+                        : Optional<RegisteredOperationName>();
+}
+
 //===----------------------------------------------------------------------===//
 // Attribute Dictionary-Like Interface
 //===----------------------------------------------------------------------===//
@@ -435,76 +560,6 @@ class NamedAttrList {
   mutable llvm::PointerIntPair<Attribute, 1, bool> dictionarySorted;
 };
 
-//===----------------------------------------------------------------------===//
-// OperationName
-//===----------------------------------------------------------------------===//
-
-class OperationName {
-public:
-  using RepresentationUnion =
-      PointerUnion<StringAttr, const AbstractOperation *>;
-
-  OperationName(AbstractOperation *op) : representation(op) {}
-  OperationName(StringRef name, MLIRContext *context);
-
-  /// Return the name of the dialect this operation is registered to.
-  StringRef getDialectNamespace() const;
-
-  /// Return the dialect this operation is registered to if the dialect is
-  /// loaded in the context, or nullptr if the dialect isn't loaded.
-  Dialect *getDialect() const {
-    if (const auto *abstractOp = getAbstractOperation())
-      return &abstractOp->dialect;
-    return representation.get<StringAttr>().getReferencedDialect();
-  }
-
-  /// Return the operation name with dialect name stripped, if it has one.
-  StringRef stripDialect() const;
-
-  /// Return the name of this operation. This always succeeds.
-  StringRef getStringRef() const;
-
-  /// Return the name of this operation as an identifier. This always succeeds.
-  StringAttr getIdentifier() const;
-
-  /// If this operation has a registered operation description, return it.
-  /// Otherwise return null.
-  const AbstractOperation *getAbstractOperation() const {
-    return representation.dyn_cast<const AbstractOperation *>();
-  }
-
-  void print(raw_ostream &os) const;
-  void dump() const;
-
-  void *getAsOpaquePointer() const {
-    return static_cast<void *>(representation.getOpaqueValue());
-  }
-  static OperationName getFromOpaquePointer(const void *pointer);
-
-private:
-  RepresentationUnion representation;
-  OperationName(RepresentationUnion representation)
-      : representation(representation) {}
-};
-
-inline raw_ostream &operator<<(raw_ostream &os, OperationName identifier) {
-  identifier.print(os);
-  return os;
-}
-
-inline bool operator==(OperationName lhs, OperationName rhs) {
-  return lhs.getAsOpaquePointer() == rhs.getAsOpaquePointer();
-}
-
-inline bool operator!=(OperationName lhs, OperationName rhs) {
-  return lhs.getAsOpaquePointer() != rhs.getAsOpaquePointer();
-}
-
-// Make operation names hashable.
-inline llvm::hash_code hash_value(OperationName arg) {
-  return llvm::hash_value(arg.getAsOpaquePointer());
-}
-
 //===----------------------------------------------------------------------===//
 // OperationState
 //===----------------------------------------------------------------------===//
@@ -1119,39 +1174,53 @@ LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
 } // end namespace mlir
 
 namespace llvm {
-// Identifiers hash just like pointers, there is no need to hash the bytes.
 template <>
 struct DenseMapInfo<mlir::OperationName> {
   static mlir::OperationName getEmptyKey() {
-    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
     return mlir::OperationName::getFromOpaquePointer(pointer);
   }
   static mlir::OperationName getTombstoneKey() {
-    auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
     return mlir::OperationName::getFromOpaquePointer(pointer);
   }
-  static unsigned getHashValue(mlir::OperationName Val) {
-    return DenseMapInfo<void *>::getHashValue(Val.getAsOpaquePointer());
+  static unsigned getHashValue(mlir::OperationName val) {
+    return DenseMapInfo<void *>::getHashValue(val.getAsOpaquePointer());
   }
-  static bool isEqual(mlir::OperationName LHS, mlir::OperationName RHS) {
-    return LHS == RHS;
+  static bool isEqual(mlir::OperationName lhs, mlir::OperationName rhs) {
+    return lhs == rhs;
+  }
+};
+template <>
+struct DenseMapInfo<mlir::RegisteredOperationName>
+    : public DenseMapInfo<mlir::OperationName> {
+  static mlir::RegisteredOperationName getEmptyKey() {
+    void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::RegisteredOperationName::getFromOpaquePointer(pointer);
+  }
+  static mlir::RegisteredOperationName getTombstoneKey() {
+    void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::RegisteredOperationName::getFromOpaquePointer(pointer);
   }
 };
 
-/// The pointer inside of an identifier comes from a StringMap, so its alignment
-/// is always at least 4 and probably 8 (on 64-bit machines).  Allow LLVM to
-/// steal the low bits.
 template <>
 struct PointerLikeTypeTraits<mlir::OperationName> {
-public:
   static inline void *getAsVoidPointer(mlir::OperationName I) {
     return const_cast<void *>(I.getAsOpaquePointer());
   }
   static inline mlir::OperationName getFromVoidPointer(void *P) {
     return mlir::OperationName::getFromOpaquePointer(P);
   }
-  static constexpr int NumLowBitsAvailable = PointerLikeTypeTraits<
-      mlir::OperationName::RepresentationUnion>::NumLowBitsAvailable;
+  static constexpr int NumLowBitsAvailable =
+      PointerLikeTypeTraits<void *>::NumLowBitsAvailable;
+};
+template <>
+struct PointerLikeTypeTraits<mlir::RegisteredOperationName>
+    : public PointerLikeTypeTraits<mlir::OperationName> {
+  static inline mlir::RegisteredOperationName getFromVoidPointer(void *P) {
+    return mlir::RegisteredOperationName::getFromOpaquePointer(P);
+  }
 };
 
 } // end namespace llvm

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index 2c474fac59d87..1a8c33e200139 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -176,6 +176,12 @@ class InterfaceMap {
 
 public:
   InterfaceMap(InterfaceMap &&) = default;
+  InterfaceMap &operator=(InterfaceMap &&rhs) {
+    for (auto &it : interfaces)
+      free(it.second);
+    interfaces = std::move(rhs.interfaces);
+    return *this;
+  }
   ~InterfaceMap() {
     for (auto &it : interfaces)
       free(it.second);

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index a339aef90455a..11b0157ae688c 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -264,9 +264,8 @@ void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
 
 static LogicalResult inferOperationTypes(OperationState &state) {
   MLIRContext *context = state.getContext();
-  const AbstractOperation *abstractOp =
-      AbstractOperation::lookup(state.name.getStringRef(), context);
-  if (!abstractOp) {
+  Optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
+  if (!info) {
     emitError(state.location)
         << "type inference was requested for the operation " << state.name
         << ", but the operation was not registered. Ensure that the dialect "
@@ -276,7 +275,7 @@ static LogicalResult inferOperationTypes(OperationState &state) {
   }
 
   // Fallback to inference via an op interface.
-  auto *inferInterface = abstractOp->getInterface<InferTypeOpInterface>();
+  auto *inferInterface = info->getInterface<InferTypeOpInterface>();
   if (!inferInterface) {
     emitError(state.location)
         << "type inference was requested for the operation " << state.name
@@ -353,9 +352,8 @@ MlirLocation mlirOperationGetLocation(MlirOperation op) {
 }
 
 MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
-  if (const auto *abstractOp = unwrap(op)->getAbstractOperation()) {
-    return wrap(abstractOp->typeID);
-  }
+  if (auto info = unwrap(op)->getRegisteredInfo())
+    return wrap(info->getTypeID());
   return {nullptr};
 }
 

diff  --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
index 315adb5fbaf68..f752a57b58576 100644
--- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp
+++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
@@ -17,17 +17,17 @@ using namespace mlir;
 
 bool mlirOperationImplementsInterface(MlirOperation operation,
                                       MlirTypeID interfaceTypeID) {
-  const AbstractOperation *abstractOp =
-      unwrap(operation)->getAbstractOperation();
-  return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID));
+  Optional<RegisteredOperationName> info =
+      unwrap(operation)->getRegisteredInfo();
+  return info && info->hasInterface(unwrap(interfaceTypeID));
 }
 
 bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
                                             MlirContext context,
                                             MlirTypeID interfaceTypeID) {
-  const AbstractOperation *abstractOp = AbstractOperation::lookup(
+  Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
       StringRef(operationName.data, operationName.length), unwrap(context));
-  return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID));
+  return info && info->hasInterface(unwrap(interfaceTypeID));
 }
 
 MlirTypeID mlirInferTypeOpInterfaceTypeID() {
@@ -40,9 +40,9 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
     intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
     void *userData) {
   StringRef name(opName.data, opName.length);
-  const AbstractOperation *abstractOp =
-      AbstractOperation::lookup(name, unwrap(context));
-  if (!abstractOp)
+  Optional<RegisteredOperationName> info =
+      RegisteredOperationName::lookup(name, unwrap(context));
+  if (!info)
     return mlirLogicalResultFailure();
 
   llvm::Optional<Location> maybeLocation = llvm::None;
@@ -68,7 +68,7 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
   });
 
   SmallVector<Type> inferredTypes;
-  if (failed(abstractOp->getInterface<InferTypeOpInterface>()->inferReturnTypes(
+  if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
           unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
           unwrappedRegions, inferredTypes)))
     return mlirLogicalResultFailure();

diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 39c25c31f1526..856297444be13 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -245,9 +245,8 @@ bool OperationOp::hasTypeInference() {
   if (!opName)
     return false;
 
-  OperationName name(*opName, getContext());
-  if (const AbstractOperation *op = name.getAbstractOperation())
-    return op->getInterface<InferTypeOpInterface>();
+  if (auto rInfo = RegisteredOperationName::lookup(*opName, getContext()))
+    return rInfo->hasInterface<InferTypeOpInterface>();
   return false;
 }
 

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 0dd935557d121..50b10bf1e4e01 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -383,7 +383,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
     if (!printerFlags.shouldPrintGenericOpForm()) {
       // Check to see if this is a known operation.  If so, use the registered
       // custom printer hook.
-      if (auto *opInfo = op->getAbstractOperation()) {
+      if (auto opInfo = op->getRegisteredInfo()) {
         opInfo->printAssembly(op, *this, /*defaultDialect=*/"");
         return;
       }
@@ -2517,9 +2517,9 @@ void OperationPrinter::printOperation(Operation *op) {
 
   // If requested, always print the generic form.
   if (!printerFlags.shouldPrintGenericOpForm()) {
-    // Check to see if this is a known operation.  If so, use the registered
+    // Check to see if this is a known operation. If so, use the registered
     // custom printer hook.
-    if (auto *opInfo = op->getAbstractOperation()) {
+    if (auto opInfo = op->getRegisteredInfo()) {
       opInfo->printAssembly(op, *this, defaultDialectStack.back());
       return;
     }

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 91f51d9f0497f..12d9166d28875 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -96,20 +96,6 @@ void mlir::registerMLIRContextCLOptions() {
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Utility reader lock that takes a runtime flag that specifies if we really
-/// need to lock.
-struct ScopedReaderLock {
-  ScopedReaderLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
-      : mutex(shouldLock ? &mutexParam : nullptr) {
-    if (mutex)
-      mutex->lock_shared();
-  }
-  ~ScopedReaderLock() {
-    if (mutex)
-      mutex->unlock_shared();
-  }
-  llvm::sys::SmartRWMutex<true> *mutex;
-};
 /// Utility writer lock that takes a runtime flag that specifies if we really
 /// need to lock.
 struct ScopedWriterLock {
@@ -277,13 +263,18 @@ class MLIRContextImpl {
   DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
   DialectRegistry dialectsRegistry;
 
-  /// This is a mapping from operation name to AbstractOperation for registered
-  /// operations.
-  llvm::StringMap<AbstractOperation> registeredOperations;
-
   /// An allocator used for AbstractAttribute and AbstractType objects.
   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
 
+  /// This is a mapping from operation name to the operation info describing it.
+  llvm::StringMap<OperationName::Impl> operations;
+
+  /// A vector of operation info specifically for registered operations.
+  SmallVector<RegisteredOperationName> registeredOperations;
+
+  /// A mutex used when accessing operation information.
+  llvm::sys::SmartRWMutex<true> operationInfoMutex;
+
   //===--------------------------------------------------------------------===//
   // Affine uniquing
   //===--------------------------------------------------------------------===//
@@ -667,28 +658,24 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
 /// Return information about all registered operations.  This isn't very
 /// efficient, typically you should ask the operations about their properties
 /// directly.
-std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
+std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
   // We just have the operations in a non-deterministic hash table order. Dump
   // into a temporary array, then sort it by operation name to get a stable
   // ordering.
-  llvm::StringMap<AbstractOperation> &registeredOps =
-      impl->registeredOperations;
-
-  std::vector<AbstractOperation *> result;
-  result.reserve(registeredOps.size());
-  for (auto &elt : registeredOps)
-    result.push_back(&elt.second);
-  llvm::array_pod_sort(
-      result.begin(), result.end(),
-      [](AbstractOperation *const *lhs, AbstractOperation *const *rhs) {
-        return (*lhs)->name.compare((*rhs)->name);
-      });
+  std::vector<RegisteredOperationName> result(
+      impl->registeredOperations.begin(), impl->registeredOperations.end());
+  llvm::array_pod_sort(result.begin(), result.end(),
+                       [](const RegisteredOperationName *lhs,
+                          const RegisteredOperationName *rhs) {
+                         return lhs->getIdentifier().compare(
+                             rhs->getIdentifier());
+                       });
 
   return result;
 }
 
 bool MLIRContext::isOperationRegistered(StringRef name) {
-  return impl->registeredOperations.count(name);
+  return OperationName(name, this).isRegistered();
 }
 
 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
@@ -739,26 +726,49 @@ AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
 }
 
 //===----------------------------------------------------------------------===//
-// AbstractOperation
+// OperationName
 //===----------------------------------------------------------------------===//
 
-ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser,
-                                             OperationState &result) const {
-  return parseAssemblyFn(parser, result);
+OperationName::OperationName(StringRef name, MLIRContext *context) {
+  MLIRContextImpl &ctxImpl = context->getImpl();
+
+  // Check for an existing name in read-only mode.
+  bool isMultithreadingEnabled = context->isMultithreadingEnabled();
+  if (isMultithreadingEnabled) {
+    llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
+    auto it = ctxImpl.operations.find(name);
+    if (it != ctxImpl.operations.end()) {
+      impl = &it->second;
+      return;
+    }
+  }
+
+  // Acquire a writer-lock so that we can safely create the new instance.
+  ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
+
+  auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
+  if (it.second)
+    it.first->second.name = StringAttr::get(context, name);
+  impl = &it.first->second;
 }
 
-/// Look up the specified operation in the operation set and return a pointer
-/// to it if present. Otherwise, return a null pointer.
-AbstractOperation *AbstractOperation::lookupMutable(StringRef opName,
-                                                    MLIRContext *context) {
-  auto &impl = context->getImpl();
-  auto it = impl.registeredOperations.find(opName);
-  if (it != impl.registeredOperations.end())
-    return &it->second;
-  return nullptr;
+StringRef OperationName::getDialectNamespace() const {
+  if (Dialect *dialect = getDialect())
+    return dialect->getNamespace();
+  return getStringRef().split('.').first;
+}
+
+//===----------------------------------------------------------------------===//
+// RegisteredOperationName
+//===----------------------------------------------------------------------===//
+
+ParseResult
+RegisteredOperationName::parseAssembly(OpAsmParser &parser,
+                                       OperationState &result) const {
+  return impl->parseAssemblyFn(parser, result);
 }
 
-void AbstractOperation::insert(
+void RegisteredOperationName::insert(
     StringRef name, Dialect &dialect, TypeID typeID,
     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
@@ -766,52 +776,48 @@ void AbstractOperation::insert(
     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
     ArrayRef<StringRef> attrNames) {
   MLIRContext *ctx = dialect.getContext();
-  auto &impl = ctx->getImpl();
-  assert(impl.multiThreadedExecutionContext == 0 &&
-         "Registering a new operation kind while in a multi-threaded execution "
+  auto &ctxImpl = ctx->getImpl();
+  assert(ctxImpl.multiThreadedExecutionContext == 0 &&
+         "registering a new operation kind while in a multi-threaded execution "
          "context");
 
   // Register the attribute names of this operation.
   MutableArrayRef<StringAttr> cachedAttrNames;
   if (!attrNames.empty()) {
     cachedAttrNames = MutableArrayRef<StringAttr>(
-        impl.abstractDialectSymbolAllocator.Allocate<StringAttr>(
+        ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>(
             attrNames.size()),
         attrNames.size());
     for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
       new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
   }
 
-  // Register the information for this operation.
-  AbstractOperation opInfo(
-      name, dialect, typeID, std::move(parseAssembly), std::move(printAssembly),
-      std::move(verifyInvariants), std::move(foldHook),
-      std::move(getCanonicalizationPatterns), std::move(interfaceMap),
-      std::move(hasTrait), cachedAttrNames);
-  if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) {
+  // Insert the operation info if it doesn't exist yet.
+  auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
+  if (it.second)
+    it.first->second.name = StringAttr::get(ctx, name);
+  OperationName::Impl &impl = it.first->second;
+
+  if (impl.isRegistered()) {
     llvm::errs() << "error: operation named '" << name
                  << "' is already registered.\n";
     abort();
   }
+  ctxImpl.registeredOperations.push_back(RegisteredOperationName(&impl));
+
+  // Update the registered info for this operation.
+  impl.dialect = &dialect;
+  impl.typeID = typeID;
+  impl.interfaceMap = std::move(interfaceMap);
+  impl.foldHookFn = std::move(foldHook);
+  impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
+  impl.hasTraitFn = std::move(hasTrait);
+  impl.parseAssemblyFn = std::move(parseAssembly);
+  impl.printAssemblyFn = std::move(printAssembly);
+  impl.verifyInvariantsFn = std::move(verifyInvariants);
+  impl.attributeNames = cachedAttrNames;
 }
 
-AbstractOperation::AbstractOperation(
-    StringRef name, Dialect &dialect, TypeID typeID,
-    ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
-    VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
-    GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
-    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
-    ArrayRef<StringAttr> attrNames)
-    : name(StringAttr::get(dialect.getContext(), name)), dialect(dialect),
-      typeID(typeID), interfaceMap(std::move(interfaceMap)),
-      foldHookFn(std::move(foldHook)),
-      getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)),
-      hasTraitFn(std::move(hasTrait)),
-      parseAssemblyFn(std::move(parseAssembly)),
-      printAssemblyFn(std::move(printAssembly)),
-      verifyInvariantsFn(std::move(verifyInvariants)),
-      attributeNames(attrNames) {}
-
 //===----------------------------------------------------------------------===//
 // AbstractType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index cfbb9a97e632e..b1a23a225732b 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -19,49 +19,6 @@
 
 using namespace mlir;
 
-//===----------------------------------------------------------------------===//
-// OperationName
-//===----------------------------------------------------------------------===//
-
-/// Form the OperationName for an op with the specified string.  This either is
-/// a reference to an AbstractOperation if one is known, or a uniqued Identifier
-/// if not.
-OperationName::OperationName(StringRef name, MLIRContext *context) {
-  if (auto *op = AbstractOperation::lookup(name, context))
-    representation = op;
-  else
-    representation = StringAttr::get(name, context);
-}
-
-/// Return the name of the dialect this operation is registered to.
-StringRef OperationName::getDialectNamespace() const {
-  if (Dialect *dialect = getDialect())
-    return dialect->getNamespace();
-  return getStringRef().split('.').first;
-}
-
-/// Return the operation name with dialect name stripped, if it has one.
-StringRef OperationName::stripDialect() const {
-  return getStringRef().split('.').second;
-}
-
-/// Return the name of this operation. This always succeeds.
-StringRef OperationName::getStringRef() const {
-  return getIdentifier().strref();
-}
-
-/// Return the name of this operation as an identifier. This always succeeds.
-StringAttr OperationName::getIdentifier() const {
-  if (auto *op = representation.dyn_cast<const AbstractOperation *>())
-    return op->name;
-  return representation.get<StringAttr>();
-}
-
-OperationName OperationName::getFromOpaquePointer(const void *pointer) {
-  return OperationName(
-      RepresentationUnion::getFromOpaqueValue(const_cast<void *>(pointer)));
-}
-
 //===----------------------------------------------------------------------===//
 // Operation
 //===----------------------------------------------------------------------===//
@@ -115,11 +72,8 @@ Operation *Operation::create(Location location, OperationName name,
 
   // If the operation is known to have no operands, don't allocate an operand
   // storage.
-  bool needsOperandStorage = true;
-  if (operands.empty()) {
-    if (const AbstractOperation *abstractOp = name.getAbstractOperation())
-      needsOperandStorage = !abstractOp->hasTrait<OpTrait::ZeroOperands>();
-  }
+  bool needsOperandStorage =
+      operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true;
 
   // Compute the byte size for the operation and the operand storage. This takes
   // into account the size of the operation, its trailing objects, and its
@@ -543,8 +497,8 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
                               SmallVectorImpl<OpFoldResult> &results) {
   // If we have a registered operation definition matching this one, use it to
   // try to constant fold the operation.
-  auto *abstractOp = getAbstractOperation();
-  if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results)))
+  Optional<RegisteredOperationName> info = getRegisteredInfo();
+  if (info && succeeded(info->foldHook(this, operands, results)))
     return success();
 
   // Otherwise, fall back on the dialect hook to handle it.

diff  --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index 90acbecb626d8..c0e4c7974c6d0 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -177,8 +177,8 @@ LogicalResult OperationVerifier::verifyOperation(
 
   // If we can get operation info for this, check the custom hook.
   OperationName opName = op.getName();
-  auto *opInfo = opName.getAbstractOperation();
-  if (opInfo && failed(opInfo->verifyInvariants(&op)))
+  Optional<RegisteredOperationName> registeredInfo = opName.getRegisteredInfo();
+  if (registeredInfo && failed(registeredInfo->verifyInvariants(&op)))
     return failure();
 
   if (unsigned numRegions = op.getNumRegions()) {
@@ -218,7 +218,7 @@ LogicalResult OperationVerifier::verifyOperation(
   }
 
   // If this is a registered operation, there is nothing left to do.
-  if (opInfo)
+  if (registeredInfo)
     return success();
 
   // Otherwise, verify that the parent dialect allows un-registered operations.

diff  --git a/mlir/lib/Parser/AsmParserState.cpp b/mlir/lib/Parser/AsmParserState.cpp
index 5d40f50a3a36a..95f64910541fe 100644
--- a/mlir/lib/Parser/AsmParserState.cpp
+++ b/mlir/lib/Parser/AsmParserState.cpp
@@ -23,8 +23,7 @@ struct AsmParserState::Impl {
 
   struct PartialOpDef {
     explicit PartialOpDef(const OperationName &opName) {
-      const auto *abstractOp = opName.getAbstractOperation();
-      if (abstractOp && abstractOp->hasTrait<OpTrait::SymbolTable>())
+      if (opName.hasTrait<OpTrait::SymbolTable>())
         symbolTable = std::make_unique<SymbolUseMap>();
     }
 

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index d69f16f840040..5f14292ace81f 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -987,20 +987,17 @@ Operation *OperationParser::parseGenericOperation() {
   OperationState result(srcLocation, name);
 
   // Lazy load dialects in the context as needed.
-  if (!result.name.getAbstractOperation()) {
+  if (!result.name.isRegistered()) {
     StringRef dialectName = StringRef(name).split('.').first;
-    if (!getContext()->getLoadedDialect(dialectName)) {
-      if (getContext()->getOrLoadDialect(dialectName)) {
-        result.name = OperationName(name, getContext());
-      } else if (!getContext()->allowsUnregisteredDialects()) {
-        // Emit an error if the dialect couldn't be loaded (i.e., it was not
-        // registered) and unregistered dialects aren't allowed.
-        return emitError(
-                   "operation being parsed with an unregistered dialect. If "
-                   "this is intended, please use -allow-unregistered-dialect "
-                   "with the MLIR tool used"),
-               nullptr;
-      }
+    if (!getContext()->getLoadedDialect(dialectName) &&
+        !getContext()->getOrLoadDialect(dialectName) &&
+        !getContext()->allowsUnregisteredDialects()) {
+      // Emit an error if the dialect couldn't be loaded (i.e., it was not
+      // registered) and unregistered dialects aren't allowed.
+      emitError("operation being parsed with an unregistered dialect. If "
+                "this is intended, please use -allow-unregistered-dialect "
+                "with the MLIR tool used");
+      return nullptr;
     }
   }
 
@@ -1018,9 +1015,8 @@ Operation *OperationParser::parseGenericOperation() {
 
   // Parse the successor list.
   if (getToken().is(Token::l_square)) {
-    // Check if the operation is a known terminator.
-    const AbstractOperation *abstractOp = result.name.getAbstractOperation();
-    if (abstractOp && !abstractOp->hasTrait<OpTrait::IsTerminator>())
+    // Check if the operation is not a known terminator.
+    if (!result.name.mightHaveTrait<OpTrait::IsTerminator>())
       return emitError("successors in non-terminator"), nullptr;
 
     SmallVector<Block *, 2> successors;
@@ -1514,11 +1510,12 @@ Operation *
 OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
   llvm::SMLoc opLoc = getToken().getLoc();
   std::string opName = getTokenSpelling().str();
-  auto *opDefinition = AbstractOperation::lookup(opName, getContext());
+  Optional<RegisteredOperationName> opInfo =
+      RegisteredOperationName::lookup(opName, getContext());
   StringRef defaultDialect = getState().defaultDialectStack.back();
   Dialect *dialect = nullptr;
-  if (opDefinition) {
-    dialect = &opDefinition->dialect;
+  if (opInfo) {
+    dialect = &opInfo->getDialect();
   } else {
     if (StringRef(opName).contains('.')) {
       // This op has a dialect, we try to check if we can register it in the
@@ -1526,19 +1523,19 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
       StringRef dialectName = StringRef(opName).split('.').first;
       dialect = getContext()->getLoadedDialect(dialectName);
       if (!dialect && (dialect = getContext()->getOrLoadDialect(dialectName)))
-        opDefinition = AbstractOperation::lookup(opName, getContext());
+        opInfo = RegisteredOperationName::lookup(opName, getContext());
     } else {
       // If the operation name has no namespace prefix we lookup the current
       // default dialect (set through OpAsmOpInterface).
-      opDefinition = AbstractOperation::lookup(
+      opInfo = RegisteredOperationName::lookup(
           Twine(defaultDialect + "." + opName).str(), getContext());
-      if (!opDefinition && getContext()->getOrLoadDialect("std")) {
-        opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(),
+      if (!opInfo && getContext()->getOrLoadDialect("std")) {
+        opInfo = RegisteredOperationName::lookup(Twine("std." + opName).str(),
                                                  getContext());
       }
-      if (opDefinition) {
-        dialect = &opDefinition->dialect;
-        opName = opDefinition->name.str();
+      if (opInfo) {
+        dialect = &opInfo->getDialect();
+        opName = opInfo->getStringRef().str();
       } else if (!defaultDialect.empty()) {
         dialect = getContext()->getOrLoadDialect(defaultDialect);
         opName = (defaultDialect + "." + opName).str();
@@ -1548,16 +1545,15 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
 
   // This is the actual hook for the custom op parsing, usually implemented by
   // the op itself (`Op::parse()`). We retrieve it either from the
-  // AbstractOperation or from the Dialect.
+  // RegisteredOperationName or from the Dialect.
   function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssemblyFn;
   bool isIsolatedFromAbove = false;
 
   defaultDialect = "";
-  if (opDefinition) {
-    parseAssemblyFn = opDefinition->getParseAssemblyFn();
-    isIsolatedFromAbove =
-        opDefinition->hasTrait<OpTrait::IsIsolatedFromAbove>();
-    auto *iface = opDefinition->getInterface<OpAsmOpInterface>();
+  if (opInfo) {
+    parseAssemblyFn = opInfo->getParseAssemblyFn();
+    isIsolatedFromAbove = opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>();
+    auto *iface = opInfo->getInterface<OpAsmOpInterface>();
     if (iface && !iface->getDefaultDialect().empty())
       defaultDialect = iface->getDefaultDialect();
   } else {

diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 98854564da735..810bcf67bccb0 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1323,7 +1323,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
 
     // Handle the case where the operation has inferred types.
     InferTypeOpInterface::Concept *concept =
-        state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
+        state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
 
     // TODO: Handle failure.
     state.types.clear();

diff  --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
index a26dee0d448f5..20c71b51c5f81 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
@@ -66,19 +66,17 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
   // Functor used to walk all of the operations registered in the context. This
   // is useful for patterns that get applied to multiple operations, such as
   // interface and trait based patterns.
-  std::vector<AbstractOperation *> abstractOps;
-  auto addToOpsWhen = [&](std::unique_ptr<RewritePattern> &pattern,
-                          function_ref<bool(AbstractOperation *)> callbackFn) {
-    if (abstractOps.empty())
-      abstractOps = pattern->getContext()->getRegisteredOperations();
-    for (AbstractOperation *absOp : abstractOps) {
-      if (callbackFn(absOp)) {
-        OperationName opName(absOp);
-        impl->nativeOpSpecificPatternMap[opName].push_back(pattern.get());
-      }
-    }
-    impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
-  };
+  std::vector<RegisteredOperationName> opInfos;
+  auto addToOpsWhen =
+      [&](std::unique_ptr<RewritePattern> &pattern,
+          function_ref<bool(RegisteredOperationName)> callbackFn) {
+        if (opInfos.empty())
+          opInfos = pattern->getContext()->getRegisteredOperations();
+        for (RegisteredOperationName info : opInfos)
+          if (callbackFn(info))
+            impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
+        impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
+      };
 
   for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
     // Don't add patterns that haven't been enabled by the user.
@@ -106,14 +104,14 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
       continue;
     }
     if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
-      addToOpsWhen(pat, [&](AbstractOperation *absOp) {
-        return absOp->hasInterface(*interfaceID);
+      addToOpsWhen(pat, [&](RegisteredOperationName info) {
+        return info.hasInterface(*interfaceID);
       });
       continue;
     }
     if (Optional<TypeID> traitID = pat->getRootTraitID()) {
-      addToOpsWhen(pat, [&](AbstractOperation *absOp) {
-        return absOp->hasTrait(*traitID);
+      addToOpsWhen(pat, [&](RegisteredOperationName info) {
+        return info.hasTrait(*traitID);
       });
       continue;
     }

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 457ad414ccdfd..22cc32fa728ab 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -37,8 +37,8 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
     RewritePatternSet owningPatterns(context);
     for (auto *dialect : context->getLoadedDialects())
       dialect->getCanonicalizationPatterns(owningPatterns);
-    for (auto *op : context->getRegisteredOperations())
-      op->getCanonicalizationPatterns(owningPatterns, context);
+    for (RegisteredOperationName op : context->getRegisteredOperations())
+      op.getCanonicalizationPatterns(owningPatterns, context);
 
     patterns = FrozenRewritePatternSet(std::move(owningPatterns),
                                        disabledPatterns, enabledPatterns);

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index b93a9e1c47546..50708ad3d2b97 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -669,7 +669,7 @@ void OpEmitter::genAttrNameGetters() {
     ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
     method->body() << "assert(index < " << attributeNames.size()
                    << " && \"invalid attribute index\");\n"
-                      "  return name.getAbstractOperation()"
+                      "  return name.getRegisteredInfo()"
                       "->getAttributeNames()[index];";
   }
 


        


More information about the Mlir-commits mailing list