[Mlir-commits] [mlir] 95019de - [mlir][IR] Define the singleton builtin types in ODS instead of C++
River Riddle
llvmlistbot at llvm.org
Tue Dec 15 13:52:01 PST 2020
Author: River Riddle
Date: 2020-12-15T13:42:19-08:00
New Revision: 95019de8a122619fc038c9fe3c80e625e3456bbf
URL: https://github.com/llvm/llvm-project/commit/95019de8a122619fc038c9fe3c80e625e3456bbf
DIFF: https://github.com/llvm/llvm-project/commit/95019de8a122619fc038c9fe3c80e625e3456bbf.diff
LOG: [mlir][IR] Define the singleton builtin types in ODS instead of C++
This exposes several issues with the current generation that this revision also fixes.
* TypeDef now allows specifying the base class to use when generating.
* TypeDef now inherits from DialectType, which allows for using it as a TypeConstraint
* Parser/Printers are now no longer generated in the header(removing duplicate symbols), and are now only generated when necessary.
- Now that generatedTypeParser/Printer are only generated in the definition file,
existing users will need to manually expose this functionality when necessary.
* ::get() is no longer generated for singleton types, because it isn't necessary.
Differential Revision: https://reviews.llvm.org/D93270
Added:
mlir/include/mlir/IR/BuiltinDialect.td
mlir/include/mlir/IR/BuiltinTypes.td
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/IR/BuiltinOps.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/CMakeLists.txt
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/TypeDef.h
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/CMakeLists.txt
mlir/lib/TableGen/Constraint.cpp
mlir/lib/TableGen/TypeDef.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/mlir-tblgen/typedefs.td
mlir/tools/mlir-tblgen/TypeDefGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 189cd0825af7..c5ffe452b927 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1370,10 +1370,10 @@ llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t value) {
## Type Definitions
-MLIR defines the TypeDef class hierarchy to enable generation of data types
-from their specifications. A type is defined by specializing the TypeDef
-class with concrete contents for all the fields it requires. For example, an
-integer type could be defined as:
+MLIR defines the TypeDef class hierarchy to enable generation of data types from
+their specifications. A type is defined by specializing the TypeDef class with
+concrete contents for all the fields it requires. For example, an integer type
+could be defined as:
```tablegen
// All of the types will extend this class.
@@ -1414,45 +1414,43 @@ def IntegerType : Test_Type<"TestInteger"> {
### Type name
The name of the C++ class which gets generated defaults to
-`<classParamName>Type` (e.g. `TestIntegerType` in the above example). This
-can be overridden via the `cppClassName` field. The field `mnemonic` is
-to specify the asm name for parsing. It is optional and not specifying it
-will imply that no parser or printer methods are attached to this class.
+`<classParamName>Type` (e.g. `TestIntegerType` in the above example). This can
+be overridden via the `cppClassName` field. The field `mnemonic` is to specify
+the asm name for parsing. It is optional and not specifying it will imply that
+no parser or printer methods are attached to this class.
### Type documentation
-The `summary` and `description` fields exist and are to be used the same way
-as in Operations. Namely, the summary should be a one-liner and `description`
+The `summary` and `description` fields exist and are to be used the same way as
+in Operations. Namely, the summary should be a one-liner and `description`
should be a longer explanation.
### Type parameters
-The `parameters` field is a list of the types parameters. If no parameters
-are specified (the default), this type is considered a singleton type.
-Parameters are in the `"c++Type":$paramName` format.
-To use C++ types as parameters which need allocation in the storage
-constructor, there are two options:
+The `parameters` field is a list of the types parameters. If no parameters are
+specified (the default), this type is considered a singleton type. Parameters
+are in the `"c++Type":$paramName` format. To use C++ types as parameters which
+need allocation in the storage constructor, there are two options:
-- Set `hasCustomStorageConstructor` to generate the TypeStorage class with
-a constructor which is just declared -- no definition -- so you can write it
-yourself.
-- Use the `TypeParameter` tablegen class instead of the "c++Type" string.
+- Set `hasCustomStorageConstructor` to generate the TypeStorage class with a
+ constructor which is just declared -- no definition -- so you can write it
+ yourself.
+- Use the `TypeParameter` tablegen class instead of the "c++Type" string.
### TypeParameter tablegen class
-This is used to further specify attributes about each of the types
-parameters. It includes documentation (`description` and `syntax`), the C++
-type to use, and a custom allocator to use in the storage constructor method.
+This is used to further specify attributes about each of the types parameters.
+It includes documentation (`description` and `syntax`), the C++ type to use, and
+a custom allocator to use in the storage constructor method.
```tablegen
// DO NOT DO THIS!
-let parameters = (ins
- "ArrayRef<int>":$dims);
+let parameters = (ins "ArrayRef<int>":$dims);
```
-The default storage constructor blindly copies fields by value. It does not
-know anything about the types. In this case, the ArrayRef<int> requires
-allocation with `dims = allocator.copyInto(dims)`.
+The default storage constructor blindly copies fields by value. It does not know
+anything about the types. In this case, the ArrayRef<int> requires allocation
+with `dims = allocator.copyInto(dims)`.
You can specify the necessary constructor by specializing the `TypeParameter`
tblgen class:
@@ -1460,28 +1458,29 @@ tblgen class:
```tablegen
class ArrayRefIntParam :
TypeParameter<"::llvm::ArrayRef<int>", "Array of ints"> {
- let allocator = [{$_dst = $_allocator.copyInto($_self);}];
+ let allocator = "$_dst = $_allocator.copyInto($_self);";
}
...
-let parameters = (ins
- ArrayRefIntParam:$dims);
+let parameters = (ins ArrayRefIntParam:$dims);
```
The `allocator` code block has the following substitutions:
-- `$_allocator` is the TypeStorageAllocator in which to allocate objects.
-- `$_dst` is the variable in which to place the allocated data.
+
+- `$_allocator` is the TypeStorageAllocator in which to allocate objects.
+- `$_dst` is the variable in which to place the allocated data.
MLIR includes several specialized classes for common situations:
-- `StringRefParameter<descriptionOfParam>` for StringRefs.
-- `ArrayRefParameter<arrayOf, descriptionOfParam>` for ArrayRefs of value
-types
-- `SelfAllocationParameter<descriptionOfParam>` for C++ classes which contain
-a method called `allocateInto(StorageAllocator &allocator)` to allocate
-itself into `allocator`.
-- `ArrayRefOfSelfAllocationParameter<arrayOf, descriptionOfParam>` for arrays
-of objects which self-allocate as per the last specialization.
+
+- `StringRefParameter<descriptionOfParam>` for StringRefs.
+- `ArrayRefParameter<arrayOf, descriptionOfParam>` for ArrayRefs of value
+ types
+- `SelfAllocationParameter<descriptionOfParam>` for C++ classes which contain
+ a method called `allocateInto(StorageAllocator &allocator)` to allocate
+ itself into `allocator`.
+- `ArrayRefOfSelfAllocationParameter<arrayOf, descriptionOfParam>` for arrays
+ of objects which self-allocate as per the last specialization.
If we were to use one of these included specializations:
@@ -1495,45 +1494,46 @@ let parameters = (ins
If a mnemonic is specified, the `printer` and `parser` code fields are active.
The rules for both are:
-- If null, generate just the declaration.
-- If non-null and non-empty, use the code in the definition. The `$_printer`
-or `$_parser` substitutions are valid and should be used.
-- It is an error to have an empty code block.
-
-For each dialect, two "dispatch" functions will be created: one for parsing
-and one for printing. You should add calls to these in your
-`Dialect::printType` and `Dialect::parseType` methods. They are created in
-the dialect's namespace and their function signatures are:
+
+- If null, generate just the declaration.
+- If non-null and non-empty, use the code in the definition. The `$_printer`
+ or `$_parser` substitutions are valid and should be used.
+- It is an error to have an empty code block.
+
+For each dialect, two "dispatch" functions will be created: one for parsing and
+one for printing. You should add calls to these in your `Dialect::printType` and
+`Dialect::parseType` methods. They are static functions placed alongside the
+type class definitions and have the following function signatures:
+
```c++
-Type generatedTypeParser(MLIRContext* ctxt, DialectAsmParser& parser,
- StringRef mnemonic);
+static Type generatedTypeParser(MLIRContext* ctxt, DialectAsmParser& parser, StringRef mnemonic);
LogicalResult generatedTypePrinter(Type type, DialectAsmPrinter& printer);
```
-The mnemonic, parser, and printer fields are optional. If they're not
-defined, the generated code will not include any parsing or printing code and
-omit the type from the dispatch functions above. In this case, the dialect
-author is responsible for parsing/printing the types in `Dialect::printType`
-and `Dialect::parseType`.
+The mnemonic, parser, and printer fields are optional. If they're not defined,
+the generated code will not include any parsing or printing code and omit the
+type from the dispatch functions above. In this case, the dialect author is
+responsible for parsing/printing the types in `Dialect::printType` and
+`Dialect::parseType`.
### Other fields
-- If the `genStorageClass` field is set to 1 (the default) a storage class is
-generated with member variables corresponding to each of the specified
-`parameters`.
-- If the `genAccessors` field is 1 (the default) accessor methods will be
-generated on the Type class (e.g. `int getWidth() const` in the example
-above).
-- If the `genVerifyInvariantsDecl` field is set, a declaration for a method
-`static LogicalResult verifyConstructionInvariants(Location, parameters...)`
-is added to the class as well as a `getChecked(Location, parameters...)`
-method which gets the result of `verifyConstructionInvariants` before calling
-`get`.
-- The `storageClass` field can be used to set the name of the storage class.
-- The `storageNamespace` field is used to set the namespace where the storage
-class should sit. Defaults to "detail".
-- The `extraClassDeclaration` field is used to include extra code in the
-class declaration.
+- If the `genStorageClass` field is set to 1 (the default) a storage class is
+ generated with member variables corresponding to each of the specified
+ `parameters`.
+- If the `genAccessors` field is 1 (the default) accessor methods will be
+ generated on the Type class (e.g. `int getWidth() const` in the example
+ above).
+- If the `genVerifyInvariantsDecl` field is set, a declaration for a method
+ `static LogicalResult verifyConstructionInvariants(Location, parameters...)`
+ is added to the class as well as a `getChecked(Location, parameters...)`
+ method which gets the result of `verifyConstructionInvariants` before
+ calling `get`.
+- The `storageClass` field can be used to set the name of the storage class.
+- The `storageNamespace` field is used to set the namespace where the storage
+ class should sit. Defaults to "detail".
+- The `extraClassDeclaration` field is used to include extra code in the class
+ declaration.
## Debugging Tips
diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td
new file mode 100644
index 000000000000..383f87bd5d60
--- /dev/null
+++ b/mlir/include/mlir/IR/BuiltinDialect.td
@@ -0,0 +1,27 @@
+//===-- BuiltinDialect.td - Builtin dialect definition -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definition of the Builtin dialect. This dialect
+// contains all of the attributes, operations, and types that are core to MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUILTIN_BASE
+#define BUILTIN_BASE
+
+include "mlir/IR/OpBase.td"
+
+def Builtin_Dialect : Dialect {
+ let summary =
+ "A dialect containing the builtin Attributes, Operations, and Types";
+
+ let name = "";
+ let cppNamespace = "::mlir";
+}
+
+#endif // BUILTIN_BASE
diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td
index 567a1249f0df..b085f721cfa9 100644
--- a/mlir/include/mlir/IR/BuiltinOps.td
+++ b/mlir/include/mlir/IR/BuiltinOps.td
@@ -14,17 +14,10 @@
#ifndef BUILTIN_OPS
#define BUILTIN_OPS
+include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
-def Builtin_Dialect : Dialect {
- let summary =
- "A dialect containing the builtin Attributes, Operations, and Types";
-
- let name = "";
- let cppNamespace = "::mlir";
-}
-
// Base class for Builtin dialect ops.
class Builtin_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Builtin_Dialect, mnemonic, traits>;
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 8ce5e4045a3a..10e78e5efb72 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -72,23 +72,6 @@ class ComplexType
Type getElementType();
};
-//===----------------------------------------------------------------------===//
-// IndexType
-//===----------------------------------------------------------------------===//
-
-/// Index is a special integer-like type with unknown platform-dependent bit
-/// width.
-class IndexType : public Type::TypeBase<IndexType, Type, TypeStorage> {
-public:
- using Base::Base;
-
- /// Get an instance of the IndexType.
- static IndexType get(MLIRContext *context);
-
- /// Storage bit width used for IndexType by internal compiler data structures.
- static constexpr unsigned kInternalStorageBitWidth = 64;
-};
-
//===----------------------------------------------------------------------===//
// IntegerType
//===----------------------------------------------------------------------===//
@@ -187,67 +170,6 @@ class FloatType : public Type {
const llvm::fltSemantics &getFloatSemantics();
};
-//===----------------------------------------------------------------------===//
-// BFloat16Type
-
-class BFloat16Type
- : public Type::TypeBase<BFloat16Type, FloatType, TypeStorage> {
-public:
- using Base::Base;
-
- /// Return an instance of the bfloat16 type.
- static BFloat16Type get(MLIRContext *context);
-};
-
-inline FloatType FloatType::getBF16(MLIRContext *ctx) {
- return BFloat16Type::get(ctx);
-}
-
-//===----------------------------------------------------------------------===//
-// Float16Type
-
-class Float16Type : public Type::TypeBase<Float16Type, FloatType, TypeStorage> {
-public:
- using Base::Base;
-
- /// Return an instance of the float16 type.
- static Float16Type get(MLIRContext *context);
-};
-
-inline FloatType FloatType::getF16(MLIRContext *ctx) {
- return Float16Type::get(ctx);
-}
-
-//===----------------------------------------------------------------------===//
-// Float32Type
-
-class Float32Type : public Type::TypeBase<Float32Type, FloatType, TypeStorage> {
-public:
- using Base::Base;
-
- /// Return an instance of the float32 type.
- static Float32Type get(MLIRContext *context);
-};
-
-inline FloatType FloatType::getF32(MLIRContext *ctx) {
- return Float32Type::get(ctx);
-}
-
-//===----------------------------------------------------------------------===//
-// Float64Type
-
-class Float64Type : public Type::TypeBase<Float64Type, FloatType, TypeStorage> {
-public:
- using Base::Base;
-
- /// Return an instance of the float64 type.
- static Float64Type get(MLIRContext *context);
-};
-
-inline FloatType FloatType::getF64(MLIRContext *ctx) {
- return Float64Type::get(ctx);
-}
-
//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//
@@ -276,20 +198,6 @@ class FunctionType
ArrayRef<unsigned> resultIndices);
};
-//===----------------------------------------------------------------------===//
-// NoneType
-//===----------------------------------------------------------------------===//
-
-/// NoneType is a unit type, i.e. a type with exactly one possible value, where
-/// its value does not have a defined dynamic representation.
-class NoneType : public Type::TypeBase<NoneType, Type, TypeStorage> {
-public:
- using Base::Base;
-
- /// Get an instance of the NoneType.
- static NoneType get(MLIRContext *context);
-};
-
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
@@ -720,11 +628,20 @@ class TupleType
return getTypes()[index];
}
};
+} // end namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Tablegen Type Declarations
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/IR/BuiltinTypes.h.inc"
//===----------------------------------------------------------------------===//
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
+namespace mlir {
inline bool BaseMemRefType::classof(Type type) {
return type.isa<MemRefType, UnrankedMemRefType>();
}
@@ -733,6 +650,22 @@ inline bool FloatType::classof(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
}
+inline FloatType FloatType::getBF16(MLIRContext *ctx) {
+ return BFloat16Type::get(ctx);
+}
+
+inline FloatType FloatType::getF16(MLIRContext *ctx) {
+ return Float16Type::get(ctx);
+}
+
+inline FloatType FloatType::getF32(MLIRContext *ctx) {
+ return Float32Type::get(ctx);
+}
+
+inline FloatType FloatType::getF64(MLIRContext *ctx) {
+ return Float64Type::get(ctx);
+}
+
inline bool ShapedType::classof(Type type) {
return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
UnrankedMemRefType, MemRefType>();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
new file mode 100644
index 000000000000..b540554fb11e
--- /dev/null
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -0,0 +1,114 @@
+//===- BuiltinTypes.td - Builtin type definitions ----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the set of builtin MLIR types, or the set of types necessary for the
+// validity of and defining the IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUILTIN_TYPES
+#define BUILTIN_TYPES
+
+include "mlir/IR/BuiltinDialect.td"
+
+// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
+// This is to
diff erentiate the types here with the ones in OpBase.td. We should
+// remove the definitions in OpBase.td, and repoint users to this file instead.
+
+// Base class for Builtin dialect types.
+class Builtin_Type<string name> : TypeDef<Builtin_Dialect, name> {
+ let mnemonic = ?;
+}
+
+//===----------------------------------------------------------------------===//
+// FloatType
+//===----------------------------------------------------------------------===//
+
+// Base class for Builtin dialect float types.
+class Builtin_FloatType<string name> : TypeDef<Builtin_Dialect, name,
+ "::mlir::FloatType"> {
+ let extraClassDeclaration = [{
+ static }] # name # [{Type get(MLIRContext *context);
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// BFloat16Type
+
+def Builtin_BFloat16 : Builtin_FloatType<"BFloat16"> {
+ let summary = "bfloat16 floating-point type";
+}
+
+//===----------------------------------------------------------------------===//
+// Float16Type
+
+def Builtin_Float16 : Builtin_FloatType<"Float16"> {
+ let summary = "16-bit floating-point type";
+}
+
+//===----------------------------------------------------------------------===//
+// Float32Type
+
+def Builtin_Float32 : Builtin_FloatType<"Float32"> {
+ let summary = "32-bit floating-point type";
+}
+
+//===----------------------------------------------------------------------===//
+// Float64Type
+
+def Builtin_Float64 : Builtin_FloatType<"Float64"> {
+ let summary = "64-bit floating-point type";
+}
+
+//===----------------------------------------------------------------------===//
+// IndexType
+//===----------------------------------------------------------------------===//
+
+def Builtin_Index : Builtin_Type<"Index"> {
+ let summary = "Integer-like type with unknown platform-dependent bit width";
+ let description = [{
+ Syntax:
+
+ ```
+ // Target word-sized integer.
+ index-type ::= `index`
+ ```
+
+ The index type is a signless integer whose size is equal to the natural
+ machine word of the target ( [rationale](https://mlir.llvm.org/docs/Rationale/Rationale/#integer-signedness-semantics) )
+ and is used by the affine constructs in MLIR. Unlike fixed-size integers,
+ it cannot be used as an element of vector ( [rationale](https://mlir.llvm.org/docs/Rationale/Rationale/#index-type-disallowed-in-vector-types) ).
+
+ **Rationale:** integers of platform-specific bit widths are practical to
+ express sizes, dimensionalities and subscripts.
+ }];
+ let extraClassDeclaration = [{
+ static IndexType get(MLIRContext *context);
+
+ /// Storage bit width used for IndexType by internal compiler data
+ /// structures.
+ static constexpr unsigned kInternalStorageBitWidth = 64;
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// NoneType
+//===----------------------------------------------------------------------===//
+
+def Builtin_None : Builtin_Type<"None"> {
+ let summary = "A unit type";
+ let description = [{
+ NoneType is a unit type, i.e. a type with exactly one possible value, where
+ its value does not have a defined dynamic representation.
+ }];
+ let extraClassDeclaration = [{
+ static NoneType get(MLIRContext *context);
+ }];
+}
+
+#endif // BUILTIN_TYPES
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 45649b3e7598..3b7ddbaf2338 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -2,10 +2,18 @@ add_mlir_interface(OpAsmInterface)
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)
+set(LLVM_TARGET_DEFINITIONS BuiltinDialect.td)
+mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
+add_public_tablegen_target(MLIRBuiltinDialectIncGen)
+
set(LLVM_TARGET_DEFINITIONS BuiltinOps.td)
mlir_tablegen(BuiltinOps.h.inc -gen-op-decls)
mlir_tablegen(BuiltinOps.cpp.inc -gen-op-defs)
-mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
add_public_tablegen_target(MLIRBuiltinOpsIncGen)
+set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
+mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
+add_public_tablegen_target(MLIRBuiltinTypesIncGen)
+
add_mlir_doc(BuiltinOps -gen-op-doc Builtin Dialects/)
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index b031769022d9..aa5ef284de11 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2415,15 +2415,18 @@ def replaceWithValue;
// Data type generation
//===----------------------------------------------------------------------===//
-// Define a new type belonging to a dialect and called 'name'.
-class TypeDef<Dialect owningdialect, string name> {
- Dialect dialect = owningdialect;
+// Define a new type, named `name`, belonging to `dialect` that inherits from
+// the given C++ base class.
+class TypeDef<Dialect dialect, string name,
+ string baseCppClass = "::mlir::Type">
+ : DialectType<dialect, CPred<"">> {
+ // The name of the C++ Type class.
string cppClassName = name # "Type";
+ // The name of the C++ base class to use for this Type.
+ string cppBaseClassName = baseCppClass;
// Short summary of the type.
string summary = ?;
- // The longer description of this type.
- string description = ?;
// Name of storage class to generate or use.
string storageClass = name # "TypeStorage";
@@ -2477,6 +2480,15 @@ class TypeDef<Dialect owningdialect, string name> {
bit genVerifyInvariantsDecl = 0;
// Extra code to include in the class declaration.
code extraClassDeclaration = [{}];
+
+ // The predicate for when this type is used as a type constraint.
+ let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
+ "::" # cppClassName # ">()">;
+ // A constant builder provided when the type has no parameters.
+ let builderCall = !if(!empty(parameters),
+ "$_builder.getType<" # dialect.cppNamespace #
+ "::" # cppClassName # ">()",
+ "");
}
// 'Parameters' should be subclasses of this or simple strings (which is a
diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h
index 462fed322438..796f5cc17859 100644
--- a/mlir/include/mlir/TableGen/TypeDef.h
+++ b/mlir/include/mlir/TableGen/TypeDef.h
@@ -48,6 +48,9 @@ class TypeDef {
// Returns the name of the C++ class to generate.
StringRef getCppClassName() const;
+ // Returns the name of the C++ base class to use when generating this type.
+ StringRef getCppBaseClassName() const;
+
// Returns the name of the storage class for this type.
StringRef getStorageClassName() const;
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index aa4eba07cf53..68cf491850d1 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -20,6 +20,13 @@
using namespace mlir;
using namespace mlir::detail;
+//===----------------------------------------------------------------------===//
+/// Tablegen Type Definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/IR/BuiltinTypes.cpp.inc"
+
//===----------------------------------------------------------------------===//
/// ComplexType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 42cdb3a91a50..e4b3b9141323 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -33,7 +33,9 @@ add_mlir_library(MLIRIR
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
DEPENDS
+ MLIRBuiltinDialectIncGen
MLIRBuiltinOpsIncGen
+ MLIRBuiltinTypesIncGen
MLIRCallInterfacesIncGen
MLIROpAsmInterfaceIncGen
MLIRRegionKindInterfaceIncGen
diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index b8e1b9f05acd..71606f90cf45 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -13,6 +13,7 @@
#include "mlir/TableGen/Constraint.h"
#include "llvm/TableGen/Record.h"
+using namespace mlir;
using namespace mlir::tblgen;
Constraint::Constraint(const llvm::Record *record)
@@ -56,11 +57,18 @@ std::string Constraint::getConditionTemplate() const {
return getPredicate().getCondition();
}
-llvm::StringRef Constraint::getDescription() const {
- auto doc = def->getValueAsString("description");
- if (doc.empty())
- return def->getName();
- return doc;
+StringRef Constraint::getDescription() const {
+ // If a summary is found, we use that given that it is a focused single line
+ // comment.
+ if (Optional<StringRef> summary = def->getValueAsOptionalString("summary"))
+ return *summary;
+ // If a summary can't be found, look for a specific description field to use
+ // for the constraint.
+ StringRef desc = def->getValueAsString("description");
+ if (!desc.empty())
+ return desc;
+ // Otherwise, fallback to the name of the constraint definition.
+ return def->getName();
}
AppliedConstraint::AppliedConstraint(Constraint &&constraint,
diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp
index d8412b6b4be5..b666d02d7eb9 100644
--- a/mlir/lib/TableGen/TypeDef.cpp
+++ b/mlir/lib/TableGen/TypeDef.cpp
@@ -31,6 +31,10 @@ StringRef TypeDef::getCppClassName() const {
return def->getValueAsString("cppClassName");
}
+StringRef TypeDef::getCppBaseClassName() const {
+ return def->getValueAsString("cppBaseClassName");
+}
+
bool TypeDef::hasDescription() const {
const llvm::RecordVal *s = def->getValue("description");
return s != nullptr && isa<llvm::StringInit>(s->getValue());
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 5e9bae893363..8aec98434f15 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -15,7 +15,6 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
-#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringSwitch.h"
using namespace mlir;
@@ -183,77 +182,6 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
return builder.create<TestOpConstant>(loc, type, value);
}
-static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
- llvm::SetVector<Type> &stack) {
- StringRef typeTag;
- if (failed(parser.parseKeyword(&typeTag)))
- return Type();
-
- auto genType = generatedTypeParser(ctxt, parser, typeTag);
- if (genType != Type())
- return genType;
-
- if (typeTag == "test_type")
- return TestType::get(parser.getBuilder().getContext());
-
- if (typeTag != "test_rec")
- return Type();
-
- StringRef name;
- if (parser.parseLess() || parser.parseKeyword(&name))
- return Type();
- auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
-
- // If this type already has been parsed above in the stack, expect just the
- // name.
- if (stack.contains(rec)) {
- if (failed(parser.parseGreater()))
- return Type();
- return rec;
- }
-
- // Otherwise, parse the body and update the type.
- if (failed(parser.parseComma()))
- return Type();
- stack.insert(rec);
- Type subtype = parseTestType(ctxt, parser, stack);
- stack.pop_back();
- if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
- return Type();
-
- return rec;
-}
-
-Type TestDialect::parseType(DialectAsmParser &parser) const {
- llvm::SetVector<Type> stack;
- return parseTestType(getContext(), parser, stack);
-}
-
-static void printTestType(Type type, DialectAsmPrinter &printer,
- llvm::SetVector<Type> &stack) {
- if (succeeded(generatedTypePrinter(type, printer)))
- return;
- if (type.isa<TestType>()) {
- printer << "test_type";
- return;
- }
-
- auto rec = type.cast<TestRecursiveType>();
- printer << "test_rec<" << rec.getName();
- if (!stack.contains(rec)) {
- printer << ", ";
- stack.insert(rec);
- printTestType(rec.getBody(), printer, stack);
- stack.pop_back();
- }
- printer << ">";
-}
-
-void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
- llvm::SetVector<Type> stack;
- printTestType(type, printer, stack);
-}
-
LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr")
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 0a91be30b53a..14a3e862d85c 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -12,9 +12,12 @@
//===----------------------------------------------------------------------===//
#include "TestTypes.h"
+#include "TestDialect.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -116,5 +119,84 @@ LogicalResult TestIntegerType::verifyConstructionInvariants(
return success();
}
+//===----------------------------------------------------------------------===//
+// Tablegen Generated Definitions
+//===----------------------------------------------------------------------===//
+
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// TestDialect
+//===----------------------------------------------------------------------===//
+
+static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
+ llvm::SetVector<Type> &stack) {
+ StringRef typeTag;
+ if (failed(parser.parseKeyword(&typeTag)))
+ return Type();
+
+ auto genType = generatedTypeParser(ctxt, parser, typeTag);
+ if (genType != Type())
+ return genType;
+
+ if (typeTag == "test_type")
+ return TestType::get(parser.getBuilder().getContext());
+
+ if (typeTag != "test_rec")
+ return Type();
+
+ StringRef name;
+ if (parser.parseLess() || parser.parseKeyword(&name))
+ return Type();
+ auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
+
+ // If this type already has been parsed above in the stack, expect just the
+ // name.
+ if (stack.contains(rec)) {
+ if (failed(parser.parseGreater()))
+ return Type();
+ return rec;
+ }
+
+ // Otherwise, parse the body and update the type.
+ if (failed(parser.parseComma()))
+ return Type();
+ stack.insert(rec);
+ Type subtype = parseTestType(ctxt, parser, stack);
+ stack.pop_back();
+ if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
+ return Type();
+
+ return rec;
+}
+
+Type TestDialect::parseType(DialectAsmParser &parser) const {
+ llvm::SetVector<Type> stack;
+ return parseTestType(getContext(), parser, stack);
+}
+
+static void printTestType(Type type, DialectAsmPrinter &printer,
+ llvm::SetVector<Type> &stack) {
+ if (succeeded(generatedTypePrinter(type, printer)))
+ return;
+ if (type.isa<TestType>()) {
+ printer << "test_type";
+ return;
+ }
+
+ auto rec = type.cast<TestRecursiveType>();
+ printer << "test_rec<" << rec.getName();
+ if (!stack.contains(rec)) {
+ printer << ", ";
+ stack.insert(rec);
+ printTestType(rec.getBody(), printer, stack);
+ stack.pop_back();
+ }
+ printer << ">";
+}
+
+void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
+ llvm::SetVector<Type> stack;
+ printTestType(type, printer, stack);
+}
diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 6db866a10c8d..6e6e1c04643a 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -11,9 +11,6 @@ include "mlir/IR/OpBase.td"
// DECL: class DialectAsmPrinter;
// DECL: } // namespace mlir
-// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);
-// DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer);
-
// DEF: #ifdef GET_TYPEDEF_LIST
// DEF: #undef GET_TYPEDEF_LIST
// DEF: ::mlir::test::SimpleAType,
diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
index 0990a5afb884..64239441581d 100644
--- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
@@ -92,7 +92,7 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
/// llvm::formatv will call this function when using an instance as a
/// replacement value.
void format(raw_ostream &os, StringRef options) override {
- if (params.size() && prependComma)
+ if (!params.empty() && prependComma)
os << ", ";
switch (emitFormat) {
@@ -146,8 +146,9 @@ class DialectAsmPrinter;
/// case.
///
/// {0}: The name of the typeDef class.
+/// {1}: The name of the type base class.
static const char *const typeDefDeclSingletonBeginStr = R"(
- class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{
+ class {0}: public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
@@ -158,15 +159,16 @@ static const char *const typeDefDeclSingletonBeginStr = R"(
/// case.
///
/// {0}: The name of the typeDef class.
-/// {1}: The typeDef storage class namespace.
-/// {2}: The storage class name.
-/// {3}: The list of parameters with types.
+/// {1}: The name of the type base class.
+/// {2}: The typeDef storage class namespace.
+/// {3}: The storage class name.
+/// {4}: The list of parameters with types.
static const char *const typeDefDeclParametricBeginStr = R"(
- namespace {1} {
- struct {2};
+ namespace {2} {
+ struct {3};
}
- class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type,
- {1}::{2}> {{
+ class {0}: public ::mlir::Type::TypeBase<{0}, {1},
+ {2}::{3}> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
@@ -196,10 +198,11 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
// template.
if (typeDef.getNumParameters() == 0)
os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
- typeDef.getStorageNamespace(), typeDef.getStorageClassName());
+ typeDef.getCppBaseClassName());
else
os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
- typeDef.getStorageNamespace(), typeDef.getStorageClassName());
+ typeDef.getCppBaseClassName(), typeDef.getStorageNamespace(),
+ typeDef.getStorageClassName());
// Emit the extra declarations first in case there's a type definition in
// there.
@@ -208,8 +211,10 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
TypeParamCommaFormatter emitTypeNamePairsAfterComma(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
- os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
- typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
+ if (!params.empty()) {
+ os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
+ typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
+ }
// Emit the verify invariants declaration.
if (typeDef.genVerifyInvariantsDecl())
@@ -252,17 +257,9 @@ static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
// Output the common "header".
os << typeDefDeclHeader;
- if (typeDefs.size() > 0) {
+ if (!typeDefs.empty()) {
NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
- // Well known print/parse dispatch function declarations. These are called
- // from Dialect::parseType() and Dialect::printType() methods.
- os << " ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
- "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n";
- os << " ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
- "::mlir::DialectAsmPrinter& printer);\n";
- os << "\n";
-
// Declare all the type classes first (in case they reference each other).
for (const TypeDef &typeDef : typeDefs)
os << " class " << typeDef.getCppClassName() << ";\n";
@@ -488,14 +485,16 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
emitStorageClass(typeDef, os);
- os << llvm::formatv(
- "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
- " return Base::get(ctxt{2});\n}\n",
- typeDef.getCppClassName(),
- TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
- TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
- parameters));
+ if (!parameters.empty()) {
+ os << llvm::formatv(
+ "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
+ " return Base::get(ctxt{2});\n}\n",
+ typeDef.getCppClassName(),
+ TypeParamCommaFormatter(
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
+ TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
+ parameters));
+ }
// Emit the parameter accessors.
if (typeDef.genAccessors())
@@ -526,38 +525,40 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
/// Emit the dialect printer/parser dispatcher. User's code should call these
/// functions from their dialect's print/parse methods.
-static void emitParsePrintDispatch(SmallVectorImpl<TypeDef> &typeDefs,
- raw_ostream &os) {
- if (typeDefs.size() == 0)
+static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
+ if (llvm::none_of(types, [](const TypeDef &type) {
+ return type.getMnemonic().hasValue();
+ })) {
return;
- const Dialect &dialect = typeDefs.begin()->getDialect();
- NamespaceEmitter ns(os, dialect);
+ }
- // The parser dispatch is just a list of if-elses, matching on the mnemonic
- // and calling the class's parse function.
- os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
+ // The parser dispatch is just a list of if-elses, matching on the
+ // mnemonic and calling the class's parse function.
+ os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* "
+ "ctxt, "
"::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
- for (const TypeDef &typeDef : typeDefs)
- if (typeDef.getMnemonic())
+ for (const TypeDef &type : types)
+ if (type.getMnemonic())
os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
"{0}::{1}::parse(ctxt, parser);\n",
- typeDef.getDialect().getCppNamespace(),
- typeDef.getCppClassName());
+ type.getDialect().getCppNamespace(),
+ type.getCppClassName());
os << " return ::mlir::Type();\n";
os << "}\n\n";
// The printer dispatch uses llvm::TypeSwitch to find and call the correct
// printer.
- os << "::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
+ os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
+ "type, "
"::mlir::DialectAsmPrinter& printer) {\n"
<< " ::mlir::LogicalResult found = ::mlir::success();\n"
<< " ::llvm::TypeSwitch<::mlir::Type>(type)\n";
- for (auto typeDef : typeDefs)
- if (typeDef.getMnemonic())
+ for (const TypeDef &type : types)
+ if (type.getMnemonic())
os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ "
"t.dyn_cast<{0}::{1}>().print(printer); })\n",
- typeDef.getDialect().getCppNamespace(),
- typeDef.getCppClassName());
+ type.getDialect().getCppNamespace(),
+ type.getCppClassName());
os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); "
"});\n"
<< " return found;\n"
More information about the Mlir-commits
mailing list