[Mlir-commits] [mlir] 7bf1e44 - Revert "Refactor OperationName to use virtual tables for dispatch (NFC)"
Mehdi Amini
llvmlistbot at llvm.org
Mon Jan 16 15:12:40 PST 2023
Author: Mehdi Amini
Date: 2023-01-16T23:11:38Z
New Revision: 7bf1e441da6b59a25495fde8e34939f93548cc6d
URL: https://github.com/llvm/llvm-project/commit/7bf1e441da6b59a25495fde8e34939f93548cc6d
DIFF: https://github.com/llvm/llvm-project/commit/7bf1e441da6b59a25495fde8e34939f93548cc6d.diff
LOG: Revert "Refactor OperationName to use virtual tables for dispatch (NFC)"
This reverts commit e055aad5ffb348472c65dfcbede85f39efe8f906.
This crashes on Windows at the moment for some reasons.
Added:
Modified:
mlir/include/mlir/IR/ExtensibleDialect.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/AsmParser/Parser.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/ExtensibleDialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 9820aa69892ec..662a7353bfd0e 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -336,15 +336,12 @@ class DynamicType
/// The definition of a dynamic op. A dynamic op is an op that is defined at
/// runtime, and that can be registered at runtime by an extensible dialect (a
-/// dialect inheriting ExtensibleDialect). This class implements the method
-/// exposed by the OperationName class, and in addition defines the TypeID of
-/// the op that will be defined. Each dynamic operation definition refers to one
-/// instance of this class.
-class DynamicOpDefinition : public OperationName::Impl {
+/// dialect inheriting ExtensibleDialect). This class stores the functions that
+/// are in the OperationName class, and in addition defines the TypeID of the op
+/// that will be defined.
+/// Each dynamic operation definition refers to one instance of this class.
+class DynamicOpDefinition {
public:
- using GetCanonicalizationPatternsFn =
- llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
-
/// Create a new op at runtime. The op is registered only after passing it to
/// the dialect using registerDynamicOp.
static std::unique_ptr<DynamicOpDefinition>
@@ -364,7 +361,8 @@ class DynamicOpDefinition : public OperationName::Impl {
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
- GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
+ OperationName::GetCanonicalizationPatternsFn
+ &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
/// Returns the op typeID.
@@ -402,8 +400,9 @@ class DynamicOpDefinition : public OperationName::Impl {
/// Set the hook returning any canonicalization pattern rewrites that the op
/// supports, for use by the canonicalization pass.
- void setGetCanonicalizationPatternsFn(
- GetCanonicalizationPatternsFn &&getCanonicalizationPatterns) {
+ void
+ setGetCanonicalizationPatternsFn(OperationName::GetCanonicalizationPatternsFn
+ &&getCanonicalizationPatterns) {
getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
}
@@ -413,29 +412,6 @@ class DynamicOpDefinition : public OperationName::Impl {
populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
- LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
- SmallVectorImpl<OpFoldResult> &results) final {
- return foldHookFn(op, attrs, results);
- }
- void getCanonicalizationPatterns(RewritePatternSet &set,
- MLIRContext *context) final {
- getCanonicalizationPatternsFn(set, context);
- }
- bool hasTrait(TypeID id) final { return false; }
- OperationName::ParseAssemblyFn getParseAssemblyFn() final { return parseFn; }
- void populateDefaultAttrs(const OperationName &name,
- NamedAttrList &attrs) final {
- populateDefaultAttrsFn(name, attrs);
- }
- void printAssembly(Operation *op, OpAsmPrinter &printer,
- StringRef name) final {
- printFn(op, printer, name);
- }
- LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); }
- LogicalResult verifyRegionInvariants(Operation *op) final {
- return verifyRegionFn(op);
- }
-
private:
DynamicOpDefinition(
StringRef name, ExtensibleDialect *dialect,
@@ -444,18 +420,26 @@ class DynamicOpDefinition : public OperationName::Impl {
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
- GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
+ OperationName::GetCanonicalizationPatternsFn
+ &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
+ /// Unique identifier for this operation.
+ TypeID typeID;
+
+ /// Name of the operation.
+ /// The name is prefixed with the dialect name.
+ std::string name;
+
/// Dialect defining this operation.
- ExtensibleDialect *getdialect();
+ ExtensibleDialect *dialect;
OperationName::VerifyInvariantsFn verifyFn;
OperationName::VerifyRegionInvariantsFn verifyRegionFn;
OperationName::ParseAssemblyFn parseFn;
OperationName::PrintAssemblyFn printFn;
OperationName::FoldHookFn foldHookFn;
- GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+ OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
friend ExtensibleDialect;
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 385597cc9a510..944f933d6af6b 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -184,7 +184,8 @@ class OpState {
MLIRContext *context) {}
/// This hook populates any unset default attrs.
- static void populateDefaultAttrs(const OperationName &, NamedAttrList &) {}
+ static void populateDefaultAttrs(const RegisteredOperationName &,
+ NamedAttrList &) {}
protected:
/// If the concrete type didn't implement a custom verifier hook, just fall
@@ -1832,11 +1833,20 @@ class Op : public OpState, public Traits<ConcreteType>... {
return result;
}
+ /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
+ static OperationName::GetCanonicalizationPatternsFn
+ getGetCanonicalizationPatternsFn() {
+ return &ConcreteType::getCanonicalizationPatterns;
+ }
/// Implementation of `GetHasTraitFn`
static OperationName::HasTraitFn getHasTraitFn() {
return
[](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
}
+ /// Implementation of `ParseAssemblyFn` OperationName hook.
+ static OperationName::ParseAssemblyFn getParseAssemblyFn() {
+ return &ConcreteType::parse;
+ }
/// Implementation of `PrintAssemblyFn` OperationName hook.
static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
if constexpr (detect_has_print<ConcreteType>::value)
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 59d450ea97bb8..2dd5c5d5f79a5 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -506,9 +506,11 @@ class alignas(8) Operation final
/// Sets default attributes on unset attributes.
void populateDefaultAttrs() {
+ if (auto registered = getRegisteredInfo()) {
NamedAttrList attrs(getAttrDictionary());
- name.populateDefaultAttrs(attrs);
+ registered->populateDefaultAttrs(attrs);
setAttrs(attrs.getDictionary(getContext()));
+ }
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 8ec11c1c4e632..a6d8a358d29a1 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -23,7 +23,6 @@
#include "mlir/Support/InterfaceSupport.h"
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/PointerUnion.h"
-#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include "llvm/Support/TrailingObjects.h"
#include <memory>
@@ -65,15 +64,17 @@ class ValueTypeRange;
class OperationName {
public:
+ using GetCanonicalizationPatternsFn =
+ llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
using FoldHookFn = llvm::unique_function<LogicalResult(
Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) const>;
using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
using ParseAssemblyFn =
- llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>;
+ llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
// Note: RegisteredOperationName is passed as reference here as the derived
// class is defined below.
- using PopulateDefaultAttrsFn =
- llvm::unique_function<void(const OperationName &, NamedAttrList &) const>;
+ using PopulateDefaultAttrsFn = llvm::unique_function<void(
+ const RegisteredOperationName &, NamedAttrList &) const>;
using PrintAssemblyFn =
llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
using VerifyInvariantsFn =
@@ -81,132 +82,63 @@ class OperationName {
using VerifyRegionInvariantsFn =
llvm::unique_function<LogicalResult(Operation *) const>;
+protected:
/// This class represents a type erased version of an operation. It contains
/// all of the components necessary for opaquely interacting with an
/// operation. If the operation is not registered, some of these components
/// may not be populated.
- struct InterfaceConcept {
- virtual ~InterfaceConcept() = default;
- virtual LogicalResult foldHook(Operation *, ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) = 0;
- virtual void getCanonicalizationPatterns(RewritePatternSet &,
- MLIRContext *) = 0;
- virtual bool hasTrait(TypeID) = 0;
- virtual OperationName::ParseAssemblyFn getParseAssemblyFn() = 0;
- virtual void populateDefaultAttrs(const OperationName &,
- NamedAttrList &) = 0;
- virtual void printAssembly(Operation *, OpAsmPrinter &, StringRef) = 0;
- virtual LogicalResult verifyInvariants(Operation *) = 0;
- virtual LogicalResult verifyRegionInvariants(Operation *) = 0;
- };
-
-public:
- class Impl : public InterfaceConcept {
- public:
- Impl(StringRef, Dialect *dialect, TypeID typeID,
- detail::InterfaceMap interfaceMap);
- Impl(StringAttr name, Dialect *dialect, TypeID typeID,
- detail::InterfaceMap interfaceMap)
- : name(name), typeID(typeID), dialect(dialect),
- interfaceMap(std::move(interfaceMap)) {}
-
- /// Returns true if this is a registered operation.
- bool isRegistered() const { return typeID != TypeID::get<void>(); }
- detail::InterfaceMap &getInterfaceMap() { return interfaceMap; }
- Dialect *getDialect() const { return dialect; }
- StringAttr getName() const { return name; }
- TypeID getTypeID() const { return typeID; }
- ArrayRef<StringAttr> getAttributeNames() const { return attributeNames; }
-
- protected:
- //===------------------------------------------------------------------===//
- // Registered Operation Info
+ struct Impl {
+ Impl(StringAttr name)
+ : name(name), dialect(nullptr), interfaceMap(std::nullopt) {}
/// The name of the operation.
StringAttr name;
- /// The unique identifier of the derived Op class.
- TypeID typeID;
+ //===------------------------------------------------------------------===//
+ // Registered Operation Info
/// The following fields are only populated when the operation is
/// registered.
+ /// Returns true if the operation has been registered, i.e. if the
+ /// registration info has been populated.
+ bool isRegistered() const { return dialect; }
+
/// This is the dialect that this operation belongs to.
Dialect *dialect;
+ /// The unique identifier of the derived Op class.
+ TypeID typeID;
+
/// A map of interfaces that were registered to this operation.
detail::InterfaceMap interfaceMap;
+ /// Internal callback hooks provided by the op implementation.
+ FoldHookFn foldHookFn;
+ GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+ HasTraitFn hasTraitFn;
+ ParseAssemblyFn parseAssemblyFn;
+ PopulateDefaultAttrsFn populateDefaultAttrsFn;
+ PrintAssemblyFn printAssemblyFn;
+ VerifyInvariantsFn verifyInvariantsFn;
+ VerifyRegionInvariantsFn verifyRegionInvariantsFn;
+
/// A list of attribute names registered to this operation in StringAttr
/// form. This allows for operation classes to use StringAttr for attribute
/// lookup/creation/etc., as opposed to raw strings.
ArrayRef<StringAttr> attributeNames;
-
- friend class RegisteredOperationName;
- };
-
-protected:
- /// Default implementation for unregistered operations.
- struct UnregisteredOpModel : public Impl {
- using Impl::Impl;
- LogicalResult foldHook(Operation *, ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) final;
- void getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) final;
- bool hasTrait(TypeID) final;
- virtual OperationName::ParseAssemblyFn getParseAssemblyFn() final;
- void populateDefaultAttrs(const OperationName &, NamedAttrList &) final;
- void printAssembly(Operation *, OpAsmPrinter &, StringRef) final;
- LogicalResult verifyInvariants(Operation *) final;
- LogicalResult verifyRegionInvariants(Operation *) final;
};
public:
OperationName(StringRef name, MLIRContext *context);
/// Return if this operation is registered.
- bool isRegistered() const { return getImpl()->isRegistered(); }
-
- /// Return the unique identifier of the derived Op class, or null if not
- /// registered.
- TypeID getTypeID() const { return getImpl()->getTypeID(); }
+ bool isRegistered() const { return impl->isRegistered(); }
/// If this operation is registered, returns the registered information,
/// std::nullopt otherwise.
std::optional<RegisteredOperationName> getRegisteredInfo() const;
- /// This hook implements a generalized folder for this operation. Operations
- /// can implement this to provide simplifications rules that are applied by
- /// the Builder::createOrFold API and the canonicalization pass.
- ///
- /// This is an intentionally limited interface - implementations of this
- /// hook can only perform the following changes to the operation:
- ///
- /// 1. They can leave the operation alone and without changing the IR, and
- /// return failure.
- /// 2. They can mutate the operation in place, without changing anything
- /// else
- /// in the IR. In this case, return success.
- /// 3. They can return a list of existing values that can be used instead
- /// of
- /// the operation. In this case, fill in the results list and return
- /// success. The caller will remove the operation and use those results
- /// instead.
- ///
- /// This allows expression of some simple in-place canonicalizations (e.g.
- /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
- /// generalized constant folding.
- LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) const {
- return getImpl()->foldHook(op, operands, results);
- }
-
- /// This hook returns any canonicalization pattern rewrites that the
- /// operation supports, for use by the canonicalization pass.
- void getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) const {
- return getImpl()->getCanonicalizationPatterns(results, context);
- }
-
/// Returns true if the operation was registered with a particular trait, e.g.
/// hasTrait<OperandsAreSignlessIntegerLike>(). Returns false if the operation
/// is unregistered.
@@ -214,7 +146,9 @@ class OperationName {
bool hasTrait() const {
return hasTrait(TypeID::get<Trait>());
}
- bool hasTrait(TypeID traitID) const { return getImpl()->hasTrait(traitID); }
+ bool hasTrait(TypeID traitID) const {
+ return isRegistered() && impl->hasTraitFn(traitID);
+ }
/// Returns true if the operation *might* have the provided trait. This
/// means that either the operation is unregistered, or it was registered with
@@ -224,54 +158,7 @@ class OperationName {
return mightHaveTrait(TypeID::get<Trait>());
}
bool mightHaveTrait(TypeID traitID) const {
- return !isRegistered() || getImpl()->hasTrait(traitID);
- }
-
- /// Return the static hook for parsing this operation assembly.
- ParseAssemblyFn getParseAssemblyFn() const {
- return getImpl()->getParseAssemblyFn();
- }
-
- /// This hook implements the method to populate defaults attributes that are
- /// unset.
- void populateDefaultAttrs(NamedAttrList &attrs) const {
- getImpl()->populateDefaultAttrs(*this, attrs);
- }
-
- /// This hook implements the AsmPrinter for this operation.
- void printAssembly(Operation *op, OpAsmPrinter &p,
- StringRef defaultDialect) const {
- return getImpl()->printAssembly(op, p, defaultDialect);
- }
-
- /// These hooks implement the verifiers for this operation. It should emits
- /// an error message and returns failure if a problem is detected, or
- /// returns success if everything is ok.
- LogicalResult verifyInvariants(Operation *op) const {
- return getImpl()->verifyInvariants(op);
- }
- LogicalResult verifyRegionInvariants(Operation *op) const {
- return getImpl()->verifyRegionInvariants(op);
- }
-
- /// Return the list of cached attribute names registered to this operation.
- /// The order of attributes cached here is unique to each type of operation,
- /// and the interpretation of this attribute list should generally be driven
- /// by the respective operation. In many cases, this caching removes the
- /// need to use the raw string name of a known attribute.
- ///
- /// For example the ODS generator, with an op defining the following
- /// attributes:
- ///
- /// let arguments = (ins I32Attr:$attr1, I32Attr:$attr2);
- ///
- /// ... may produce an order here of ["attr1", "attr2"]. This allows for the
- /// ODS generator to directly access the cached name for a known attribute,
- /// greatly simplifying the cost and complexity of attribute usage produced
- /// by the generator.
- ///
- ArrayRef<StringAttr> getAttributeNames() const {
- return getImpl()->getAttributeNames();
+ return !isRegistered() || impl->hasTraitFn(traitID);
}
/// Returns an instance of the concept object for the given interface if it
@@ -279,13 +166,7 @@ class OperationName {
/// directly.
template <typename T>
typename T::Concept *getInterface() const {
- return getImpl()->getInterfaceMap().lookup<T>();
- }
-
- /// Attach the given models as implementations of the corresponding
- /// interfaces for the concrete operation.
- template <typename... Models> void attachInterface() {
- getImpl()->getInterfaceMap().insert<Models...>();
+ return impl->interfaceMap.lookup<T>();
}
/// Returns true if this operation has the given interface registered to it.
@@ -294,7 +175,7 @@ class OperationName {
return hasInterface(TypeID::get<T>());
}
bool hasInterface(TypeID interfaceID) const {
- return getImpl()->getInterfaceMap().contains(interfaceID);
+ return impl->interfaceMap.contains(interfaceID);
}
/// Returns true if the operation *might* have the provided interface. This
@@ -311,8 +192,7 @@ class OperationName {
/// Return the dialect this operation is registered to if the dialect is
/// loaded in the context, or nullptr if the dialect isn't loaded.
Dialect *getDialect() const {
- return isRegistered() ? getImpl()->getDialect()
- : getImpl()->getName().getReferencedDialect();
+ return isRegistered() ? impl->dialect : impl->name.getReferencedDialect();
}
/// Return the name of the dialect this operation is registered to.
@@ -325,7 +205,7 @@ class OperationName {
StringRef getStringRef() const { return getIdentifier(); }
/// Return the name of this operation as a StringAttr.
- StringAttr getIdentifier() const { return getImpl()->getName(); }
+ StringAttr getIdentifier() const { return impl->name; }
void print(raw_ostream &os) const;
void dump() const;
@@ -343,17 +223,12 @@ class OperationName {
protected:
OperationName(Impl *impl) : impl(impl) {}
- Impl *getImpl() const { return impl; }
- void setImpl(Impl *rhs) { impl = rhs; }
-private:
/// The internal implementation of the operation name.
- Impl *impl = nullptr;
+ Impl *impl;
/// Allow access to the Impl struct.
friend MLIRContextImpl;
- friend DenseMapInfo<mlir::OperationName>;
- friend DenseMapInfo<mlir::RegisteredOperationName>;
};
inline raw_ostream &operator<<(raw_ostream &os, OperationName info) {
@@ -376,62 +251,137 @@ inline llvm::hash_code hash_value(OperationName arg) {
/// the concrete operation types.
class RegisteredOperationName : public OperationName {
public:
- /// Implementation of the InterfaceConcept for operation APIs that forwarded
- /// to a concrete op implementation.
- template <typename ConcreteOp> struct Model : public Impl {
- Model(Dialect *dialect)
- : Impl(ConcreteOp::getOperationName(), dialect,
- TypeID::get<ConcreteOp>(), ConcreteOp::getInterfaceMap()) {}
- LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
- SmallVectorImpl<OpFoldResult> &results) final {
- return ConcreteOp::getFoldHookFn()(op, attrs, results);
- }
- void getCanonicalizationPatterns(RewritePatternSet &set,
- MLIRContext *context) final {
- ConcreteOp::getCanonicalizationPatterns(set, context);
- }
- bool hasTrait(TypeID id) final { return ConcreteOp::getHasTraitFn()(id); }
- OperationName::ParseAssemblyFn getParseAssemblyFn() final {
- return ConcreteOp::parse;
- }
- void populateDefaultAttrs(const OperationName &name,
- NamedAttrList &attrs) final {
- ConcreteOp::populateDefaultAttrs(name, attrs);
- }
- void printAssembly(Operation *op, OpAsmPrinter &printer,
- StringRef name) final {
- ConcreteOp::getPrintAssemblyFn()(op, printer, name);
- }
- LogicalResult verifyInvariants(Operation *op) final {
- return ConcreteOp::getVerifyInvariantsFn()(op);
- }
- LogicalResult verifyRegionInvariants(Operation *op) final {
- return ConcreteOp::getVerifyRegionInvariantsFn()(op);
- }
- };
-
/// Lookup the registered operation information for the given operation.
/// Returns std::nullopt if the operation isn't registered.
static std::optional<RegisteredOperationName> lookup(StringRef name,
MLIRContext *ctx);
/// Register a new operation in a Dialect object.
- /// This constructor is used by Dialect objects when they register the list
- /// of operations they contain.
- template <typename T> static void insert(Dialect &dialect) {
- insert(std::make_unique<Model<T>>(&dialect), T::getAttributeNames());
+ /// This constructor is used by Dialect objects when they register the list of
+ /// operations they contain.
+ template <typename T>
+ static void insert(Dialect &dialect) {
+ insert(T::getOperationName(), dialect, TypeID::get<T>(),
+ T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
+ T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
+ T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
+ T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(),
+ T::getPopulateDefaultAttrsFn());
}
/// The use of this method is in general discouraged in favor of
/// 'insert<CustomOp>(dialect)'.
- static void insert(std::unique_ptr<OperationName::Impl> ownedImpl,
- ArrayRef<StringRef> attrNames);
+ static void
+ insert(StringRef name, Dialect &dialect, TypeID typeID,
+ ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+ VerifyInvariantsFn &&verifyInvariants,
+ VerifyRegionInvariantsFn &&verifyRegionInvariants,
+ FoldHookFn &&foldHook,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+ detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
+ ArrayRef<StringRef> attrNames,
+ PopulateDefaultAttrsFn &&populateDefaultAttrs);
/// Return the dialect this operation is registered to.
- Dialect &getDialect() const { return *getImpl()->getDialect(); }
+ Dialect &getDialect() const { return *impl->dialect; }
+
+ /// Return the unique identifier of the derived Op class.
+ TypeID getTypeID() const { return impl->typeID; }
/// Use the specified object to parse this ops custom assembly format.
ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
+ /// Return the static hook for parsing this operation assembly.
+ const ParseAssemblyFn &getParseAssemblyFn() const {
+ return impl->parseAssemblyFn;
+ }
+
+ /// This hook implements the AsmPrinter for this operation.
+ void printAssembly(Operation *op, OpAsmPrinter &p,
+ StringRef defaultDialect) const {
+ return impl->printAssemblyFn(op, p, defaultDialect);
+ }
+
+ /// These hooks implement the verifiers for this operation. It should emits
+ /// an error message and returns failure if a problem is detected, or returns
+ /// success if everything is ok.
+ LogicalResult verifyInvariants(Operation *op) const {
+ return impl->verifyInvariantsFn(op);
+ }
+ LogicalResult verifyRegionInvariants(Operation *op) const {
+ return impl->verifyRegionInvariantsFn(op);
+ }
+
+ /// This hook implements a generalized folder for this operation. Operations
+ /// can implement this to provide simplifications rules that are applied by
+ /// the Builder::createOrFold API and the canonicalization pass.
+ ///
+ /// This is an intentionally limited interface - implementations of this hook
+ /// can only perform the following changes to the operation:
+ ///
+ /// 1. They can leave the operation alone and without changing the IR, and
+ /// return failure.
+ /// 2. They can mutate the operation in place, without changing anything else
+ /// in the IR. In this case, return success.
+ /// 3. They can return a list of existing values that can be used instead of
+ /// the operation. In this case, fill in the results list and return
+ /// success. The caller will remove the operation and use those results
+ /// instead.
+ ///
+ /// This allows expression of some simple in-place canonicalizations (e.g.
+ /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+ /// generalized constant folding.
+ LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) const {
+ return impl->foldHookFn(op, operands, results);
+ }
+
+ /// This hook returns any canonicalization pattern rewrites that the operation
+ /// supports, for use by the canonicalization pass.
+ void getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) const {
+ return impl->getCanonicalizationPatternsFn(results, context);
+ }
+
+ /// Attach the given models as implementations of the corresponding interfaces
+ /// for the concrete operation.
+ template <typename... Models>
+ void attachInterface() {
+ impl->interfaceMap.insert<Models...>();
+ }
+
+ /// Returns true if the operation has a particular trait.
+ template <template <typename T> class Trait>
+ bool hasTrait() const {
+ return hasTrait(TypeID::get<Trait>());
+ }
+
+ /// Returns true if the operation has a particular trait.
+ bool hasTrait(TypeID traitID) const { return impl->hasTraitFn(traitID); }
+
+ /// Return the list of cached attribute names registered to this operation.
+ /// The order of attributes cached here is unique to each type of operation,
+ /// and the interpretation of this attribute list should generally be driven
+ /// by the respective operation. In many cases, this caching removes the need
+ /// to use the raw string name of a known attribute.
+ ///
+ /// For example the ODS generator, with an op defining the following
+ /// attributes:
+ ///
+ /// let arguments = (ins I32Attr:$attr1, I32Attr:$attr2);
+ ///
+ /// ... may produce an order here of ["attr1", "attr2"]. This allows for the
+ /// ODS generator to directly access the cached name for a known attribute,
+ /// greatly simplifying the cost and complexity of attribute usage produced by
+ /// the generator.
+ ///
+ ArrayRef<StringAttr> getAttributeNames() const {
+ return impl->attributeNames;
+ }
+
+ /// This hook implements the method to populate defaults attributes that are
+ /// unset.
+ void populateDefaultAttrs(NamedAttrList &attrs) const;
+
/// Represent the operation name as an opaque pointer. (Used to support
/// PointerLikeTypeTraits).
static RegisteredOperationName getFromOpaquePointer(const void *pointer) {
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index c56befe7b5e28..a1af48134f84f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1368,18 +1368,14 @@ Operation *OperationParser::parseGenericOperation() {
if (!result.name.isRegistered()) {
StringRef dialectName = StringRef(name).split('.').first;
if (!getContext()->getLoadedDialect(dialectName) &&
- !getContext()->getOrLoadDialect(dialectName)) {
- if (!getContext()->allowsUnregisteredDialects()) {
- // Emit an error if the dialect couldn't be loaded (i.e., it was not
- // registered) and unregistered dialects aren't allowed.
- emitError("operation being parsed with an unregistered dialect. If "
- "this is intended, please use -allow-unregistered-dialect "
- "with the MLIR tool used");
- return nullptr;
- }
- } else {
- // Reload the OperationName now that the dialect is loaded.
- result.name = OperationName(name, getContext());
+ !getContext()->getOrLoadDialect(dialectName) &&
+ !getContext()->allowsUnregisteredDialects()) {
+ // Emit an error if the dialect couldn't be loaded (i.e., it was not
+ // registered) and unregistered dialects aren't allowed.
+ emitError("operation being parsed with an unregistered dialect. If "
+ "this is intended, please use -allow-unregistered-dialect "
+ "with the MLIR tool used");
+ return nullptr;
}
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 0981b3b1cbe27..76807a5d8c593 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -603,8 +603,12 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
// If requested, always print the generic form.
if (!printerFlags.shouldPrintGenericOpForm()) {
- op->getName().printAssembly(op, *this, /*defaultDialect=*/"");
- return;
+ // Check to see if this is a known operation. If so, use the registered
+ // custom printer hook.
+ if (auto opInfo = op->getRegisteredInfo()) {
+ opInfo->printAssembly(op, *this, /*defaultDialect=*/"");
+ return;
+ }
}
// Otherwise print with the generic assembly form.
diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp
index 5190b5f97af7a..fd169a8bde7f1 100644
--- a/mlir/lib/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/IR/ExtensibleDialect.cpp
@@ -294,19 +294,16 @@ DynamicOpDefinition::DynamicOpDefinition(
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
- GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
+ OperationName::GetCanonicalizationPatternsFn
+ &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn)
- : Impl(StringAttr::get(dialect->getContext(),
- (dialect->getNamespace() + "." + name).str()),
- dialect, dialect->allocateTypeID(),
- /*interfaceMap=*/detail::InterfaceMap(std::nullopt)),
+ : typeID(dialect->allocateTypeID()),
+ name((dialect->getNamespace() + "." + name).str()), dialect(dialect),
verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
parseFn(std::move(parseFn)), printFn(std::move(printFn)),
foldHookFn(std::move(foldHookFn)),
getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)),
- populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {
- typeID = dialect->allocateTypeID();
-}
+ populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {}
std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
StringRef name, ExtensibleDialect *dialect,
@@ -341,7 +338,8 @@ std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
};
- auto populateDefaultAttrsFn = [](const OperationName &, NamedAttrList &) {};
+ auto populateDefaultAttrsFn = [](const RegisteredOperationName &,
+ NamedAttrList &) {};
return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
std::move(verifyRegionFn), std::move(parseFn),
@@ -357,7 +355,8 @@ std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
- GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
+ OperationName::GetCanonicalizationPatternsFn
+ &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) {
return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
@@ -449,7 +448,15 @@ void ExtensibleDialect::registerDynamicOp(
std::unique_ptr<DynamicOpDefinition> &&op) {
assert(op->dialect == this &&
"trying to register a dynamic op in the wrong dialect");
- RegisteredOperationName::insert(std::move(op), /*attrNames=*/{});
+ auto hasTraitFn = [](TypeID traitId) { return false; };
+
+ RegisteredOperationName::insert(
+ op->name, *op->dialect, op->typeID, std::move(op->parseFn),
+ std::move(op->printFn), std::move(op->verifyFn),
+ std::move(op->verifyRegionFn), std::move(op->foldHookFn),
+ std::move(op->getCanonicalizationPatternsFn),
+ detail::InterfaceMap::get<>(), std::move(hasTraitFn), {},
+ std::move(op->populateDefaultAttrsFn));
}
bool ExtensibleDialect::classof(const Dialect *dialect) {
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index b0fe94f0c513a..8e3edc8613185 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -180,7 +180,7 @@ class MLIRContextImpl {
llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
/// This is a mapping from operation name to the operation info describing it.
- llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
+ llvm::StringMap<OperationName::Impl> operations;
/// A vector of operation info specifically for registered operations.
llvm::StringMap<RegisteredOperationName> registeredOperations;
@@ -706,11 +706,6 @@ AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
// OperationName
//===----------------------------------------------------------------------===//
-OperationName::Impl::Impl(StringRef name, Dialect *dialect, TypeID typeID,
- detail::InterfaceMap interfaceMap)
- : Impl(StringAttr::get(dialect->getContext(), name), dialect, typeID,
- std::move(interfaceMap)) {}
-
OperationName::OperationName(StringRef name, MLIRContext *context) {
MLIRContextImpl &ctxImpl = context->getImpl();
@@ -729,7 +724,7 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
auto it = ctxImpl.operations.find(name);
if (it != ctxImpl.operations.end()) {
- impl = it->second.get();
+ impl = &it->second;
return;
}
}
@@ -737,14 +732,10 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
// Acquire a writer-lock so that we can safely create the new instance.
ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
- auto it = ctxImpl.operations.insert({name, nullptr});
- if (it.second) {
- auto nameAttr = StringAttr::get(context, name);
- it.first->second = std::make_unique<UnregisteredOpModel>(
- nameAttr, nameAttr.getReferencedDialect(), TypeID::get<void>(),
- detail::InterfaceMap(std::nullopt));
- }
- impl = it.first->second.get();
+ auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
+ if (it.second)
+ it.first->second.name = StringAttr::get(context, name);
+ impl = &it.first->second;
}
StringRef OperationName::getDialectNamespace() const {
@@ -753,34 +744,6 @@ StringRef OperationName::getDialectNamespace() const {
return getStringRef().split('.').first;
}
-LogicalResult
-OperationName::UnregisteredOpModel::foldHook(Operation *, ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return failure();
-}
-void OperationName::UnregisteredOpModel::getCanonicalizationPatterns(
- RewritePatternSet &, MLIRContext *) {}
-bool OperationName::UnregisteredOpModel::hasTrait(TypeID) { return false; }
-
-OperationName::ParseAssemblyFn
-OperationName::UnregisteredOpModel::getParseAssemblyFn() {
- llvm::report_fatal_error("getParseAssemblyFn hook called on unregistered op");
-}
-void OperationName::UnregisteredOpModel::populateDefaultAttrs(
- const OperationName &, NamedAttrList &) {}
-void OperationName::UnregisteredOpModel::printAssembly(
- Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
- p.printGenericOp(op);
-}
-LogicalResult
-OperationName::UnregisteredOpModel::verifyInvariants(Operation *) {
- return success();
-}
-LogicalResult
-OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation *) {
- return success();
-}
-
//===----------------------------------------------------------------------===//
// RegisteredOperationName
//===----------------------------------------------------------------------===//
@@ -794,11 +757,26 @@ RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
return std::nullopt;
}
+ParseResult
+RegisteredOperationName::parseAssembly(OpAsmParser &parser,
+ OperationState &result) const {
+ return impl->parseAssemblyFn(parser, result);
+}
+
+void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
+ impl->populateDefaultAttrsFn(*this, attrs);
+}
+
void RegisteredOperationName::insert(
- std::unique_ptr<RegisteredOperationName::Impl> ownedImpl,
- ArrayRef<StringRef> attrNames) {
- RegisteredOperationName::Impl *impl = ownedImpl.get();
- MLIRContext *ctx = impl->getDialect()->getContext();
+ StringRef name, Dialect &dialect, TypeID typeID,
+ ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+ VerifyInvariantsFn &&verifyInvariants,
+ VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+ detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
+ ArrayRef<StringRef> attrNames,
+ PopulateDefaultAttrsFn &&populateDefaultAttrs) {
+ MLIRContext *ctx = dialect.getContext();
auto &ctxImpl = ctx->getImpl();
assert(ctxImpl.multiThreadedExecutionContext == 0 &&
"registering a new operation kind while in a multi-threaded execution "
@@ -813,16 +791,21 @@ void RegisteredOperationName::insert(
attrNames.size());
for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
- impl->attributeNames = cachedAttrNames;
}
- StringRef name = impl->getName().strref();
- // Insert the operation info if it doesn't exist yet.
- auto it = ctxImpl.operations.insert({name, nullptr});
- it.first->second = std::move(ownedImpl);
- // Update the registered info for this operation.
+ // Insert the operation info if it doesn't exist yet.
+ auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
+ if (it.second)
+ it.first->second.name = StringAttr::get(ctx, name);
+ OperationName::Impl &impl = it.first->second;
+
+ if (impl.isRegistered()) {
+ llvm::errs() << "error: operation named '" << name
+ << "' is already registered.\n";
+ abort();
+ }
auto emplaced = ctxImpl.registeredOperations.try_emplace(
- name, RegisteredOperationName(impl));
+ name, RegisteredOperationName(&impl));
assert(emplaced.second && "operation name registration must be successful");
// Add emplaced operation name to the sorted operations container.
@@ -834,6 +817,20 @@ void RegisteredOperationName::insert(
rhs.getIdentifier());
}),
value);
+
+ // Update the registered info for this operation.
+ impl.dialect = &dialect;
+ impl.typeID = typeID;
+ impl.interfaceMap = std::move(interfaceMap);
+ impl.foldHookFn = std::move(foldHook);
+ impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
+ impl.hasTraitFn = std::move(hasTrait);
+ impl.parseAssemblyFn = std::move(parseAssembly);
+ impl.printAssemblyFn = std::move(printAssembly);
+ impl.verifyInvariantsFn = std::move(verifyInvariants);
+ impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
+ impl.attributeNames = cachedAttrNames;
+ impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 50815b3738bfe..0c5869b49cc05 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -78,7 +78,8 @@ Operation *Operation::create(Location location, OperationName name,
void *rawMem = mallocMem + prefixByteSize;
// Populate default attributes.
- name.populateDefaultAttrs(attributes);
+ if (Optional<RegisteredOperationName> info = name.getRegisteredInfo())
+ info->populateDefaultAttrs(attributes);
// Create the new Operation.
Operation *op = ::new (rawMem) Operation(
@@ -490,7 +491,8 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
- if (succeeded(name.foldHook(this, operands, results)))
+ Optional<RegisteredOperationName> info = getRegisteredInfo();
+ if (info && succeeded(info->foldHook(this, operands, results)))
return success();
// Otherwise, fall back on the dialect hook to handle it.
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 1ea4bef5402a7..91441ced0625f 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1606,7 +1606,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
unsigned numResults = read();
if (numResults == kInferTypesMarker) {
InferTypeOpInterface::Concept *inferInterface =
- state.name.getInterface<InferTypeOpInterface>();
+ state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
assert(inferInterface &&
"expected operation to provide InferTypeOpInterface");
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 83937f46e6649..c9d056b56ea84 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -926,7 +926,7 @@ void OpEmitter::genAttrNameGetters() {
const char *const getAttrName = R"(
assert(index < {0} && "invalid attribute index");
assert(name.getStringRef() == getOperationName() && "invalid operation name");
- return name.getAttributeNames()[index];
+ return name.getRegisteredInfo()->getAttributeNames()[index];
)";
method->body() << formatv(getAttrName, attributes.size());
}
@@ -1739,7 +1739,7 @@ void OpEmitter::genPopulateDefaultAttributes() {
return;
SmallVector<MethodParameter> paramList;
- paramList.emplace_back("const ::mlir::OperationName &", "opName");
+ paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName");
paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
diff --git a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
index 2fc8a43f7c04c..27b0978c37884 100644
--- a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
@@ -36,7 +36,6 @@ class ValueShapeRangeTest : public testing::Test {
registry.insert<func::FuncDialect, arith::ArithDialect>();
ctx.appendDialectRegistry(registry);
module = parseSourceString<ModuleOp>(ir, &ctx);
- assert(module);
mapFn = cast<func::FuncOp>(module->front());
}
More information about the Mlir-commits
mailing list