[llvm-branch-commits] [mlir] 95019de - [mlir][IR] Define the singleton builtin types in ODS instead of C++

River Riddle via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Dec 15 13:56:03 PST 2020


Author: River Riddle
Date: 2020-12-15T13:42:19-08:00
New Revision: 95019de8a122619fc038c9fe3c80e625e3456bbf

URL: https://github.com/llvm/llvm-project/commit/95019de8a122619fc038c9fe3c80e625e3456bbf
DIFF: https://github.com/llvm/llvm-project/commit/95019de8a122619fc038c9fe3c80e625e3456bbf.diff

LOG: [mlir][IR] Define the singleton builtin types in ODS instead of C++

This exposes several issues with the current generation that this revision also fixes.
 * TypeDef now allows specifying the base class to use when generating.
 * TypeDef now inherits from DialectType, which allows for using it as a TypeConstraint
 * Parser/Printers are now no longer generated in the header(removing duplicate symbols), and are now only generated when necessary.
    - Now that generatedTypeParser/Printer are only generated in the definition file,
      existing users will need to manually expose this functionality when necessary.
 * ::get() is no longer generated for singleton types, because it isn't necessary.

Differential Revision: https://reviews.llvm.org/D93270

Added: 
    mlir/include/mlir/IR/BuiltinDialect.td
    mlir/include/mlir/IR/BuiltinTypes.td

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/IR/BuiltinOps.td
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/CMakeLists.txt
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/TypeDef.h
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/CMakeLists.txt
    mlir/lib/TableGen/Constraint.cpp
    mlir/lib/TableGen/TypeDef.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    mlir/test/mlir-tblgen/typedefs.td
    mlir/tools/mlir-tblgen/TypeDefGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 189cd0825af7..c5ffe452b927 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1370,10 +1370,10 @@ llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t value) {
 
 ## Type Definitions
 
-MLIR defines the TypeDef class hierarchy to enable generation of data types
-from their specifications. A type is defined by specializing the TypeDef
-class with concrete contents for all the fields it requires. For example, an
-integer type could be defined as:
+MLIR defines the TypeDef class hierarchy to enable generation of data types from
+their specifications. A type is defined by specializing the TypeDef class with
+concrete contents for all the fields it requires. For example, an integer type
+could be defined as:
 
 ```tablegen
 // All of the types will extend this class.
@@ -1414,45 +1414,43 @@ def IntegerType : Test_Type<"TestInteger"> {
 ### Type name
 
 The name of the C++ class which gets generated defaults to
-`<classParamName>Type` (e.g. `TestIntegerType` in the above example). This
-can be overridden via the `cppClassName` field. The field `mnemonic` is
-to specify the asm name for parsing. It is optional and not specifying it
-will imply that no parser or printer methods are attached to this class.
+`<classParamName>Type` (e.g. `TestIntegerType` in the above example). This can
+be overridden via the `cppClassName` field. The field `mnemonic` is to specify
+the asm name for parsing. It is optional and not specifying it will imply that
+no parser or printer methods are attached to this class.
 
 ### Type documentation
 
-The `summary` and `description` fields exist and are to be used the same way
-as in Operations. Namely, the summary should be a one-liner and `description`
+The `summary` and `description` fields exist and are to be used the same way as
+in Operations. Namely, the summary should be a one-liner and `description`
 should be a longer explanation.
 
 ### Type parameters
 
-The `parameters` field is a list of the types parameters. If no parameters
-are specified (the default), this type is considered a singleton type.
-Parameters are in the `"c++Type":$paramName` format.
-To use C++ types as parameters which need allocation in the storage
-constructor, there are two options:
+The `parameters` field is a list of the types parameters. If no parameters are
+specified (the default), this type is considered a singleton type. Parameters
+are in the `"c++Type":$paramName` format. To use C++ types as parameters which
+need allocation in the storage constructor, there are two options:
 
-- Set `hasCustomStorageConstructor` to generate the TypeStorage class with
-a constructor which is just declared -- no definition -- so you can write it
-yourself.
-- Use the `TypeParameter` tablegen class instead of the "c++Type" string.
+-   Set `hasCustomStorageConstructor` to generate the TypeStorage class with a
+    constructor which is just declared -- no definition -- so you can write it
+    yourself.
+-   Use the `TypeParameter` tablegen class instead of the "c++Type" string.
 
 ### TypeParameter tablegen class
 
-This is used to further specify attributes about each of the types
-parameters. It includes documentation (`description` and `syntax`), the C++
-type to use, and a custom allocator to use in the storage constructor method.
+This is used to further specify attributes about each of the types parameters.
+It includes documentation (`description` and `syntax`), the C++ type to use, and
+a custom allocator to use in the storage constructor method.
 
 ```tablegen
 // DO NOT DO THIS!
-let parameters = (ins
-  "ArrayRef<int>":$dims);
+let parameters = (ins "ArrayRef<int>":$dims);
 ```
 
-The default storage constructor blindly copies fields by value. It does not
-know anything about the types. In this case, the ArrayRef<int> requires
-allocation with `dims = allocator.copyInto(dims)`.
+The default storage constructor blindly copies fields by value. It does not know
+anything about the types. In this case, the ArrayRef<int> requires allocation
+with `dims = allocator.copyInto(dims)`.
 
 You can specify the necessary constructor by specializing the `TypeParameter`
 tblgen class:
@@ -1460,28 +1458,29 @@ tblgen class:
 ```tablegen
 class ArrayRefIntParam :
     TypeParameter<"::llvm::ArrayRef<int>", "Array of ints"> {
-  let allocator = [{$_dst = $_allocator.copyInto($_self);}];
+  let allocator = "$_dst = $_allocator.copyInto($_self);";
 }
 
 ...
 
-let parameters = (ins
-  ArrayRefIntParam:$dims);
+let parameters = (ins ArrayRefIntParam:$dims);
 ```
 
 The `allocator` code block has the following substitutions:
-- `$_allocator` is the TypeStorageAllocator in which to allocate objects.
-- `$_dst` is the variable in which to place the allocated data.
+
+-   `$_allocator` is the TypeStorageAllocator in which to allocate objects.
+-   `$_dst` is the variable in which to place the allocated data.
 
 MLIR includes several specialized classes for common situations:
-- `StringRefParameter<descriptionOfParam>` for StringRefs.
-- `ArrayRefParameter<arrayOf, descriptionOfParam>` for ArrayRefs of value
-types
-- `SelfAllocationParameter<descriptionOfParam>` for C++ classes which contain
-a method called `allocateInto(StorageAllocator &allocator)` to allocate
-itself into `allocator`.
-- `ArrayRefOfSelfAllocationParameter<arrayOf, descriptionOfParam>` for arrays
-of objects which self-allocate as per the last specialization.
+
+-   `StringRefParameter<descriptionOfParam>` for StringRefs.
+-   `ArrayRefParameter<arrayOf, descriptionOfParam>` for ArrayRefs of value
+    types
+-   `SelfAllocationParameter<descriptionOfParam>` for C++ classes which contain
+    a method called `allocateInto(StorageAllocator &allocator)` to allocate
+    itself into `allocator`.
+-   `ArrayRefOfSelfAllocationParameter<arrayOf, descriptionOfParam>` for arrays
+    of objects which self-allocate as per the last specialization.
 
 If we were to use one of these included specializations:
 
@@ -1495,45 +1494,46 @@ let parameters = (ins
 
 If a mnemonic is specified, the `printer` and `parser` code fields are active.
 The rules for both are:
-- If null, generate just the declaration.
-- If non-null and non-empty, use the code in the definition. The `$_printer`
-or `$_parser` substitutions are valid and should be used.
-- It is an error to have an empty code block.
-
-For each dialect, two "dispatch" functions will be created: one for parsing
-and one for printing. You should add calls to these in your
-`Dialect::printType` and `Dialect::parseType` methods. They are created in
-the dialect's namespace and their function signatures are:
+
+-   If null, generate just the declaration.
+-   If non-null and non-empty, use the code in the definition. The `$_printer`
+    or `$_parser` substitutions are valid and should be used.
+-   It is an error to have an empty code block.
+
+For each dialect, two "dispatch" functions will be created: one for parsing and
+one for printing. You should add calls to these in your `Dialect::printType` and
+`Dialect::parseType` methods. They are static functions placed alongside the
+type class definitions and have the following function signatures:
+
 ```c++
-Type generatedTypeParser(MLIRContext* ctxt, DialectAsmParser& parser,
-                         StringRef mnemonic);
+static Type generatedTypeParser(MLIRContext* ctxt, DialectAsmParser& parser, StringRef mnemonic);
 LogicalResult generatedTypePrinter(Type type, DialectAsmPrinter& printer);
 ```
 
-The mnemonic, parser, and printer fields are optional. If they're not
-defined, the generated code will not include any parsing or printing code and
-omit the type from the dispatch functions above. In this case, the dialect
-author is responsible for parsing/printing the types in `Dialect::printType`
-and `Dialect::parseType`.
+The mnemonic, parser, and printer fields are optional. If they're not defined,
+the generated code will not include any parsing or printing code and omit the
+type from the dispatch functions above. In this case, the dialect author is
+responsible for parsing/printing the types in `Dialect::printType` and
+`Dialect::parseType`.
 
 ### Other fields
 
-- If the `genStorageClass` field is set to 1 (the default) a storage class is
-generated with member variables corresponding to each of the specified
-`parameters`.
-- If the `genAccessors` field is 1 (the default) accessor methods will be
-generated on the Type class (e.g. `int getWidth() const` in the example
-above).
-- If the `genVerifyInvariantsDecl` field is set, a declaration for a method
-`static LogicalResult verifyConstructionInvariants(Location, parameters...)`
-is added to the class as well as a `getChecked(Location, parameters...)`
-method which gets the result of `verifyConstructionInvariants` before calling
-`get`.
-- The `storageClass` field can be used to set the name of the storage class.
-- The `storageNamespace` field is used to set the namespace where the storage
-class should sit. Defaults to "detail".
-- The `extraClassDeclaration` field is used to include extra code in the
-class declaration.
+-   If the `genStorageClass` field is set to 1 (the default) a storage class is
+    generated with member variables corresponding to each of the specified
+    `parameters`.
+-   If the `genAccessors` field is 1 (the default) accessor methods will be
+    generated on the Type class (e.g. `int getWidth() const` in the example
+    above).
+-   If the `genVerifyInvariantsDecl` field is set, a declaration for a method
+    `static LogicalResult verifyConstructionInvariants(Location, parameters...)`
+    is added to the class as well as a `getChecked(Location, parameters...)`
+    method which gets the result of `verifyConstructionInvariants` before
+    calling `get`.
+-   The `storageClass` field can be used to set the name of the storage class.
+-   The `storageNamespace` field is used to set the namespace where the storage
+    class should sit. Defaults to "detail".
+-   The `extraClassDeclaration` field is used to include extra code in the class
+    declaration.
 
 ## Debugging Tips
 

diff  --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td
new file mode 100644
index 000000000000..383f87bd5d60
--- /dev/null
+++ b/mlir/include/mlir/IR/BuiltinDialect.td
@@ -0,0 +1,27 @@
+//===-- BuiltinDialect.td - Builtin dialect definition -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definition of the Builtin dialect. This dialect
+// contains all of the attributes, operations, and types that are core to MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUILTIN_BASE
+#define BUILTIN_BASE
+
+include "mlir/IR/OpBase.td"
+
+def Builtin_Dialect : Dialect {
+  let summary =
+    "A dialect containing the builtin Attributes, Operations, and Types";
+
+  let name = "";
+  let cppNamespace = "::mlir";
+}
+
+#endif // BUILTIN_BASE

diff  --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td
index 567a1249f0df..b085f721cfa9 100644
--- a/mlir/include/mlir/IR/BuiltinOps.td
+++ b/mlir/include/mlir/IR/BuiltinOps.td
@@ -14,17 +14,10 @@
 #ifndef BUILTIN_OPS
 #define BUILTIN_OPS
 
+include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 
-def Builtin_Dialect : Dialect {
-  let summary =
-    "A dialect containing the builtin Attributes, Operations, and Types";
-
-  let name = "";
-  let cppNamespace = "::mlir";
-}
-
 // Base class for Builtin dialect ops.
 class Builtin_Op<string mnemonic, list<OpTrait> traits = []> :
     Op<Builtin_Dialect, mnemonic, traits>;

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 8ce5e4045a3a..10e78e5efb72 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -72,23 +72,6 @@ class ComplexType
   Type getElementType();
 };
 
-//===----------------------------------------------------------------------===//
-// IndexType
-//===----------------------------------------------------------------------===//
-
-/// Index is a special integer-like type with unknown platform-dependent bit
-/// width.
-class IndexType : public Type::TypeBase<IndexType, Type, TypeStorage> {
-public:
-  using Base::Base;
-
-  /// Get an instance of the IndexType.
-  static IndexType get(MLIRContext *context);
-
-  /// Storage bit width used for IndexType by internal compiler data structures.
-  static constexpr unsigned kInternalStorageBitWidth = 64;
-};
-
 //===----------------------------------------------------------------------===//
 // IntegerType
 //===----------------------------------------------------------------------===//
@@ -187,67 +170,6 @@ class FloatType : public Type {
   const llvm::fltSemantics &getFloatSemantics();
 };
 
-//===----------------------------------------------------------------------===//
-// BFloat16Type
-
-class BFloat16Type
-    : public Type::TypeBase<BFloat16Type, FloatType, TypeStorage> {
-public:
-  using Base::Base;
-
-  /// Return an instance of the bfloat16 type.
-  static BFloat16Type get(MLIRContext *context);
-};
-
-inline FloatType FloatType::getBF16(MLIRContext *ctx) {
-  return BFloat16Type::get(ctx);
-}
-
-//===----------------------------------------------------------------------===//
-// Float16Type
-
-class Float16Type : public Type::TypeBase<Float16Type, FloatType, TypeStorage> {
-public:
-  using Base::Base;
-
-  /// Return an instance of the float16 type.
-  static Float16Type get(MLIRContext *context);
-};
-
-inline FloatType FloatType::getF16(MLIRContext *ctx) {
-  return Float16Type::get(ctx);
-}
-
-//===----------------------------------------------------------------------===//
-// Float32Type
-
-class Float32Type : public Type::TypeBase<Float32Type, FloatType, TypeStorage> {
-public:
-  using Base::Base;
-
-  /// Return an instance of the float32 type.
-  static Float32Type get(MLIRContext *context);
-};
-
-inline FloatType FloatType::getF32(MLIRContext *ctx) {
-  return Float32Type::get(ctx);
-}
-
-//===----------------------------------------------------------------------===//
-// Float64Type
-
-class Float64Type : public Type::TypeBase<Float64Type, FloatType, TypeStorage> {
-public:
-  using Base::Base;
-
-  /// Return an instance of the float64 type.
-  static Float64Type get(MLIRContext *context);
-};
-
-inline FloatType FloatType::getF64(MLIRContext *ctx) {
-  return Float64Type::get(ctx);
-}
-
 //===----------------------------------------------------------------------===//
 // FunctionType
 //===----------------------------------------------------------------------===//
@@ -276,20 +198,6 @@ class FunctionType
                                         ArrayRef<unsigned> resultIndices);
 };
 
-//===----------------------------------------------------------------------===//
-// NoneType
-//===----------------------------------------------------------------------===//
-
-/// NoneType is a unit type, i.e. a type with exactly one possible value, where
-/// its value does not have a defined dynamic representation.
-class NoneType : public Type::TypeBase<NoneType, Type, TypeStorage> {
-public:
-  using Base::Base;
-
-  /// Get an instance of the NoneType.
-  static NoneType get(MLIRContext *context);
-};
-
 //===----------------------------------------------------------------------===//
 // OpaqueType
 //===----------------------------------------------------------------------===//
@@ -720,11 +628,20 @@ class TupleType
     return getTypes()[index];
   }
 };
+} // end namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Tablegen Type Declarations
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/IR/BuiltinTypes.h.inc"
 
 //===----------------------------------------------------------------------===//
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
 
+namespace mlir {
 inline bool BaseMemRefType::classof(Type type) {
   return type.isa<MemRefType, UnrankedMemRefType>();
 }
@@ -733,6 +650,22 @@ inline bool FloatType::classof(Type type) {
   return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
 }
 
+inline FloatType FloatType::getBF16(MLIRContext *ctx) {
+  return BFloat16Type::get(ctx);
+}
+
+inline FloatType FloatType::getF16(MLIRContext *ctx) {
+  return Float16Type::get(ctx);
+}
+
+inline FloatType FloatType::getF32(MLIRContext *ctx) {
+  return Float32Type::get(ctx);
+}
+
+inline FloatType FloatType::getF64(MLIRContext *ctx) {
+  return Float64Type::get(ctx);
+}
+
 inline bool ShapedType::classof(Type type) {
   return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
                   UnrankedMemRefType, MemRefType>();

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
new file mode 100644
index 000000000000..b540554fb11e
--- /dev/null
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -0,0 +1,114 @@
+//===- BuiltinTypes.td - Builtin type definitions ----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the set of builtin MLIR types, or the set of types necessary for the
+// validity of and defining the IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUILTIN_TYPES
+#define BUILTIN_TYPES
+
+include "mlir/IR/BuiltinDialect.td"
+
+// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
+// This is to 
diff erentiate the types here with the ones in OpBase.td. We should
+// remove the definitions in OpBase.td, and repoint users to this file instead.
+
+// Base class for Builtin dialect types.
+class Builtin_Type<string name> : TypeDef<Builtin_Dialect, name> {
+  let mnemonic = ?;
+}
+
+//===----------------------------------------------------------------------===//
+// FloatType
+//===----------------------------------------------------------------------===//
+
+// Base class for Builtin dialect float types.
+class Builtin_FloatType<string name> : TypeDef<Builtin_Dialect, name,
+                                               "::mlir::FloatType"> {
+  let extraClassDeclaration = [{
+    static }] # name # [{Type get(MLIRContext *context);
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// BFloat16Type
+
+def Builtin_BFloat16 : Builtin_FloatType<"BFloat16"> {
+  let summary = "bfloat16 floating-point type";
+}
+
+//===----------------------------------------------------------------------===//
+// Float16Type
+
+def Builtin_Float16 : Builtin_FloatType<"Float16"> {
+  let summary = "16-bit floating-point type";
+}
+
+//===----------------------------------------------------------------------===//
+// Float32Type
+
+def Builtin_Float32 : Builtin_FloatType<"Float32"> {
+  let summary = "32-bit floating-point type";
+}
+
+//===----------------------------------------------------------------------===//
+// Float64Type
+
+def Builtin_Float64 : Builtin_FloatType<"Float64"> {
+  let summary = "64-bit floating-point type";
+}
+
+//===----------------------------------------------------------------------===//
+// IndexType
+//===----------------------------------------------------------------------===//
+
+def Builtin_Index : Builtin_Type<"Index"> {
+  let summary = "Integer-like type with unknown platform-dependent bit width";
+  let description = [{
+    Syntax:
+
+    ```
+    // Target word-sized integer.
+    index-type ::= `index`
+    ```
+
+    The index type is a signless integer whose size is equal to the natural
+    machine word of the target ( [rationale](https://mlir.llvm.org/docs/Rationale/Rationale/#integer-signedness-semantics) )
+    and is used by the affine constructs in MLIR. Unlike fixed-size integers,
+    it cannot be used as an element of vector ( [rationale](https://mlir.llvm.org/docs/Rationale/Rationale/#index-type-disallowed-in-vector-types) ).
+
+    **Rationale:** integers of platform-specific bit widths are practical to
+    express sizes, dimensionalities and subscripts.
+  }];
+  let extraClassDeclaration = [{
+    static IndexType get(MLIRContext *context);
+
+    /// Storage bit width used for IndexType by internal compiler data
+    /// structures.
+    static constexpr unsigned kInternalStorageBitWidth = 64;
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// NoneType
+//===----------------------------------------------------------------------===//
+
+def Builtin_None : Builtin_Type<"None"> {
+  let summary = "A unit type";
+  let description = [{
+    NoneType is a unit type, i.e. a type with exactly one possible value, where
+    its value does not have a defined dynamic representation.
+  }];
+  let extraClassDeclaration = [{
+    static NoneType get(MLIRContext *context);
+  }];
+}
+
+#endif // BUILTIN_TYPES

diff  --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 45649b3e7598..3b7ddbaf2338 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -2,10 +2,18 @@ add_mlir_interface(OpAsmInterface)
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
+set(LLVM_TARGET_DEFINITIONS BuiltinDialect.td)
+mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
+add_public_tablegen_target(MLIRBuiltinDialectIncGen)
+
 set(LLVM_TARGET_DEFINITIONS BuiltinOps.td)
 mlir_tablegen(BuiltinOps.h.inc -gen-op-decls)
 mlir_tablegen(BuiltinOps.cpp.inc -gen-op-defs)
-mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
 add_public_tablegen_target(MLIRBuiltinOpsIncGen)
 
+set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
+mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
+add_public_tablegen_target(MLIRBuiltinTypesIncGen)
+
 add_mlir_doc(BuiltinOps -gen-op-doc Builtin Dialects/)

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index b031769022d9..aa5ef284de11 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2415,15 +2415,18 @@ def replaceWithValue;
 // Data type generation
 //===----------------------------------------------------------------------===//
 
-// Define a new type belonging to a dialect and called 'name'.
-class TypeDef<Dialect owningdialect, string name> {
-  Dialect dialect = owningdialect;
+// Define a new type, named `name`, belonging to `dialect` that inherits from
+// the given C++ base class.
+class TypeDef<Dialect dialect, string name,
+              string baseCppClass = "::mlir::Type">
+    : DialectType<dialect, CPred<"">> {
+  // The name of the C++ Type class.
   string cppClassName = name # "Type";
+  // The name of the C++ base class to use for this Type.
+  string cppBaseClassName = baseCppClass;
 
   // Short summary of the type.
   string summary = ?;
-  // The longer description of this type.
-  string description = ?;
 
   // Name of storage class to generate or use.
   string storageClass = name # "TypeStorage";
@@ -2477,6 +2480,15 @@ class TypeDef<Dialect owningdialect, string name> {
   bit genVerifyInvariantsDecl = 0;
   // Extra code to include in the class declaration.
   code extraClassDeclaration = [{}];
+
+  // The predicate for when this type is used as a type constraint.
+  let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
+                                 "::" # cppClassName # ">()">;
+  // A constant builder provided when the type has no parameters.
+  let builderCall = !if(!empty(parameters),
+                           "$_builder.getType<" # dialect.cppNamespace #
+                               "::" # cppClassName # ">()",
+                           "");
 }
 
 // 'Parameters' should be subclasses of this or simple strings (which is a

diff  --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h
index 462fed322438..796f5cc17859 100644
--- a/mlir/include/mlir/TableGen/TypeDef.h
+++ b/mlir/include/mlir/TableGen/TypeDef.h
@@ -48,6 +48,9 @@ class TypeDef {
   // Returns the name of the C++ class to generate.
   StringRef getCppClassName() const;
 
+  // Returns the name of the C++ base class to use when generating this type.
+  StringRef getCppBaseClassName() const;
+
   // Returns the name of the storage class for this type.
   StringRef getStorageClassName() const;
 

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index aa4eba07cf53..68cf491850d1 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -20,6 +20,13 @@
 using namespace mlir;
 using namespace mlir::detail;
 
+//===----------------------------------------------------------------------===//
+/// Tablegen Type Definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/IR/BuiltinTypes.cpp.inc"
+
 //===----------------------------------------------------------------------===//
 /// ComplexType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 42cdb3a91a50..e4b3b9141323 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -33,7 +33,9 @@ add_mlir_library(MLIRIR
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
 
   DEPENDS
+  MLIRBuiltinDialectIncGen
   MLIRBuiltinOpsIncGen
+  MLIRBuiltinTypesIncGen
   MLIRCallInterfacesIncGen
   MLIROpAsmInterfaceIncGen
   MLIRRegionKindInterfaceIncGen

diff  --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index b8e1b9f05acd..71606f90cf45 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -13,6 +13,7 @@
 #include "mlir/TableGen/Constraint.h"
 #include "llvm/TableGen/Record.h"
 
+using namespace mlir;
 using namespace mlir::tblgen;
 
 Constraint::Constraint(const llvm::Record *record)
@@ -56,11 +57,18 @@ std::string Constraint::getConditionTemplate() const {
   return getPredicate().getCondition();
 }
 
-llvm::StringRef Constraint::getDescription() const {
-  auto doc = def->getValueAsString("description");
-  if (doc.empty())
-    return def->getName();
-  return doc;
+StringRef Constraint::getDescription() const {
+  // If a summary is found, we use that given that it is a focused single line
+  // comment.
+  if (Optional<StringRef> summary = def->getValueAsOptionalString("summary"))
+    return *summary;
+  // If a summary can't be found, look for a specific description field to use
+  // for the constraint.
+  StringRef desc = def->getValueAsString("description");
+  if (!desc.empty())
+    return desc;
+  // Otherwise, fallback to the name of the constraint definition.
+  return def->getName();
 }
 
 AppliedConstraint::AppliedConstraint(Constraint &&constraint,

diff  --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp
index d8412b6b4be5..b666d02d7eb9 100644
--- a/mlir/lib/TableGen/TypeDef.cpp
+++ b/mlir/lib/TableGen/TypeDef.cpp
@@ -31,6 +31,10 @@ StringRef TypeDef::getCppClassName() const {
   return def->getValueAsString("cppClassName");
 }
 
+StringRef TypeDef::getCppBaseClassName() const {
+  return def->getValueAsString("cppBaseClassName");
+}
+
 bool TypeDef::hasDescription() const {
   const llvm::RecordVal *s = def->getValue("description");
   return s != nullptr && isa<llvm::StringInit>(s->getValue());

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 5e9bae893363..8aec98434f15 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -15,7 +15,6 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
-#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringSwitch.h"
 
 using namespace mlir;
@@ -183,77 +182,6 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return builder.create<TestOpConstant>(loc, type, value);
 }
 
-static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
-                          llvm::SetVector<Type> &stack) {
-  StringRef typeTag;
-  if (failed(parser.parseKeyword(&typeTag)))
-    return Type();
-
-  auto genType = generatedTypeParser(ctxt, parser, typeTag);
-  if (genType != Type())
-    return genType;
-
-  if (typeTag == "test_type")
-    return TestType::get(parser.getBuilder().getContext());
-
-  if (typeTag != "test_rec")
-    return Type();
-
-  StringRef name;
-  if (parser.parseLess() || parser.parseKeyword(&name))
-    return Type();
-  auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
-
-  // If this type already has been parsed above in the stack, expect just the
-  // name.
-  if (stack.contains(rec)) {
-    if (failed(parser.parseGreater()))
-      return Type();
-    return rec;
-  }
-
-  // Otherwise, parse the body and update the type.
-  if (failed(parser.parseComma()))
-    return Type();
-  stack.insert(rec);
-  Type subtype = parseTestType(ctxt, parser, stack);
-  stack.pop_back();
-  if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
-    return Type();
-
-  return rec;
-}
-
-Type TestDialect::parseType(DialectAsmParser &parser) const {
-  llvm::SetVector<Type> stack;
-  return parseTestType(getContext(), parser, stack);
-}
-
-static void printTestType(Type type, DialectAsmPrinter &printer,
-                          llvm::SetVector<Type> &stack) {
-  if (succeeded(generatedTypePrinter(type, printer)))
-    return;
-  if (type.isa<TestType>()) {
-    printer << "test_type";
-    return;
-  }
-
-  auto rec = type.cast<TestRecursiveType>();
-  printer << "test_rec<" << rec.getName();
-  if (!stack.contains(rec)) {
-    printer << ", ";
-    stack.insert(rec);
-    printTestType(rec.getBody(), printer, stack);
-    stack.pop_back();
-  }
-  printer << ">";
-}
-
-void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
-  llvm::SetVector<Type> stack;
-  printTestType(type, printer, stack);
-}
-
 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
                                                     NamedAttribute namedAttr) {
   if (namedAttr.first == "test.invalid_attr")

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 0a91be30b53a..14a3e862d85c 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -12,9 +12,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestTypes.h"
+#include "TestDialect.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -116,5 +119,84 @@ LogicalResult TestIntegerType::verifyConstructionInvariants(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Tablegen Generated Definitions
+//===----------------------------------------------------------------------===//
+
 #define GET_TYPEDEF_CLASSES
 #include "TestTypeDefs.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// TestDialect
+//===----------------------------------------------------------------------===//
+
+static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
+                          llvm::SetVector<Type> &stack) {
+  StringRef typeTag;
+  if (failed(parser.parseKeyword(&typeTag)))
+    return Type();
+
+  auto genType = generatedTypeParser(ctxt, parser, typeTag);
+  if (genType != Type())
+    return genType;
+
+  if (typeTag == "test_type")
+    return TestType::get(parser.getBuilder().getContext());
+
+  if (typeTag != "test_rec")
+    return Type();
+
+  StringRef name;
+  if (parser.parseLess() || parser.parseKeyword(&name))
+    return Type();
+  auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
+
+  // If this type already has been parsed above in the stack, expect just the
+  // name.
+  if (stack.contains(rec)) {
+    if (failed(parser.parseGreater()))
+      return Type();
+    return rec;
+  }
+
+  // Otherwise, parse the body and update the type.
+  if (failed(parser.parseComma()))
+    return Type();
+  stack.insert(rec);
+  Type subtype = parseTestType(ctxt, parser, stack);
+  stack.pop_back();
+  if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
+    return Type();
+
+  return rec;
+}
+
+Type TestDialect::parseType(DialectAsmParser &parser) const {
+  llvm::SetVector<Type> stack;
+  return parseTestType(getContext(), parser, stack);
+}
+
+static void printTestType(Type type, DialectAsmPrinter &printer,
+                          llvm::SetVector<Type> &stack) {
+  if (succeeded(generatedTypePrinter(type, printer)))
+    return;
+  if (type.isa<TestType>()) {
+    printer << "test_type";
+    return;
+  }
+
+  auto rec = type.cast<TestRecursiveType>();
+  printer << "test_rec<" << rec.getName();
+  if (!stack.contains(rec)) {
+    printer << ", ";
+    stack.insert(rec);
+    printTestType(rec.getBody(), printer, stack);
+    stack.pop_back();
+  }
+  printer << ">";
+}
+
+void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
+  llvm::SetVector<Type> stack;
+  printTestType(type, printer, stack);
+}

diff  --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 6db866a10c8d..6e6e1c04643a 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -11,9 +11,6 @@ include "mlir/IR/OpBase.td"
 // DECL: class DialectAsmPrinter;
 // DECL: } // namespace mlir
 
-// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);
-// DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer);
-
 // DEF: #ifdef GET_TYPEDEF_LIST
 // DEF: #undef GET_TYPEDEF_LIST
 // DEF: ::mlir::test::SimpleAType,

diff  --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
index 0990a5afb884..64239441581d 100644
--- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
@@ -92,7 +92,7 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
   /// llvm::formatv will call this function when using an instance as a
   /// replacement value.
   void format(raw_ostream &os, StringRef options) override {
-    if (params.size() && prependComma)
+    if (!params.empty() && prependComma)
       os << ", ";
 
     switch (emitFormat) {
@@ -146,8 +146,9 @@ class DialectAsmPrinter;
 /// case.
 ///
 /// {0}: The name of the typeDef class.
+/// {1}: The name of the type base class.
 static const char *const typeDefDeclSingletonBeginStr = R"(
-  class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{
+  class {0}: public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
   public:
     /// Inherit some necessary constructors from 'TypeBase'.
     using Base::Base;
@@ -158,15 +159,16 @@ static const char *const typeDefDeclSingletonBeginStr = R"(
 /// case.
 ///
 /// {0}: The name of the typeDef class.
-/// {1}: The typeDef storage class namespace.
-/// {2}: The storage class name.
-/// {3}: The list of parameters with types.
+/// {1}: The name of the type base class.
+/// {2}: The typeDef storage class namespace.
+/// {3}: The storage class name.
+/// {4}: The list of parameters with types.
 static const char *const typeDefDeclParametricBeginStr = R"(
-  namespace {1} {
-    struct {2};
+  namespace {2} {
+    struct {3};
   }
-  class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type,
-                                        {1}::{2}> {{
+  class {0}: public ::mlir::Type::TypeBase<{0}, {1},
+                                        {2}::{3}> {{
   public:
     /// Inherit some necessary constructors from 'TypeBase'.
     using Base::Base;
@@ -196,10 +198,11 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
   // template.
   if (typeDef.getNumParameters() == 0)
     os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
-                  typeDef.getStorageNamespace(), typeDef.getStorageClassName());
+                  typeDef.getCppBaseClassName());
   else
     os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
-                  typeDef.getStorageNamespace(), typeDef.getStorageClassName());
+                  typeDef.getCppBaseClassName(), typeDef.getStorageNamespace(),
+                  typeDef.getStorageClassName());
 
   // Emit the extra declarations first in case there's a type definition in
   // there.
@@ -208,8 +211,10 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
 
   TypeParamCommaFormatter emitTypeNamePairsAfterComma(
       TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
-  os << llvm::formatv("    static {0} get(::mlir::MLIRContext* ctxt{1});\n",
-                      typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
+  if (!params.empty()) {
+    os << llvm::formatv("    static {0} get(::mlir::MLIRContext* ctxt{1});\n",
+                        typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
+  }
 
   // Emit the verify invariants declaration.
   if (typeDef.genVerifyInvariantsDecl())
@@ -252,17 +257,9 @@ static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
   // Output the common "header".
   os << typeDefDeclHeader;
 
-  if (typeDefs.size() > 0) {
+  if (!typeDefs.empty()) {
     NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
 
-    // Well known print/parse dispatch function declarations. These are called
-    // from Dialect::parseType() and Dialect::printType() methods.
-    os << "  ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
-          "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n";
-    os << "  ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
-          "::mlir::DialectAsmPrinter& printer);\n";
-    os << "\n";
-
     // Declare all the type classes first (in case they reference each other).
     for (const TypeDef &typeDef : typeDefs)
       os << "  class " << typeDef.getCppClassName() << ";\n";
@@ -488,14 +485,16 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
   if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
     emitStorageClass(typeDef, os);
 
-  os << llvm::formatv(
-      "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
-      "  return Base::get(ctxt{2});\n}\n",
-      typeDef.getCppClassName(),
-      TypeParamCommaFormatter(
-          TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
-      TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
-                              parameters));
+  if (!parameters.empty()) {
+    os << llvm::formatv(
+        "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
+        "  return Base::get(ctxt{2});\n}\n",
+        typeDef.getCppClassName(),
+        TypeParamCommaFormatter(
+            TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
+        TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
+                                parameters));
+  }
 
   // Emit the parameter accessors.
   if (typeDef.genAccessors())
@@ -526,38 +525,40 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
 
 /// Emit the dialect printer/parser dispatcher. User's code should call these
 /// functions from their dialect's print/parse methods.
-static void emitParsePrintDispatch(SmallVectorImpl<TypeDef> &typeDefs,
-                                   raw_ostream &os) {
-  if (typeDefs.size() == 0)
+static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
+  if (llvm::none_of(types, [](const TypeDef &type) {
+        return type.getMnemonic().hasValue();
+      })) {
     return;
-  const Dialect &dialect = typeDefs.begin()->getDialect();
-  NamespaceEmitter ns(os, dialect);
+  }
 
-  // The parser dispatch is just a list of if-elses, matching on the mnemonic
-  // and calling the class's parse function.
-  os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
+  // The parser dispatch is just a list of if-elses, matching on the
+  // mnemonic and calling the class's parse function.
+  os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* "
+        "ctxt, "
         "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
-  for (const TypeDef &typeDef : typeDefs)
-    if (typeDef.getMnemonic())
+  for (const TypeDef &type : types)
+    if (type.getMnemonic())
       os << formatv("  if (mnemonic == {0}::{1}::getMnemonic()) return "
                     "{0}::{1}::parse(ctxt, parser);\n",
-                    typeDef.getDialect().getCppNamespace(),
-                    typeDef.getCppClassName());
+                    type.getDialect().getCppNamespace(),
+                    type.getCppClassName());
   os << "  return ::mlir::Type();\n";
   os << "}\n\n";
 
   // The printer dispatch uses llvm::TypeSwitch to find and call the correct
   // printer.
-  os << "::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
+  os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
+        "type, "
         "::mlir::DialectAsmPrinter& printer) {\n"
      << "  ::mlir::LogicalResult found = ::mlir::success();\n"
      << "  ::llvm::TypeSwitch<::mlir::Type>(type)\n";
-  for (auto typeDef : typeDefs)
-    if (typeDef.getMnemonic())
+  for (const TypeDef &type : types)
+    if (type.getMnemonic())
       os << formatv("    .Case<{0}::{1}>([&](::mlir::Type t) {{ "
                     "t.dyn_cast<{0}::{1}>().print(printer); })\n",
-                    typeDef.getDialect().getCppNamespace(),
-                    typeDef.getCppClassName());
+                    type.getDialect().getCppNamespace(),
+                    type.getCppClassName());
   os << "    .Default([&found](::mlir::Type) { found = ::mlir::failure(); "
         "});\n"
      << "  return found;\n"


        


More information about the llvm-branch-commits mailing list