[Mlir-commits] [mlir] 948be58 - [mlir][TypeDefGen] Add support for adding builders when generating a TypeDef

River Riddle llvmlistbot at llvm.org
Mon Jan 11 12:08:46 PST 2021


Author: River Riddle
Date: 2021-01-11T12:06:22-08:00
New Revision: 948be58258dd81d56b1057657193f7dcf6dfa9bd

URL: https://github.com/llvm/llvm-project/commit/948be58258dd81d56b1057657193f7dcf6dfa9bd
DIFF: https://github.com/llvm/llvm-project/commit/948be58258dd81d56b1057657193f7dcf6dfa9bd.diff

LOG: [mlir][TypeDefGen] Add support for adding builders when generating a TypeDef

This allows for specifying additional get/getChecked methods that should be generated on the type, and acts similarly to how OpBuilders work. TypeBuilders have two additional components though:
* InferredContextParam
  - Bit indicating that the context parameter of a get method is inferred from one of the builder parameters
* checkedBody
  - A code block representing the body of the equivalent getChecked method.

Differential Revision: https://reviews.llvm.org/D94274

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/TypeDef.h
    mlir/lib/TableGen/TypeDef.cpp
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    mlir/test/mlir-tblgen/typedefs.td
    mlir/tools/mlir-tblgen/TypeDefGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index bfd3d43c60b9..dd522904dd73 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1536,6 +1536,171 @@ responsible for parsing/printing the types in `Dialect::printType` and
 -   The `extraClassDeclaration` field is used to include extra code in the class
     declaration.
 
+### Type builder methods
+
+For each type, there are a few builders(`get`/`getChecked`) automatically
+generated based on the parameters of the type. For example, given the following
+type definition:
+
+```tablegen
+def MyType : ... {
+  let parameters = (ins "int":$intParam);
+}
+```
+
+The following builders are generated:
+
+```c++
+// Type builders are named `get`, and return a new instance of a type for a
+// given set of parameters.
+static MyType get(MLIRContext *context, int intParam);
+
+// If `genVerifyInvariantsDecl` is set to 1, the following method is also
+// generated.
+static MyType getChecked(Location loc, int intParam);
+```
+
+If these autogenerated methods are not desired, such as when they conflict with
+a custom builder method, a type can set `skipDefaultBuilders` to 1 to signal
+that they should not be generated.
+
+#### Custom type builder methods
+
+The default build methods may cover a majority of the simple cases related to
+type construction, but when they cannot satisfy a type's needs, you can define
+additional convenience get methods in the `builders` field as follows:
+
+```tablegen
+def MyType : ... {
+  let parameters = (ins "int":$intParam);
+
+  let builders = [
+    TypeBuilder<(ins "int":$intParam)>,
+    TypeBuilder<(ins CArg<"int", "0">:$intParam)>,
+    TypeBuilder<(ins CArg<"int", "0">:$intParam), [{
+      // Write the body of the `get` builder inline here.
+      return Base::get($_ctxt, intParam);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
+      // This builder states that it can infer an MLIRContext instance from
+      // its arguments.
+      return Base::get(typeParam.getContext(), ...);
+    }]>,
+  ];
+}
+```
+
+The `builders` field is a list of custom builders that are added to the type
+class. In this example, we provide a several 
diff erent convenience builders that
+are useful in 
diff erent scenarios. The `ins` prefix is common to many function
+declarations in ODS, which use a TableGen [`dag`](#tablegen-syntax). What
+follows is a comma-separated list of types (quoted string or CArg) and names
+prefixed with the `$` sign. The use of `CArg` allows for providing a default
+value to that argument. Let's take a look at each of these builders individually
+
+The first builder will generate the declaration of a builder method that looks
+like:
+
+```tablegen
+  let builders = [
+    TypeBuilder<(ins "int":$intParam)>,
+  ];
+```
+
+```c++
+class MyType : /*...*/ {
+  /*...*/
+  static MyType get(::mlir::MLIRContext *context, int intParam);
+};
+```
+
+This builder is identical to the one that will be automatically generated for
+`MyType`. The `context` parameter is implicitly added by the generator, and is
+used when building the file Type instance (with `Base::get`). The distinction
+here is that we can provide the implementation of this `get` method. With this
+style of builder definition only the declaration is generated, the implementor
+of MyType will need to provide a definition of `MyType::get`.
+
+The second builder will generate the declaration of a builder method that looks
+like:
+
+```tablegen
+  let builders = [
+    TypeBuilder<(ins CArg<"int", "0">:$intParam)>,
+  ];
+```
+
+```c++
+class MyType : /*...*/ {
+  /*...*/
+  static MyType get(::mlir::MLIRContext *context, int intParam = 0);
+};
+```
+
+The constraints here are identical to the first builder example except for the
+fact that `intParam` now has a default value attached.
+
+The third builder will generate the declaration of a builder method that looks
+like:
+
+```tablegen
+  let builders = [
+    TypeBuilder<(ins CArg<"int", "0">:$intParam), [{
+      // Write the body of the `get` builder inline here.
+      return Base::get($_ctxt, intParam);
+    }]>,
+  ];
+```
+
+```c++
+class MyType : /*...*/ {
+  /*...*/
+  static MyType get(::mlir::MLIRContext *context, int intParam = 0);
+};
+
+MyType MyType::get(::mlir::MLIRContext *context, int intParam) {
+  // Write the body of the `get` builder inline here.
+  return Base::get(context, intParam);
+}
+```
+
+This is identical to the second builder example. The 
diff erence is that now, a
+definition for the builder method will be generated automatically using the
+provided code block as the body. When specifying the body inline, `$_ctxt` may
+be used to access the `MLIRContext *` parameter.
+
+The fourth builder will generate the declaration of a builder method that looks
+like:
+
+```tablegen
+  let builders = [
+    TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
+      // This builder states that it can infer an MLIRContext instance from
+      // its arguments.
+      return Base::get(typeParam.getContext(), ...);
+    }]>,
+  ];
+```
+
+```c++
+class MyType : /*...*/ {
+  /*...*/
+  static MyType get(Type typeParam);
+};
+
+MyType MyType::get(Type typeParam) {
+  // This builder states that it can infer an MLIRContext instance from its
+  // arguments.
+  return Base::get(typeParam.getContext(), ...);
+}
+```
+
+In this builder example, the main 
diff erence from the third builder example
+three is that the `MLIRContext` parameter is no longer added. This is because
+the builder type used `TypeBuilderWithInferredContext` implies that the context
+parameter is not necessary as it can be inferred from the arguments to the
+builder.
+
 ## Debugging Tips
 
 ### Run `mlir-tblgen` to see the generated content

diff  --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index b1a42e99126c..3bfbccf618d6 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -74,7 +74,7 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
     VectorType vector;
     if ($_parser.parseType(vector))
       return Type();
-    return get(ctxt, vector.getShape(), vector.getElementType());
+    return get($_ctxt, vector.getShape(), vector.getElementType());
   }];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index dc3e8a6367cd..73ddbc1d56eb 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2430,6 +2430,73 @@ def replaceWithValue;
 // Data type generation
 //===----------------------------------------------------------------------===//
 
+// Class for defining a custom type getter.
+//
+// TableGen generates several generic getter methods for each type by default,
+// corresponding to the specified dag parameters. If the default generated ones
+// cannot cover some use case, custom getters can be defined using instances of
+// this class.
+//
+// The signature of the `get` is always either:
+//
+// ```c++
+// static <Type-Name> get(MLIRContext *context, <other-parameters>...) {
+//   <body>...
+// }
+// ```
+//
+// or:
+//
+// ```c++
+// static <TypeName> get(MLIRContext *context, <parameters>...);
+// ```
+//
+// To define a custom getter, the parameter list and body should be passed
+// in as separate template arguments to this class. The parameter list is a
+// TableGen DAG with `ins` operation with named arguments, which has either:
+//   - string initializers ("Type":$name) to represent a typed parameter, or
+//   - CArg-typed initializers (CArg<"Type", "default">:$name) to represent a
+//     typed parameter that may have a default value.
+// The type string is used verbatim to produce code and, therefore, must be a
+// valid C++ type. It is used inside the C++ namespace of the parent Type's
+// dialect; explicit namespace qualification like `::mlir` may be necessary if
+// Types are not placed inside the `mlir` namespace. The default value string is
+// used verbatim to produce code and must be a valid C++ initializer the given
+// type. For example, the following signature specification
+//
+// ```
+// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)>
+// ```
+//
+// has an integer parameter and a float parameter with a default value.
+//
+// If an empty string is passed in for `body`, then *only* the builder
+// declaration will be generated; this provides a way to define complicated
+// builders entirely in C++.
+//
+// `checkedBody` is similar to `body`, but is the code block used when
+// generating a `getChecked` method.
+class TypeBuilder<dag parameters, code bodyCode = "",
+                  code checkedBodyCode = ""> {
+  dag dagParams = parameters;
+  code body = bodyCode;
+  code checkedBody = checkedBodyCode;
+
+  // The context parameter can be inferred from one of the other parameters and
+  // is not implicitly added to the parameter list.
+  bit hasInferredContextParam = 0;
+}
+
+// A class of TypeBuilder that is able to infer the MLIRContext parameter from
+// one of the other builder parameters. Instances of this builder do not have
+// `MLIRContext *` implicitly added to the parameter list.
+class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
+                                     code checkedBodyCode = "">
+  : TypeBuilder<parameters, bodyCode> {
+  code checkedBody = checkedBodyCode;
+  let hasInferredContextParam = 1;
+}
+
 // Define a new type, named `name`, belonging to `dialect` that inherits from
 // the given C++ base class.
 class TypeDef<Dialect dialect, string name,
@@ -2475,6 +2542,18 @@ class TypeDef<Dialect dialect, string name,
   // for re-allocating ArrayRefs. It is defined below.)
   dag parameters = (ins);
 
+  // Custom type builder methods.
+  // In addition to the custom builders provided here, and unless
+  // skipDefaultBuilders is set, a default builder is generated with the
+  // following signature:
+  //
+  // ```c++
+  // static <TypeName> get(MLIRContext *, <parameters>);
+  // ```
+  //
+  // Note that builders should only be provided when a type has parameters.
+  list<TypeBuilder> builders = ?;
+
   // Use the lowercased name as the keyword for parsing/printing. Specify only
   // if you want tblgen to generate declarations and/or definitions of
   // printer/parser for this type.
@@ -2488,6 +2567,9 @@ class TypeDef<Dialect dialect, string name,
 
   // If set, generate accessors for each Type parameter.
   bit genAccessors = 1;
+  // Avoid generating default get/getChecked functions. Custom get methods must
+  // be provided.
+  bit skipDefaultBuilders = 0;
   // Generate the verifyConstructionInvariants declaration and getChecked
   // method.
   bit genVerifyInvariantsDecl = 0;

diff  --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h
index 1be5140011f0..73a3a1002d0a 100644
--- a/mlir/include/mlir/TableGen/TypeDef.h
+++ b/mlir/include/mlir/TableGen/TypeDef.h
@@ -14,24 +14,45 @@
 #define MLIR_TABLEGEN_TYPEDEF_H
 
 #include "mlir/Support/LLVM.h"
-#include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/Builder.h"
 
 namespace llvm {
-class Record;
 class DagInit;
+class Record;
 class SMLoc;
 } // namespace llvm
 
 namespace mlir {
 namespace tblgen {
-
+class Dialect;
 class TypeParameter;
 
+//===----------------------------------------------------------------------===//
+// TypeBuilder
+//===----------------------------------------------------------------------===//
+
+/// Wrapper class that represents a Tablegen TypeBuilder.
+class TypeBuilder : public Builder {
+public:
+  using Builder::Builder;
+
+  /// Return an optional code body used for the `getChecked` variant of this
+  /// builder.
+  Optional<StringRef> getCheckedBody() const;
+
+  /// Returns true if this builder is able to infer the MLIRContext parameter.
+  bool hasInferredContextParameter() const;
+};
+
+//===----------------------------------------------------------------------===//
+// TypeDef
+//===----------------------------------------------------------------------===//
+
 /// Wrapper class that contains a TableGen TypeDef's record and provides helper
 /// methods for accessing them.
 class TypeDef {
 public:
-  explicit TypeDef(const llvm::Record *def) : def(def) {}
+  explicit TypeDef(const llvm::Record *def);
 
   // Get the dialect for which this type belongs.
   Dialect getDialect() const;
@@ -95,6 +116,13 @@ class TypeDef {
   // Get the code location (for error printing).
   ArrayRef<llvm::SMLoc> getLoc() const;
 
+  // Returns true if the default get/getChecked methods should be skipped during
+  // generation.
+  bool skipDefaultBuilders() const;
+
+  // Returns the builders of this type.
+  ArrayRef<TypeBuilder> getBuilders() const { return builders; }
+
   // Returns whether two TypeDefs are equal by checking the equality of the
   // underlying record.
   bool operator==(const TypeDef &other) const;
@@ -107,8 +135,15 @@ class TypeDef {
 
 private:
   const llvm::Record *def;
+
+  // The builders of this type definition.
+  SmallVector<TypeBuilder> builders;
 };
 
+//===----------------------------------------------------------------------===//
+// TypeParameter
+//===----------------------------------------------------------------------===//
+
 // A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs
 // to parameterize them.
 class TypeParameter {

diff  --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp
index 5c4f5348d1b3..d6adbc20ef76 100644
--- a/mlir/lib/TableGen/TypeDef.cpp
+++ b/mlir/lib/TableGen/TypeDef.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/TableGen/TypeDef.h"
+#include "mlir/TableGen/Dialect.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
@@ -18,6 +19,26 @@
 using namespace mlir;
 using namespace mlir::tblgen;
 
+//===----------------------------------------------------------------------===//
+// TypeBuilder
+//===----------------------------------------------------------------------===//
+
+/// Return an optional code body used for the `getChecked` variant of this
+/// builder.
+Optional<StringRef> TypeBuilder::getCheckedBody() const {
+  Optional<StringRef> body = def->getValueAsOptionalString("checkedBody");
+  return body && !body->empty() ? body : llvm::None;
+}
+
+/// Returns true if this builder is able to infer the MLIRContext parameter.
+bool TypeBuilder::hasInferredContextParameter() const {
+  return def->getValueAsBit("hasInferredContextParam");
+}
+
+//===----------------------------------------------------------------------===//
+// TypeDef
+//===----------------------------------------------------------------------===//
+
 Dialect TypeDef::getDialect() const {
   auto *dialectDef =
       dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
@@ -98,6 +119,11 @@ llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
   return value.empty() ? llvm::Optional<StringRef>() : value;
 }
 llvm::ArrayRef<llvm::SMLoc> TypeDef::getLoc() const { return def->getLoc(); }
+
+bool TypeDef::skipDefaultBuilders() const {
+  return def->getValueAsBit("skipDefaultBuilders");
+}
+
 bool TypeDef::operator==(const TypeDef &other) const {
   return def == other.def;
 }
@@ -106,6 +132,33 @@ bool TypeDef::operator<(const TypeDef &other) const {
   return getName() < other.getName();
 }
 
+//===----------------------------------------------------------------------===//
+// TypeParameter
+//===----------------------------------------------------------------------===//
+
+TypeDef::TypeDef(const llvm::Record *def) : def(def) {
+  // Populate the builders.
+  auto *builderList =
+      dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
+  if (builderList && !builderList->empty()) {
+    for (llvm::Init *init : builderList->getValues()) {
+      TypeBuilder builder(cast<llvm::DefInit>(init)->getDef(), def->getLoc());
+
+      // Ensure that all parameters have names.
+      for (const TypeBuilder::Parameter &param : builder.getParameters()) {
+        if (!param.getName())
+          PrintFatalError(def->getLoc(),
+                          "type builder parameters must have a name");
+      }
+      builders.emplace_back(builder);
+    }
+  } else if (skipDefaultBuilders()) {
+    PrintFatalError(
+        def->getLoc(),
+        "default builders are skipped and no custom builders provided");
+  }
+}
+
 StringRef TypeParameter::getName() const {
   return def->getArgName(num)->getValue();
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 80927dff62c2..0e2c11a2ecb0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -51,9 +51,9 @@ def IntegerType : Test_Type<"TestInteger"> {
   let genVerifyInvariantsDecl = 1;
   let parameters = (
     ins
+    "unsigned":$width,
     // SignednessSemantics is defined below.
-    "::mlir::test::TestIntegerType::SignednessSemantics":$signedness,
-    "unsigned":$width
+    "::mlir::test::TestIntegerType::SignednessSemantics":$signedness
   );
 
   // We define the printer inline.
@@ -63,6 +63,17 @@ def IntegerType : Test_Type<"TestInteger"> {
     $_printer << ", " << getImpl()->width << ">";
   }];
 
+  // Define custom builder methods.
+  let builders = [
+    TypeBuilder<(ins "unsigned":$width,
+                     CArg<"SignednessSemantics", "Signless">:$signedness), [{
+      return Base::get($_ctxt, width, signedness);
+    }], [{
+      return Base::getChecked($_loc, width, signedness);
+    }]>
+  ];
+  let skipDefaultBuilders = 1;
+
   // The parser is defined here also.
   let parser = [{
     if (parser.parseLess()) return Type();
@@ -73,7 +84,7 @@ def IntegerType : Test_Type<"TestInteger"> {
     if ($_parser.parseInteger(width)) return Type();
     if ($_parser.parseGreater()) return Type();
     Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
-    return getChecked(loc, signedness, width);
+    return getChecked(loc, width, signedness);
   }];
 
   // Any extra code one wants in the type's class declaration.
@@ -85,9 +96,6 @@ def IntegerType : Test_Type<"TestInteger"> {
       Unsigned, /// Unsigned integer
     };
 
-    /// This extra function is necessary since it doesn't include signedness
-    static IntegerType getChecked(unsigned width, Location location);
-
     /// Return true if this is a signless integer type.
     bool isSignless() const { return getSignedness() == Signless; }
     /// Return true if this is a signed integer type.

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 14a3e862d85c..094e5c9fc631 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -113,7 +113,7 @@ static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
 
 // Example type validity checker.
 LogicalResult TestIntegerType::verifyConstructionInvariants(
-    Location loc, TestIntegerType::SignednessSemantics ss, unsigned int width) {
+    Location loc, unsigned width, TestIntegerType::SignednessSemantics ss) {
   if (width > 8)
     return failure();
   return success();

diff  --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 6e6e1c04643a..5471519ab60a 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -19,8 +19,8 @@ include "mlir/IR/OpBase.td"
 // DEF: ::mlir::test::SingleParameterType,
 // DEF: ::mlir::test::IntegerType
 
-// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic)
-// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser);
+// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &parser, ::llvm::StringRef mnemonic)
+// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(context, parser);
 // DEF return ::mlir::Type();
 
 def Test_Dialect: Dialect {
@@ -33,7 +33,7 @@ def Test_Dialect: Dialect {
 class TestType<string name> : TypeDef<Test_Dialect, name> { }
 
 def A_SimpleTypeA : TestType<"SimpleA"> {
-// DECL: class SimpleAType: public ::mlir::Type
+// DECL: class SimpleAType : public ::mlir::Type
 }
 
 def RTLValueType : Type<CPred<"isRTLValueType($_self)">, "Type"> {
@@ -56,12 +56,13 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
 
   let genVerifyInvariantsDecl = 1;
 
-// DECL-LABEL: class CompoundAType: public ::mlir::Type
+// DECL-LABEL: class CompoundAType : public ::mlir::Type
+// DECL: static CompoundAType getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
 // DECL: static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
-// DECL: static ::mlir::Type getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
 // DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
-// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
-// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
+// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context,
+// DECL-NEXT: ::mlir::DialectAsmParser &parser);
+// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
 // DECL: int getWidthOfSomething() const;
 // DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
 // DECL: SomeCppStruct getExampleCppType() const;
@@ -75,10 +76,11 @@ def C_IndexType : TestType<"Index"> {
       StringRefParameter<"Label for index">:$label
     );
 
-// DECL-LABEL: class IndexType: public ::mlir::Type
+// DECL-LABEL: class IndexType : public ::mlir::Type
 // DECL: static ::llvm::StringRef getMnemonic() { return "index"; }
-// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
-// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
+// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context,
+// DECL-NEXT: ::mlir::DialectAsmParser &parser);
+// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
 }
 
 def D_SingleParameterType : TestType<"SingleParameter"> {
@@ -100,7 +102,7 @@ def E_IntegerType : TestType<"Integer"> {
         TypeParameter<"unsigned", "Bitwidth of integer">:$width
     );
 
-// DECL-LABEL: IntegerType: public ::mlir::Type
+// DECL-LABEL: IntegerType : public ::mlir::Type
 
     let extraClassDeclaration = [{
   /// Signedness semantics.

diff  --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
index 20168168bc8d..9cbd53270983 100644
--- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
@@ -148,7 +148,7 @@ class DialectAsmPrinter;
 /// {0}: The name of the typeDef class.
 /// {1}: The name of the type base class.
 static const char *const typeDefDeclSingletonBeginStr = R"(
-  class {0}: public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
+  class {0} : public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
   public:
     /// Inherit some necessary constructors from 'TypeBase'.
     using Base::Base;
@@ -166,9 +166,9 @@ static const char *const typeDefDeclSingletonBeginStr = R"(
 static const char *const typeDefDeclParametricBeginStr = R"(
   namespace {2} {
     struct {3};
-  }
-  class {0}: public ::mlir::Type::TypeBase<{0}, {1},
-                                        {2}::{3}> {{
+  } // end namespace {2}
+  class {0} : public ::mlir::Type::TypeBase<{0}, {1},
+                                         {2}::{3}> {{
   public:
     /// Inherit some necessary constructors from 'TypeBase'.
     using Base::Base;
@@ -177,18 +177,68 @@ static const char *const typeDefDeclParametricBeginStr = R"(
 
 /// The snippet for print/parse.
 static const char *const typeDefParsePrint = R"(
-    static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
-    void print(::mlir::DialectAsmPrinter& printer) const;
+    static ::mlir::Type parse(::mlir::MLIRContext *context,
+                              ::mlir::DialectAsmParser &parser);
+    void print(::mlir::DialectAsmPrinter &printer) const;
 )";
 
 /// The code block for the verifyConstructionInvariants and getChecked.
 ///
-/// {0}: List of parameters, parameters style.
+/// {0}: The name of the typeDef class.
+/// {1}: List of parameters, parameters style.
 static const char *const typeDefDeclVerifyStr = R"(
-    static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{0});
-    static ::mlir::Type getChecked(::mlir::Location loc{0});
+    static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{1});
 )";
 
+/// Emit the builders for the given type.
+static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
+                                 TypeParamCommaFormatter &paramTypes) {
+  StringRef typeClass = typeDef.getCppClassName();
+  bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
+  if (!typeDef.skipDefaultBuilders()) {
+    os << llvm::formatv(
+        "    static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
+        paramTypes);
+    if (genCheckedMethods) {
+      os << llvm::formatv(
+          "    static {0} getChecked(::mlir::Location loc{1});\n", typeClass,
+          paramTypes);
+    }
+  }
+
+  // Generate the builders specified by the user.
+  for (const TypeBuilder &builder : typeDef.getBuilders()) {
+    std::string paramStr;
+    llvm::raw_string_ostream paramOS(paramStr);
+    llvm::interleaveComma(
+        builder.getParameters(), paramOS,
+        [&](const TypeBuilder::Parameter &param) {
+          // Note: TypeBuilder parameters are guaranteed to have names.
+          paramOS << param.getCppType() << " " << *param.getName();
+          if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
+            paramOS << " = " << *defaultParamValue;
+        });
+    paramOS.flush();
+
+    // Generate the `get` variant of the builder.
+    os << "    static " << typeClass << " get(";
+    if (!builder.hasInferredContextParameter()) {
+      os << "::mlir::MLIRContext *context";
+      if (!paramStr.empty())
+        os << ", ";
+    }
+    os << paramStr << ");\n";
+
+    // Generate the `getChecked` variant of the builder.
+    if (genCheckedMethods) {
+      os << "    static " << typeClass << " getChecked(::mlir::Location loc";
+      if (!paramStr.empty())
+        os << ", " << paramStr;
+      os << ");\n";
+    }
+  }
+}
+
 /// Generate the declaration for the given typeDef class.
 static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
   SmallVector<TypeParameter, 4> params;
@@ -212,13 +262,13 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
   TypeParamCommaFormatter emitTypeNamePairsAfterComma(
       TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
   if (!params.empty()) {
-    os << llvm::formatv("    static {0} get(::mlir::MLIRContext* ctxt{1});\n",
-                        typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
-  }
+    emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma);
 
-  // Emit the verify invariants declaration.
-  if (typeDef.genVerifyInvariantsDecl())
-    os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
+    // Emit the verify invariants declaration.
+    if (typeDef.genVerifyInvariantsDecl())
+      os << llvm::formatv(typeDefDeclVerifyStr, typeDef.getCppClassName(),
+                          emitTypeNamePairsAfterComma);
+  }
 
   // Emit the mnenomic, if specified.
   if (auto mnenomic = typeDef.getMnemonic()) {
@@ -226,7 +276,8 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
        << "\"; }\n";
 
     // If mnemonic specified, emit print/parse declarations.
-    os << typeDefParsePrint;
+    if (typeDef.getParserCode() || typeDef.getPrinterCode() || !params.empty())
+      os << typeDefParsePrint;
   }
 
   if (typeDef.genAccessors()) {
@@ -330,17 +381,6 @@ static const char *const typeDefStorageClassConstructorReturn = R"(
     }
 )";
 
-/// The code block for the getChecked definition.
-///
-/// {0}: List of parameters, parameters style.
-/// {1}: C++ type class name.
-/// {2}: Comma separated list of parameter names.
-static const char *const typeDefDefGetCheckeStr = R"(
-    ::mlir::Type {1}::getChecked(Location loc{0}) {{
-      return Base::getChecked(loc{2});
-    }
-)";
-
 /// Use tgfmt to emit custom allocation code for each parameter, if necessary.
 static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
   SmallVector<TypeParameter, 4> parameters;
@@ -403,13 +443,13 @@ static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
                               parameters, /* prependComma */ false));
 
   // 3) Emit the construct method.
-  if (typeDef.hasStorageCustomConstructor())
+  if (typeDef.hasStorageCustomConstructor()) {
     // If user wants to build the storage constructor themselves, declare it
     // here and then they can write the definition elsewhere.
     os << "    static " << typeDef.getStorageClassName()
        << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
           "&key);\n";
-  else {
+  } else {
     // If not, autogenerate one.
 
     // First, unbox the parameters.
@@ -445,7 +485,7 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
     // Both the mnenomic and printerCode must be defined (for parity with
     // parserCode).
     os << "void " << typeDef.getCppClassName()
-       << "::print(::mlir::DialectAsmPrinter& printer) const {\n";
+       << "::print(::mlir::DialectAsmPrinter &printer) const {\n";
     if (*printerCode == "") {
       // If no code specified, emit error.
       PrintFatalError(typeDef.getLoc(),
@@ -460,7 +500,7 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
   if (auto parserCode = typeDef.getParserCode()) {
     // The mnenomic must be defined so the dispatcher knows how to dispatch.
     os << "::mlir::Type " << typeDef.getCppClassName()
-       << "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& "
+       << "::parse(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &"
           "parser) "
           "{\n";
     if (*parserCode == "") {
@@ -470,51 +510,112 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
                           ": parser (if specified) must have non-empty code");
     }
     auto fmtCtxt =
-        FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt");
+        FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "context");
     os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
   }
 }
 
-/// Print all the typedef-specific definition code.
-static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
-  NamespaceEmitter ns(os, typeDef.getDialect());
-  SmallVector<TypeParameter, 4> parameters;
-  typeDef.getParameters(parameters);
-
-  // Emit the storage class, if requested and necessary.
-  if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
-    emitStorageClass(typeDef, os);
-
-  if (!parameters.empty()) {
+/// Emit the builders for the given type.
+static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
+                                ArrayRef<TypeParameter> typeDefParams) {
+  bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
+  StringRef typeClass = typeDef.getCppClassName();
+  if (!typeDef.skipDefaultBuilders()) {
     os << llvm::formatv(
-        "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
-        "  return Base::get(ctxt{2});\n}\n",
-        typeDef.getCppClassName(),
+        "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
+        "  return Base::get(context{2});\n}\n",
+        typeClass,
         TypeParamCommaFormatter(
-            TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
+            TypeParamCommaFormatter::EmitFormat::TypeNamePairs, typeDefParams),
         TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
-                                parameters));
+                                typeDefParams));
+    if (genCheckedMethods) {
+      os << llvm::formatv(
+          "{0} {0}::getChecked(::mlir::Location loc{1}) {{\n"
+          "  return Base::getChecked(loc{2});\n}\n",
+          typeClass,
+          TypeParamCommaFormatter(
+              TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
+              typeDefParams),
+          TypeParamCommaFormatter(
+              TypeParamCommaFormatter::EmitFormat::JustParams, typeDefParams));
+    }
   }
 
-  // Emit the parameter accessors.
-  if (typeDef.genAccessors())
-    for (const TypeParameter &parameter : parameters) {
-      SmallString<16> name = parameter.getName();
-      name[0] = llvm::toUpper(name[0]);
-      os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
-                    parameter.getCppType(), name, parameter.getName(),
-                    typeDef.getCppClassName());
+  // Generate the builders specified by the user.
+  auto builderFmtCtx = FmtContext().addSubst("_ctxt", "context");
+  auto checkedBuilderFmtCtx = FmtContext()
+                                  .addSubst("_loc", "loc")
+                                  .addSubst("_ctxt", "loc.getContext()");
+  for (const TypeBuilder &builder : typeDef.getBuilders()) {
+    Optional<StringRef> body = builder.getBody();
+    Optional<StringRef> checkedBody =
+        genCheckedMethods ? builder.getCheckedBody() : llvm::None;
+    if (!body && !checkedBody)
+      continue;
+    std::string paramStr;
+    llvm::raw_string_ostream paramOS(paramStr);
+    llvm::interleaveComma(builder.getParameters(), paramOS,
+                          [&](const TypeBuilder::Parameter &param) {
+                            // Note: TypeBuilder parameters are guaranteed to
+                            // have names.
+                            paramOS << param.getCppType() << " "
+                                    << *param.getName();
+                          });
+    paramOS.flush();
+
+    // Emit the `get` variant of the builder.
+    if (body) {
+      os << llvm::formatv("{0} {0}::get(", typeClass);
+      if (!builder.hasInferredContextParameter()) {
+        os << "::mlir::MLIRContext *context";
+        if (!paramStr.empty())
+          os << ", ";
+        os << llvm::formatv("{0}) {{\n  {1};\n}\n", paramStr,
+                            tgfmt(*body, &builderFmtCtx).str());
+      } else {
+        os << llvm::formatv("{0}) {{\n  {1};\n}\n", paramStr, *body);
+      }
     }
 
-  // Generate getChecked() method.
-  if (typeDef.genVerifyInvariantsDecl()) {
-    os << llvm::formatv(
-        typeDefDefGetCheckeStr,
-        TypeParamCommaFormatter(
-            TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
-        typeDef.getCppClassName(),
-        TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
-                                parameters));
+    // Emit the `getChecked` variant of the builder.
+    if (checkedBody) {
+      os << llvm::formatv("{0} {0}::getChecked(::mlir::Location loc",
+                          typeClass);
+      if (!paramStr.empty())
+        os << ", " << paramStr;
+      os << llvm::formatv(") {{\n  {0};\n}\n",
+                          tgfmt(*checkedBody, &checkedBuilderFmtCtx));
+    }
+  }
+}
+
+/// Print all the typedef-specific definition code.
+static void emitTypeDefDef(const TypeDef &typeDef, raw_ostream &os) {
+  NamespaceEmitter ns(os, typeDef.getDialect());
+
+  SmallVector<TypeParameter, 4> parameters;
+  typeDef.getParameters(parameters);
+  if (!parameters.empty()) {
+    // Emit the storage class, if requested and necessary.
+    if (typeDef.genStorageClass())
+      emitStorageClass(typeDef, os);
+
+    // Emit the builders for this type.
+    emitTypeBuilderDefs(typeDef, os, parameters);
+
+    // Generate accessor definitions only if we also generate the storage class.
+    // Otherwise, let the user define the exact accessor definition.
+    if (typeDef.genAccessors() && typeDef.genStorageClass()) {
+      // Emit the parameter accessors.
+      for (const TypeParameter &parameter : parameters) {
+        SmallString<16> name = parameter.getName();
+        name[0] = llvm::toUpper(name[0]);
+        os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
+                      parameter.getCppType(), name, parameter.getName(),
+                      typeDef.getCppClassName());
+      }
+    }
   }
 
   // If mnemonic is specified maybe print definitions for the parser and printer
@@ -534,9 +635,9 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
 
   // The parser dispatch is just a list of if-elses, matching on the
   // mnemonic and calling the class's parse function.
-  os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* "
-        "ctxt, "
-        "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
+  os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext *"
+        "context, ::mlir::DialectAsmParser &parser, "
+        "::llvm::StringRef mnemonic) {\n";
   for (const TypeDef &type : types) {
     if (type.getMnemonic()) {
       os << formatv("  if (mnemonic == {0}::{1}::getMnemonic()) return "
@@ -547,9 +648,9 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
       // If the type has no parameters and no parser code, just invoke a normal
       // `get`.
       if (type.getNumParameters() == 0 && !type.getParserCode())
-        os << "get(ctxt);\n";
+        os << "get(context);\n";
       else
-        os << "parse(ctxt, parser);\n";
+        os << "parse(context, parser);\n";
     }
   }
   os << "  return ::mlir::Type();\n";
@@ -559,7 +660,7 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
   // printer.
   os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
         "type, "
-        "::mlir::DialectAsmPrinter& printer) {\n"
+        "::mlir::DialectAsmPrinter &printer) {\n"
      << "  return ::llvm::TypeSwitch<::mlir::Type, "
         "::mlir::LogicalResult>(type)\n";
   for (const TypeDef &type : types) {
@@ -594,7 +695,7 @@ static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
 
   IfDefScope scope("GET_TYPEDEF_CLASSES", os);
   emitParsePrintDispatch(typeDefs, os);
-  for (auto typeDef : typeDefs)
+  for (const TypeDef &typeDef : typeDefs)
     emitTypeDefDef(typeDef, os);
 
   return false;


        


More information about the Mlir-commits mailing list