[Mlir-commits] [mlir] 94662ee - [mlir] Add support for adding attribute+type traits/interfaces to tablegen defs
River Riddle
llvmlistbot at llvm.org
Thu Apr 15 11:44:12 PDT 2021
Author: River Riddle
Date: 2021-04-15T11:41:51-07:00
New Revision: 94662ee0c175165e60bc09fc73396a3e344829d4
URL: https://github.com/llvm/llvm-project/commit/94662ee0c175165e60bc09fc73396a3e344829d4
DIFF: https://github.com/llvm/llvm-project/commit/94662ee0c175165e60bc09fc73396a3e344829d4.diff
LOG: [mlir] Add support for adding attribute+type traits/interfaces to tablegen defs
This matches the current support provided to operations, and allows attaching traits, interfaces, and using the DeclareInterfaceMethods utility. This was missed when attribute/type generation was first added.
Differential Revision: https://reviews.llvm.org/D100233
Added:
mlir/include/mlir/TableGen/Trait.h
mlir/lib/TableGen/Trait.cpp
Modified:
mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/BuiltinLocationAttributes.td
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/include/mlir/TableGen/Operator.h
mlir/include/mlir/TableGen/SideEffects.h
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/lib/TableGen/CMakeLists.txt
mlir/lib/TableGen/Operator.cpp
mlir/lib/TableGen/SideEffects.cpp
mlir/test/lib/Dialect/Test/TestInterfaces.td
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/lib/Dialect/Test/TestTypes.h
mlir/test/mlir-tblgen/attrdefs.td
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
mlir/tools/mlir-tblgen/DialectGen.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
mlir/include/mlir/TableGen/OpTrait.h
mlir/lib/TableGen/OpTrait.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
index 1e0578339ad86..10b6130afd148 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
@@ -20,7 +20,7 @@ include "mlir/Dialect/PDL/IR/PDLDialect.td"
//===----------------------------------------------------------------------===//
class PDL_Type<string name, string typeMnemonic>
- : TypeDef<PDL_Dialect, name, "::mlir::pdl::PDLType"> {
+ : TypeDef<PDL_Dialect, name, [], "::mlir::pdl::PDLType"> {
let mnemonic = typeMnemonic;
}
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 45214535b1f8c..c248ad5822f87 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -23,7 +23,7 @@ include "mlir/IR/BuiltinDialect.td"
// Base class for Builtin dialect attributes.
class Builtin_Attr<string name, string baseCppClass = "::mlir::Attribute">
- : AttrDef<Builtin_Dialect, name, baseCppClass> {
+ : AttrDef<Builtin_Dialect, name, [], baseCppClass> {
let mnemonic = ?;
}
diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index d6347441d09c6..3858b11e05215 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -17,7 +17,7 @@ include "mlir/IR/BuiltinDialect.td"
// Base class for Builtin dialect location attributes.
class Builtin_LocationAttr<string name>
- : AttrDef<Builtin_Dialect, name, "::mlir::LocationAttr"> {
+ : AttrDef<Builtin_Dialect, name, [], "::mlir::LocationAttr"> {
let cppClassName = name;
let mnemonic = ?;
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index f266f61e182ea..16afcfdfd00a2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -22,7 +22,7 @@ include "mlir/IR/BuiltinDialect.td"
// Base class for Builtin dialect types.
class Builtin_Type<string name, string baseCppClass = "::mlir::Type">
- : TypeDef<Builtin_Dialect, name, baseCppClass> {
+ : TypeDef<Builtin_Dialect, name, [], baseCppClass> {
let mnemonic = ?;
}
@@ -65,8 +65,7 @@ def Builtin_Complex : Builtin_Type<"Complex"> {
//===----------------------------------------------------------------------===//
// Base class for Builtin dialect float types.
-class Builtin_FloatType<string name> : TypeDef<Builtin_Dialect, name,
- "::mlir::FloatType"> {
+class Builtin_FloatType<string name> : Builtin_Type<name, "::mlir::FloatType"> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 20f3c7fa84130..11038ef12d2ad 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1734,42 +1734,58 @@ def AnySuccessor : Successor<?, "any successor">;
class VariadicSuccessor<Successor successor>
: Successor<successor.predicate, successor.summary>;
+
//===----------------------------------------------------------------------===//
-// OpTrait definitions
+// Trait definitions
//===----------------------------------------------------------------------===//
-// OpTrait represents a trait regarding an op.
-class OpTrait;
+// Trait represents a trait regarding an attribute, operation, or type.
+class Trait;
-// NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The
-// purpose to wrap around C++ symbol string with this class is to make
-// traits specified for ops in TableGen less alien and more integrated.
-class NativeOpTrait<string name> : OpTrait {
+// NativeTrait corresponds to the MLIR C++ trait mechanism. The purpose to wrap
+// around C++ symbol string with this class is to make traits specified for
+// entities in TableGen less alien and more integrated.
+class NativeTrait<string name, string entityType> : Trait {
string trait = name;
- string cppNamespace = "::mlir::OpTrait";
+ string cppNamespace = "::mlir::" # entityType # "Trait";
}
-// ParamNativeOpTrait corresponds to the template-parameterized traits in the
-// C++ implementation. MLIR uses nested class templates to implement such
-// traits leading to constructs of the form "TraitName<Parameters>::Impl". Use
-// the value in `prop` as the trait name and the value in `params` as
-// parameters to construct the native trait class name.
-class ParamNativeOpTrait<string prop, string params>
- : NativeOpTrait<prop # "<" # params # ">::Impl">;
+// ParamNativeTrait corresponds to the template-parameterized traits in the C++
+// implementation. MLIR uses nested class templates to implement such traits
+// leading to constructs of the form "TraitName<Parameters>::Impl". Use the
+// value in `prop` as the trait name and the value in `params` as parameters to
+// construct the native trait class name.
+class ParamNativeTrait<string prop, string params, string entityType>
+ : NativeTrait<prop # "<" # params # ">::Impl", entityType>;
-// GenInternalOpTrait is an op trait that does not have direct C++ mapping but
-// affects op definition generator internals, like how op builders and
+// GenInternalTrait is a trait that does not have direct C++ mapping but affects
+// an entities definition generator internals, like how operation builders and
// operand/attribute/result getters are generated.
-class GenInternalOpTrait<string prop> : OpTrait {
- string trait = "::mlir::OpTrait::" # prop;
+class GenInternalTrait<string prop, string entityType> : Trait {
+ string trait = "::mlir::" # entityType # "Trait::" # prop;
}
-// PredOpTrait is an op trait implemented by way of a predicate on the op.
-class PredOpTrait<string descr, Pred pred> : OpTrait {
+// PredTrait is a trait implemented by way of a predicate on an entity.
+class PredTrait<string descr, Pred pred> : Trait {
string summary = descr;
Pred predicate = pred;
}
+//===----------------------------------------------------------------------===//
+// OpTrait definitions
+//===----------------------------------------------------------------------===//
+
+// OpTrait represents a trait regarding an operation.
+// TODO: Remove this class in favor of using Trait.
+class OpTrait;
+
+// These classes are used to define operation specific traits.
+class NativeOpTrait<string name> : NativeTrait<name, "Op">, OpTrait;
+class ParamNativeOpTrait<string prop, string params>
+ : ParamNativeTrait<prop, params, "Op">, OpTrait;
+class GenInternalOpTrait<string prop> : GenInternalTrait<prop, "Op">, OpTrait;
+class PredOpTrait<string descr, Pred pred> : PredTrait<descr, pred>, OpTrait;
+
// Op defines an affine scope.
def AffineScope : NativeOpTrait<"AffineScope">;
// Op defines an automatic allocation scope.
@@ -1895,23 +1911,28 @@ class CArg<string ty, string value = ""> {
string defaultValue = value;
}
-// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
-// C++. The purpose to wrap around C++ symbol string with this class is to make
+// InterfaceTrait corresponds to a specific 'Interface' class defined in C++.
+// The purpose to wrap around C++ symbol string with this class is to make
// interfaces specified for ops in TableGen less alien and more integrated.
-class OpInterfaceTrait<string name, code verifyBody = [{}]>
- : NativeOpTrait<""> {
+class InterfaceTrait<string name> : NativeTrait<"", ""> {
let trait = name # "::Trait";
let cppNamespace = "";
- // Specify the body of the verification function. `$_op` will be replaced with
- // the operation being verified.
- code verify = verifyBody;
-
// An optional code block containing extra declarations to place in the
// interface trait declaration.
code extraTraitClassDeclaration = "";
}
+// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
+// C++. The purpose to wrap around C++ symbol string with this class is to make
+// interfaces specified for ops in TableGen less alien and more integrated.
+class OpInterfaceTrait<string name, code verifyBody = [{}]>
+ : InterfaceTrait<name>, OpTrait {
+ // Specify the body of the verification function. `$_op` will be replaced with
+ // the operation being verified.
+ code verify = verifyBody;
+}
+
// This class represents a single, optionally static, interface method.
// Note: non-static interface methods have an implicit parameter, either
// $_op/$_attr/$_type corresponding to an instance of the derived value.
@@ -1967,39 +1988,52 @@ class Interface<string name> {
}
// AttrInterface represents an interface registered to an attribute.
-class AttrInterface<string name> : Interface<name> {
- // An optional code block containing extra declarations to place in the
- // interface trait declaration.
- code extraTraitClassDeclaration = "";
-}
+class AttrInterface<string name> : Interface<name>, InterfaceTrait<name>;
// OpInterface represents an interface registered to an operation.
class OpInterface<string name> : Interface<name>, OpInterfaceTrait<name>;
// TypeInterface represents an interface registered to a type.
-class TypeInterface<string name> : Interface<name> {
- // An optional code block containing extra declarations to place in the
- // interface trait declaration.
- code extraTraitClassDeclaration = "";
-}
+class TypeInterface<string name> : Interface<name>, InterfaceTrait<name>;
-// Whether to declare the op interface methods in the op's header. This class
-// simply wraps an OpInterface but is used to indicate that the method
+// Whether to declare the interface methods in the user entity's header. This
+// class simply wraps an Interface but is used to indicate that the method
// declarations should be generated. This class takes an optional set of methods
// that should have declarations generated even if the method has a default
// implementation.
+class DeclareInterfaceMethods<Interface interface,
+ list<string> overridenMethods = []> {
+ // This field contains a set of method names that should always have their
+ // declarations generated. This allows for generating declarations for
+ // methods with default implementations that need to be overridden.
+ list<string> alwaysOverriddenMethods = overridenMethods;
+}
+class DeclareAttrInterfaceMethods<AttrInterface interface,
+ list<string> overridenMethods = []>
+ : DeclareInterfaceMethods<interface, overridenMethods>,
+ AttrInterface<interface.cppClassName> {
+ let description = interface.description;
+ let cppClassName = interface.cppClassName;
+ let cppNamespace = interface.cppNamespace;
+ let methods = interface.methods;
+}
class DeclareOpInterfaceMethods<OpInterface interface,
list<string> overridenMethods = []>
- : OpInterface<interface.cppClassName> {
+ : DeclareInterfaceMethods<interface, overridenMethods>,
+ OpInterface<interface.cppClassName> {
+ let description = interface.description;
+ let cppClassName = interface.cppClassName;
+ let cppNamespace = interface.cppNamespace;
+ let methods = interface.methods;
+}
+class DeclareTypeInterfaceMethods<TypeInterface interface,
+ list<string> overridenMethods = []>
+ : DeclareInterfaceMethods<interface, overridenMethods>,
+ TypeInterface<interface.cppClassName> {
let description = interface.description;
let cppClassName = interface.cppClassName;
let cppNamespace = interface.cppNamespace;
let methods = interface.methods;
-
- // This field contains a set of method names that should always have their
- // declarations generated. This allows for generating declarations for
- // methods with default implementations that need to be overridden.
- list<string> alwaysOverriddenMethods = overridenMethods;
}
//===----------------------------------------------------------------------===//
@@ -2609,7 +2643,8 @@ class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
// Define a new attribute or type, named `name`, that inherits from the given
// C++ base class.
-class AttrOrTypeDef<string valueType, string name, string baseCppClass> {
+class AttrOrTypeDef<string valueType, string name, list<Trait> defTraits,
+ string baseCppClass> {
// The name of the C++ base class to use for this def.
string cppBaseClassName = baseCppClass;
@@ -2664,6 +2699,9 @@ class AttrOrTypeDef<string valueType, string name, string baseCppClass> {
// Note that builders should only be provided when a def has parameters.
list<AttrOrTypeBuilder> builders = ?;
+ // The list of traits attached to this def.
+ list<Trait> traits = defTraits;
+
// Use the lowercased name as the keyword for parsing/printing. Specify only
// if you want tblgen to generate declarations and/or definitions of
// the printer/parser.
@@ -2692,10 +2730,10 @@ class AttrOrTypeDef<string valueType, string name, string baseCppClass> {
// Define a new attribute, named `name`, belonging to `dialect` that inherits
// from the given C++ base class.
-class AttrDef<Dialect dialect, string name,
+class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: DialectAttr<dialect, CPred<"">, /*descr*/"">,
- AttrOrTypeDef<"Attr", name, baseCppClass> {
+ AttrOrTypeDef<"Attr", name, traits, baseCppClass> {
// The name of the C++ Attribute class.
string cppClassName = name # "Attr";
@@ -2728,10 +2766,10 @@ class AttrDef<Dialect dialect, string name,
// Define a new type, named `name`, belonging to `dialect` that inherits from
// the given C++ base class.
-class TypeDef<Dialect dialect, string name,
+class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
- AttrOrTypeDef<"Type", name, baseCppClass> {
+ AttrOrTypeDef<"Type", name, traits, baseCppClass> {
// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 5271fbae0eea3..ab07f43c1b5e8 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -16,6 +16,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Builder.h"
+#include "mlir/TableGen/Trait.h"
namespace llvm {
class DagInit;
@@ -120,6 +121,9 @@ class AttrOrTypeDef {
// Returns the builders of this def.
ArrayRef<AttrOrTypeBuilder> getBuilders() const { return builders; }
+ // Returns the traits of this def.
+ ArrayRef<Trait> getTraits() const { return traits; }
+
// Returns whether two AttrOrTypeDefs are equal by checking the equality of
// the underlying record.
bool operator==(const AttrOrTypeDef &other) const;
@@ -136,8 +140,11 @@ class AttrOrTypeDef {
protected:
const llvm::Record *def;
- // The builders of this type definition.
+ // The builders of this definition.
SmallVector<AttrOrTypeBuilder> builders;
+
+ // The traits of this definition.
+ SmallVector<Trait> traits;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index d21b4b213ee41..3da693dde0be9 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -18,9 +18,9 @@
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Dialect.h"
-#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Region.h"
#include "mlir/TableGen/Successor.h"
+#include "mlir/TableGen/Trait.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"
@@ -176,9 +176,7 @@ class Operator {
var_decorator_range getArgDecorators(int index) const;
// Returns the trait wrapper for the given MLIR C++ `trait`.
- // TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
- // requiring the raw MLIR trait here.
- const OpTrait *getTrait(llvm::StringRef trait) const;
+ const Trait *getTrait(llvm::StringRef trait) const;
// Regions.
using const_region_iterator = const NamedRegion *;
@@ -209,7 +207,7 @@ class Operator {
unsigned getNumVariadicSuccessors() const;
// Trait.
- using const_trait_iterator = const OpTrait *;
+ using const_trait_iterator = const Trait *;
const_trait_iterator trait_begin() const;
const_trait_iterator trait_end() const;
llvm::iterator_range<const_trait_iterator> getTraits() const;
@@ -325,7 +323,7 @@ class Operator {
SmallVector<NamedSuccessor, 0> successors;
// The traits of the op.
- SmallVector<OpTrait, 4> traits;
+ SmallVector<Trait, 4> traits;
// The regions of this op.
SmallVector<NamedRegion, 1> regions;
diff --git a/mlir/include/mlir/TableGen/SideEffects.h b/mlir/include/mlir/TableGen/SideEffects.h
index 7e464476cea11..c5ced6682c16e 100644
--- a/mlir/include/mlir/TableGen/SideEffects.h
+++ b/mlir/include/mlir/TableGen/SideEffects.h
@@ -41,7 +41,7 @@ class SideEffect : public Operator::VariableDecorator {
// This class represents an instance of a side effect interface applied to an
// operation. This is a wrapper around an OpInterfaceTrait that also includes
// the effects that are applied.
-class SideEffectTrait : public InterfaceOpTrait {
+class SideEffectTrait : public InterfaceTrait {
public:
// Return the effects that are attached to the side effect interface.
Operator::var_decorator_range getEffects() const;
@@ -49,7 +49,7 @@ class SideEffectTrait : public InterfaceOpTrait {
// Return the name of the base C++ effect.
StringRef getBaseEffectName() const;
- static bool classof(const OpTrait *t);
+ static bool classof(const Trait *t);
};
} // end namespace tblgen
diff --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/Trait.h
similarity index 51%
rename from mlir/include/mlir/TableGen/OpTrait.h
rename to mlir/include/mlir/TableGen/Trait.h
index 4ac0c1c4ed25e..52d056d0581e0 100644
--- a/mlir/include/mlir/TableGen/OpTrait.h
+++ b/mlir/include/mlir/TableGen/Trait.h
@@ -1,4 +1,4 @@
-//===- OpTrait.h - OpTrait wrapper class ------------------------*- C++ -*-===//
+//===- Trait.h - Trait wrapper class ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
//
-// OpTrait wrapper to simplify using TableGen Record defining an MLIR OpTrait.
+// Trait wrapper to simplify using TableGen Record defining an MLIR Trait.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TABLEGEN_OPTRAIT_H_
-#define MLIR_TABLEGEN_OPTRAIT_H_
+#ifndef MLIR_TABLEGEN_TRAIT_H_
+#define MLIR_TABLEGEN_TRAIT_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
@@ -25,28 +25,28 @@ class Record;
namespace mlir {
namespace tblgen {
-struct OpInterface;
+class Interface;
-// Wrapper class with helper methods for accessing OpTrait constraints defined
-// in TableGen.
-class OpTrait {
+// Wrapper class with helper methods for accessing Trait constraints defined in
+// TableGen.
+class Trait {
public:
- // Discriminator for kinds of op traits.
+ // Discriminator for kinds of traits.
enum class Kind {
- // OpTrait corresponding to C++ class.
+ // Trait corresponding to C++ class.
Native,
- // OpTrait corresponding to predicate on operation.
+ // Trait corresponding to a predicate.
Pred,
- // OpTrait controlling op definition generator internals.
+ // Trait controlling definition generator internals.
Internal,
- // OpTrait corresponding to OpInterface.
+ // Trait corresponding to an Interface.
Interface
};
- explicit OpTrait(Kind kind, const llvm::Record *def);
+ explicit Trait(Kind kind, const llvm::Record *def);
- // Returns an OpTrait corresponding to the init provided.
- static OpTrait create(const llvm::Init *init);
+ // Returns an Trait corresponding to the init provided.
+ static Trait create(const llvm::Init *init);
Kind getKind() const { return kind; }
@@ -59,17 +59,17 @@ class OpTrait {
Kind kind;
};
-// OpTrait corresponding to a native C++ OpTrait.
-class NativeOpTrait : public OpTrait {
+// Trait corresponding to a native C++ Trait.
+class NativeTrait : public Trait {
public:
// Returns the trait corresponding to a C++ trait class.
- std::string getTrait() const;
+ std::string getFullyQualifiedTraitName() const;
- static bool classof(const OpTrait *t) { return t->getKind() == Kind::Native; }
+ static bool classof(const Trait *t) { return t->getKind() == Kind::Native; }
};
-// OpTrait corresponding to a predicate on the operation.
-class PredOpTrait : public OpTrait {
+// Trait corresponding to a predicate on the operation.
+class PredTrait : public Trait {
public:
// Returns the template for constructing the predicate.
std::string getPredTemplate() const;
@@ -77,30 +77,28 @@ class PredOpTrait : public OpTrait {
// Returns the description of what the predicate is verifying.
StringRef getSummary() const;
- static bool classof(const OpTrait *t) { return t->getKind() == Kind::Pred; }
+ static bool classof(const Trait *t) { return t->getKind() == Kind::Pred; }
};
-// OpTrait controlling op definition generator internals.
-class InternalOpTrait : public OpTrait {
+// Trait controlling op definition generator internals.
+class InternalTrait : public Trait {
public:
// Returns the trait controlling op definition generator internals.
- StringRef getTrait() const;
+ StringRef getFullyQualifiedTraitName() const;
- static bool classof(const OpTrait *t) {
- return t->getKind() == Kind::Internal;
- }
+ static bool classof(const Trait *t) { return t->getKind() == Kind::Internal; }
};
-// OpTrait corresponding to an OpInterface on the operation.
-class InterfaceOpTrait : public OpTrait {
+// Trait corresponding to an OpInterface on the operation.
+class InterfaceTrait : public Trait {
public:
- // Returns member function definitions corresponding to the trait,
- OpInterface getOpInterface() const;
+ // Returns interface corresponding to the trait.
+ Interface getInterface() const;
// Returns the trait corresponding to a C++ trait class.
- std::string getTrait() const;
+ std::string getFullyQualifiedTraitName() const;
- static bool classof(const OpTrait *t) {
+ static bool classof(const Trait *t) {
return t->getKind() == Kind::Interface;
}
@@ -115,4 +113,4 @@ class InterfaceOpTrait : public OpTrait {
} // end namespace tblgen
} // end namespace mlir
-#endif // MLIR_TABLEGEN_OPTRAIT_H_
+#endif // MLIR_TABLEGEN_TRAIT_H_
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 1e4f5e4becdd6..c439b3c827755 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -8,6 +8,7 @@
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Dialect.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -46,6 +47,15 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
builders.emplace_back(builder);
}
}
+
+ // Populate the traits.
+ if (auto *traitList = def->getValueAsListInit("traits")) {
+ SmallPtrSet<const llvm::Init *, 32> traitSet;
+ traits.reserve(traitSet.size());
+ for (auto *traitInit : *traitList)
+ if (traitSet.insert(traitInit).second)
+ traits.push_back(Trait::create(traitInit));
+ }
}
Dialect AttrOrTypeDef::getDialect() const {
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index 557caf1c9c1a5..a97419b193216 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -19,13 +19,13 @@ llvm_add_library(MLIRTableGen STATIC
Interfaces.cpp
Operator.cpp
OpClass.cpp
- OpTrait.cpp
Pass.cpp
Pattern.cpp
Predicate.cpp
Region.cpp
SideEffects.cpp
Successor.cpp
+ Trait.cpp
Type.cpp
DISABLE_LLVM_LINK_LLVM_DYLIB
diff --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp
deleted file mode 100644
index 4ebfc9bb531b3..0000000000000
--- a/mlir/lib/TableGen/OpTrait.cpp
+++ /dev/null
@@ -1,75 +0,0 @@
-//===- OpTrait.cpp - OpTrait class ----------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// OpTrait wrapper to simplify using TableGen Record defining a MLIR OpTrait.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/TableGen/OpTrait.h"
-#include "mlir/TableGen/Interfaces.h"
-#include "mlir/TableGen/Predicate.h"
-#include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/TableGen/Error.h"
-#include "llvm/TableGen/Record.h"
-
-using namespace mlir;
-using namespace mlir::tblgen;
-
-OpTrait OpTrait::create(const llvm::Init *init) {
- auto def = cast<llvm::DefInit>(init)->getDef();
- if (def->isSubClassOf("PredOpTrait"))
- return OpTrait(Kind::Pred, def);
- if (def->isSubClassOf("GenInternalOpTrait"))
- return OpTrait(Kind::Internal, def);
- if (def->isSubClassOf("OpInterfaceTrait"))
- return OpTrait(Kind::Interface, def);
- assert(def->isSubClassOf("NativeOpTrait"));
- return OpTrait(Kind::Native, def);
-}
-
-OpTrait::OpTrait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {}
-
-std::string NativeOpTrait::getTrait() const {
- llvm::StringRef trait = def->getValueAsString("trait");
- llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
- return cppNamespace.empty() ? trait.str()
- : (cppNamespace + "::" + trait).str();
-}
-
-llvm::StringRef InternalOpTrait::getTrait() const {
- return def->getValueAsString("trait");
-}
-
-std::string PredOpTrait::getPredTemplate() const {
- auto pred = Pred(def->getValueInit("predicate"));
- return pred.getCondition();
-}
-
-llvm::StringRef PredOpTrait::getSummary() const {
- return def->getValueAsString("summary");
-}
-
-OpInterface InterfaceOpTrait::getOpInterface() const {
- return OpInterface(def);
-}
-
-std::string InterfaceOpTrait::getTrait() const {
- llvm::StringRef trait = def->getValueAsString("trait");
- llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
- return cppNamespace.empty() ? trait.str()
- : (cppNamespace + "::" + trait).str();
-}
-
-bool InterfaceOpTrait::shouldDeclareMethods() const {
- return def->isSubClassOf("DeclareOpInterfaceMethods");
-}
-
-std::vector<StringRef> InterfaceOpTrait::getAlwaysDeclaredMethods() const {
- return def->getValueAsListOfStrings("alwaysOverriddenMethods");
-}
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 209c1ec0d94a7..c1ba549e8abfb 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -11,8 +11,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Operator.h"
-#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Predicate.h"
+#include "mlir/TableGen/Trait.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/STLExtras.h"
@@ -158,17 +158,17 @@ auto Operator::getArgDecorators(int index) const -> var_decorator_range {
return *arg->getValueAsListInit("decorators");
}
-const OpTrait *Operator::getTrait(StringRef trait) const {
+const Trait *Operator::getTrait(StringRef trait) const {
for (const auto &t : traits) {
- if (const auto *opTrait = dyn_cast<NativeOpTrait>(&t)) {
- if (opTrait->getTrait() == trait)
- return opTrait;
- } else if (const auto *opTrait = dyn_cast<InternalOpTrait>(&t)) {
- if (opTrait->getTrait() == trait)
- return opTrait;
- } else if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&t)) {
- if (opTrait->getTrait() == trait)
- return opTrait;
+ if (const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
+ if (traitDef->getFullyQualifiedTraitName() == trait)
+ return traitDef;
+ } else if (const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
+ if (traitDef->getFullyQualifiedTraitName() == trait)
+ return traitDef;
+ } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
+ if (traitDef->getFullyQualifiedTraitName() == trait)
+ return traitDef;
}
}
return nullptr;
@@ -314,7 +314,7 @@ void Operator::populateTypeInferenceInfo(
return found;
};
- for (const OpTrait &trait : traits) {
+ for (const Trait &trait : traits) {
const llvm::Record &def = trait.getDef();
// If the infer type op interface was manually added, then treat it as
// intention that the op needs special handling.
@@ -323,8 +323,8 @@ void Operator::populateTypeInferenceInfo(
if (def.isSubClassOf(
llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
return;
- if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&trait))
- if (&opTrait->getDef() == inferTrait)
+ if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
+ if (&traitDef->getDef() == inferTrait)
return;
if (!def.isSubClassOf("AllTypesMatch"))
@@ -344,7 +344,7 @@ void Operator::populateTypeInferenceInfo(
// If the types could be computed, then add type inference trait.
if (allResultsHaveKnownTypes)
- traits.push_back(OpTrait::create(inferTrait->getDefInit()));
+ traits.push_back(Trait::create(inferTrait->getDefInit()));
}
void Operator::populateOpStructure() {
@@ -489,7 +489,7 @@ void Operator::populateOpStructure() {
for (auto *traitInit : *traitList) {
// Keep traits in the same order while skipping over duplicates.
if (traitSet.insert(traitInit).second)
- traits.push_back(OpTrait::create(traitInit));
+ traits.push_back(Trait::create(traitInit));
}
}
diff --git a/mlir/lib/TableGen/SideEffects.cpp b/mlir/lib/TableGen/SideEffects.cpp
index 286cacfdacf8b..a635f198c595a 100644
--- a/mlir/lib/TableGen/SideEffects.cpp
+++ b/mlir/lib/TableGen/SideEffects.cpp
@@ -53,6 +53,6 @@ StringRef SideEffectTrait::getBaseEffectName() const {
return def->getValueAsString("baseEffectName");
}
-bool SideEffectTrait::classof(const OpTrait *t) {
+bool SideEffectTrait::classof(const Trait *t) {
return t->getDef().isSubClassOf("SideEffectsTraitBase");
}
diff --git a/mlir/lib/TableGen/Trait.cpp b/mlir/lib/TableGen/Trait.cpp
new file mode 100644
index 0000000000000..02bb4d4de64af
--- /dev/null
+++ b/mlir/lib/TableGen/Trait.cpp
@@ -0,0 +1,93 @@
+//===- Trait.cpp ----------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Trait wrapper to simplify using TableGen Record defining a MLIR Trait.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Trait.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "mlir/TableGen/Predicate.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+//===----------------------------------------------------------------------===//
+// Trait
+//===----------------------------------------------------------------------===//
+
+Trait Trait::create(const llvm::Init *init) {
+ auto def = cast<llvm::DefInit>(init)->getDef();
+ if (def->isSubClassOf("PredTrait"))
+ return Trait(Kind::Pred, def);
+ if (def->isSubClassOf("GenInternalTrait"))
+ return Trait(Kind::Internal, def);
+ if (def->isSubClassOf("InterfaceTrait"))
+ return Trait(Kind::Interface, def);
+ assert(def->isSubClassOf("NativeTrait"));
+ return Trait(Kind::Native, def);
+}
+
+Trait::Trait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {}
+
+//===----------------------------------------------------------------------===//
+// NativeTrait
+//===----------------------------------------------------------------------===//
+
+std::string NativeTrait::getFullyQualifiedTraitName() const {
+ llvm::StringRef trait = def->getValueAsString("trait");
+ llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
+ return cppNamespace.empty() ? trait.str()
+ : (cppNamespace + "::" + trait).str();
+}
+
+//===----------------------------------------------------------------------===//
+// InternalTrait
+//===----------------------------------------------------------------------===//
+
+llvm::StringRef InternalTrait::getFullyQualifiedTraitName() const {
+ return def->getValueAsString("trait");
+}
+
+//===----------------------------------------------------------------------===//
+// PredTrait
+//===----------------------------------------------------------------------===//
+
+std::string PredTrait::getPredTemplate() const {
+ auto pred = Pred(def->getValueInit("predicate"));
+ return pred.getCondition();
+}
+
+llvm::StringRef PredTrait::getSummary() const {
+ return def->getValueAsString("summary");
+}
+
+//===----------------------------------------------------------------------===//
+// InterfaceTrait
+//===----------------------------------------------------------------------===//
+
+Interface InterfaceTrait::getInterface() const { return Interface(def); }
+
+std::string InterfaceTrait::getFullyQualifiedTraitName() const {
+ llvm::StringRef trait = def->getValueAsString("trait");
+ llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
+ return cppNamespace.empty() ? trait.str()
+ : (cppNamespace + "::" + trait).str();
+}
+
+bool InterfaceTrait::shouldDeclareMethods() const {
+ return def->isSubClassOf("DeclareInterfaceMethods");
+}
+
+std::vector<StringRef> InterfaceTrait::getAlwaysDeclaredMethods() const {
+ return def->getValueAsListOfStrings("alwaysOverriddenMethods");
+}
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index d4c9650e85122..ef21d9f9413ec 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -14,6 +14,7 @@ include "mlir/Interfaces/SideEffectInterfaceBase.td"
// A type interface used to test the ODS generation of type interfaces.
def TestTypeInterface : TypeInterface<"TestTypeInterface"> {
+ let cppNamespace = "::mlir::test";
let methods = [
InterfaceMethod<"Prints the type name.",
"void", "printTypeA", (ins "Location":$loc), [{
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 5d23bb5e22404..9821774eeede7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -15,9 +15,11 @@
// To get the test dialect def.
include "TestOps.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
// All of the types will extend this class.
-class Test_Type<string name> : TypeDef<Test_Dialect, name> { }
+class Test_Type<string name, list<Trait> traits = []>
+ : TypeDef<Test_Dialect, name, traits>;
def SimpleTypeA : Test_Type<"SimpleA"> {
let mnemonic = "smpla";
@@ -151,4 +153,27 @@ def StructType : FieldInfo_Type<"Struct"> {
let mnemonic = "struct";
}
+def TestType : Test_Type<"Test", [
+ DeclareTypeInterfaceMethods<TestTypeInterface>
+]> {
+ let mnemonic = "test_type";
+}
+
+def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
+ DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["areCompatible"]>
+]> {
+ let mnemonic = "test_type_with_layout";
+ let parameters = (ins "unsigned":$key);
+ let extraClassDeclaration = [{
+ LogicalResult verifyEntries(DataLayoutEntryListRef params,
+ Location loc) const;
+
+ private:
+ unsigned extractKind(DataLayoutEntryListRef params,
+ StringRef expectedKind) const;
+
+ public:
+ }];
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 38ab9c819974d..817b66a000829 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -58,6 +58,34 @@ static void printSignedness(DialectAsmPrinter &printer,
}
}
+// The functions don't need to be in the header file, but need to be in the mlir
+// namespace. Declare them here, then define them immediately below. Separating
+// the declaration and definition adheres to the LLVM coding standards.
+namespace mlir {
+namespace test {
+// FieldInfo is used as part of a parameter, so equality comparison is
+// compulsory.
+static bool operator==(const FieldInfo &a, const FieldInfo &b);
+// FieldInfo is used as part of a parameter, so a hash will be computed.
+static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
+} // namespace test
+} // namespace mlir
+
+// FieldInfo is used as part of a parameter, so equality comparison is
+// compulsory.
+static bool mlir::test::operator==(const FieldInfo &a, const FieldInfo &b) {
+ return a.name == b.name && a.type == b.type;
+}
+
+// FieldInfo is used as part of a parameter, so a hash will be computed.
+static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
+ return llvm::hash_combine(fi.name, fi.type);
+}
+
+//===----------------------------------------------------------------------===//
+// CompoundAType
+//===----------------------------------------------------------------------===//
+
Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) {
int widthOfSomething;
Type oneType;
@@ -87,29 +115,9 @@ void CompoundAType::print(DialectAsmPrinter &printer) const {
printer << "]>";
}
-// The functions don't need to be in the header file, but need to be in the mlir
-// namespace. Declare them here, then define them immediately below. Separating
-// the declaration and definition adheres to the LLVM coding standards.
-namespace mlir {
-namespace test {
-// FieldInfo is used as part of a parameter, so equality comparison is
-// compulsory.
-static bool operator==(const FieldInfo &a, const FieldInfo &b);
-// FieldInfo is used as part of a parameter, so a hash will be computed.
-static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
-} // namespace test
-} // namespace mlir
-
-// FieldInfo is used as part of a parameter, so equality comparison is
-// compulsory.
-static bool mlir::test::operator==(const FieldInfo &a, const FieldInfo &b) {
- return a.name == b.name && a.type == b.type;
-}
-
-// FieldInfo is used as part of a parameter, so a hash will be computed.
-static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
- return llvm::hash_combine(fi.name, fi.type);
-}
+//===----------------------------------------------------------------------===//
+// TestIntegerType
+//===----------------------------------------------------------------------===//
// Example type validity checker.
LogicalResult
@@ -122,18 +130,58 @@ TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
}
//===----------------------------------------------------------------------===//
-// Tablegen Generated Definitions
+// TestType
//===----------------------------------------------------------------------===//
-#define GET_TYPEDEF_CLASSES
-#include "TestTypeDefs.cpp.inc"
+void TestType::printTypeC(Location loc) const {
+ emitRemark(loc) << *this << " - TestC";
+}
+
+//===----------------------------------------------------------------------===//
+// TestTypeWithLayout
+//===----------------------------------------------------------------------===//
+
+Type TestTypeWithLayoutType::parse(MLIRContext *ctx, DialectAsmParser &parser) {
+ unsigned val;
+ if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater())
+ return Type();
+ return TestTypeWithLayoutType::get(ctx, val);
+}
+
+void TestTypeWithLayoutType::print(DialectAsmPrinter &printer) const {
+ printer << "test_type_with_layout<" << getKey() << ">";
+}
+
+unsigned
+TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ return extractKind(params, "size");
+}
-LogicalResult TestTypeWithLayout::verifyEntries(DataLayoutEntryListRef params,
- Location loc) const {
+unsigned
+TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ return extractKind(params, "alignment");
+}
+
+unsigned TestTypeWithLayoutType::getPreferredAlignment(
+ const DataLayout &dataLayout, DataLayoutEntryListRef params) const {
+ return extractKind(params, "preferred");
+}
+
+bool TestTypeWithLayoutType::areCompatible(
+ DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const {
+ unsigned old = extractKind(oldLayout, "alignment");
+ return old == 1 || extractKind(newLayout, "alignment") <= old;
+}
+
+LogicalResult
+TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
+ Location loc) const {
for (DataLayoutEntryInterface entry : params) {
// This is for testing purposes only, so assert well-formedness.
assert(entry.isTypeEntry() && "unexpected identifier entry");
- assert(entry.getKey().get<Type>().isa<TestTypeWithLayout>() &&
+ assert(entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() &&
"wrong type passed in");
auto array = entry.getValue().dyn_cast<ArrayAttr>();
assert(array && array.getValue().size() == 2 &&
@@ -149,8 +197,8 @@ LogicalResult TestTypeWithLayout::verifyEntries(DataLayoutEntryListRef params,
return success();
}
-unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params,
- StringRef expectedKind) const {
+unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
+ StringRef expectedKind) const {
for (DataLayoutEntryInterface entry : params) {
ArrayRef<Attribute> pair = entry.getValue().cast<ArrayAttr>().getValue();
StringRef kind = pair.front().cast<StringAttr>().getValue();
@@ -160,12 +208,19 @@ unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params,
return 1;
}
+//===----------------------------------------------------------------------===//
+// Tablegen Generated Definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.cpp.inc"
+
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
void TestDialect::registerTypes() {
- addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
+ addTypes<TestRecursiveType,
#define GET_TYPEDEF_LIST
#include "TestTypeDefs.cpp.inc"
>();
@@ -183,17 +238,6 @@ static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
if (parseResult.hasValue())
return genType;
}
- if (typeTag == "test_type")
- return TestType::get(parser.getBuilder().getContext());
-
- if (typeTag == "test_type_with_layout") {
- unsigned val;
- if (parser.parseLess() || parser.parseInteger(val) ||
- parser.parseGreater()) {
- return Type();
- }
- return TestTypeWithLayout::get(parser.getBuilder().getContext(), val);
- }
if (typeTag != "test_rec") {
parser.emitError(parser.getNameLoc()) << "unknown type!";
@@ -234,15 +278,6 @@ static void printTestType(Type type, DialectAsmPrinter &printer,
llvm::SetVector<Type> &stack) {
if (succeeded(generatedTypePrinter(type, printer)))
return;
- if (type.isa<TestType>()) {
- printer << "test_type";
- return;
- }
-
- if (auto t = type.dyn_cast<TestTypeWithLayout>()) {
- printer << "test_type_with_layout<" << t.getKey() << ">";
- return;
- }
auto rec = type.cast<TestRecursiveType>();
printer << "test_rec<" << rec.getName();
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index e2a65a2aa12df..f9a0289f20b01 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -41,25 +41,14 @@ struct FieldInfo {
} // namespace test
} // namespace mlir
+#include "TestTypeInterfaces.h.inc"
+
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"
namespace mlir {
namespace test {
-#include "TestTypeInterfaces.h.inc"
-
-/// This class is a simple test type that uses a generated interface.
-struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
- TestTypeInterface::Trait> {
- using Base::Base;
-
- /// Provide a definition for the necessary interface methods.
- void printTypeC(Location loc) const {
- emitRemark(loc) << *this << " - TestC";
- }
-};
-
/// Storage for simple named recursive types, where the type is identified by
/// its name and can "contain" another type, including itself.
struct TestRecursiveTypeStorage : public TypeStorage {
@@ -108,62 +97,6 @@ class TestRecursiveType
StringRef getName() { return getImpl()->name; }
};
-struct TestTypeWithLayoutStorage : public TypeStorage {
- using KeyTy = unsigned;
-
- explicit TestTypeWithLayoutStorage(unsigned key) : key(key) {}
- bool operator==(const KeyTy &other) const { return other == key; }
-
- static TestTypeWithLayoutStorage *construct(TypeStorageAllocator &allocator,
- const KeyTy &key) {
- return new (allocator.allocate<TestTypeWithLayoutStorage>())
- TestTypeWithLayoutStorage(key);
- }
-
- unsigned key;
-};
-
-class TestTypeWithLayout
- : public Type::TypeBase<TestTypeWithLayout, Type, TestTypeWithLayoutStorage,
- DataLayoutTypeInterface::Trait> {
-public:
- using Base::Base;
-
- static TestTypeWithLayout get(MLIRContext *ctx, unsigned key) {
- return Base::get(ctx, key);
- }
-
- unsigned getKey() { return getImpl()->key; }
-
- unsigned getTypeSizeInBits(const DataLayout &dataLayout,
- DataLayoutEntryListRef params) const {
- return extractKind(params, "size");
- }
-
- unsigned getABIAlignment(const DataLayout &dataLayout,
- DataLayoutEntryListRef params) const {
- return extractKind(params, "alignment");
- }
-
- unsigned getPreferredAlignment(const DataLayout &dataLayout,
- DataLayoutEntryListRef params) const {
- return extractKind(params, "preferred");
- }
-
- bool areCompatible(DataLayoutEntryListRef oldLayout,
- DataLayoutEntryListRef newLayout) const {
- unsigned old = extractKind(oldLayout, "alignment");
- return old == 1 || extractKind(newLayout, "alignment") <= old;
- }
-
- LogicalResult verifyEntries(DataLayoutEntryListRef params,
- Location loc) const;
-
-private:
- unsigned extractKind(DataLayoutEntryListRef params,
- StringRef expectedKind) const;
-};
-
} // namespace test
} // namespace mlir
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index fc95fba3c91cd..a5a41b3039918 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -78,16 +78,16 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
// DEF: CompoundAAttrStorage (
// DEF-NEXT: : ::mlir::AttributeStorage(inner),
-// DEF: bool operator==(const KeyTy &key) const {
-// DEF-NEXT: if (!(widthOfSomething == std::get<0>(key)))
+// DEF: bool operator==(const KeyTy &tblgenKey) const {
+// DEF-NEXT: if (!(widthOfSomething == std::get<0>(tblgenKey)))
// DEF-NEXT: return false;
-// DEF-NEXT: if (!(exampleTdType == std::get<1>(key)))
+// DEF-NEXT: if (!(exampleTdType == std::get<1>(tblgenKey)))
// DEF-NEXT: return false;
-// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(key))))
+// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(tblgenKey))))
// DEF-NEXT: return false;
-// DEF-NEXT: if (!(dims == std::get<3>(key)))
+// DEF-NEXT: if (!(dims == std::get<3>(tblgenKey)))
// DEF-NEXT: return false;
-// DEF-NEXT: if (!(getType() == std::get<4>(key)))
+// DEF-NEXT: if (!(getType() == std::get<4>(tblgenKey)))
// DEF-NEXT: return false;
// DEF-NEXT: return true;
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index a951df92fe180..c8910eab69ece 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -11,8 +11,10 @@
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
@@ -208,28 +210,29 @@ class DialectAsmPrinter;
/// {1}: The name of the type base class.
/// {2}: The name of the base value type, e.g. Attribute or Type.
/// {3}: The tablegen record type prefix, e.g. Attr or Type.
+/// {4}: The traits of the def class.
static const char *const defDeclSingletonBeginStr = R"(
- class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage> {{
+ class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage{4}> {{
public:
/// Inherit some necessary constructors from '{3}Base'.
using Base::Base;
)";
-/// The code block for the start of a typeDef class declaration -- parametric
-/// case.
+/// The code block for the start of a class declaration -- parametric case.
///
-/// {0}: The name of the typeDef class.
-/// {1}: The name of the type base class.
-/// {2}: The typeDef storage class namespace.
+/// {0}: The name of the def class.
+/// {1}: The name of the base class.
+/// {2}: The def storage class namespace.
/// {3}: The storage class name.
/// {4}: The name of the base value type, e.g. Attribute or Type.
/// {5}: The tablegen record type prefix, e.g. Attr or Type.
+/// {6}: The traits of the def class.
static const char *const defDeclParametricBeginStr = R"(
namespace {2} {
struct {3};
} // end namespace {2}
class {0} : public ::mlir::{4}::{5}Base<{0}, {1},
- {2}::{3}> {{
+ {2}::{3}{6}> {{
public:
/// Inherit some necessary constructors from '{5}Base'.
using Base::Base;
@@ -309,19 +312,71 @@ static void emitBuilderDecls(const AttrOrTypeDef &def, raw_ostream &os,
}
}
+static void emitInterfaceMethodDecls(const InterfaceTrait *trait,
+ raw_ostream &os) {
+ Interface interface = trait->getInterface();
+
+ // Get the set of methods that should always be declared.
+ auto alwaysDeclaredMethodsVec = trait->getAlwaysDeclaredMethods();
+ llvm::StringSet<> alwaysDeclaredMethods;
+ alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
+ alwaysDeclaredMethodsVec.end());
+
+ for (const InterfaceMethod &method : interface.getMethods()) {
+ // Don't declare if the method has a body.
+ if (method.getBody())
+ continue;
+ // Don't declare if the method has a default implementation and the def
+ // didn't request that it always be declared.
+ if (method.getDefaultImplementation() &&
+ !alwaysDeclaredMethods.count(method.getName()))
+ continue;
+
+ // Emit the method declaration.
+ os << " " << (method.isStatic() ? "static " : "")
+ << method.getReturnType() << " " << method.getName() << "(";
+ llvm::interleaveComma(method.getArguments(), os,
+ [&](const InterfaceMethod::Argument &arg) {
+ os << arg.type << " " << arg.name;
+ });
+ os << ")" << (method.isStatic() ? "" : " const") << ";\n";
+ }
+}
+
void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
SmallVector<AttrOrTypeParameter, 4> params;
def.getParameters(params);
+ // Build the trait list for this def.
+ std::vector<std::string> traitList;
+ StringSet<> traitSet;
+ for (const Trait &baseTrait : def.getTraits()) {
+ std::string traitStr;
+ if (const auto *trait = dyn_cast<NativeTrait>(&baseTrait))
+ traitStr = trait->getFullyQualifiedTraitName();
+ else if (const auto *trait = dyn_cast<InterfaceTrait>(&baseTrait))
+ traitStr = trait->getFullyQualifiedTraitName();
+ else
+ llvm_unreachable("unexpected Attribute/Type trait type");
+
+ if (traitSet.insert(traitStr).second)
+ traitList.emplace_back(std::move(traitStr));
+ }
+ std::string traitStr;
+ if (!traitList.empty())
+ traitStr = ", " + llvm::join(traitList, ", ");
+
// Emit the beginning string template: either the singleton or parametric
// template.
if (def.getNumParameters() == 0) {
os << formatv(defDeclSingletonBeginStr, def.getCppClassName(),
- def.getCppBaseClassName(), valueType, defTypePrefix);
+ def.getCppBaseClassName(), valueType, defTypePrefix,
+ traitStr);
} else {
os << formatv(defDeclParametricBeginStr, def.getCppClassName(),
def.getCppBaseClassName(), def.getStorageNamespace(),
- def.getStorageClassName(), valueType, defTypePrefix);
+ def.getStorageClassName(), valueType, defTypePrefix,
+ traitStr);
}
// Emit the extra declarations first in case there's a definition in there.
@@ -362,6 +417,14 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
}
}
+ // Emit any interface method declarations.
+ for (const Trait &trait : def.getTraits()) {
+ if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait)) {
+ if (traitDef->shouldDeclareMethods())
+ emitInterfaceMethodDecls(traitDef, os);
+ }
+ }
+
// End the decl.
os << " };\n";
}
@@ -452,7 +515,7 @@ static const char *const defStorageClassConstructorBeginStr = R"(
/// Define a construction method for creating a new instance of this
/// storage.
static {0} *construct(::mlir::{1}StorageAllocator &allocator,
- const KeyTy &key) {{
+ const KeyTy &tblgenKey) {{
)";
/// The storage class' constructor return template.
@@ -558,7 +621,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
paramInitializer, parameterTypeList, valueType);
// * Emit the comparison method.
- os << " bool operator==(const KeyTy &key) const {\n";
+ os << " bool operator==(const KeyTy &tblgenKey) const {\n";
for (auto it : llvm::enumerate(params)) {
os << " if (!(";
@@ -566,7 +629,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
bool isSelfType = isa<AttributeSelfTypeParameter>(it.value());
FmtContext context;
context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName())
- .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(key)");
+ .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(tblgenKey)");
// Use the parameter specified comparator if possible, otherwise default to
// operator==.
@@ -577,13 +640,13 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
os << " return true;\n }\n";
// * Emit the haskKey method.
- os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
+ os << " static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {\n";
// Extract each parameter from the key.
os << " return ::llvm::hash_combine(";
llvm::interleaveComma(
llvm::seq<unsigned>(0, params.size()), os,
- [&](unsigned it) { os << "std::get<" << it << ">(key)"; });
+ [&](unsigned it) { os << "std::get<" << it << ">(tblgenKey)"; });
os << ");\n }\n";
// * Emit the construct method.
@@ -592,7 +655,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
// here and then they can write the definition elsewhere.
if (def.hasStorageCustomConstructor()) {
os << llvm::formatv(" static {0} *construct(::mlir::{1}StorageAllocator "
- "&allocator, const KeyTy &key);\n",
+ "&allocator, const KeyTy &tblgenKey);\n",
def.getStorageClassName(), valueType);
// Otherwise, generate one.
@@ -601,7 +664,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(),
valueType);
for (unsigned i = 0, e = params.size(); i < e; ++i) {
- os << formatv(" auto {0} = std::get<{1}>(key);\n",
+ os << formatv(" auto {0} = std::get<{1}>(tblgenKey);\n",
params[i].getName(), i);
}
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index a2c7dd5673bbf..350cc6f18f9cb 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -15,8 +15,8 @@
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
-#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 25e4d3360b3ee..413d45a6b8d65 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -18,9 +18,9 @@
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
-#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/SideEffects.h"
+#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Path.h"
@@ -430,7 +430,7 @@ class OpEmitter {
void genOpInterfaceMethods();
// Generate op interface methods for the given interface.
- void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
+ void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
// Generate op interface method for the given interface method. If
// 'declaration' is true, generates a declaration, else a definition.
@@ -1719,8 +1719,8 @@ void OpEmitter::genFolderDecls() {
}
}
-void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
- auto interface = opTrait->getOpInterface();
+void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
+ Interface interface = opTrait->getInterface();
// Get the set of methods that should always be declared.
auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
@@ -1757,7 +1757,7 @@ OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
void OpEmitter::genOpInterfaceMethods() {
for (const auto &trait : op.getTraits()) {
- if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
+ if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
if (opTrait->shouldDeclareMethods())
genOpInterfaceMethods(opTrait);
}
@@ -1866,9 +1866,9 @@ void OpEmitter::genTypeInterfaceMethods() {
return;
// Generate 'inferReturnTypes' method declaration using the interface method
// declared in 'InferTypeOpInterface' op interface.
- const auto *trait = dyn_cast<InterfaceOpTrait>(
+ const auto *trait = dyn_cast<InterfaceTrait>(
op.getTrait("::mlir::InferTypeOpInterface::Trait"));
- auto interface = trait->getOpInterface();
+ Interface interface = trait->getInterface();
OpMethod *method = [&]() -> OpMethod * {
for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
if (interfaceMethod.getName() == "inferReturnTypes") {
@@ -1966,7 +1966,7 @@ void OpEmitter::genVerifier() {
genOperandResultVerifier(body, op.getResults(), "result");
for (auto &trait : op.getTraits()) {
- if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
+ if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
body << tgfmt(" if (!($0))\n "
"return emitOpError(\"failed to verify that $1\");\n",
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
@@ -2187,10 +2187,10 @@ void OpEmitter::genTraits() {
// Add the native and interface traits.
for (const auto &trait : op.getTraits()) {
- if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
- opClass.addTrait(opTrait->getTrait());
- else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
- opClass.addTrait(opTrait->getTrait());
+ if (auto opTrait = dyn_cast<tblgen::NativeTrait>(&trait))
+ opClass.addTrait(opTrait->getFullyQualifiedTraitName());
+ else if (auto opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
+ opClass.addTrait(opTrait->getFullyQualifiedTraitName());
}
}
@@ -2379,12 +2379,14 @@ void OpOperandAdaptorEmitter::addVerification() {
// Verify a few traits first so that we can use
// getODSOperands()/getODSResults() in the rest of the verifier.
for (auto &trait : op.getTraits()) {
- if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
- if (t->getTrait() == "::mlir::OpTrait::AttrSizedOperandSegments") {
+ if (auto *t = dyn_cast<tblgen::NativeTrait>(&trait)) {
+ if (t->getFullyQualifiedTraitName() ==
+ "::mlir::OpTrait::AttrSizedOperandSegments") {
body << formatv(checkAttrSizedValueSegmentsCode,
"operand_segment_sizes", op.getNumOperands(),
"operand");
- } else if (t->getTrait() == "::mlir::OpTrait::AttrSizedResultSegments") {
+ } else if (t->getFullyQualifiedTraitName() ==
+ "::mlir::OpTrait::AttrSizedResultSegments") {
body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
op.getNumResults(), "result");
}
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 93d2155c3c821..3c3f00f2379c8 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -12,8 +12,8 @@
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
-#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
@@ -445,18 +445,12 @@ struct OperationFormat {
operandTypes.resize(op.getNumOperands(), TypeResolution());
resultTypes.resize(op.getNumResults(), TypeResolution());
- hasImplicitTermTrait =
- llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
- return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
- });
+ hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
+ return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
+ });
hasSingleBlockTrait =
- hasImplicitTermTrait ||
- llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
- if (auto *native = dyn_cast<NativeOpTrait>(&trait))
- return native->getTrait() == "::mlir::OpTrait::SingleBlock";
- return false;
- });
+ hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock");
}
/// Generate the operation parser from this format.
@@ -2416,7 +2410,7 @@ LogicalResult FormatParser::parse() {
// Check for any type traits that we can use for inferring types.
llvm::StringMap<TypeResolutionInstance> variableTyResolver;
- for (const OpTrait &trait : op.getTraits()) {
+ for (const Trait &trait : op.getTraits()) {
const llvm::Record &def = trait.getDef();
if (def.isSubClassOf("AllTypesMatch")) {
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
More information about the Mlir-commits
mailing list