[Mlir-commits] [mlir] e055aad - Refactor OperationName to use virtual tables for dispatch (NFC)
Mehdi Amini
llvmlistbot at llvm.org
Fri Jan 13 17:27:48 PST 2023
Author: Mehdi Amini
Date: 2023-01-14T01:27:38Z
New Revision: e055aad5ffb348472c65dfcbede85f39efe8f906
URL: https://github.com/llvm/llvm-project/commit/e055aad5ffb348472c65dfcbede85f39efe8f906
DIFF: https://github.com/llvm/llvm-project/commit/e055aad5ffb348472c65dfcbede85f39efe8f906.diff
LOG: Refactor OperationName to use virtual tables for dispatch (NFC)
This streamlines the implementation and makes it so that the virtual tables are in the binary instead of dynamically assembled during initialization.
The dynamic allocation size of op registration is also smaller with this
change.
Differential Revision: https://reviews.llvm.org/D141492
Added:
Modified:
mlir/include/mlir/IR/ExtensibleDialect.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/AsmParser/Parser.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/ExtensibleDialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 662a7353bfd0e..9820aa69892ec 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -336,12 +336,15 @@ class DynamicType
/// The definition of a dynamic op. A dynamic op is an op that is defined at
/// runtime, and that can be registered at runtime by an extensible dialect (a
-/// dialect inheriting ExtensibleDialect). This class stores the functions that
-/// are in the OperationName class, and in addition defines the TypeID of the op
-/// that will be defined.
-/// Each dynamic operation definition refers to one instance of this class.
-class DynamicOpDefinition {
+/// dialect inheriting ExtensibleDialect). This class implements the method
+/// exposed by the OperationName class, and in addition defines the TypeID of
+/// the op that will be defined. Each dynamic operation definition refers to one
+/// instance of this class.
+class DynamicOpDefinition : public OperationName::Impl {
public:
+ using GetCanonicalizationPatternsFn =
+ llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
+
/// Create a new op at runtime. The op is registered only after passing it to
/// the dialect using registerDynamicOp.
static std::unique_ptr<DynamicOpDefinition>
@@ -361,8 +364,7 @@ class DynamicOpDefinition {
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
- OperationName::GetCanonicalizationPatternsFn
- &&getCanonicalizationPatternsFn,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
/// Returns the op typeID.
@@ -400,9 +402,8 @@ class DynamicOpDefinition {
/// Set the hook returning any canonicalization pattern rewrites that the op
/// supports, for use by the canonicalization pass.
- void
- setGetCanonicalizationPatternsFn(OperationName::GetCanonicalizationPatternsFn
- &&getCanonicalizationPatterns) {
+ void setGetCanonicalizationPatternsFn(
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatterns) {
getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
}
@@ -412,6 +413,29 @@ class DynamicOpDefinition {
populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
+ LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
+ SmallVectorImpl<OpFoldResult> &results) final {
+ return foldHookFn(op, attrs, results);
+ }
+ void getCanonicalizationPatterns(RewritePatternSet &set,
+ MLIRContext *context) final {
+ getCanonicalizationPatternsFn(set, context);
+ }
+ bool hasTrait(TypeID id) final { return false; }
+ OperationName::ParseAssemblyFn getParseAssemblyFn() final { return parseFn; }
+ void populateDefaultAttrs(const OperationName &name,
+ NamedAttrList &attrs) final {
+ populateDefaultAttrsFn(name, attrs);
+ }
+ void printAssembly(Operation *op, OpAsmPrinter &printer,
+ StringRef name) final {
+ printFn(op, printer, name);
+ }
+ LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); }
+ LogicalResult verifyRegionInvariants(Operation *op) final {
+ return verifyRegionFn(op);
+ }
+
private:
DynamicOpDefinition(
StringRef name, ExtensibleDialect *dialect,
@@ -420,26 +444,18 @@ class DynamicOpDefinition {
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
- OperationName::GetCanonicalizationPatternsFn
- &&getCanonicalizationPatternsFn,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
- /// Unique identifier for this operation.
- TypeID typeID;
-
- /// Name of the operation.
- /// The name is prefixed with the dialect name.
- std::string name;
-
/// Dialect defining this operation.
- ExtensibleDialect *dialect;
+ ExtensibleDialect *getdialect();
OperationName::VerifyInvariantsFn verifyFn;
OperationName::VerifyRegionInvariantsFn verifyRegionFn;
OperationName::ParseAssemblyFn parseFn;
OperationName::PrintAssemblyFn printFn;
OperationName::FoldHookFn foldHookFn;
- OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+ GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
friend ExtensibleDialect;
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 96b6e174f5081..34020e70f84ff 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -183,8 +183,7 @@ class OpState {
MLIRContext *context) {}
/// This hook populates any unset default attrs.
- static void populateDefaultAttrs(const RegisteredOperationName &,
- NamedAttrList &) {}
+ static void populateDefaultAttrs(const OperationName &, NamedAttrList &) {}
protected:
/// If the concrete type didn't implement a custom verifier hook, just fall
@@ -1831,20 +1830,11 @@ class Op : public OpState, public Traits<ConcreteType>... {
return result;
}
- /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
- static OperationName::GetCanonicalizationPatternsFn
- getGetCanonicalizationPatternsFn() {
- return &ConcreteType::getCanonicalizationPatterns;
- }
/// Implementation of `GetHasTraitFn`
static OperationName::HasTraitFn getHasTraitFn() {
return
[](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
}
- /// Implementation of `ParseAssemblyFn` OperationName hook.
- static OperationName::ParseAssemblyFn getParseAssemblyFn() {
- return &ConcreteType::parse;
- }
/// Implementation of `PrintAssemblyFn` OperationName hook.
static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
if constexpr (detect_has_print<ConcreteType>::value)
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 18844000787c7..c4ef826e94366 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -505,11 +505,9 @@ class alignas(8) Operation final
/// Sets default attributes on unset attributes.
void populateDefaultAttrs() {
- if (auto registered = getRegisteredInfo()) {
NamedAttrList attrs(getAttrDictionary());
- registered->populateDefaultAttrs(attrs);
+ name.populateDefaultAttrs(attrs);
setAttrs(attrs.getDictionary(getContext()));
- }
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index cfd8844403a9a..629236046a1e0 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -23,6 +23,7 @@
#include "mlir/Support/InterfaceSupport.h"
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include "llvm/Support/TrailingObjects.h"
#include <memory>
@@ -63,17 +64,15 @@ class ValueTypeRange;
class OperationName {
public:
- using GetCanonicalizationPatternsFn =
- llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
using FoldHookFn = llvm::unique_function<LogicalResult(
Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) const>;
using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
using ParseAssemblyFn =
- llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
+ llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>;
// Note: RegisteredOperationName is passed as reference here as the derived
// class is defined below.
- using PopulateDefaultAttrsFn = llvm::unique_function<void(
- const RegisteredOperationName &, NamedAttrList &) const>;
+ using PopulateDefaultAttrsFn =
+ llvm::unique_function<void(const OperationName &, NamedAttrList &) const>;
using PrintAssemblyFn =
llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
using VerifyInvariantsFn =
@@ -81,63 +80,132 @@ class OperationName {
using VerifyRegionInvariantsFn =
llvm::unique_function<LogicalResult(Operation *) const>;
-protected:
/// This class represents a type erased version of an operation. It contains
/// all of the components necessary for opaquely interacting with an
/// operation. If the operation is not registered, some of these components
/// may not be populated.
- struct Impl {
- Impl(StringAttr name)
- : name(name), dialect(nullptr), interfaceMap(std::nullopt) {}
+ struct InterfaceConcept {
+ virtual ~InterfaceConcept() = default;
+ virtual LogicalResult foldHook(Operation *, ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) = 0;
+ virtual void getCanonicalizationPatterns(RewritePatternSet &,
+ MLIRContext *) = 0;
+ virtual bool hasTrait(TypeID) = 0;
+ virtual OperationName::ParseAssemblyFn getParseAssemblyFn() = 0;
+ virtual void populateDefaultAttrs(const OperationName &,
+ NamedAttrList &) = 0;
+ virtual void printAssembly(Operation *, OpAsmPrinter &, StringRef) = 0;
+ virtual LogicalResult verifyInvariants(Operation *) = 0;
+ virtual LogicalResult verifyRegionInvariants(Operation *) = 0;
+ };
+
+public:
+ class Impl : public InterfaceConcept {
+ public:
+ Impl(StringRef, Dialect *dialect, TypeID typeID,
+ detail::InterfaceMap interfaceMap);
+ Impl(StringAttr name, Dialect *dialect, TypeID typeID,
+ detail::InterfaceMap interfaceMap)
+ : name(name), typeID(typeID), dialect(dialect),
+ interfaceMap(std::move(interfaceMap)) {}
+
+ /// Returns true if this is a registered operation.
+ bool isRegistered() const { return typeID != TypeID::get<void>(); }
+ detail::InterfaceMap &getInterfaceMap() { return interfaceMap; }
+ Dialect *getDialect() const { return dialect; }
+ StringAttr getName() const { return name; }
+ TypeID getTypeID() const { return typeID; }
+ ArrayRef<StringAttr> getAttributeNames() const { return attributeNames; }
+
+ protected:
+ //===------------------------------------------------------------------===//
+ // Registered Operation Info
/// The name of the operation.
StringAttr name;
- //===------------------------------------------------------------------===//
- // Registered Operation Info
+ /// The unique identifier of the derived Op class.
+ TypeID typeID;
/// The following fields are only populated when the operation is
/// registered.
- /// Returns true if the operation has been registered, i.e. if the
- /// registration info has been populated.
- bool isRegistered() const { return dialect; }
-
/// This is the dialect that this operation belongs to.
Dialect *dialect;
- /// The unique identifier of the derived Op class.
- TypeID typeID;
-
/// A map of interfaces that were registered to this operation.
detail::InterfaceMap interfaceMap;
- /// Internal callback hooks provided by the op implementation.
- FoldHookFn foldHookFn;
- GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
- HasTraitFn hasTraitFn;
- ParseAssemblyFn parseAssemblyFn;
- PopulateDefaultAttrsFn populateDefaultAttrsFn;
- PrintAssemblyFn printAssemblyFn;
- VerifyInvariantsFn verifyInvariantsFn;
- VerifyRegionInvariantsFn verifyRegionInvariantsFn;
-
/// A list of attribute names registered to this operation in StringAttr
/// form. This allows for operation classes to use StringAttr for attribute
/// lookup/creation/etc., as opposed to raw strings.
ArrayRef<StringAttr> attributeNames;
+
+ friend class RegisteredOperationName;
+ };
+
+protected:
+ /// Default implementation for unregistered operations.
+ struct UnregisteredOpModel : public Impl {
+ using Impl::Impl;
+ LogicalResult foldHook(Operation *, ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) final;
+ void getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) final;
+ bool hasTrait(TypeID) final;
+ virtual OperationName::ParseAssemblyFn getParseAssemblyFn() final;
+ void populateDefaultAttrs(const OperationName &, NamedAttrList &) final;
+ void printAssembly(Operation *, OpAsmPrinter &, StringRef) final;
+ LogicalResult verifyInvariants(Operation *) final;
+ LogicalResult verifyRegionInvariants(Operation *) final;
};
public:
OperationName(StringRef name, MLIRContext *context);
/// Return if this operation is registered.
- bool isRegistered() const { return impl->isRegistered(); }
+ bool isRegistered() const { return getImpl()->isRegistered(); }
+
+ /// Return the unique identifier of the derived Op class, or null if not
+ /// registered.
+ TypeID getTypeID() const { return getImpl()->getTypeID(); }
/// If this operation is registered, returns the registered information,
/// std::nullopt otherwise.
Optional<RegisteredOperationName> getRegisteredInfo() const;
+ /// This hook implements a generalized folder for this operation. Operations
+ /// can implement this to provide simplifications rules that are applied by
+ /// the Builder::createOrFold API and the canonicalization pass.
+ ///
+ /// This is an intentionally limited interface - implementations of this
+ /// hook can only perform the following changes to the operation:
+ ///
+ /// 1. They can leave the operation alone and without changing the IR, and
+ /// return failure.
+ /// 2. They can mutate the operation in place, without changing anything
+ /// else
+ /// in the IR. In this case, return success.
+ /// 3. They can return a list of existing values that can be used instead
+ /// of
+ /// the operation. In this case, fill in the results list and return
+ /// success. The caller will remove the operation and use those results
+ /// instead.
+ ///
+ /// This allows expression of some simple in-place canonicalizations (e.g.
+ /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+ /// generalized constant folding.
+ LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) const {
+ return getImpl()->foldHook(op, operands, results);
+ }
+
+ /// This hook returns any canonicalization pattern rewrites that the
+ /// operation supports, for use by the canonicalization pass.
+ void getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) const {
+ return getImpl()->getCanonicalizationPatterns(results, context);
+ }
+
/// Returns true if the operation was registered with a particular trait, e.g.
/// hasTrait<OperandsAreSignlessIntegerLike>(). Returns false if the operation
/// is unregistered.
@@ -145,9 +213,7 @@ class OperationName {
bool hasTrait() const {
return hasTrait(TypeID::get<Trait>());
}
- bool hasTrait(TypeID traitID) const {
- return isRegistered() && impl->hasTraitFn(traitID);
- }
+ bool hasTrait(TypeID traitID) const { return getImpl()->hasTrait(traitID); }
/// Returns true if the operation *might* have the provided trait. This
/// means that either the operation is unregistered, or it was registered with
@@ -157,7 +223,54 @@ class OperationName {
return mightHaveTrait(TypeID::get<Trait>());
}
bool mightHaveTrait(TypeID traitID) const {
- return !isRegistered() || impl->hasTraitFn(traitID);
+ return !isRegistered() || getImpl()->hasTrait(traitID);
+ }
+
+ /// Return the static hook for parsing this operation assembly.
+ ParseAssemblyFn getParseAssemblyFn() const {
+ return getImpl()->getParseAssemblyFn();
+ }
+
+ /// This hook implements the method to populate defaults attributes that are
+ /// unset.
+ void populateDefaultAttrs(NamedAttrList &attrs) const {
+ getImpl()->populateDefaultAttrs(*this, attrs);
+ }
+
+ /// This hook implements the AsmPrinter for this operation.
+ void printAssembly(Operation *op, OpAsmPrinter &p,
+ StringRef defaultDialect) const {
+ return getImpl()->printAssembly(op, p, defaultDialect);
+ }
+
+ /// These hooks implement the verifiers for this operation. It should emits
+ /// an error message and returns failure if a problem is detected, or
+ /// returns success if everything is ok.
+ LogicalResult verifyInvariants(Operation *op) const {
+ return getImpl()->verifyInvariants(op);
+ }
+ LogicalResult verifyRegionInvariants(Operation *op) const {
+ return getImpl()->verifyRegionInvariants(op);
+ }
+
+ /// Return the list of cached attribute names registered to this operation.
+ /// The order of attributes cached here is unique to each type of operation,
+ /// and the interpretation of this attribute list should generally be driven
+ /// by the respective operation. In many cases, this caching removes the
+ /// need to use the raw string name of a known attribute.
+ ///
+ /// For example the ODS generator, with an op defining the following
+ /// attributes:
+ ///
+ /// let arguments = (ins I32Attr:$attr1, I32Attr:$attr2);
+ ///
+ /// ... may produce an order here of ["attr1", "attr2"]. This allows for the
+ /// ODS generator to directly access the cached name for a known attribute,
+ /// greatly simplifying the cost and complexity of attribute usage produced
+ /// by the generator.
+ ///
+ ArrayRef<StringAttr> getAttributeNames() const {
+ return getImpl()->getAttributeNames();
}
/// Returns an instance of the concept object for the given interface if it
@@ -165,7 +278,13 @@ class OperationName {
/// directly.
template <typename T>
typename T::Concept *getInterface() const {
- return impl->interfaceMap.lookup<T>();
+ return getImpl()->getInterfaceMap().lookup<T>();
+ }
+
+ /// Attach the given models as implementations of the corresponding
+ /// interfaces for the concrete operation.
+ template <typename... Models> void attachInterface() {
+ getImpl()->getInterfaceMap().insert<Models...>();
}
/// Returns true if this operation has the given interface registered to it.
@@ -174,7 +293,7 @@ class OperationName {
return hasInterface(TypeID::get<T>());
}
bool hasInterface(TypeID interfaceID) const {
- return impl->interfaceMap.contains(interfaceID);
+ return getImpl()->getInterfaceMap().contains(interfaceID);
}
/// Returns true if the operation *might* have the provided interface. This
@@ -191,7 +310,8 @@ class OperationName {
/// Return the dialect this operation is registered to if the dialect is
/// loaded in the context, or nullptr if the dialect isn't loaded.
Dialect *getDialect() const {
- return isRegistered() ? impl->dialect : impl->name.getReferencedDialect();
+ return isRegistered() ? getImpl()->getDialect()
+ : getImpl()->getName().getReferencedDialect();
}
/// Return the name of the dialect this operation is registered to.
@@ -204,7 +324,7 @@ class OperationName {
StringRef getStringRef() const { return getIdentifier(); }
/// Return the name of this operation as a StringAttr.
- StringAttr getIdentifier() const { return impl->name; }
+ StringAttr getIdentifier() const { return getImpl()->getName(); }
void print(raw_ostream &os) const;
void dump() const;
@@ -222,12 +342,17 @@ class OperationName {
protected:
OperationName(Impl *impl) : impl(impl) {}
+ Impl *getImpl() const { return impl; }
+ void setImpl(Impl *rhs) { impl = rhs; }
+private:
/// The internal implementation of the operation name.
- Impl *impl;
+ Impl *impl = nullptr;
/// Allow access to the Impl struct.
friend MLIRContextImpl;
+ friend DenseMapInfo<mlir::OperationName>;
+ friend DenseMapInfo<mlir::RegisteredOperationName>;
};
inline raw_ostream &operator<<(raw_ostream &os, OperationName info) {
@@ -250,137 +375,62 @@ inline llvm::hash_code hash_value(OperationName arg) {
/// the concrete operation types.
class RegisteredOperationName : public OperationName {
public:
+ /// Implementation of the InterfaceConcept for operation APIs that forwarded
+ /// to a concrete op implementation.
+ template <typename ConcreteOp> struct Model : public Impl {
+ Model(Dialect *dialect)
+ : Impl(ConcreteOp::getOperationName(), dialect,
+ TypeID::get<ConcreteOp>(), ConcreteOp::getInterfaceMap()) {}
+ LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
+ SmallVectorImpl<OpFoldResult> &results) final {
+ return ConcreteOp::getFoldHookFn()(op, attrs, results);
+ }
+ void getCanonicalizationPatterns(RewritePatternSet &set,
+ MLIRContext *context) final {
+ ConcreteOp::getCanonicalizationPatterns(set, context);
+ }
+ bool hasTrait(TypeID id) final { return ConcreteOp::getHasTraitFn()(id); }
+ OperationName::ParseAssemblyFn getParseAssemblyFn() final {
+ return ConcreteOp::parse;
+ }
+ void populateDefaultAttrs(const OperationName &name,
+ NamedAttrList &attrs) final {
+ ConcreteOp::populateDefaultAttrs(name, attrs);
+ }
+ void printAssembly(Operation *op, OpAsmPrinter &printer,
+ StringRef name) final {
+ ConcreteOp::getPrintAssemblyFn()(op, printer, name);
+ }
+ LogicalResult verifyInvariants(Operation *op) final {
+ return ConcreteOp::getVerifyInvariantsFn()(op);
+ }
+ LogicalResult verifyRegionInvariants(Operation *op) final {
+ return ConcreteOp::getVerifyRegionInvariantsFn()(op);
+ }
+ };
+
/// Lookup the registered operation information for the given operation.
/// Returns std::nullopt if the operation isn't registered.
static Optional<RegisteredOperationName> lookup(StringRef name,
MLIRContext *ctx);
/// Register a new operation in a Dialect object.
- /// This constructor is used by Dialect objects when they register the list of
- /// operations they contain.
- template <typename T>
- static void insert(Dialect &dialect) {
- insert(T::getOperationName(), dialect, TypeID::get<T>(),
- T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
- T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
- T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
- T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(),
- T::getPopulateDefaultAttrsFn());
+ /// This constructor is used by Dialect objects when they register the list
+ /// of operations they contain.
+ template <typename T> static void insert(Dialect &dialect) {
+ insert(std::make_unique<Model<T>>(&dialect), T::getAttributeNames());
}
/// The use of this method is in general discouraged in favor of
/// 'insert<CustomOp>(dialect)'.
- static void
- insert(StringRef name, Dialect &dialect, TypeID typeID,
- ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
- VerifyInvariantsFn &&verifyInvariants,
- VerifyRegionInvariantsFn &&verifyRegionInvariants,
- FoldHookFn &&foldHook,
- GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
- detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
- ArrayRef<StringRef> attrNames,
- PopulateDefaultAttrsFn &&populateDefaultAttrs);
+ static void insert(std::unique_ptr<OperationName::Impl> ownedImpl,
+ ArrayRef<StringRef> attrNames);
/// Return the dialect this operation is registered to.
- Dialect &getDialect() const { return *impl->dialect; }
-
- /// Return the unique identifier of the derived Op class.
- TypeID getTypeID() const { return impl->typeID; }
+ Dialect &getDialect() const { return *getImpl()->getDialect(); }
/// Use the specified object to parse this ops custom assembly format.
ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
- /// Return the static hook for parsing this operation assembly.
- const ParseAssemblyFn &getParseAssemblyFn() const {
- return impl->parseAssemblyFn;
- }
-
- /// This hook implements the AsmPrinter for this operation.
- void printAssembly(Operation *op, OpAsmPrinter &p,
- StringRef defaultDialect) const {
- return impl->printAssemblyFn(op, p, defaultDialect);
- }
-
- /// These hooks implement the verifiers for this operation. It should emits
- /// an error message and returns failure if a problem is detected, or returns
- /// success if everything is ok.
- LogicalResult verifyInvariants(Operation *op) const {
- return impl->verifyInvariantsFn(op);
- }
- LogicalResult verifyRegionInvariants(Operation *op) const {
- return impl->verifyRegionInvariantsFn(op);
- }
-
- /// This hook implements a generalized folder for this operation. Operations
- /// can implement this to provide simplifications rules that are applied by
- /// the Builder::createOrFold API and the canonicalization pass.
- ///
- /// This is an intentionally limited interface - implementations of this hook
- /// can only perform the following changes to the operation:
- ///
- /// 1. They can leave the operation alone and without changing the IR, and
- /// return failure.
- /// 2. They can mutate the operation in place, without changing anything else
- /// in the IR. In this case, return success.
- /// 3. They can return a list of existing values that can be used instead of
- /// the operation. In this case, fill in the results list and return
- /// success. The caller will remove the operation and use those results
- /// instead.
- ///
- /// This allows expression of some simple in-place canonicalizations (e.g.
- /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
- /// generalized constant folding.
- LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) const {
- return impl->foldHookFn(op, operands, results);
- }
-
- /// This hook returns any canonicalization pattern rewrites that the operation
- /// supports, for use by the canonicalization pass.
- void getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) const {
- return impl->getCanonicalizationPatternsFn(results, context);
- }
-
- /// Attach the given models as implementations of the corresponding interfaces
- /// for the concrete operation.
- template <typename... Models>
- void attachInterface() {
- impl->interfaceMap.insert<Models...>();
- }
-
- /// Returns true if the operation has a particular trait.
- template <template <typename T> class Trait>
- bool hasTrait() const {
- return hasTrait(TypeID::get<Trait>());
- }
-
- /// Returns true if the operation has a particular trait.
- bool hasTrait(TypeID traitID) const { return impl->hasTraitFn(traitID); }
-
- /// Return the list of cached attribute names registered to this operation.
- /// The order of attributes cached here is unique to each type of operation,
- /// and the interpretation of this attribute list should generally be driven
- /// by the respective operation. In many cases, this caching removes the need
- /// to use the raw string name of a known attribute.
- ///
- /// For example the ODS generator, with an op defining the following
- /// attributes:
- ///
- /// let arguments = (ins I32Attr:$attr1, I32Attr:$attr2);
- ///
- /// ... may produce an order here of ["attr1", "attr2"]. This allows for the
- /// ODS generator to directly access the cached name for a known attribute,
- /// greatly simplifying the cost and complexity of attribute usage produced by
- /// the generator.
- ///
- ArrayRef<StringAttr> getAttributeNames() const {
- return impl->attributeNames;
- }
-
- /// This hook implements the method to populate defaults attributes that are
- /// unset.
- void populateDefaultAttrs(NamedAttrList &attrs) const;
-
/// Represent the operation name as an opaque pointer. (Used to support
/// PointerLikeTypeTraits).
static RegisteredOperationName getFromOpaquePointer(const void *pointer) {
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 864ed2d5ca336..aad6b450ec038 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1367,14 +1367,18 @@ Operation *OperationParser::parseGenericOperation() {
if (!result.name.isRegistered()) {
StringRef dialectName = StringRef(name).split('.').first;
if (!getContext()->getLoadedDialect(dialectName) &&
- !getContext()->getOrLoadDialect(dialectName) &&
- !getContext()->allowsUnregisteredDialects()) {
- // Emit an error if the dialect couldn't be loaded (i.e., it was not
- // registered) and unregistered dialects aren't allowed.
- emitError("operation being parsed with an unregistered dialect. If "
- "this is intended, please use -allow-unregistered-dialect "
- "with the MLIR tool used");
- return nullptr;
+ !getContext()->getOrLoadDialect(dialectName)) {
+ if (!getContext()->allowsUnregisteredDialects()) {
+ // Emit an error if the dialect couldn't be loaded (i.e., it was not
+ // registered) and unregistered dialects aren't allowed.
+ emitError("operation being parsed with an unregistered dialect. If "
+ "this is intended, please use -allow-unregistered-dialect "
+ "with the MLIR tool used");
+ return nullptr;
+ }
+ } else {
+ // Reload the OperationName now that the dialect is loaded.
+ result.name = OperationName(name, getContext());
}
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index bd2bc2c935356..d92bde34b705d 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -602,12 +602,8 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
// If requested, always print the generic form.
if (!printerFlags.shouldPrintGenericOpForm()) {
- // Check to see if this is a known operation. If so, use the registered
- // custom printer hook.
- if (auto opInfo = op->getRegisteredInfo()) {
- opInfo->printAssembly(op, *this, /*defaultDialect=*/"");
- return;
- }
+ op->getName().printAssembly(op, *this, /*defaultDialect=*/"");
+ return;
}
// Otherwise print with the generic assembly form.
diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp
index fd169a8bde7f1..5190b5f97af7a 100644
--- a/mlir/lib/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/IR/ExtensibleDialect.cpp
@@ -294,16 +294,19 @@ DynamicOpDefinition::DynamicOpDefinition(
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
- OperationName::GetCanonicalizationPatternsFn
- &&getCanonicalizationPatternsFn,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn)
- : typeID(dialect->allocateTypeID()),
- name((dialect->getNamespace() + "." + name).str()), dialect(dialect),
+ : Impl(StringAttr::get(dialect->getContext(),
+ (dialect->getNamespace() + "." + name).str()),
+ dialect, dialect->allocateTypeID(),
+ /*interfaceMap=*/detail::InterfaceMap(std::nullopt)),
verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
parseFn(std::move(parseFn)), printFn(std::move(printFn)),
foldHookFn(std::move(foldHookFn)),
getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)),
- populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {}
+ populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {
+ typeID = dialect->allocateTypeID();
+}
std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
StringRef name, ExtensibleDialect *dialect,
@@ -338,8 +341,7 @@ std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
};
- auto populateDefaultAttrsFn = [](const RegisteredOperationName &,
- NamedAttrList &) {};
+ auto populateDefaultAttrsFn = [](const OperationName &, NamedAttrList &) {};
return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
std::move(verifyRegionFn), std::move(parseFn),
@@ -355,8 +357,7 @@ std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
- OperationName::GetCanonicalizationPatternsFn
- &&getCanonicalizationPatternsFn,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) {
return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
@@ -448,15 +449,7 @@ void ExtensibleDialect::registerDynamicOp(
std::unique_ptr<DynamicOpDefinition> &&op) {
assert(op->dialect == this &&
"trying to register a dynamic op in the wrong dialect");
- auto hasTraitFn = [](TypeID traitId) { return false; };
-
- RegisteredOperationName::insert(
- op->name, *op->dialect, op->typeID, std::move(op->parseFn),
- std::move(op->printFn), std::move(op->verifyFn),
- std::move(op->verifyRegionFn), std::move(op->foldHookFn),
- std::move(op->getCanonicalizationPatternsFn),
- detail::InterfaceMap::get<>(), std::move(hasTraitFn), {},
- std::move(op->populateDefaultAttrsFn));
+ RegisteredOperationName::insert(std::move(op), /*attrNames=*/{});
}
bool ExtensibleDialect::classof(const Dialect *dialect) {
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 18b45137dc3d7..4415605a1273e 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -179,7 +179,7 @@ class MLIRContextImpl {
llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
/// This is a mapping from operation name to the operation info describing it.
- llvm::StringMap<OperationName::Impl> operations;
+ llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
/// A vector of operation info specifically for registered operations.
llvm::StringMap<RegisteredOperationName> registeredOperations;
@@ -705,6 +705,11 @@ AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
// OperationName
//===----------------------------------------------------------------------===//
+OperationName::Impl::Impl(StringRef name, Dialect *dialect, TypeID typeID,
+ detail::InterfaceMap interfaceMap)
+ : Impl(StringAttr::get(dialect->getContext(), name), dialect, typeID,
+ std::move(interfaceMap)) {}
+
OperationName::OperationName(StringRef name, MLIRContext *context) {
MLIRContextImpl &ctxImpl = context->getImpl();
@@ -723,7 +728,7 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
auto it = ctxImpl.operations.find(name);
if (it != ctxImpl.operations.end()) {
- impl = &it->second;
+ impl = it->second.get();
return;
}
}
@@ -731,10 +736,14 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
// Acquire a writer-lock so that we can safely create the new instance.
ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
- auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
- if (it.second)
- it.first->second.name = StringAttr::get(context, name);
- impl = &it.first->second;
+ auto it = ctxImpl.operations.insert({name, nullptr});
+ if (it.second) {
+ auto nameAttr = StringAttr::get(context, name);
+ it.first->second = std::make_unique<UnregisteredOpModel>(
+ nameAttr, nameAttr.getReferencedDialect(), TypeID::get<void>(),
+ detail::InterfaceMap(std::nullopt));
+ }
+ impl = it.first->second.get();
}
StringRef OperationName::getDialectNamespace() const {
@@ -743,6 +752,34 @@ StringRef OperationName::getDialectNamespace() const {
return getStringRef().split('.').first;
}
+LogicalResult
+OperationName::UnregisteredOpModel::foldHook(Operation *, ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return failure();
+}
+void OperationName::UnregisteredOpModel::getCanonicalizationPatterns(
+ RewritePatternSet &, MLIRContext *) {}
+bool OperationName::UnregisteredOpModel::hasTrait(TypeID) { return false; }
+
+OperationName::ParseAssemblyFn
+OperationName::UnregisteredOpModel::getParseAssemblyFn() {
+ llvm::report_fatal_error("getParseAssemblyFn hook called on unregistered op");
+}
+void OperationName::UnregisteredOpModel::populateDefaultAttrs(
+ const OperationName &, NamedAttrList &) {}
+void OperationName::UnregisteredOpModel::printAssembly(
+ Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
+ p.printGenericOp(op);
+}
+LogicalResult
+OperationName::UnregisteredOpModel::verifyInvariants(Operation *) {
+ return success();
+}
+LogicalResult
+OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation *) {
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// RegisteredOperationName
//===----------------------------------------------------------------------===//
@@ -756,26 +793,11 @@ RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
return std::nullopt;
}
-ParseResult
-RegisteredOperationName::parseAssembly(OpAsmParser &parser,
- OperationState &result) const {
- return impl->parseAssemblyFn(parser, result);
-}
-
-void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
- impl->populateDefaultAttrsFn(*this, attrs);
-}
-
void RegisteredOperationName::insert(
- StringRef name, Dialect &dialect, TypeID typeID,
- ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
- VerifyInvariantsFn &&verifyInvariants,
- VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
- GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
- detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
- ArrayRef<StringRef> attrNames,
- PopulateDefaultAttrsFn &&populateDefaultAttrs) {
- MLIRContext *ctx = dialect.getContext();
+ std::unique_ptr<RegisteredOperationName::Impl> ownedImpl,
+ ArrayRef<StringRef> attrNames) {
+ RegisteredOperationName::Impl *impl = ownedImpl.get();
+ MLIRContext *ctx = impl->getDialect()->getContext();
auto &ctxImpl = ctx->getImpl();
assert(ctxImpl.multiThreadedExecutionContext == 0 &&
"registering a new operation kind while in a multi-threaded execution "
@@ -790,21 +812,16 @@ void RegisteredOperationName::insert(
attrNames.size());
for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
+ impl->attributeNames = cachedAttrNames;
}
-
+ StringRef name = impl->getName().strref();
// Insert the operation info if it doesn't exist yet.
- auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
- if (it.second)
- it.first->second.name = StringAttr::get(ctx, name);
- OperationName::Impl &impl = it.first->second;
-
- if (impl.isRegistered()) {
- llvm::errs() << "error: operation named '" << name
- << "' is already registered.\n";
- abort();
- }
+ auto it = ctxImpl.operations.insert({name, nullptr});
+ it.first->second = std::move(ownedImpl);
+
+ // Update the registered info for this operation.
auto emplaced = ctxImpl.registeredOperations.try_emplace(
- name, RegisteredOperationName(&impl));
+ name, RegisteredOperationName(impl));
assert(emplaced.second && "operation name registration must be successful");
// Add emplaced operation name to the sorted operations container.
@@ -816,20 +833,6 @@ void RegisteredOperationName::insert(
rhs.getIdentifier());
}),
value);
-
- // Update the registered info for this operation.
- impl.dialect = &dialect;
- impl.typeID = typeID;
- impl.interfaceMap = std::move(interfaceMap);
- impl.foldHookFn = std::move(foldHook);
- impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
- impl.hasTraitFn = std::move(hasTrait);
- impl.parseAssemblyFn = std::move(parseAssembly);
- impl.printAssemblyFn = std::move(printAssembly);
- impl.verifyInvariantsFn = std::move(verifyInvariants);
- impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
- impl.attributeNames = cachedAttrNames;
- impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 0c5869b49cc05..50815b3738bfe 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -78,8 +78,7 @@ Operation *Operation::create(Location location, OperationName name,
void *rawMem = mallocMem + prefixByteSize;
// Populate default attributes.
- if (Optional<RegisteredOperationName> info = name.getRegisteredInfo())
- info->populateDefaultAttrs(attributes);
+ name.populateDefaultAttrs(attributes);
// Create the new Operation.
Operation *op = ::new (rawMem) Operation(
@@ -491,8 +490,7 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
- Optional<RegisteredOperationName> info = getRegisteredInfo();
- if (info && succeeded(info->foldHook(this, operands, results)))
+ if (succeeded(name.foldHook(this, operands, results)))
return success();
// Otherwise, fall back on the dialect hook to handle it.
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 90530c95fb539..89735f068069e 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1605,7 +1605,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
unsigned numResults = read();
if (numResults == kInferTypesMarker) {
InferTypeOpInterface::Concept *inferInterface =
- state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
+ state.name.getInterface<InferTypeOpInterface>();
assert(inferInterface &&
"expected operation to provide InferTypeOpInterface");
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 3e7166e4ccd8c..8dd1bb50b30be 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -926,7 +926,7 @@ void OpEmitter::genAttrNameGetters() {
const char *const getAttrName = R"(
assert(index < {0} && "invalid attribute index");
assert(name.getStringRef() == getOperationName() && "invalid operation name");
- return name.getRegisteredInfo()->getAttributeNames()[index];
+ return name.getAttributeNames()[index];
)";
method->body() << formatv(getAttrName, attributes.size());
}
@@ -1739,7 +1739,7 @@ void OpEmitter::genPopulateDefaultAttributes() {
return;
SmallVector<MethodParameter> paramList;
- paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName");
+ paramList.emplace_back("const ::mlir::OperationName &", "opName");
paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
diff --git a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
index 27b0978c37884..2fc8a43f7c04c 100644
--- a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
@@ -36,6 +36,7 @@ class ValueShapeRangeTest : public testing::Test {
registry.insert<func::FuncDialect, arith::ArithDialect>();
ctx.appendDialectRegistry(registry);
module = parseSourceString<ModuleOp>(ir, &ctx);
+ assert(module);
mapFn = cast<func::FuncOp>(module->front());
}
More information about the Mlir-commits
mailing list