[Mlir-commits] [mlir] 948be58 - [mlir][TypeDefGen] Add support for adding builders when generating a TypeDef
River Riddle
llvmlistbot at llvm.org
Mon Jan 11 12:08:46 PST 2021
Author: River Riddle
Date: 2021-01-11T12:06:22-08:00
New Revision: 948be58258dd81d56b1057657193f7dcf6dfa9bd
URL: https://github.com/llvm/llvm-project/commit/948be58258dd81d56b1057657193f7dcf6dfa9bd
DIFF: https://github.com/llvm/llvm-project/commit/948be58258dd81d56b1057657193f7dcf6dfa9bd.diff
LOG: [mlir][TypeDefGen] Add support for adding builders when generating a TypeDef
This allows for specifying additional get/getChecked methods that should be generated on the type, and acts similarly to how OpBuilders work. TypeBuilders have two additional components though:
* InferredContextParam
- Bit indicating that the context parameter of a get method is inferred from one of the builder parameters
* checkedBody
- A code block representing the body of the equivalent getChecked method.
Differential Revision: https://reviews.llvm.org/D94274
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/TypeDef.h
mlir/lib/TableGen/TypeDef.cpp
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/mlir-tblgen/typedefs.td
mlir/tools/mlir-tblgen/TypeDefGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index bfd3d43c60b9..dd522904dd73 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1536,6 +1536,171 @@ responsible for parsing/printing the types in `Dialect::printType` and
- The `extraClassDeclaration` field is used to include extra code in the class
declaration.
+### Type builder methods
+
+For each type, there are a few builders(`get`/`getChecked`) automatically
+generated based on the parameters of the type. For example, given the following
+type definition:
+
+```tablegen
+def MyType : ... {
+ let parameters = (ins "int":$intParam);
+}
+```
+
+The following builders are generated:
+
+```c++
+// Type builders are named `get`, and return a new instance of a type for a
+// given set of parameters.
+static MyType get(MLIRContext *context, int intParam);
+
+// If `genVerifyInvariantsDecl` is set to 1, the following method is also
+// generated.
+static MyType getChecked(Location loc, int intParam);
+```
+
+If these autogenerated methods are not desired, such as when they conflict with
+a custom builder method, a type can set `skipDefaultBuilders` to 1 to signal
+that they should not be generated.
+
+#### Custom type builder methods
+
+The default build methods may cover a majority of the simple cases related to
+type construction, but when they cannot satisfy a type's needs, you can define
+additional convenience get methods in the `builders` field as follows:
+
+```tablegen
+def MyType : ... {
+ let parameters = (ins "int":$intParam);
+
+ let builders = [
+ TypeBuilder<(ins "int":$intParam)>,
+ TypeBuilder<(ins CArg<"int", "0">:$intParam)>,
+ TypeBuilder<(ins CArg<"int", "0">:$intParam), [{
+ // Write the body of the `get` builder inline here.
+ return Base::get($_ctxt, intParam);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
+ // This builder states that it can infer an MLIRContext instance from
+ // its arguments.
+ return Base::get(typeParam.getContext(), ...);
+ }]>,
+ ];
+}
+```
+
+The `builders` field is a list of custom builders that are added to the type
+class. In this example, we provide a several
diff erent convenience builders that
+are useful in
diff erent scenarios. The `ins` prefix is common to many function
+declarations in ODS, which use a TableGen [`dag`](#tablegen-syntax). What
+follows is a comma-separated list of types (quoted string or CArg) and names
+prefixed with the `$` sign. The use of `CArg` allows for providing a default
+value to that argument. Let's take a look at each of these builders individually
+
+The first builder will generate the declaration of a builder method that looks
+like:
+
+```tablegen
+ let builders = [
+ TypeBuilder<(ins "int":$intParam)>,
+ ];
+```
+
+```c++
+class MyType : /*...*/ {
+ /*...*/
+ static MyType get(::mlir::MLIRContext *context, int intParam);
+};
+```
+
+This builder is identical to the one that will be automatically generated for
+`MyType`. The `context` parameter is implicitly added by the generator, and is
+used when building the file Type instance (with `Base::get`). The distinction
+here is that we can provide the implementation of this `get` method. With this
+style of builder definition only the declaration is generated, the implementor
+of MyType will need to provide a definition of `MyType::get`.
+
+The second builder will generate the declaration of a builder method that looks
+like:
+
+```tablegen
+ let builders = [
+ TypeBuilder<(ins CArg<"int", "0">:$intParam)>,
+ ];
+```
+
+```c++
+class MyType : /*...*/ {
+ /*...*/
+ static MyType get(::mlir::MLIRContext *context, int intParam = 0);
+};
+```
+
+The constraints here are identical to the first builder example except for the
+fact that `intParam` now has a default value attached.
+
+The third builder will generate the declaration of a builder method that looks
+like:
+
+```tablegen
+ let builders = [
+ TypeBuilder<(ins CArg<"int", "0">:$intParam), [{
+ // Write the body of the `get` builder inline here.
+ return Base::get($_ctxt, intParam);
+ }]>,
+ ];
+```
+
+```c++
+class MyType : /*...*/ {
+ /*...*/
+ static MyType get(::mlir::MLIRContext *context, int intParam = 0);
+};
+
+MyType MyType::get(::mlir::MLIRContext *context, int intParam) {
+ // Write the body of the `get` builder inline here.
+ return Base::get(context, intParam);
+}
+```
+
+This is identical to the second builder example. The
diff erence is that now, a
+definition for the builder method will be generated automatically using the
+provided code block as the body. When specifying the body inline, `$_ctxt` may
+be used to access the `MLIRContext *` parameter.
+
+The fourth builder will generate the declaration of a builder method that looks
+like:
+
+```tablegen
+ let builders = [
+ TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
+ // This builder states that it can infer an MLIRContext instance from
+ // its arguments.
+ return Base::get(typeParam.getContext(), ...);
+ }]>,
+ ];
+```
+
+```c++
+class MyType : /*...*/ {
+ /*...*/
+ static MyType get(Type typeParam);
+};
+
+MyType MyType::get(Type typeParam) {
+ // This builder states that it can infer an MLIRContext instance from its
+ // arguments.
+ return Base::get(typeParam.getContext(), ...);
+}
+```
+
+In this builder example, the main
diff erence from the third builder example
+three is that the `MLIRContext` parameter is no longer added. This is because
+the builder type used `TypeBuilderWithInferredContext` implies that the context
+parameter is not necessary as it can be inferred from the arguments to the
+builder.
+
## Debugging Tips
### Run `mlir-tblgen` to see the generated content
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index b1a42e99126c..3bfbccf618d6 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -74,7 +74,7 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
VectorType vector;
if ($_parser.parseType(vector))
return Type();
- return get(ctxt, vector.getShape(), vector.getElementType());
+ return get($_ctxt, vector.getShape(), vector.getElementType());
}];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index dc3e8a6367cd..73ddbc1d56eb 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2430,6 +2430,73 @@ def replaceWithValue;
// Data type generation
//===----------------------------------------------------------------------===//
+// Class for defining a custom type getter.
+//
+// TableGen generates several generic getter methods for each type by default,
+// corresponding to the specified dag parameters. If the default generated ones
+// cannot cover some use case, custom getters can be defined using instances of
+// this class.
+//
+// The signature of the `get` is always either:
+//
+// ```c++
+// static <Type-Name> get(MLIRContext *context, <other-parameters>...) {
+// <body>...
+// }
+// ```
+//
+// or:
+//
+// ```c++
+// static <TypeName> get(MLIRContext *context, <parameters>...);
+// ```
+//
+// To define a custom getter, the parameter list and body should be passed
+// in as separate template arguments to this class. The parameter list is a
+// TableGen DAG with `ins` operation with named arguments, which has either:
+// - string initializers ("Type":$name) to represent a typed parameter, or
+// - CArg-typed initializers (CArg<"Type", "default">:$name) to represent a
+// typed parameter that may have a default value.
+// The type string is used verbatim to produce code and, therefore, must be a
+// valid C++ type. It is used inside the C++ namespace of the parent Type's
+// dialect; explicit namespace qualification like `::mlir` may be necessary if
+// Types are not placed inside the `mlir` namespace. The default value string is
+// used verbatim to produce code and must be a valid C++ initializer the given
+// type. For example, the following signature specification
+//
+// ```
+// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)>
+// ```
+//
+// has an integer parameter and a float parameter with a default value.
+//
+// If an empty string is passed in for `body`, then *only* the builder
+// declaration will be generated; this provides a way to define complicated
+// builders entirely in C++.
+//
+// `checkedBody` is similar to `body`, but is the code block used when
+// generating a `getChecked` method.
+class TypeBuilder<dag parameters, code bodyCode = "",
+ code checkedBodyCode = ""> {
+ dag dagParams = parameters;
+ code body = bodyCode;
+ code checkedBody = checkedBodyCode;
+
+ // The context parameter can be inferred from one of the other parameters and
+ // is not implicitly added to the parameter list.
+ bit hasInferredContextParam = 0;
+}
+
+// A class of TypeBuilder that is able to infer the MLIRContext parameter from
+// one of the other builder parameters. Instances of this builder do not have
+// `MLIRContext *` implicitly added to the parameter list.
+class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
+ code checkedBodyCode = "">
+ : TypeBuilder<parameters, bodyCode> {
+ code checkedBody = checkedBodyCode;
+ let hasInferredContextParam = 1;
+}
+
// Define a new type, named `name`, belonging to `dialect` that inherits from
// the given C++ base class.
class TypeDef<Dialect dialect, string name,
@@ -2475,6 +2542,18 @@ class TypeDef<Dialect dialect, string name,
// for re-allocating ArrayRefs. It is defined below.)
dag parameters = (ins);
+ // Custom type builder methods.
+ // In addition to the custom builders provided here, and unless
+ // skipDefaultBuilders is set, a default builder is generated with the
+ // following signature:
+ //
+ // ```c++
+ // static <TypeName> get(MLIRContext *, <parameters>);
+ // ```
+ //
+ // Note that builders should only be provided when a type has parameters.
+ list<TypeBuilder> builders = ?;
+
// Use the lowercased name as the keyword for parsing/printing. Specify only
// if you want tblgen to generate declarations and/or definitions of
// printer/parser for this type.
@@ -2488,6 +2567,9 @@ class TypeDef<Dialect dialect, string name,
// If set, generate accessors for each Type parameter.
bit genAccessors = 1;
+ // Avoid generating default get/getChecked functions. Custom get methods must
+ // be provided.
+ bit skipDefaultBuilders = 0;
// Generate the verifyConstructionInvariants declaration and getChecked
// method.
bit genVerifyInvariantsDecl = 0;
diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h
index 1be5140011f0..73a3a1002d0a 100644
--- a/mlir/include/mlir/TableGen/TypeDef.h
+++ b/mlir/include/mlir/TableGen/TypeDef.h
@@ -14,24 +14,45 @@
#define MLIR_TABLEGEN_TYPEDEF_H
#include "mlir/Support/LLVM.h"
-#include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/Builder.h"
namespace llvm {
-class Record;
class DagInit;
+class Record;
class SMLoc;
} // namespace llvm
namespace mlir {
namespace tblgen {
-
+class Dialect;
class TypeParameter;
+//===----------------------------------------------------------------------===//
+// TypeBuilder
+//===----------------------------------------------------------------------===//
+
+/// Wrapper class that represents a Tablegen TypeBuilder.
+class TypeBuilder : public Builder {
+public:
+ using Builder::Builder;
+
+ /// Return an optional code body used for the `getChecked` variant of this
+ /// builder.
+ Optional<StringRef> getCheckedBody() const;
+
+ /// Returns true if this builder is able to infer the MLIRContext parameter.
+ bool hasInferredContextParameter() const;
+};
+
+//===----------------------------------------------------------------------===//
+// TypeDef
+//===----------------------------------------------------------------------===//
+
/// Wrapper class that contains a TableGen TypeDef's record and provides helper
/// methods for accessing them.
class TypeDef {
public:
- explicit TypeDef(const llvm::Record *def) : def(def) {}
+ explicit TypeDef(const llvm::Record *def);
// Get the dialect for which this type belongs.
Dialect getDialect() const;
@@ -95,6 +116,13 @@ class TypeDef {
// Get the code location (for error printing).
ArrayRef<llvm::SMLoc> getLoc() const;
+ // Returns true if the default get/getChecked methods should be skipped during
+ // generation.
+ bool skipDefaultBuilders() const;
+
+ // Returns the builders of this type.
+ ArrayRef<TypeBuilder> getBuilders() const { return builders; }
+
// Returns whether two TypeDefs are equal by checking the equality of the
// underlying record.
bool operator==(const TypeDef &other) const;
@@ -107,8 +135,15 @@ class TypeDef {
private:
const llvm::Record *def;
+
+ // The builders of this type definition.
+ SmallVector<TypeBuilder> builders;
};
+//===----------------------------------------------------------------------===//
+// TypeParameter
+//===----------------------------------------------------------------------===//
+
// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs
// to parameterize them.
class TypeParameter {
diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp
index 5c4f5348d1b3..d6adbc20ef76 100644
--- a/mlir/lib/TableGen/TypeDef.cpp
+++ b/mlir/lib/TableGen/TypeDef.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/TypeDef.h"
+#include "mlir/TableGen/Dialect.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -18,6 +19,26 @@
using namespace mlir;
using namespace mlir::tblgen;
+//===----------------------------------------------------------------------===//
+// TypeBuilder
+//===----------------------------------------------------------------------===//
+
+/// Return an optional code body used for the `getChecked` variant of this
+/// builder.
+Optional<StringRef> TypeBuilder::getCheckedBody() const {
+ Optional<StringRef> body = def->getValueAsOptionalString("checkedBody");
+ return body && !body->empty() ? body : llvm::None;
+}
+
+/// Returns true if this builder is able to infer the MLIRContext parameter.
+bool TypeBuilder::hasInferredContextParameter() const {
+ return def->getValueAsBit("hasInferredContextParam");
+}
+
+//===----------------------------------------------------------------------===//
+// TypeDef
+//===----------------------------------------------------------------------===//
+
Dialect TypeDef::getDialect() const {
auto *dialectDef =
dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
@@ -98,6 +119,11 @@ llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
return value.empty() ? llvm::Optional<StringRef>() : value;
}
llvm::ArrayRef<llvm::SMLoc> TypeDef::getLoc() const { return def->getLoc(); }
+
+bool TypeDef::skipDefaultBuilders() const {
+ return def->getValueAsBit("skipDefaultBuilders");
+}
+
bool TypeDef::operator==(const TypeDef &other) const {
return def == other.def;
}
@@ -106,6 +132,33 @@ bool TypeDef::operator<(const TypeDef &other) const {
return getName() < other.getName();
}
+//===----------------------------------------------------------------------===//
+// TypeParameter
+//===----------------------------------------------------------------------===//
+
+TypeDef::TypeDef(const llvm::Record *def) : def(def) {
+ // Populate the builders.
+ auto *builderList =
+ dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
+ if (builderList && !builderList->empty()) {
+ for (llvm::Init *init : builderList->getValues()) {
+ TypeBuilder builder(cast<llvm::DefInit>(init)->getDef(), def->getLoc());
+
+ // Ensure that all parameters have names.
+ for (const TypeBuilder::Parameter ¶m : builder.getParameters()) {
+ if (!param.getName())
+ PrintFatalError(def->getLoc(),
+ "type builder parameters must have a name");
+ }
+ builders.emplace_back(builder);
+ }
+ } else if (skipDefaultBuilders()) {
+ PrintFatalError(
+ def->getLoc(),
+ "default builders are skipped and no custom builders provided");
+ }
+}
+
StringRef TypeParameter::getName() const {
return def->getArgName(num)->getValue();
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 80927dff62c2..0e2c11a2ecb0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -51,9 +51,9 @@ def IntegerType : Test_Type<"TestInteger"> {
let genVerifyInvariantsDecl = 1;
let parameters = (
ins
+ "unsigned":$width,
// SignednessSemantics is defined below.
- "::mlir::test::TestIntegerType::SignednessSemantics":$signedness,
- "unsigned":$width
+ "::mlir::test::TestIntegerType::SignednessSemantics":$signedness
);
// We define the printer inline.
@@ -63,6 +63,17 @@ def IntegerType : Test_Type<"TestInteger"> {
$_printer << ", " << getImpl()->width << ">";
}];
+ // Define custom builder methods.
+ let builders = [
+ TypeBuilder<(ins "unsigned":$width,
+ CArg<"SignednessSemantics", "Signless">:$signedness), [{
+ return Base::get($_ctxt, width, signedness);
+ }], [{
+ return Base::getChecked($_loc, width, signedness);
+ }]>
+ ];
+ let skipDefaultBuilders = 1;
+
// The parser is defined here also.
let parser = [{
if (parser.parseLess()) return Type();
@@ -73,7 +84,7 @@ def IntegerType : Test_Type<"TestInteger"> {
if ($_parser.parseInteger(width)) return Type();
if ($_parser.parseGreater()) return Type();
Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
- return getChecked(loc, signedness, width);
+ return getChecked(loc, width, signedness);
}];
// Any extra code one wants in the type's class declaration.
@@ -85,9 +96,6 @@ def IntegerType : Test_Type<"TestInteger"> {
Unsigned, /// Unsigned integer
};
- /// This extra function is necessary since it doesn't include signedness
- static IntegerType getChecked(unsigned width, Location location);
-
/// Return true if this is a signless integer type.
bool isSignless() const { return getSignedness() == Signless; }
/// Return true if this is a signed integer type.
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 14a3e862d85c..094e5c9fc631 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -113,7 +113,7 @@ static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
// Example type validity checker.
LogicalResult TestIntegerType::verifyConstructionInvariants(
- Location loc, TestIntegerType::SignednessSemantics ss, unsigned int width) {
+ Location loc, unsigned width, TestIntegerType::SignednessSemantics ss) {
if (width > 8)
return failure();
return success();
diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 6e6e1c04643a..5471519ab60a 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -19,8 +19,8 @@ include "mlir/IR/OpBase.td"
// DEF: ::mlir::test::SingleParameterType,
// DEF: ::mlir::test::IntegerType
-// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic)
-// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser);
+// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &parser, ::llvm::StringRef mnemonic)
+// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(context, parser);
// DEF return ::mlir::Type();
def Test_Dialect: Dialect {
@@ -33,7 +33,7 @@ def Test_Dialect: Dialect {
class TestType<string name> : TypeDef<Test_Dialect, name> { }
def A_SimpleTypeA : TestType<"SimpleA"> {
-// DECL: class SimpleAType: public ::mlir::Type
+// DECL: class SimpleAType : public ::mlir::Type
}
def RTLValueType : Type<CPred<"isRTLValueType($_self)">, "Type"> {
@@ -56,12 +56,13 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
let genVerifyInvariantsDecl = 1;
-// DECL-LABEL: class CompoundAType: public ::mlir::Type
+// DECL-LABEL: class CompoundAType : public ::mlir::Type
+// DECL: static CompoundAType getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
-// DECL: static ::mlir::Type getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
-// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
-// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
+// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context,
+// DECL-NEXT: ::mlir::DialectAsmParser &parser);
+// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
// DECL: int getWidthOfSomething() const;
// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
// DECL: SomeCppStruct getExampleCppType() const;
@@ -75,10 +76,11 @@ def C_IndexType : TestType<"Index"> {
StringRefParameter<"Label for index">:$label
);
-// DECL-LABEL: class IndexType: public ::mlir::Type
+// DECL-LABEL: class IndexType : public ::mlir::Type
// DECL: static ::llvm::StringRef getMnemonic() { return "index"; }
-// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
-// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
+// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context,
+// DECL-NEXT: ::mlir::DialectAsmParser &parser);
+// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
}
def D_SingleParameterType : TestType<"SingleParameter"> {
@@ -100,7 +102,7 @@ def E_IntegerType : TestType<"Integer"> {
TypeParameter<"unsigned", "Bitwidth of integer">:$width
);
-// DECL-LABEL: IntegerType: public ::mlir::Type
+// DECL-LABEL: IntegerType : public ::mlir::Type
let extraClassDeclaration = [{
/// Signedness semantics.
diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
index 20168168bc8d..9cbd53270983 100644
--- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
@@ -148,7 +148,7 @@ class DialectAsmPrinter;
/// {0}: The name of the typeDef class.
/// {1}: The name of the type base class.
static const char *const typeDefDeclSingletonBeginStr = R"(
- class {0}: public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
+ class {0} : public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
@@ -166,9 +166,9 @@ static const char *const typeDefDeclSingletonBeginStr = R"(
static const char *const typeDefDeclParametricBeginStr = R"(
namespace {2} {
struct {3};
- }
- class {0}: public ::mlir::Type::TypeBase<{0}, {1},
- {2}::{3}> {{
+ } // end namespace {2}
+ class {0} : public ::mlir::Type::TypeBase<{0}, {1},
+ {2}::{3}> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
@@ -177,18 +177,68 @@ static const char *const typeDefDeclParametricBeginStr = R"(
/// The snippet for print/parse.
static const char *const typeDefParsePrint = R"(
- static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
- void print(::mlir::DialectAsmPrinter& printer) const;
+ static ::mlir::Type parse(::mlir::MLIRContext *context,
+ ::mlir::DialectAsmParser &parser);
+ void print(::mlir::DialectAsmPrinter &printer) const;
)";
/// The code block for the verifyConstructionInvariants and getChecked.
///
-/// {0}: List of parameters, parameters style.
+/// {0}: The name of the typeDef class.
+/// {1}: List of parameters, parameters style.
static const char *const typeDefDeclVerifyStr = R"(
- static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{0});
- static ::mlir::Type getChecked(::mlir::Location loc{0});
+ static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{1});
)";
+/// Emit the builders for the given type.
+static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
+ TypeParamCommaFormatter ¶mTypes) {
+ StringRef typeClass = typeDef.getCppClassName();
+ bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
+ if (!typeDef.skipDefaultBuilders()) {
+ os << llvm::formatv(
+ " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
+ paramTypes);
+ if (genCheckedMethods) {
+ os << llvm::formatv(
+ " static {0} getChecked(::mlir::Location loc{1});\n", typeClass,
+ paramTypes);
+ }
+ }
+
+ // Generate the builders specified by the user.
+ for (const TypeBuilder &builder : typeDef.getBuilders()) {
+ std::string paramStr;
+ llvm::raw_string_ostream paramOS(paramStr);
+ llvm::interleaveComma(
+ builder.getParameters(), paramOS,
+ [&](const TypeBuilder::Parameter ¶m) {
+ // Note: TypeBuilder parameters are guaranteed to have names.
+ paramOS << param.getCppType() << " " << *param.getName();
+ if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
+ paramOS << " = " << *defaultParamValue;
+ });
+ paramOS.flush();
+
+ // Generate the `get` variant of the builder.
+ os << " static " << typeClass << " get(";
+ if (!builder.hasInferredContextParameter()) {
+ os << "::mlir::MLIRContext *context";
+ if (!paramStr.empty())
+ os << ", ";
+ }
+ os << paramStr << ");\n";
+
+ // Generate the `getChecked` variant of the builder.
+ if (genCheckedMethods) {
+ os << " static " << typeClass << " getChecked(::mlir::Location loc";
+ if (!paramStr.empty())
+ os << ", " << paramStr;
+ os << ");\n";
+ }
+ }
+}
+
/// Generate the declaration for the given typeDef class.
static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> params;
@@ -212,13 +262,13 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
TypeParamCommaFormatter emitTypeNamePairsAfterComma(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
if (!params.empty()) {
- os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
- typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
- }
+ emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma);
- // Emit the verify invariants declaration.
- if (typeDef.genVerifyInvariantsDecl())
- os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
+ // Emit the verify invariants declaration.
+ if (typeDef.genVerifyInvariantsDecl())
+ os << llvm::formatv(typeDefDeclVerifyStr, typeDef.getCppClassName(),
+ emitTypeNamePairsAfterComma);
+ }
// Emit the mnenomic, if specified.
if (auto mnenomic = typeDef.getMnemonic()) {
@@ -226,7 +276,8 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
<< "\"; }\n";
// If mnemonic specified, emit print/parse declarations.
- os << typeDefParsePrint;
+ if (typeDef.getParserCode() || typeDef.getPrinterCode() || !params.empty())
+ os << typeDefParsePrint;
}
if (typeDef.genAccessors()) {
@@ -330,17 +381,6 @@ static const char *const typeDefStorageClassConstructorReturn = R"(
}
)";
-/// The code block for the getChecked definition.
-///
-/// {0}: List of parameters, parameters style.
-/// {1}: C++ type class name.
-/// {2}: Comma separated list of parameter names.
-static const char *const typeDefDefGetCheckeStr = R"(
- ::mlir::Type {1}::getChecked(Location loc{0}) {{
- return Base::getChecked(loc{2});
- }
-)";
-
/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> parameters;
@@ -403,13 +443,13 @@ static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
parameters, /* prependComma */ false));
// 3) Emit the construct method.
- if (typeDef.hasStorageCustomConstructor())
+ if (typeDef.hasStorageCustomConstructor()) {
// If user wants to build the storage constructor themselves, declare it
// here and then they can write the definition elsewhere.
os << " static " << typeDef.getStorageClassName()
<< " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
"&key);\n";
- else {
+ } else {
// If not, autogenerate one.
// First, unbox the parameters.
@@ -445,7 +485,7 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
// Both the mnenomic and printerCode must be defined (for parity with
// parserCode).
os << "void " << typeDef.getCppClassName()
- << "::print(::mlir::DialectAsmPrinter& printer) const {\n";
+ << "::print(::mlir::DialectAsmPrinter &printer) const {\n";
if (*printerCode == "") {
// If no code specified, emit error.
PrintFatalError(typeDef.getLoc(),
@@ -460,7 +500,7 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
if (auto parserCode = typeDef.getParserCode()) {
// The mnenomic must be defined so the dispatcher knows how to dispatch.
os << "::mlir::Type " << typeDef.getCppClassName()
- << "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& "
+ << "::parse(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &"
"parser) "
"{\n";
if (*parserCode == "") {
@@ -470,51 +510,112 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
": parser (if specified) must have non-empty code");
}
auto fmtCtxt =
- FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt");
+ FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "context");
os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
}
}
-/// Print all the typedef-specific definition code.
-static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
- NamespaceEmitter ns(os, typeDef.getDialect());
- SmallVector<TypeParameter, 4> parameters;
- typeDef.getParameters(parameters);
-
- // Emit the storage class, if requested and necessary.
- if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
- emitStorageClass(typeDef, os);
-
- if (!parameters.empty()) {
+/// Emit the builders for the given type.
+static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
+ ArrayRef<TypeParameter> typeDefParams) {
+ bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
+ StringRef typeClass = typeDef.getCppClassName();
+ if (!typeDef.skipDefaultBuilders()) {
os << llvm::formatv(
- "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
- " return Base::get(ctxt{2});\n}\n",
- typeDef.getCppClassName(),
+ "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
+ " return Base::get(context{2});\n}\n",
+ typeClass,
TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs, typeDefParams),
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
- parameters));
+ typeDefParams));
+ if (genCheckedMethods) {
+ os << llvm::formatv(
+ "{0} {0}::getChecked(::mlir::Location loc{1}) {{\n"
+ " return Base::getChecked(loc{2});\n}\n",
+ typeClass,
+ TypeParamCommaFormatter(
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
+ typeDefParams),
+ TypeParamCommaFormatter(
+ TypeParamCommaFormatter::EmitFormat::JustParams, typeDefParams));
+ }
}
- // Emit the parameter accessors.
- if (typeDef.genAccessors())
- for (const TypeParameter ¶meter : parameters) {
- SmallString<16> name = parameter.getName();
- name[0] = llvm::toUpper(name[0]);
- os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
- parameter.getCppType(), name, parameter.getName(),
- typeDef.getCppClassName());
+ // Generate the builders specified by the user.
+ auto builderFmtCtx = FmtContext().addSubst("_ctxt", "context");
+ auto checkedBuilderFmtCtx = FmtContext()
+ .addSubst("_loc", "loc")
+ .addSubst("_ctxt", "loc.getContext()");
+ for (const TypeBuilder &builder : typeDef.getBuilders()) {
+ Optional<StringRef> body = builder.getBody();
+ Optional<StringRef> checkedBody =
+ genCheckedMethods ? builder.getCheckedBody() : llvm::None;
+ if (!body && !checkedBody)
+ continue;
+ std::string paramStr;
+ llvm::raw_string_ostream paramOS(paramStr);
+ llvm::interleaveComma(builder.getParameters(), paramOS,
+ [&](const TypeBuilder::Parameter ¶m) {
+ // Note: TypeBuilder parameters are guaranteed to
+ // have names.
+ paramOS << param.getCppType() << " "
+ << *param.getName();
+ });
+ paramOS.flush();
+
+ // Emit the `get` variant of the builder.
+ if (body) {
+ os << llvm::formatv("{0} {0}::get(", typeClass);
+ if (!builder.hasInferredContextParameter()) {
+ os << "::mlir::MLIRContext *context";
+ if (!paramStr.empty())
+ os << ", ";
+ os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
+ tgfmt(*body, &builderFmtCtx).str());
+ } else {
+ os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, *body);
+ }
}
- // Generate getChecked() method.
- if (typeDef.genVerifyInvariantsDecl()) {
- os << llvm::formatv(
- typeDefDefGetCheckeStr,
- TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
- typeDef.getCppClassName(),
- TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
- parameters));
+ // Emit the `getChecked` variant of the builder.
+ if (checkedBody) {
+ os << llvm::formatv("{0} {0}::getChecked(::mlir::Location loc",
+ typeClass);
+ if (!paramStr.empty())
+ os << ", " << paramStr;
+ os << llvm::formatv(") {{\n {0};\n}\n",
+ tgfmt(*checkedBody, &checkedBuilderFmtCtx));
+ }
+ }
+}
+
+/// Print all the typedef-specific definition code.
+static void emitTypeDefDef(const TypeDef &typeDef, raw_ostream &os) {
+ NamespaceEmitter ns(os, typeDef.getDialect());
+
+ SmallVector<TypeParameter, 4> parameters;
+ typeDef.getParameters(parameters);
+ if (!parameters.empty()) {
+ // Emit the storage class, if requested and necessary.
+ if (typeDef.genStorageClass())
+ emitStorageClass(typeDef, os);
+
+ // Emit the builders for this type.
+ emitTypeBuilderDefs(typeDef, os, parameters);
+
+ // Generate accessor definitions only if we also generate the storage class.
+ // Otherwise, let the user define the exact accessor definition.
+ if (typeDef.genAccessors() && typeDef.genStorageClass()) {
+ // Emit the parameter accessors.
+ for (const TypeParameter ¶meter : parameters) {
+ SmallString<16> name = parameter.getName();
+ name[0] = llvm::toUpper(name[0]);
+ os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
+ parameter.getCppType(), name, parameter.getName(),
+ typeDef.getCppClassName());
+ }
+ }
}
// If mnemonic is specified maybe print definitions for the parser and printer
@@ -534,9 +635,9 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
// The parser dispatch is just a list of if-elses, matching on the
// mnemonic and calling the class's parse function.
- os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* "
- "ctxt, "
- "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
+ os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext *"
+ "context, ::mlir::DialectAsmParser &parser, "
+ "::llvm::StringRef mnemonic) {\n";
for (const TypeDef &type : types) {
if (type.getMnemonic()) {
os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
@@ -547,9 +648,9 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
// If the type has no parameters and no parser code, just invoke a normal
// `get`.
if (type.getNumParameters() == 0 && !type.getParserCode())
- os << "get(ctxt);\n";
+ os << "get(context);\n";
else
- os << "parse(ctxt, parser);\n";
+ os << "parse(context, parser);\n";
}
}
os << " return ::mlir::Type();\n";
@@ -559,7 +660,7 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
// printer.
os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
"type, "
- "::mlir::DialectAsmPrinter& printer) {\n"
+ "::mlir::DialectAsmPrinter &printer) {\n"
<< " return ::llvm::TypeSwitch<::mlir::Type, "
"::mlir::LogicalResult>(type)\n";
for (const TypeDef &type : types) {
@@ -594,7 +695,7 @@ static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
IfDefScope scope("GET_TYPEDEF_CLASSES", os);
emitParsePrintDispatch(typeDefs, os);
- for (auto typeDef : typeDefs)
+ for (const TypeDef &typeDef : typeDefs)
emitTypeDefDef(typeDef, os);
return false;
More information about the Mlir-commits
mailing list