[Mlir-commits] [mlir] 94662ee - [mlir] Add support for adding attribute+type traits/interfaces to tablegen defs

River Riddle llvmlistbot at llvm.org
Thu Apr 15 11:44:12 PDT 2021


Author: River Riddle
Date: 2021-04-15T11:41:51-07:00
New Revision: 94662ee0c175165e60bc09fc73396a3e344829d4

URL: https://github.com/llvm/llvm-project/commit/94662ee0c175165e60bc09fc73396a3e344829d4
DIFF: https://github.com/llvm/llvm-project/commit/94662ee0c175165e60bc09fc73396a3e344829d4.diff

LOG: [mlir] Add support for adding attribute+type traits/interfaces to tablegen defs

This matches the current support provided to operations, and allows attaching traits, interfaces, and using the DeclareInterfaceMethods utility. This was missed when attribute/type generation was first added.

Differential Revision: https://reviews.llvm.org/D100233

Added: 
    mlir/include/mlir/TableGen/Trait.h
    mlir/lib/TableGen/Trait.cpp

Modified: 
    mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/BuiltinLocationAttributes.td
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/AttrOrTypeDef.h
    mlir/include/mlir/TableGen/Operator.h
    mlir/include/mlir/TableGen/SideEffects.h
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/lib/TableGen/CMakeLists.txt
    mlir/lib/TableGen/Operator.cpp
    mlir/lib/TableGen/SideEffects.cpp
    mlir/test/lib/Dialect/Test/TestInterfaces.td
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    mlir/test/lib/Dialect/Test/TestTypes.h
    mlir/test/mlir-tblgen/attrdefs.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/tools/mlir-tblgen/DialectGen.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    mlir/include/mlir/TableGen/OpTrait.h
    mlir/lib/TableGen/OpTrait.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
index 1e0578339ad86..10b6130afd148 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
@@ -20,7 +20,7 @@ include "mlir/Dialect/PDL/IR/PDLDialect.td"
 //===----------------------------------------------------------------------===//
 
 class PDL_Type<string name, string typeMnemonic>
-    : TypeDef<PDL_Dialect, name, "::mlir::pdl::PDLType"> {
+    : TypeDef<PDL_Dialect, name, [], "::mlir::pdl::PDLType"> {
   let mnemonic = typeMnemonic;
 }
 

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 45214535b1f8c..c248ad5822f87 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -23,7 +23,7 @@ include "mlir/IR/BuiltinDialect.td"
 
 // Base class for Builtin dialect attributes.
 class Builtin_Attr<string name, string baseCppClass = "::mlir::Attribute">
-    : AttrDef<Builtin_Dialect, name, baseCppClass> {
+    : AttrDef<Builtin_Dialect, name, [], baseCppClass> {
   let mnemonic = ?;
 }
 

diff  --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index d6347441d09c6..3858b11e05215 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -17,7 +17,7 @@ include "mlir/IR/BuiltinDialect.td"
 
 // Base class for Builtin dialect location attributes.
 class Builtin_LocationAttr<string name>
-    : AttrDef<Builtin_Dialect, name, "::mlir::LocationAttr"> {
+    : AttrDef<Builtin_Dialect, name, [], "::mlir::LocationAttr"> {
   let cppClassName = name;
   let mnemonic = ?;
 }

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index f266f61e182ea..16afcfdfd00a2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -22,7 +22,7 @@ include "mlir/IR/BuiltinDialect.td"
 
 // Base class for Builtin dialect types.
 class Builtin_Type<string name, string baseCppClass = "::mlir::Type">
-    : TypeDef<Builtin_Dialect, name, baseCppClass> {
+    : TypeDef<Builtin_Dialect, name, [], baseCppClass> {
   let mnemonic = ?;
 }
 
@@ -65,8 +65,7 @@ def Builtin_Complex : Builtin_Type<"Complex"> {
 //===----------------------------------------------------------------------===//
 
 // Base class for Builtin dialect float types.
-class Builtin_FloatType<string name> : TypeDef<Builtin_Dialect, name,
-                                               "::mlir::FloatType"> {
+class Builtin_FloatType<string name> : Builtin_Type<name, "::mlir::FloatType"> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
   }];

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 20f3c7fa84130..11038ef12d2ad 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1734,42 +1734,58 @@ def AnySuccessor : Successor<?, "any successor">;
 class VariadicSuccessor<Successor successor>
   : Successor<successor.predicate, successor.summary>;
 
+
 //===----------------------------------------------------------------------===//
-// OpTrait definitions
+// Trait definitions
 //===----------------------------------------------------------------------===//
 
-// OpTrait represents a trait regarding an op.
-class OpTrait;
+// Trait represents a trait regarding an attribute, operation, or type.
+class Trait;
 
-// NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The
-// purpose to wrap around C++ symbol string with this class is to make
-// traits specified for ops in TableGen less alien and more integrated.
-class NativeOpTrait<string name> : OpTrait {
+// NativeTrait corresponds to the MLIR C++ trait mechanism. The purpose to wrap
+// around C++ symbol string with this class is to make traits specified for
+// entities in TableGen less alien and more integrated.
+class NativeTrait<string name, string entityType> : Trait {
   string trait = name;
-  string cppNamespace = "::mlir::OpTrait";
+  string cppNamespace = "::mlir::" # entityType # "Trait";
 }
 
-// ParamNativeOpTrait corresponds to the template-parameterized traits in the
-// C++ implementation.  MLIR uses nested class templates to implement such
-// traits leading to constructs of the form "TraitName<Parameters>::Impl". Use
-// the value in `prop` as the trait name and the value in `params` as
-// parameters to construct the native trait class name.
-class ParamNativeOpTrait<string prop, string params>
-    : NativeOpTrait<prop # "<" # params # ">::Impl">;
+// ParamNativeTrait corresponds to the template-parameterized traits in the C++
+// implementation. MLIR uses nested class templates to implement such traits
+// leading to constructs of the form "TraitName<Parameters>::Impl". Use the
+// value in `prop` as the trait name and the value in `params` as parameters to
+// construct the native trait class name.
+class ParamNativeTrait<string prop, string params, string entityType>
+    : NativeTrait<prop # "<" # params # ">::Impl", entityType>;
 
-// GenInternalOpTrait is an op trait that does not have direct C++ mapping but
-// affects op definition generator internals, like how op builders and
+// GenInternalTrait is a trait that does not have direct C++ mapping but affects
+// an entities definition generator internals, like how operation builders and
 // operand/attribute/result getters are generated.
-class GenInternalOpTrait<string prop> : OpTrait {
-  string trait = "::mlir::OpTrait::" # prop;
+class GenInternalTrait<string prop, string entityType> : Trait {
+  string trait = "::mlir::" # entityType # "Trait::" # prop;
 }
 
-// PredOpTrait is an op trait implemented by way of a predicate on the op.
-class PredOpTrait<string descr, Pred pred> : OpTrait {
+// PredTrait is a trait implemented by way of a predicate on an entity.
+class PredTrait<string descr, Pred pred> : Trait {
   string summary = descr;
   Pred predicate = pred;
 }
 
+//===----------------------------------------------------------------------===//
+// OpTrait definitions
+//===----------------------------------------------------------------------===//
+
+// OpTrait represents a trait regarding an operation.
+// TODO: Remove this class in favor of using Trait.
+class OpTrait;
+
+// These classes are used to define operation specific traits.
+class NativeOpTrait<string name> : NativeTrait<name, "Op">, OpTrait;
+class ParamNativeOpTrait<string prop, string params>
+    : ParamNativeTrait<prop, params, "Op">, OpTrait;
+class GenInternalOpTrait<string prop> : GenInternalTrait<prop, "Op">, OpTrait;
+class PredOpTrait<string descr, Pred pred> : PredTrait<descr, pred>, OpTrait;
+
 // Op defines an affine scope.
 def AffineScope : NativeOpTrait<"AffineScope">;
 // Op defines an automatic allocation scope.
@@ -1895,23 +1911,28 @@ class CArg<string ty, string value = ""> {
   string defaultValue = value;
 }
 
-// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
-// C++. The purpose to wrap around C++ symbol string with this class is to make
+// InterfaceTrait corresponds to a specific 'Interface' class defined in C++.
+// The purpose to wrap around C++ symbol string with this class is to make
 // interfaces specified for ops in TableGen less alien and more integrated.
-class OpInterfaceTrait<string name, code verifyBody = [{}]>
-          : NativeOpTrait<""> {
+class InterfaceTrait<string name> : NativeTrait<"", ""> {
   let trait = name # "::Trait";
   let cppNamespace = "";
 
-  // Specify the body of the verification function. `$_op` will be replaced with
-  // the operation being verified.
-  code verify = verifyBody;
-
   // An optional code block containing extra declarations to place in the
   // interface trait declaration.
   code extraTraitClassDeclaration = "";
 }
 
+// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
+// C++. The purpose to wrap around C++ symbol string with this class is to make
+// interfaces specified for ops in TableGen less alien and more integrated.
+class OpInterfaceTrait<string name, code verifyBody = [{}]>
+    : InterfaceTrait<name>, OpTrait {
+  // Specify the body of the verification function. `$_op` will be replaced with
+  // the operation being verified.
+  code verify = verifyBody;
+}
+
 // This class represents a single, optionally static, interface method.
 // Note: non-static interface methods have an implicit parameter, either
 // $_op/$_attr/$_type corresponding to an instance of the derived value.
@@ -1967,39 +1988,52 @@ class Interface<string name> {
 }
 
 // AttrInterface represents an interface registered to an attribute.
-class AttrInterface<string name> : Interface<name> {
-  // An optional code block containing extra declarations to place in the
-  // interface trait declaration.
-  code extraTraitClassDeclaration = "";
-}
+class AttrInterface<string name> : Interface<name>, InterfaceTrait<name>;
 
 // OpInterface represents an interface registered to an operation.
 class OpInterface<string name> : Interface<name>, OpInterfaceTrait<name>;
 
 // TypeInterface represents an interface registered to a type.
-class TypeInterface<string name> : Interface<name> {
-  // An optional code block containing extra declarations to place in the
-  // interface trait declaration.
-  code extraTraitClassDeclaration = "";
-}
+class TypeInterface<string name> : Interface<name>, InterfaceTrait<name>;
 
-// Whether to declare the op interface methods in the op's header. This class
-// simply wraps an OpInterface but is used to indicate that the method
+// Whether to declare the interface methods in the user entity's header. This
+// class simply wraps an Interface but is used to indicate that the method
 // declarations should be generated. This class takes an optional set of methods
 // that should have declarations generated even if the method has a default
 // implementation.
+class DeclareInterfaceMethods<Interface interface,
+                              list<string> overridenMethods = []> {
+    // This field contains a set of method names that should always have their
+    // declarations generated. This allows for generating declarations for
+    // methods with default implementations that need to be overridden.
+    list<string> alwaysOverriddenMethods = overridenMethods;
+}
+class DeclareAttrInterfaceMethods<AttrInterface interface,
+                                  list<string> overridenMethods = []>
+      : DeclareInterfaceMethods<interface, overridenMethods>,
+        AttrInterface<interface.cppClassName> {
+    let description = interface.description;
+    let cppClassName = interface.cppClassName;
+    let cppNamespace = interface.cppNamespace;
+    let methods = interface.methods;
+}
 class DeclareOpInterfaceMethods<OpInterface interface,
                                 list<string> overridenMethods = []>
-      : OpInterface<interface.cppClassName> {
+      : DeclareInterfaceMethods<interface, overridenMethods>,
+        OpInterface<interface.cppClassName> {
+    let description = interface.description;
+    let cppClassName = interface.cppClassName;
+    let cppNamespace = interface.cppNamespace;
+    let methods = interface.methods;
+}
+class DeclareTypeInterfaceMethods<TypeInterface interface,
+                                  list<string> overridenMethods = []>
+      : DeclareInterfaceMethods<interface, overridenMethods>,
+        TypeInterface<interface.cppClassName> {
     let description = interface.description;
     let cppClassName = interface.cppClassName;
     let cppNamespace = interface.cppNamespace;
     let methods = interface.methods;
-
-    // This field contains a set of method names that should always have their
-    // declarations generated. This allows for generating declarations for
-    // methods with default implementations that need to be overridden.
-    list<string> alwaysOverriddenMethods = overridenMethods;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2609,7 +2643,8 @@ class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
 
 // Define a new attribute or type, named `name`, that inherits from the given
 // C++ base class.
-class AttrOrTypeDef<string valueType, string name, string baseCppClass> {
+class AttrOrTypeDef<string valueType, string name, list<Trait> defTraits,
+                    string baseCppClass> {
   // The name of the C++ base class to use for this def.
   string cppBaseClassName = baseCppClass;
 
@@ -2664,6 +2699,9 @@ class AttrOrTypeDef<string valueType, string name, string baseCppClass> {
   // Note that builders should only be provided when a def has parameters.
   list<AttrOrTypeBuilder> builders = ?;
 
+  // The list of traits attached to this def.
+  list<Trait> traits = defTraits;
+
   // Use the lowercased name as the keyword for parsing/printing. Specify only
   // if you want tblgen to generate declarations and/or definitions of
   // the printer/parser.
@@ -2692,10 +2730,10 @@ class AttrOrTypeDef<string valueType, string name, string baseCppClass> {
 
 // Define a new attribute, named `name`, belonging to `dialect` that inherits
 // from the given C++ base class.
-class AttrDef<Dialect dialect, string name,
+class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
               string baseCppClass = "::mlir::Attribute">
     : DialectAttr<dialect, CPred<"">, /*descr*/"">,
-      AttrOrTypeDef<"Attr", name, baseCppClass> {
+      AttrOrTypeDef<"Attr", name, traits, baseCppClass> {
   // The name of the C++ Attribute class.
   string cppClassName = name # "Attr";
 
@@ -2728,10 +2766,10 @@ class AttrDef<Dialect dialect, string name,
 
 // Define a new type, named `name`, belonging to `dialect` that inherits from
 // the given C++ base class.
-class TypeDef<Dialect dialect, string name,
+class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
               string baseCppClass = "::mlir::Type">
     : DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
-      AttrOrTypeDef<"Type", name, baseCppClass> {
+      AttrOrTypeDef<"Type", name, traits, baseCppClass> {
   // A constant builder provided when the type has no parameters.
   let builderCall = !if(!empty(parameters),
                            "$_builder.getType<" # dialect.cppNamespace #

diff  --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 5271fbae0eea3..ab07f43c1b5e8 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Builder.h"
+#include "mlir/TableGen/Trait.h"
 
 namespace llvm {
 class DagInit;
@@ -120,6 +121,9 @@ class AttrOrTypeDef {
   // Returns the builders of this def.
   ArrayRef<AttrOrTypeBuilder> getBuilders() const { return builders; }
 
+  // Returns the traits of this def.
+  ArrayRef<Trait> getTraits() const { return traits; }
+
   // Returns whether two AttrOrTypeDefs are equal by checking the equality of
   // the underlying record.
   bool operator==(const AttrOrTypeDef &other) const;
@@ -136,8 +140,11 @@ class AttrOrTypeDef {
 protected:
   const llvm::Record *def;
 
-  // The builders of this type definition.
+  // The builders of this definition.
   SmallVector<AttrOrTypeBuilder> builders;
+
+  // The traits of this definition.
+  SmallVector<Trait> traits;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index d21b4b213ee41..3da693dde0be9 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -18,9 +18,9 @@
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Builder.h"
 #include "mlir/TableGen/Dialect.h"
-#include "mlir/TableGen/OpTrait.h"
 #include "mlir/TableGen/Region.h"
 #include "mlir/TableGen/Successor.h"
+#include "mlir/TableGen/Trait.h"
 #include "mlir/TableGen/Type.h"
 #include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/SmallVector.h"
@@ -176,9 +176,7 @@ class Operator {
   var_decorator_range getArgDecorators(int index) const;
 
   // Returns the trait wrapper for the given MLIR C++ `trait`.
-  // TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
-  // requiring the raw MLIR trait here.
-  const OpTrait *getTrait(llvm::StringRef trait) const;
+  const Trait *getTrait(llvm::StringRef trait) const;
 
   // Regions.
   using const_region_iterator = const NamedRegion *;
@@ -209,7 +207,7 @@ class Operator {
   unsigned getNumVariadicSuccessors() const;
 
   // Trait.
-  using const_trait_iterator = const OpTrait *;
+  using const_trait_iterator = const Trait *;
   const_trait_iterator trait_begin() const;
   const_trait_iterator trait_end() const;
   llvm::iterator_range<const_trait_iterator> getTraits() const;
@@ -325,7 +323,7 @@ class Operator {
   SmallVector<NamedSuccessor, 0> successors;
 
   // The traits of the op.
-  SmallVector<OpTrait, 4> traits;
+  SmallVector<Trait, 4> traits;
 
   // The regions of this op.
   SmallVector<NamedRegion, 1> regions;

diff  --git a/mlir/include/mlir/TableGen/SideEffects.h b/mlir/include/mlir/TableGen/SideEffects.h
index 7e464476cea11..c5ced6682c16e 100644
--- a/mlir/include/mlir/TableGen/SideEffects.h
+++ b/mlir/include/mlir/TableGen/SideEffects.h
@@ -41,7 +41,7 @@ class SideEffect : public Operator::VariableDecorator {
 // This class represents an instance of a side effect interface applied to an
 // operation. This is a wrapper around an OpInterfaceTrait that also includes
 // the effects that are applied.
-class SideEffectTrait : public InterfaceOpTrait {
+class SideEffectTrait : public InterfaceTrait {
 public:
   // Return the effects that are attached to the side effect interface.
   Operator::var_decorator_range getEffects() const;
@@ -49,7 +49,7 @@ class SideEffectTrait : public InterfaceOpTrait {
   // Return the name of the base C++ effect.
   StringRef getBaseEffectName() const;
 
-  static bool classof(const OpTrait *t);
+  static bool classof(const Trait *t);
 };
 
 } // end namespace tblgen

diff  --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/Trait.h
similarity index 51%
rename from mlir/include/mlir/TableGen/OpTrait.h
rename to mlir/include/mlir/TableGen/Trait.h
index 4ac0c1c4ed25e..52d056d0581e0 100644
--- a/mlir/include/mlir/TableGen/OpTrait.h
+++ b/mlir/include/mlir/TableGen/Trait.h
@@ -1,4 +1,4 @@
-//===- OpTrait.h - OpTrait wrapper class ------------------------*- C++ -*-===//
+//===- Trait.h - Trait wrapper class ----------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// OpTrait wrapper to simplify using TableGen Record defining an MLIR OpTrait.
+// Trait wrapper to simplify using TableGen Record defining an MLIR Trait.
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_TABLEGEN_OPTRAIT_H_
-#define MLIR_TABLEGEN_OPTRAIT_H_
+#ifndef MLIR_TABLEGEN_TRAIT_H_
+#define MLIR_TABLEGEN_TRAIT_H_
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/StringRef.h"
@@ -25,28 +25,28 @@ class Record;
 namespace mlir {
 namespace tblgen {
 
-struct OpInterface;
+class Interface;
 
-// Wrapper class with helper methods for accessing OpTrait constraints defined
-// in TableGen.
-class OpTrait {
+// Wrapper class with helper methods for accessing Trait constraints defined in
+// TableGen.
+class Trait {
 public:
-  // Discriminator for kinds of op traits.
+  // Discriminator for kinds of traits.
   enum class Kind {
-    // OpTrait corresponding to C++ class.
+    // Trait corresponding to C++ class.
     Native,
-    // OpTrait corresponding to predicate on operation.
+    // Trait corresponding to a predicate.
     Pred,
-    // OpTrait controlling op definition generator internals.
+    // Trait controlling definition generator internals.
     Internal,
-    // OpTrait corresponding to OpInterface.
+    // Trait corresponding to an Interface.
     Interface
   };
 
-  explicit OpTrait(Kind kind, const llvm::Record *def);
+  explicit Trait(Kind kind, const llvm::Record *def);
 
-  // Returns an OpTrait corresponding to the init provided.
-  static OpTrait create(const llvm::Init *init);
+  // Returns an Trait corresponding to the init provided.
+  static Trait create(const llvm::Init *init);
 
   Kind getKind() const { return kind; }
 
@@ -59,17 +59,17 @@ class OpTrait {
   Kind kind;
 };
 
-// OpTrait corresponding to a native C++ OpTrait.
-class NativeOpTrait : public OpTrait {
+// Trait corresponding to a native C++ Trait.
+class NativeTrait : public Trait {
 public:
   // Returns the trait corresponding to a C++ trait class.
-  std::string getTrait() const;
+  std::string getFullyQualifiedTraitName() const;
 
-  static bool classof(const OpTrait *t) { return t->getKind() == Kind::Native; }
+  static bool classof(const Trait *t) { return t->getKind() == Kind::Native; }
 };
 
-// OpTrait corresponding to a predicate on the operation.
-class PredOpTrait : public OpTrait {
+// Trait corresponding to a predicate on the operation.
+class PredTrait : public Trait {
 public:
   // Returns the template for constructing the predicate.
   std::string getPredTemplate() const;
@@ -77,30 +77,28 @@ class PredOpTrait : public OpTrait {
   // Returns the description of what the predicate is verifying.
   StringRef getSummary() const;
 
-  static bool classof(const OpTrait *t) { return t->getKind() == Kind::Pred; }
+  static bool classof(const Trait *t) { return t->getKind() == Kind::Pred; }
 };
 
-// OpTrait controlling op definition generator internals.
-class InternalOpTrait : public OpTrait {
+// Trait controlling op definition generator internals.
+class InternalTrait : public Trait {
 public:
   // Returns the trait controlling op definition generator internals.
-  StringRef getTrait() const;
+  StringRef getFullyQualifiedTraitName() const;
 
-  static bool classof(const OpTrait *t) {
-    return t->getKind() == Kind::Internal;
-  }
+  static bool classof(const Trait *t) { return t->getKind() == Kind::Internal; }
 };
 
-// OpTrait corresponding to an OpInterface on the operation.
-class InterfaceOpTrait : public OpTrait {
+// Trait corresponding to an OpInterface on the operation.
+class InterfaceTrait : public Trait {
 public:
-  // Returns member function definitions corresponding to the trait,
-  OpInterface getOpInterface() const;
+  // Returns interface corresponding to the trait.
+  Interface getInterface() const;
 
   // Returns the trait corresponding to a C++ trait class.
-  std::string getTrait() const;
+  std::string getFullyQualifiedTraitName() const;
 
-  static bool classof(const OpTrait *t) {
+  static bool classof(const Trait *t) {
     return t->getKind() == Kind::Interface;
   }
 
@@ -115,4 +113,4 @@ class InterfaceOpTrait : public OpTrait {
 } // end namespace tblgen
 } // end namespace mlir
 
-#endif // MLIR_TABLEGEN_OPTRAIT_H_
+#endif // MLIR_TABLEGEN_TRAIT_H_

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 1e4f5e4becdd6..c439b3c827755 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Dialect.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
@@ -46,6 +47,15 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
       builders.emplace_back(builder);
     }
   }
+
+  // Populate the traits.
+  if (auto *traitList = def->getValueAsListInit("traits")) {
+    SmallPtrSet<const llvm::Init *, 32> traitSet;
+    traits.reserve(traitSet.size());
+    for (auto *traitInit : *traitList)
+      if (traitSet.insert(traitInit).second)
+        traits.push_back(Trait::create(traitInit));
+  }
 }
 
 Dialect AttrOrTypeDef::getDialect() const {

diff  --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index 557caf1c9c1a5..a97419b193216 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -19,13 +19,13 @@ llvm_add_library(MLIRTableGen STATIC
   Interfaces.cpp
   Operator.cpp
   OpClass.cpp
-  OpTrait.cpp
   Pass.cpp
   Pattern.cpp
   Predicate.cpp
   Region.cpp
   SideEffects.cpp
   Successor.cpp
+  Trait.cpp
   Type.cpp
 
   DISABLE_LLVM_LINK_LLVM_DYLIB

diff  --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp
deleted file mode 100644
index 4ebfc9bb531b3..0000000000000
--- a/mlir/lib/TableGen/OpTrait.cpp
+++ /dev/null
@@ -1,75 +0,0 @@
-//===- OpTrait.cpp - OpTrait class ----------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// OpTrait wrapper to simplify using TableGen Record defining a MLIR OpTrait.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/TableGen/OpTrait.h"
-#include "mlir/TableGen/Interfaces.h"
-#include "mlir/TableGen/Predicate.h"
-#include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/TableGen/Error.h"
-#include "llvm/TableGen/Record.h"
-
-using namespace mlir;
-using namespace mlir::tblgen;
-
-OpTrait OpTrait::create(const llvm::Init *init) {
-  auto def = cast<llvm::DefInit>(init)->getDef();
-  if (def->isSubClassOf("PredOpTrait"))
-    return OpTrait(Kind::Pred, def);
-  if (def->isSubClassOf("GenInternalOpTrait"))
-    return OpTrait(Kind::Internal, def);
-  if (def->isSubClassOf("OpInterfaceTrait"))
-    return OpTrait(Kind::Interface, def);
-  assert(def->isSubClassOf("NativeOpTrait"));
-  return OpTrait(Kind::Native, def);
-}
-
-OpTrait::OpTrait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {}
-
-std::string NativeOpTrait::getTrait() const {
-  llvm::StringRef trait = def->getValueAsString("trait");
-  llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
-  return cppNamespace.empty() ? trait.str()
-                              : (cppNamespace + "::" + trait).str();
-}
-
-llvm::StringRef InternalOpTrait::getTrait() const {
-  return def->getValueAsString("trait");
-}
-
-std::string PredOpTrait::getPredTemplate() const {
-  auto pred = Pred(def->getValueInit("predicate"));
-  return pred.getCondition();
-}
-
-llvm::StringRef PredOpTrait::getSummary() const {
-  return def->getValueAsString("summary");
-}
-
-OpInterface InterfaceOpTrait::getOpInterface() const {
-  return OpInterface(def);
-}
-
-std::string InterfaceOpTrait::getTrait() const {
-  llvm::StringRef trait = def->getValueAsString("trait");
-  llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
-  return cppNamespace.empty() ? trait.str()
-                              : (cppNamespace + "::" + trait).str();
-}
-
-bool InterfaceOpTrait::shouldDeclareMethods() const {
-  return def->isSubClassOf("DeclareOpInterfaceMethods");
-}
-
-std::vector<StringRef> InterfaceOpTrait::getAlwaysDeclaredMethods() const {
-  return def->getValueAsListOfStrings("alwaysOverriddenMethods");
-}

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 209c1ec0d94a7..c1ba549e8abfb 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -11,8 +11,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/TableGen/Operator.h"
-#include "mlir/TableGen/OpTrait.h"
 #include "mlir/TableGen/Predicate.h"
+#include "mlir/TableGen/Trait.h"
 #include "mlir/TableGen/Type.h"
 #include "llvm/ADT/EquivalenceClasses.h"
 #include "llvm/ADT/STLExtras.h"
@@ -158,17 +158,17 @@ auto Operator::getArgDecorators(int index) const -> var_decorator_range {
   return *arg->getValueAsListInit("decorators");
 }
 
-const OpTrait *Operator::getTrait(StringRef trait) const {
+const Trait *Operator::getTrait(StringRef trait) const {
   for (const auto &t : traits) {
-    if (const auto *opTrait = dyn_cast<NativeOpTrait>(&t)) {
-      if (opTrait->getTrait() == trait)
-        return opTrait;
-    } else if (const auto *opTrait = dyn_cast<InternalOpTrait>(&t)) {
-      if (opTrait->getTrait() == trait)
-        return opTrait;
-    } else if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&t)) {
-      if (opTrait->getTrait() == trait)
-        return opTrait;
+    if (const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
+      if (traitDef->getFullyQualifiedTraitName() == trait)
+        return traitDef;
+    } else if (const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
+      if (traitDef->getFullyQualifiedTraitName() == trait)
+        return traitDef;
+    } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
+      if (traitDef->getFullyQualifiedTraitName() == trait)
+        return traitDef;
     }
   }
   return nullptr;
@@ -314,7 +314,7 @@ void Operator::populateTypeInferenceInfo(
     return found;
   };
 
-  for (const OpTrait &trait : traits) {
+  for (const Trait &trait : traits) {
     const llvm::Record &def = trait.getDef();
     // If the infer type op interface was manually added, then treat it as
     // intention that the op needs special handling.
@@ -323,8 +323,8 @@ void Operator::populateTypeInferenceInfo(
     if (def.isSubClassOf(
             llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
       return;
-    if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&trait))
-      if (&opTrait->getDef() == inferTrait)
+    if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
+      if (&traitDef->getDef() == inferTrait)
         return;
 
     if (!def.isSubClassOf("AllTypesMatch"))
@@ -344,7 +344,7 @@ void Operator::populateTypeInferenceInfo(
 
   // If the types could be computed, then add type inference trait.
   if (allResultsHaveKnownTypes)
-    traits.push_back(OpTrait::create(inferTrait->getDefInit()));
+    traits.push_back(Trait::create(inferTrait->getDefInit()));
 }
 
 void Operator::populateOpStructure() {
@@ -489,7 +489,7 @@ void Operator::populateOpStructure() {
     for (auto *traitInit : *traitList) {
       // Keep traits in the same order while skipping over duplicates.
       if (traitSet.insert(traitInit).second)
-        traits.push_back(OpTrait::create(traitInit));
+        traits.push_back(Trait::create(traitInit));
     }
   }
 

diff  --git a/mlir/lib/TableGen/SideEffects.cpp b/mlir/lib/TableGen/SideEffects.cpp
index 286cacfdacf8b..a635f198c595a 100644
--- a/mlir/lib/TableGen/SideEffects.cpp
+++ b/mlir/lib/TableGen/SideEffects.cpp
@@ -53,6 +53,6 @@ StringRef SideEffectTrait::getBaseEffectName() const {
   return def->getValueAsString("baseEffectName");
 }
 
-bool SideEffectTrait::classof(const OpTrait *t) {
+bool SideEffectTrait::classof(const Trait *t) {
   return t->getDef().isSubClassOf("SideEffectsTraitBase");
 }

diff  --git a/mlir/lib/TableGen/Trait.cpp b/mlir/lib/TableGen/Trait.cpp
new file mode 100644
index 0000000000000..02bb4d4de64af
--- /dev/null
+++ b/mlir/lib/TableGen/Trait.cpp
@@ -0,0 +1,93 @@
+//===- Trait.cpp ----------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Trait wrapper to simplify using TableGen Record defining a MLIR Trait.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Trait.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "mlir/TableGen/Predicate.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+//===----------------------------------------------------------------------===//
+// Trait
+//===----------------------------------------------------------------------===//
+
+Trait Trait::create(const llvm::Init *init) {
+  auto def = cast<llvm::DefInit>(init)->getDef();
+  if (def->isSubClassOf("PredTrait"))
+    return Trait(Kind::Pred, def);
+  if (def->isSubClassOf("GenInternalTrait"))
+    return Trait(Kind::Internal, def);
+  if (def->isSubClassOf("InterfaceTrait"))
+    return Trait(Kind::Interface, def);
+  assert(def->isSubClassOf("NativeTrait"));
+  return Trait(Kind::Native, def);
+}
+
+Trait::Trait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {}
+
+//===----------------------------------------------------------------------===//
+// NativeTrait
+//===----------------------------------------------------------------------===//
+
+std::string NativeTrait::getFullyQualifiedTraitName() const {
+  llvm::StringRef trait = def->getValueAsString("trait");
+  llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
+  return cppNamespace.empty() ? trait.str()
+                              : (cppNamespace + "::" + trait).str();
+}
+
+//===----------------------------------------------------------------------===//
+// InternalTrait
+//===----------------------------------------------------------------------===//
+
+llvm::StringRef InternalTrait::getFullyQualifiedTraitName() const {
+  return def->getValueAsString("trait");
+}
+
+//===----------------------------------------------------------------------===//
+// PredTrait
+//===----------------------------------------------------------------------===//
+
+std::string PredTrait::getPredTemplate() const {
+  auto pred = Pred(def->getValueInit("predicate"));
+  return pred.getCondition();
+}
+
+llvm::StringRef PredTrait::getSummary() const {
+  return def->getValueAsString("summary");
+}
+
+//===----------------------------------------------------------------------===//
+// InterfaceTrait
+//===----------------------------------------------------------------------===//
+
+Interface InterfaceTrait::getInterface() const { return Interface(def); }
+
+std::string InterfaceTrait::getFullyQualifiedTraitName() const {
+  llvm::StringRef trait = def->getValueAsString("trait");
+  llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
+  return cppNamespace.empty() ? trait.str()
+                              : (cppNamespace + "::" + trait).str();
+}
+
+bool InterfaceTrait::shouldDeclareMethods() const {
+  return def->isSubClassOf("DeclareInterfaceMethods");
+}
+
+std::vector<StringRef> InterfaceTrait::getAlwaysDeclaredMethods() const {
+  return def->getValueAsListOfStrings("alwaysOverriddenMethods");
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index d4c9650e85122..ef21d9f9413ec 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -14,6 +14,7 @@ include "mlir/Interfaces/SideEffectInterfaceBase.td"
 
 // A type interface used to test the ODS generation of type interfaces.
 def TestTypeInterface : TypeInterface<"TestTypeInterface"> {
+  let cppNamespace = "::mlir::test";
   let methods = [
     InterfaceMethod<"Prints the type name.",
       "void", "printTypeA", (ins "Location":$loc), [{

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 5d23bb5e22404..9821774eeede7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -15,9 +15,11 @@
 
 // To get the test dialect def.
 include "TestOps.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
 
 // All of the types will extend this class.
-class Test_Type<string name> : TypeDef<Test_Dialect, name> { }
+class Test_Type<string name, list<Trait> traits = []>
+  : TypeDef<Test_Dialect, name, traits>;
 
 def SimpleTypeA : Test_Type<"SimpleA"> {
   let mnemonic = "smpla";
@@ -151,4 +153,27 @@ def StructType : FieldInfo_Type<"Struct"> {
     let mnemonic = "struct";
 }
 
+def TestType : Test_Type<"Test", [
+  DeclareTypeInterfaceMethods<TestTypeInterface>
+]> {
+  let mnemonic = "test_type";
+}
+
+def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
+  DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["areCompatible"]>
+]> {
+  let mnemonic = "test_type_with_layout";
+  let parameters = (ins "unsigned":$key);
+  let extraClassDeclaration = [{
+    LogicalResult verifyEntries(DataLayoutEntryListRef params,
+                                Location loc) const;
+
+  private:
+    unsigned extractKind(DataLayoutEntryListRef params,
+                         StringRef expectedKind) const;
+
+  public:
+  }];
+}
+
 #endif // TEST_TYPEDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 38ab9c819974d..817b66a000829 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -58,6 +58,34 @@ static void printSignedness(DialectAsmPrinter &printer,
   }
 }
 
+// The functions don't need to be in the header file, but need to be in the mlir
+// namespace. Declare them here, then define them immediately below. Separating
+// the declaration and definition adheres to the LLVM coding standards.
+namespace mlir {
+namespace test {
+// FieldInfo is used as part of a parameter, so equality comparison is
+// compulsory.
+static bool operator==(const FieldInfo &a, const FieldInfo &b);
+// FieldInfo is used as part of a parameter, so a hash will be computed.
+static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
+} // namespace test
+} // namespace mlir
+
+// FieldInfo is used as part of a parameter, so equality comparison is
+// compulsory.
+static bool mlir::test::operator==(const FieldInfo &a, const FieldInfo &b) {
+  return a.name == b.name && a.type == b.type;
+}
+
+// FieldInfo is used as part of a parameter, so a hash will be computed.
+static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
+  return llvm::hash_combine(fi.name, fi.type);
+}
+
+//===----------------------------------------------------------------------===//
+// CompoundAType
+//===----------------------------------------------------------------------===//
+
 Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) {
   int widthOfSomething;
   Type oneType;
@@ -87,29 +115,9 @@ void CompoundAType::print(DialectAsmPrinter &printer) const {
   printer << "]>";
 }
 
-// The functions don't need to be in the header file, but need to be in the mlir
-// namespace. Declare them here, then define them immediately below. Separating
-// the declaration and definition adheres to the LLVM coding standards.
-namespace mlir {
-namespace test {
-// FieldInfo is used as part of a parameter, so equality comparison is
-// compulsory.
-static bool operator==(const FieldInfo &a, const FieldInfo &b);
-// FieldInfo is used as part of a parameter, so a hash will be computed.
-static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
-} // namespace test
-} // namespace mlir
-
-// FieldInfo is used as part of a parameter, so equality comparison is
-// compulsory.
-static bool mlir::test::operator==(const FieldInfo &a, const FieldInfo &b) {
-  return a.name == b.name && a.type == b.type;
-}
-
-// FieldInfo is used as part of a parameter, so a hash will be computed.
-static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
-  return llvm::hash_combine(fi.name, fi.type);
-}
+//===----------------------------------------------------------------------===//
+// TestIntegerType
+//===----------------------------------------------------------------------===//
 
 // Example type validity checker.
 LogicalResult
@@ -122,18 +130,58 @@ TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 //===----------------------------------------------------------------------===//
-// Tablegen Generated Definitions
+// TestType
 //===----------------------------------------------------------------------===//
 
-#define GET_TYPEDEF_CLASSES
-#include "TestTypeDefs.cpp.inc"
+void TestType::printTypeC(Location loc) const {
+  emitRemark(loc) << *this << " - TestC";
+}
+
+//===----------------------------------------------------------------------===//
+// TestTypeWithLayout
+//===----------------------------------------------------------------------===//
+
+Type TestTypeWithLayoutType::parse(MLIRContext *ctx, DialectAsmParser &parser) {
+  unsigned val;
+  if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater())
+    return Type();
+  return TestTypeWithLayoutType::get(ctx, val);
+}
+
+void TestTypeWithLayoutType::print(DialectAsmPrinter &printer) const {
+  printer << "test_type_with_layout<" << getKey() << ">";
+}
+
+unsigned
+TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout,
+                                          DataLayoutEntryListRef params) const {
+  return extractKind(params, "size");
+}
 
-LogicalResult TestTypeWithLayout::verifyEntries(DataLayoutEntryListRef params,
-                                                Location loc) const {
+unsigned
+TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout,
+                                        DataLayoutEntryListRef params) const {
+  return extractKind(params, "alignment");
+}
+
+unsigned TestTypeWithLayoutType::getPreferredAlignment(
+    const DataLayout &dataLayout, DataLayoutEntryListRef params) const {
+  return extractKind(params, "preferred");
+}
+
+bool TestTypeWithLayoutType::areCompatible(
+    DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const {
+  unsigned old = extractKind(oldLayout, "alignment");
+  return old == 1 || extractKind(newLayout, "alignment") <= old;
+}
+
+LogicalResult
+TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
+                                      Location loc) const {
   for (DataLayoutEntryInterface entry : params) {
     // This is for testing purposes only, so assert well-formedness.
     assert(entry.isTypeEntry() && "unexpected identifier entry");
-    assert(entry.getKey().get<Type>().isa<TestTypeWithLayout>() &&
+    assert(entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() &&
            "wrong type passed in");
     auto array = entry.getValue().dyn_cast<ArrayAttr>();
     assert(array && array.getValue().size() == 2 &&
@@ -149,8 +197,8 @@ LogicalResult TestTypeWithLayout::verifyEntries(DataLayoutEntryListRef params,
   return success();
 }
 
-unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params,
-                                         StringRef expectedKind) const {
+unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
+                                             StringRef expectedKind) const {
   for (DataLayoutEntryInterface entry : params) {
     ArrayRef<Attribute> pair = entry.getValue().cast<ArrayAttr>().getValue();
     StringRef kind = pair.front().cast<StringAttr>().getValue();
@@ -160,12 +208,19 @@ unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params,
   return 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Tablegen Generated Definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.cpp.inc"
+
 //===----------------------------------------------------------------------===//
 // TestDialect
 //===----------------------------------------------------------------------===//
 
 void TestDialect::registerTypes() {
-  addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
+  addTypes<TestRecursiveType,
 #define GET_TYPEDEF_LIST
 #include "TestTypeDefs.cpp.inc"
            >();
@@ -183,17 +238,6 @@ static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
     if (parseResult.hasValue())
       return genType;
   }
-  if (typeTag == "test_type")
-    return TestType::get(parser.getBuilder().getContext());
-
-  if (typeTag == "test_type_with_layout") {
-    unsigned val;
-    if (parser.parseLess() || parser.parseInteger(val) ||
-        parser.parseGreater()) {
-      return Type();
-    }
-    return TestTypeWithLayout::get(parser.getBuilder().getContext(), val);
-  }
 
   if (typeTag != "test_rec") {
     parser.emitError(parser.getNameLoc()) << "unknown type!";
@@ -234,15 +278,6 @@ static void printTestType(Type type, DialectAsmPrinter &printer,
                           llvm::SetVector<Type> &stack) {
   if (succeeded(generatedTypePrinter(type, printer)))
     return;
-  if (type.isa<TestType>()) {
-    printer << "test_type";
-    return;
-  }
-
-  if (auto t = type.dyn_cast<TestTypeWithLayout>()) {
-    printer << "test_type_with_layout<" << t.getKey() << ">";
-    return;
-  }
 
   auto rec = type.cast<TestRecursiveType>();
   printer << "test_rec<" << rec.getName();

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index e2a65a2aa12df..f9a0289f20b01 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -41,25 +41,14 @@ struct FieldInfo {
 } // namespace test
 } // namespace mlir
 
+#include "TestTypeInterfaces.h.inc"
+
 #define GET_TYPEDEF_CLASSES
 #include "TestTypeDefs.h.inc"
 
 namespace mlir {
 namespace test {
 
-#include "TestTypeInterfaces.h.inc"
-
-/// This class is a simple test type that uses a generated interface.
-struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
-                                        TestTypeInterface::Trait> {
-  using Base::Base;
-
-  /// Provide a definition for the necessary interface methods.
-  void printTypeC(Location loc) const {
-    emitRemark(loc) << *this << " - TestC";
-  }
-};
-
 /// Storage for simple named recursive types, where the type is identified by
 /// its name and can "contain" another type, including itself.
 struct TestRecursiveTypeStorage : public TypeStorage {
@@ -108,62 +97,6 @@ class TestRecursiveType
   StringRef getName() { return getImpl()->name; }
 };
 
-struct TestTypeWithLayoutStorage : public TypeStorage {
-  using KeyTy = unsigned;
-
-  explicit TestTypeWithLayoutStorage(unsigned key) : key(key) {}
-  bool operator==(const KeyTy &other) const { return other == key; }
-
-  static TestTypeWithLayoutStorage *construct(TypeStorageAllocator &allocator,
-                                              const KeyTy &key) {
-    return new (allocator.allocate<TestTypeWithLayoutStorage>())
-        TestTypeWithLayoutStorage(key);
-  }
-
-  unsigned key;
-};
-
-class TestTypeWithLayout
-    : public Type::TypeBase<TestTypeWithLayout, Type, TestTypeWithLayoutStorage,
-                            DataLayoutTypeInterface::Trait> {
-public:
-  using Base::Base;
-
-  static TestTypeWithLayout get(MLIRContext *ctx, unsigned key) {
-    return Base::get(ctx, key);
-  }
-
-  unsigned getKey() { return getImpl()->key; }
-
-  unsigned getTypeSizeInBits(const DataLayout &dataLayout,
-                             DataLayoutEntryListRef params) const {
-    return extractKind(params, "size");
-  }
-
-  unsigned getABIAlignment(const DataLayout &dataLayout,
-                           DataLayoutEntryListRef params) const {
-    return extractKind(params, "alignment");
-  }
-
-  unsigned getPreferredAlignment(const DataLayout &dataLayout,
-                                 DataLayoutEntryListRef params) const {
-    return extractKind(params, "preferred");
-  }
-
-  bool areCompatible(DataLayoutEntryListRef oldLayout,
-                     DataLayoutEntryListRef newLayout) const {
-    unsigned old = extractKind(oldLayout, "alignment");
-    return old == 1 || extractKind(newLayout, "alignment") <= old;
-  }
-
-  LogicalResult verifyEntries(DataLayoutEntryListRef params,
-                              Location loc) const;
-
-private:
-  unsigned extractKind(DataLayoutEntryListRef params,
-                       StringRef expectedKind) const;
-};
-
 } // namespace test
 } // namespace mlir
 

diff  --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index fc95fba3c91cd..a5a41b3039918 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -78,16 +78,16 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
 // DEF: CompoundAAttrStorage (
 // DEF-NEXT: : ::mlir::AttributeStorage(inner),
 
-// DEF: bool operator==(const KeyTy &key) const {
-// DEF-NEXT: if (!(widthOfSomething == std::get<0>(key)))
+// DEF: bool operator==(const KeyTy &tblgenKey) const {
+// DEF-NEXT: if (!(widthOfSomething == std::get<0>(tblgenKey)))
 // DEF-NEXT:   return false;
-// DEF-NEXT: if (!(exampleTdType == std::get<1>(key)))
+// DEF-NEXT: if (!(exampleTdType == std::get<1>(tblgenKey)))
 // DEF-NEXT:   return false;
-// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(key))))
+// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(tblgenKey))))
 // DEF-NEXT:   return false;
-// DEF-NEXT: if (!(dims == std::get<3>(key)))
+// DEF-NEXT: if (!(dims == std::get<3>(tblgenKey)))
 // DEF-NEXT:   return false;
-// DEF-NEXT: if (!(getType() == std::get<4>(key)))
+// DEF-NEXT: if (!(getType() == std::get<4>(tblgenKey)))
 // DEF-NEXT:   return false;
 // DEF-NEXT: return true;
 

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index a951df92fe180..c8910eab69ece 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -11,8 +11,10 @@
 #include "mlir/TableGen/CodeGenHelpers.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Interfaces.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/TableGenBackend.h"
@@ -208,28 +210,29 @@ class DialectAsmPrinter;
 /// {1}: The name of the type base class.
 /// {2}: The name of the base value type, e.g. Attribute or Type.
 /// {3}: The tablegen record type prefix, e.g. Attr or Type.
+/// {4}: The traits of the def class.
 static const char *const defDeclSingletonBeginStr = R"(
-  class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage> {{
+  class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage{4}> {{
   public:
     /// Inherit some necessary constructors from '{3}Base'.
     using Base::Base;
 )";
 
-/// The code block for the start of a typeDef class declaration -- parametric
-/// case.
+/// The code block for the start of a class declaration -- parametric case.
 ///
-/// {0}: The name of the typeDef class.
-/// {1}: The name of the type base class.
-/// {2}: The typeDef storage class namespace.
+/// {0}: The name of the def class.
+/// {1}: The name of the base class.
+/// {2}: The def storage class namespace.
 /// {3}: The storage class name.
 /// {4}: The name of the base value type, e.g. Attribute or Type.
 /// {5}: The tablegen record type prefix, e.g. Attr or Type.
+/// {6}: The traits of the def class.
 static const char *const defDeclParametricBeginStr = R"(
   namespace {2} {
     struct {3};
   } // end namespace {2}
   class {0} : public ::mlir::{4}::{5}Base<{0}, {1},
-                                         {2}::{3}> {{
+                                         {2}::{3}{6}> {{
   public:
     /// Inherit some necessary constructors from '{5}Base'.
     using Base::Base;
@@ -309,19 +312,71 @@ static void emitBuilderDecls(const AttrOrTypeDef &def, raw_ostream &os,
   }
 }
 
+static void emitInterfaceMethodDecls(const InterfaceTrait *trait,
+                                     raw_ostream &os) {
+  Interface interface = trait->getInterface();
+
+  // Get the set of methods that should always be declared.
+  auto alwaysDeclaredMethodsVec = trait->getAlwaysDeclaredMethods();
+  llvm::StringSet<> alwaysDeclaredMethods;
+  alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
+                               alwaysDeclaredMethodsVec.end());
+
+  for (const InterfaceMethod &method : interface.getMethods()) {
+    // Don't declare if the method has a body.
+    if (method.getBody())
+      continue;
+    // Don't declare if the method has a default implementation and the def
+    // didn't request that it always be declared.
+    if (method.getDefaultImplementation() &&
+        !alwaysDeclaredMethods.count(method.getName()))
+      continue;
+
+    // Emit the method declaration.
+    os << "    " << (method.isStatic() ? "static " : "")
+       << method.getReturnType() << " " << method.getName() << "(";
+    llvm::interleaveComma(method.getArguments(), os,
+                          [&](const InterfaceMethod::Argument &arg) {
+                            os << arg.type << " " << arg.name;
+                          });
+    os << ")" << (method.isStatic() ? "" : " const") << ";\n";
+  }
+}
+
 void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
   SmallVector<AttrOrTypeParameter, 4> params;
   def.getParameters(params);
 
+  // Build the trait list for this def.
+  std::vector<std::string> traitList;
+  StringSet<> traitSet;
+  for (const Trait &baseTrait : def.getTraits()) {
+    std::string traitStr;
+    if (const auto *trait = dyn_cast<NativeTrait>(&baseTrait))
+      traitStr = trait->getFullyQualifiedTraitName();
+    else if (const auto *trait = dyn_cast<InterfaceTrait>(&baseTrait))
+      traitStr = trait->getFullyQualifiedTraitName();
+    else
+      llvm_unreachable("unexpected Attribute/Type trait type");
+
+    if (traitSet.insert(traitStr).second)
+      traitList.emplace_back(std::move(traitStr));
+  }
+  std::string traitStr;
+  if (!traitList.empty())
+    traitStr = ", " + llvm::join(traitList, ", ");
+
   // Emit the beginning string template: either the singleton or parametric
   // template.
   if (def.getNumParameters() == 0) {
     os << formatv(defDeclSingletonBeginStr, def.getCppClassName(),
-                  def.getCppBaseClassName(), valueType, defTypePrefix);
+                  def.getCppBaseClassName(), valueType, defTypePrefix,
+                  traitStr);
   } else {
     os << formatv(defDeclParametricBeginStr, def.getCppClassName(),
                   def.getCppBaseClassName(), def.getStorageNamespace(),
-                  def.getStorageClassName(), valueType, defTypePrefix);
+                  def.getStorageClassName(), valueType, defTypePrefix,
+                  traitStr);
   }
 
   // Emit the extra declarations first in case there's a definition in there.
@@ -362,6 +417,14 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
     }
   }
 
+  // Emit any interface method declarations.
+  for (const Trait &trait : def.getTraits()) {
+    if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait)) {
+      if (traitDef->shouldDeclareMethods())
+        emitInterfaceMethodDecls(traitDef, os);
+    }
+  }
+
   // End the decl.
   os << "  };\n";
 }
@@ -452,7 +515,7 @@ static const char *const defStorageClassConstructorBeginStr = R"(
     /// Define a construction method for creating a new instance of this
     /// storage.
     static {0} *construct(::mlir::{1}StorageAllocator &allocator,
-                          const KeyTy &key) {{
+                          const KeyTy &tblgenKey) {{
 )";
 
 /// The storage class' constructor return template.
@@ -558,7 +621,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
       paramInitializer, parameterTypeList, valueType);
 
   // * Emit the comparison method.
-  os << "  bool operator==(const KeyTy &key) const {\n";
+  os << "  bool operator==(const KeyTy &tblgenKey) const {\n";
   for (auto it : llvm::enumerate(params)) {
     os << "    if (!(";
 
@@ -566,7 +629,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
     bool isSelfType = isa<AttributeSelfTypeParameter>(it.value());
     FmtContext context;
     context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName())
-        .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(key)");
+        .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(tblgenKey)");
 
     // Use the parameter specified comparator if possible, otherwise default to
     // operator==.
@@ -577,13 +640,13 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
   os << "    return true;\n  }\n";
 
   // * Emit the haskKey method.
-  os << "  static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
+  os << "  static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {\n";
 
   // Extract each parameter from the key.
   os << "      return ::llvm::hash_combine(";
   llvm::interleaveComma(
       llvm::seq<unsigned>(0, params.size()), os,
-      [&](unsigned it) { os << "std::get<" << it << ">(key)"; });
+      [&](unsigned it) { os << "std::get<" << it << ">(tblgenKey)"; });
   os << ");\n    }\n";
 
   // * Emit the construct method.
@@ -592,7 +655,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
   // here and then they can write the definition elsewhere.
   if (def.hasStorageCustomConstructor()) {
     os << llvm::formatv("    static {0} *construct(::mlir::{1}StorageAllocator "
-                        "&allocator, const KeyTy &key);\n",
+                        "&allocator, const KeyTy &tblgenKey);\n",
                         def.getStorageClassName(), valueType);
 
     // Otherwise, generate one.
@@ -601,7 +664,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
     os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(),
                   valueType);
     for (unsigned i = 0, e = params.size(); i < e; ++i) {
-      os << formatv("      auto {0} = std::get<{1}>(key);\n",
+      os << formatv("      auto {0} = std::get<{1}>(tblgenKey);\n",
                     params[i].getName(), i);
     }
 

diff  --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index a2c7dd5673bbf..350cc6f18f9cb 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -15,8 +15,8 @@
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Interfaces.h"
 #include "mlir/TableGen/OpClass.h"
-#include "mlir/TableGen/OpTrait.h"
 #include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Trait.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/CommandLine.h"

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 25e4d3360b3ee..413d45a6b8d65 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -18,9 +18,9 @@
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Interfaces.h"
 #include "mlir/TableGen/OpClass.h"
-#include "mlir/TableGen/OpTrait.h"
 #include "mlir/TableGen/Operator.h"
 #include "mlir/TableGen/SideEffects.h"
+#include "mlir/TableGen/Trait.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Path.h"
@@ -430,7 +430,7 @@ class OpEmitter {
   void genOpInterfaceMethods();
 
   // Generate op interface methods for the given interface.
-  void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
+  void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
 
   // Generate op interface method for the given interface method. If
   // 'declaration' is true, generates a declaration, else a definition.
@@ -1719,8 +1719,8 @@ void OpEmitter::genFolderDecls() {
   }
 }
 
-void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
-  auto interface = opTrait->getOpInterface();
+void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
+  Interface interface = opTrait->getInterface();
 
   // Get the set of methods that should always be declared.
   auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
@@ -1757,7 +1757,7 @@ OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
 
 void OpEmitter::genOpInterfaceMethods() {
   for (const auto &trait : op.getTraits()) {
-    if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
+    if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
       if (opTrait->shouldDeclareMethods())
         genOpInterfaceMethods(opTrait);
   }
@@ -1866,9 +1866,9 @@ void OpEmitter::genTypeInterfaceMethods() {
     return;
   // Generate 'inferReturnTypes' method declaration using the interface method
   // declared in 'InferTypeOpInterface' op interface.
-  const auto *trait = dyn_cast<InterfaceOpTrait>(
+  const auto *trait = dyn_cast<InterfaceTrait>(
       op.getTrait("::mlir::InferTypeOpInterface::Trait"));
-  auto interface = trait->getOpInterface();
+  Interface interface = trait->getInterface();
   OpMethod *method = [&]() -> OpMethod * {
     for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
       if (interfaceMethod.getName() == "inferReturnTypes") {
@@ -1966,7 +1966,7 @@ void OpEmitter::genVerifier() {
   genOperandResultVerifier(body, op.getResults(), "result");
 
   for (auto &trait : op.getTraits()) {
-    if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
+    if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
       body << tgfmt("  if (!($0))\n    "
                     "return emitOpError(\"failed to verify that $1\");\n",
                     &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
@@ -2187,10 +2187,10 @@ void OpEmitter::genTraits() {
 
   // Add the native and interface traits.
   for (const auto &trait : op.getTraits()) {
-    if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
-      opClass.addTrait(opTrait->getTrait());
-    else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
-      opClass.addTrait(opTrait->getTrait());
+    if (auto opTrait = dyn_cast<tblgen::NativeTrait>(&trait))
+      opClass.addTrait(opTrait->getFullyQualifiedTraitName());
+    else if (auto opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
+      opClass.addTrait(opTrait->getFullyQualifiedTraitName());
   }
 }
 
@@ -2379,12 +2379,14 @@ void OpOperandAdaptorEmitter::addVerification() {
   // Verify a few traits first so that we can use
   // getODSOperands()/getODSResults() in the rest of the verifier.
   for (auto &trait : op.getTraits()) {
-    if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
-      if (t->getTrait() == "::mlir::OpTrait::AttrSizedOperandSegments") {
+    if (auto *t = dyn_cast<tblgen::NativeTrait>(&trait)) {
+      if (t->getFullyQualifiedTraitName() ==
+          "::mlir::OpTrait::AttrSizedOperandSegments") {
         body << formatv(checkAttrSizedValueSegmentsCode,
                         "operand_segment_sizes", op.getNumOperands(),
                         "operand");
-      } else if (t->getTrait() == "::mlir::OpTrait::AttrSizedResultSegments") {
+      } else if (t->getFullyQualifiedTraitName() ==
+                 "::mlir::OpTrait::AttrSizedResultSegments") {
         body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
                         op.getNumResults(), "result");
       }

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 93d2155c3c821..3c3f00f2379c8 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -12,8 +12,8 @@
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Interfaces.h"
 #include "mlir/TableGen/OpClass.h"
-#include "mlir/TableGen/OpTrait.h"
 #include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Trait.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SetVector.h"
@@ -445,18 +445,12 @@ struct OperationFormat {
     operandTypes.resize(op.getNumOperands(), TypeResolution());
     resultTypes.resize(op.getNumResults(), TypeResolution());
 
-    hasImplicitTermTrait =
-        llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
-          return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
-        });
+    hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
+      return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
+    });
 
     hasSingleBlockTrait =
-        hasImplicitTermTrait ||
-        llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
-          if (auto *native = dyn_cast<NativeOpTrait>(&trait))
-            return native->getTrait() == "::mlir::OpTrait::SingleBlock";
-          return false;
-        });
+        hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock");
   }
 
   /// Generate the operation parser from this format.
@@ -2416,7 +2410,7 @@ LogicalResult FormatParser::parse() {
 
   // Check for any type traits that we can use for inferring types.
   llvm::StringMap<TypeResolutionInstance> variableTyResolver;
-  for (const OpTrait &trait : op.getTraits()) {
+  for (const Trait &trait : op.getTraits()) {
     const llvm::Record &def = trait.getDef();
     if (def.isSubClassOf("AllTypesMatch")) {
       handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),


        


More information about the Mlir-commits mailing list