[Mlir-commits] [mlir] 83ef862 - [mlir] Add support for generating Attribute classes for ODS
River Riddle
llvmlistbot at llvm.org
Wed Mar 3 16:48:48 PST 2021
Author: River Riddle
Date: 2021-03-03T16:41:49-08:00
New Revision: 83ef862fad6b14dd1651f5e31e331eb89a95f0ff
URL: https://github.com/llvm/llvm-project/commit/83ef862fad6b14dd1651f5e31e331eb89a95f0ff
DIFF: https://github.com/llvm/llvm-project/commit/83ef862fad6b14dd1651f5e31e331eb89a95f0ff.diff
LOG: [mlir] Add support for generating Attribute classes for ODS
The support for attributes closely maps that of Types (basically 1-1) given that Attributes are defined in exactly the same way as Types. All of the current ODS TypeDef classes get an Attr equivalent. The generation of the attribute classes themselves share the same generator as types.
Differential Revision: https://reviews.llvm.org/D97589
Added:
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestAttributes.cpp
mlir/test/lib/Dialect/Test/TestAttributes.h
mlir/test/mlir-tblgen/attrdefs.td
mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Modified:
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/CodeGenHelpers.h
mlir/lib/TableGen/CMakeLists.txt
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/typedefs.td
mlir/tools/mlir-tblgen/CMakeLists.txt
mlir/tools/mlir-tblgen/OpDocGen.cpp
Removed:
mlir/include/mlir/TableGen/TypeDef.h
mlir/lib/TableGen/TypeDef.cpp
mlir/tools/mlir-tblgen/TypeDefGen.cpp
################################################################################
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index f0f5f1d0e4ea..4001f3ec175d 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2465,20 +2465,20 @@ def replaceWithValue;
//===----------------------------------------------------------------------===//
-// Data type generation
+// Attribute and Type generation
//===----------------------------------------------------------------------===//
-// Class for defining a custom type getter.
+// Class for defining a custom getter.
//
-// TableGen generates several generic getter methods for each type by default,
-// corresponding to the specified dag parameters. If the default generated ones
-// cannot cover some use case, custom getters can be defined using instances of
-// this class.
+// TableGen generates several generic getter methods for each attribute and type
+// by default, corresponding to the specified dag parameters. If the default
+// generated ones cannot cover some use case, custom getters can be defined
+// using instances of this class.
//
// The signature of the `get` is always either:
//
// ```c++
-// static <Type-Name> get(MLIRContext *context, <other-parameters>...) {
+// static <ClassName> get(MLIRContext *context, <other-parameters>...) {
// <body>...
// }
// ```
@@ -2486,7 +2486,7 @@ def replaceWithValue;
// or:
//
// ```c++
-// static <TypeName> get(MLIRContext *context, <parameters>...);
+// static <ClassName> get(MLIRContext *context, <parameters>...);
// ```
//
// To define a custom getter, the parameter list and body should be passed
@@ -2503,7 +2503,7 @@ def replaceWithValue;
// type. For example, the following signature specification
//
// ```
-// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)>
+// AttrOrTypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)>
// ```
//
// has an integer parameter and a float parameter with a default value.
@@ -2514,7 +2514,7 @@ def replaceWithValue;
// method should be invoked using `$_get`, e.g.:
//
// ```
-// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg), [{
+// AttrOrTypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg), [{
// return $_get($_ctxt, integerArg, floatArg);
// }]>
// ```
@@ -2522,7 +2522,7 @@ def replaceWithValue;
// This is necessary because the `body` is also used to generate `getChecked`
// methods, which have a
diff erent underlying `Base::get*` call.
//
-class TypeBuilder<dag parameters, code bodyCode = ""> {
+class AttrOrTypeBuilder<dag parameters, code bodyCode = ""> {
dag dagParams = parameters;
code body = bodyCode;
@@ -2530,33 +2530,42 @@ class TypeBuilder<dag parameters, code bodyCode = ""> {
// is not implicitly added to the parameter list.
bit hasInferredContextParam = 0;
}
+class AttrBuilder<dag parameters, code bodyCode = "">
+ : AttrOrTypeBuilder<parameters, bodyCode>;
+class TypeBuilder<dag parameters, code bodyCode = "">
+ : AttrOrTypeBuilder<parameters, bodyCode>;
-// A class of TypeBuilder that is able to infer the MLIRContext parameter from
-// one of the other builder parameters. Instances of this builder do not have
-// `MLIRContext *` implicitly added to the parameter list.
-class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
+// A class of AttrOrTypeBuilder that is able to infer the MLIRContext parameter
+// from one of the other builder parameters. Instances of this builder do not
+// have `MLIRContext *` implicitly added to the parameter list.
+class AttrOrTypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
: TypeBuilder<parameters, bodyCode> {
let hasInferredContextParam = 1;
}
+class AttrBuilderWithInferredContext<dag parameters, code bodyCode = "">
+ : AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
+class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
+ : AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
-// Define a new type, named `name`, belonging to `dialect` that inherits from
-// the given C++ base class.
-class TypeDef<Dialect dialect, string name,
- string baseCppClass = "::mlir::Type">
- : DialectType<dialect, CPred<"">, /*descr*/"", name # "Type"> {
- // The name of the C++ base class to use for this Type.
+// Define a new attribute or type, named `name`, that inherits from the given
+// C++ base class.
+class AttrOrTypeDef<string valueType, string name, string baseCppClass> {
+ // The name of the C++ base class to use for this def.
string cppBaseClassName = baseCppClass;
- // Additional, longer human-readable description of what the op does.
+ // Additional, longer human-readable description of what the def does.
string description = "";
// Name of storage class to generate or use.
- string storageClass = name # "TypeStorage";
+ string storageClass = name # valueType # "Storage";
+
// Namespace (withing dialect c++ namespace) in which the storage class
// resides.
string storageNamespace = "detail";
+
// Specify if the storage class is to be generated.
bit genStorageClass = 1;
+
// Specify that the generated storage class has a constructor which is written
// in C++.
bit hasStorageCustomConstructor = 0;
@@ -2568,38 +2577,38 @@ class TypeDef<Dialect dialect, string name,
// (ins
// "<c++ type>":$param1Name,
// "<c++ type>":$param2Name,
- // TypeParameter<"c++ type", "param description">:$param3Name)
- // TypeParameters (or more likely one of their subclasses) are required to add
- // more information about the parameter, specifically:
+ // AttrOrTypeParameter<"c++ type", "param description">:$param3Name)
+ // AttrOrTypeParameters (or more likely one of their subclasses) are required
+ // to add more information about the parameter, specifically:
// - Documentation
// - Code to allocate the parameter (if allocation is needed in the storage
// class constructor)
//
// For example:
- // (ins
- // "int":$width,
- // ArrayRefParameter<"bool", "list of bools">:$yesNoArray)
+ // (ins "int":$width,
+ // ArrayRefParameter<"bool", "list of bools">:$yesNoArray)
//
- // (ArrayRefParameter is a subclass of TypeParameter which has allocation code
- // for re-allocating ArrayRefs. It is defined below.)
+ // (ArrayRefParameter is a subclass of AttrOrTypeParameter which has
+ // allocation code for re-allocating ArrayRefs. It is defined below.)
dag parameters = (ins);
- // Custom type builder methods.
+ // Custom builder methods.
// In addition to the custom builders provided here, and unless
// skipDefaultBuilders is set, a default builder is generated with the
// following signature:
//
// ```c++
- // static <TypeName> get(MLIRContext *, <parameters>);
+ // static <ClassName> get(MLIRContext *, <parameters>);
// ```
//
- // Note that builders should only be provided when a type has parameters.
- list<TypeBuilder> builders = ?;
+ // Note that builders should only be provided when a def has parameters.
+ list<AttrOrTypeBuilder> builders = ?;
// Use the lowercased name as the keyword for parsing/printing. Specify only
// if you want tblgen to generate declarations and/or definitions of
- // printer/parser for this type.
+ // the printer/parser.
string mnemonic = ?;
+
// If 'mnemonic' specified,
// If null, generate just the declarations.
// If a non-empty code block, just use that code as the definition code.
@@ -2607,29 +2616,53 @@ class TypeDef<Dialect dialect, string name,
code printer = ?;
code parser = ?;
- // If set, generate accessors for each Type parameter.
+ // If set, generate accessors for each parameter.
bit genAccessors = 1;
+
// Avoid generating default get/getChecked functions. Custom get methods must
// be provided.
bit skipDefaultBuilders = 0;
+
// Generate the verify and getChecked methods.
bit genVerifyDecl = 0;
+
// Extra code to include in the class declaration.
code extraClassDeclaration = [{}];
+}
+
+// Define a new attribute, named `name`, belonging to `dialect` that inherits
+// from the given C++ base class.
+class AttrDef<Dialect dialect, string name,
+ string baseCppClass = "::mlir::Attribute">
+ : DialectAttr<dialect, CPred<"">, /*descr*/"">,
+ AttrOrTypeDef<"Attr", name, baseCppClass> {
+ // The name of the C++ Attribute class.
+ string cppClassName = name # "Attr";
- // The predicate for when this type is used as a type constraint.
+ // The predicate for when this def is used as a constraint.
let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
"::" # cppClassName # ">()">;
+}
+
+// Define a new type, named `name`, belonging to `dialect` that inherits from
+// the given C++ base class.
+class TypeDef<Dialect dialect, string name,
+ string baseCppClass = "::mlir::Type">
+ : DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
+ AttrOrTypeDef<"Type", name, baseCppClass> {
// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
"::" # cppClassName # ">()",
"");
+ // The predicate for when this def is used as a constraint.
+ let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
+ "::" # cppClassName # ">()">;
}
// 'Parameters' should be subclasses of this or simple strings (which is a
-// shorthand for TypeParameter<"C++Type">).
-class TypeParameter<string type, string desc> {
+// shorthand for AttrOrTypeParameter<"C++Type">).
+class AttrOrTypeParameter<string type, string desc> {
// Custom memory allocation code for storage constructor.
code allocator = ?;
// The C++ type of this parameter.
@@ -2639,28 +2672,30 @@ class TypeParameter<string type, string desc> {
// The format string for the asm syntax (documentation only).
string syntax = ?;
}
+class AttrParameter<string type, string desc> : AttrOrTypeParameter<type, desc>;
+class TypeParameter<string type, string desc> : AttrOrTypeParameter<type, desc>;
// For StringRefs, which require allocation.
class StringRefParameter<string desc> :
- TypeParameter<"::llvm::StringRef", desc> {
+ AttrOrTypeParameter<"::llvm::StringRef", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
}
// For standard ArrayRefs, which require allocation.
class ArrayRefParameter<string arrayOf, string desc> :
- TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
+ AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
}
// For classes which require allocation and have their own allocateInto method.
class SelfAllocationParameter<string type, string desc> :
- TypeParameter<type, desc> {
+ AttrOrTypeParameter<type, desc> {
let allocator = [{$_dst = $_self.allocateInto($_allocator);}];
}
// For ArrayRefs which contain things which allocate themselves.
class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
- TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
+ AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
let allocator = [{
llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields;
for (size_t i = 0, e = $_self.size(); i < e; ++i)
@@ -2669,5 +2704,4 @@ class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
}];
}
-
#endif // OP_BASE
diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
similarity index 54%
rename from mlir/include/mlir/TableGen/TypeDef.h
rename to mlir/include/mlir/TableGen/AttrOrTypeDef.h
index a82f85d48863..8ce752014ebd 100644
--- a/mlir/include/mlir/TableGen/TypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -1,4 +1,4 @@
-//===-- TypeDef.h - Record wrapper for type definitions ---------*- C++ -*-===//
+//===-- AttrOrTypeDef.h - Wrapper for attr and type definitions -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
-// TypeDef wrapper to simplify using TableGen Record defining a MLIR type.
+// AttrOrTypeDef, AttrDef, and TypeDef wrappers to simplify using TableGen
+// Record defining a MLIR attributes and types.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TABLEGEN_TYPEDEF_H
-#define MLIR_TABLEGEN_TYPEDEF_H
+#ifndef MLIR_TABLEGEN_ATTRORTYPEDEF_H
+#define MLIR_TABLEGEN_ATTRORTYPEDEF_H
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Builder.h"
@@ -25,14 +26,14 @@ class SMLoc;
namespace mlir {
namespace tblgen {
class Dialect;
-class TypeParameter;
+class AttrOrTypeParameter;
//===----------------------------------------------------------------------===//
-// TypeBuilder
+// AttrOrTypeBuilder
//===----------------------------------------------------------------------===//
-/// Wrapper class that represents a Tablegen TypeBuilder.
-class TypeBuilder : public Builder {
+/// Wrapper class that represents a Tablegen AttrOrTypeBuilder.
+class AttrOrTypeBuilder : public Builder {
public:
using Builder::Builder;
@@ -41,22 +42,22 @@ class TypeBuilder : public Builder {
};
//===----------------------------------------------------------------------===//
-// TypeDef
+// AttrOrTypeDef
//===----------------------------------------------------------------------===//
-/// Wrapper class that contains a TableGen TypeDef's record and provides helper
-/// methods for accessing them.
-class TypeDef {
+/// Wrapper class that contains a TableGen AttrOrTypeDef's record and provides
+/// helper methods for accessing them.
+class AttrOrTypeDef {
public:
- explicit TypeDef(const llvm::Record *def);
+ explicit AttrOrTypeDef(const llvm::Record *def);
- // Get the dialect for which this type belongs.
+ // Get the dialect for which this def belongs.
Dialect getDialect() const;
- // Returns the name of this TypeDef record.
+ // Returns the name of this AttrOrTypeDef record.
StringRef getName() const;
- // Query functions for the documentation of the operator.
+ // Query functions for the documentation of the def.
bool hasDescription() const;
StringRef getDescription() const;
bool hasSummary() const;
@@ -65,13 +66,13 @@ class TypeDef {
// Returns the name of the C++ class to generate.
StringRef getCppClassName() const;
- // Returns the name of the C++ base class to use when generating this type.
+ // Returns the name of the C++ base class to use when generating this def.
StringRef getCppBaseClassName() const;
- // Returns the name of the storage class for this type.
+ // Returns the name of the storage class for this def.
StringRef getStorageClassName() const;
- // Returns the C++ namespace for this types storage class.
+ // Returns the C++ namespace for this def's storage class.
StringRef getStorageNamespace() const;
// Returns true if we should generate the storage class.
@@ -80,10 +81,11 @@ class TypeDef {
// Indicates whether or not to generate the storage class constructor.
bool hasStorageCustomConstructor() const;
- // Fill a list with this types parameters. See TypeDef in OpBase.td for
+ // Fill a list with this def's parameters. See AttrOrTypeDef in OpBase.td for
// documentation of parameter usage.
- void getParameters(SmallVectorImpl<TypeParameter> &) const;
- // Return the number of type parameters
+ void getParameters(SmallVectorImpl<AttrOrTypeParameter> &) const;
+
+ // Return the number of parameters
unsigned getNumParameters() const;
// Return the keyword/mnemonic to use in the printer/parser methods if we are
@@ -94,19 +96,18 @@ class TypeDef {
// return a non-value. Otherwise, return the contents of that code block.
Optional<StringRef> getPrinterCode() const;
- // Returns the code to use as the types parser method. If not specified,
- // return a non-value. Otherwise, return the contents of that code block.
+ // Returns the code to use as the parser method. If not specified, returns
+ // None. Otherwise, returns the contents of that code block.
Optional<StringRef> getParserCode() const;
- // Returns true if the accessors based on the types parameters should be
- // generated.
+ // Returns true if the accessors based on the parameters should be generated.
bool genAccessors() const;
// Return true if we need to generate the verify declaration and getChecked
// method.
bool genVerifyDecl() const;
- // Returns the dialects extra class declaration code.
+ // Returns the def's extra class declaration code.
Optional<StringRef> getExtraDecls() const;
// Get the code location (for error printing).
@@ -116,54 +117,80 @@ class TypeDef {
// generation.
bool skipDefaultBuilders() const;
- // Returns the builders of this type.
- ArrayRef<TypeBuilder> getBuilders() const { return builders; }
+ // Returns the builders of this def.
+ ArrayRef<AttrOrTypeBuilder> getBuilders() const { return builders; }
- // Returns whether two TypeDefs are equal by checking the equality of the
- // underlying record.
- bool operator==(const TypeDef &other) const;
+ // Returns whether two AttrOrTypeDefs are equal by checking the equality of
+ // the underlying record.
+ bool operator==(const AttrOrTypeDef &other) const;
- // Compares two TypeDefs by comparing the names of the dialects.
- bool operator<(const TypeDef &other) const;
+ // Compares two AttrOrTypeDefs by comparing the names of the dialects.
+ bool operator<(const AttrOrTypeDef &other) const;
- // Returns whether the TypeDef is defined.
+ // Returns whether the AttrOrTypeDef is defined.
operator bool() const { return def != nullptr; }
private:
const llvm::Record *def;
// The builders of this type definition.
- SmallVector<TypeBuilder> builders;
+ SmallVector<AttrOrTypeBuilder> builders;
+};
+
+//===----------------------------------------------------------------------===//
+// AttrDef
+//===----------------------------------------------------------------------===//
+
+/// This class represents a wrapper around a tablegen AttrDef record.
+class AttrDef : public AttrOrTypeDef {
+public:
+ using AttrOrTypeDef::AttrOrTypeDef;
+};
+
+//===----------------------------------------------------------------------===//
+// TypeDef
+//===----------------------------------------------------------------------===//
+
+/// This class represents a wrapper around a tablegen TypeDef record.
+class TypeDef : public AttrOrTypeDef {
+public:
+ using AttrOrTypeDef::AttrOrTypeDef;
};
//===----------------------------------------------------------------------===//
-// TypeParameter
+// AttrOrTypeParameter
//===----------------------------------------------------------------------===//
-// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs
-// to parameterize them.
-class TypeParameter {
+// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to
+// AttrOrTypeDefs to parameterize them.
+class AttrOrTypeParameter {
public:
- explicit TypeParameter(const llvm::DagInit *def, unsigned num)
- : def(def), num(num) {}
+ explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index)
+ : def(def), index(index) {}
// Get the parameter name.
StringRef getName() const;
+
// If specified, get the custom allocator code for this parameter.
Optional<StringRef> getAllocator() const;
+
// Get the C++ type of this parameter.
StringRef getCppType() const;
+
// Get a description of this parameter for documentation purposes.
Optional<StringRef> getSummary() const;
+
// Get the assembly syntax documentation.
StringRef getSyntax() const;
private:
+ /// The underlying tablegen parameter list this parameter is a part of.
const llvm::DagInit *def;
- const unsigned num;
+ /// The index of the parameter within the parameter list (`def`).
+ unsigned index;
};
} // end namespace tblgen
} // end namespace mlir
-#endif // MLIR_TABLEGEN_TYPEDEF_H
+#endif // MLIR_TABLEGEN_ATTRORTYPEDEF_H
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index 7ba83aae55bb..9d651ac08e56 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -23,14 +23,15 @@ namespace tblgen {
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
public:
- IfDefScope(llvm::StringRef name, llvm::raw_ostream &os) : name(name), os(os) {
+ IfDefScope(llvm::StringRef name, llvm::raw_ostream &os)
+ : name(name.str()), os(os) {
os << "#ifdef " << name << "\n"
<< "#undef " << name << "\n\n";
}
~IfDefScope() { os << "\n#endif // " << name << "\n\n"; }
private:
- llvm::StringRef name;
+ std::string name;
llvm::raw_ostream &os;
};
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
new file mode 100644
index 000000000000..e82f0f069ddf
--- /dev/null
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -0,0 +1,221 @@
+//===- AttrOrTypeDef.cpp - AttrOrTypeDef wrapper classes ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/AttrOrTypeDef.h"
+#include "mlir/TableGen/Dialect.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+//===----------------------------------------------------------------------===//
+// AttrOrTypeBuilder
+//===----------------------------------------------------------------------===//
+
+/// Returns true if this builder is able to infer the MLIRContext parameter.
+bool AttrOrTypeBuilder::hasInferredContextParameter() const {
+ return def->getValueAsBit("hasInferredContextParam");
+}
+
+//===----------------------------------------------------------------------===//
+// AttrOrTypeDef
+//===----------------------------------------------------------------------===//
+
+AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
+ // Populate the builders.
+ auto *builderList =
+ dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
+ if (builderList && !builderList->empty()) {
+ for (llvm::Init *init : builderList->getValues()) {
+ AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
+ def->getLoc());
+
+ // Ensure that all parameters have names.
+ for (const AttrOrTypeBuilder::Parameter ¶m :
+ builder.getParameters()) {
+ if (!param.getName())
+ PrintFatalError(def->getLoc(), "builder parameters must have a name");
+ }
+ builders.emplace_back(builder);
+ }
+ } else if (skipDefaultBuilders()) {
+ PrintFatalError(
+ def->getLoc(),
+ "default builders are skipped and no custom builders provided");
+ }
+}
+
+Dialect AttrOrTypeDef::getDialect() const {
+ auto *dialect = dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
+ return Dialect(dialect ? dialect->getDef() : nullptr);
+}
+
+StringRef AttrOrTypeDef::getName() const { return def->getName(); }
+
+StringRef AttrOrTypeDef::getCppClassName() const {
+ return def->getValueAsString("cppClassName");
+}
+
+StringRef AttrOrTypeDef::getCppBaseClassName() const {
+ return def->getValueAsString("cppBaseClassName");
+}
+
+bool AttrOrTypeDef::hasDescription() const {
+ const llvm::RecordVal *desc = def->getValue("description");
+ return desc && isa<llvm::StringInit>(desc->getValue());
+}
+
+StringRef AttrOrTypeDef::getDescription() const {
+ return def->getValueAsString("description");
+}
+
+bool AttrOrTypeDef::hasSummary() const {
+ const llvm::RecordVal *summary = def->getValue("summary");
+ return summary && isa<llvm::StringInit>(summary->getValue());
+}
+
+StringRef AttrOrTypeDef::getSummary() const {
+ return def->getValueAsString("summary");
+}
+
+StringRef AttrOrTypeDef::getStorageClassName() const {
+ return def->getValueAsString("storageClass");
+}
+
+StringRef AttrOrTypeDef::getStorageNamespace() const {
+ return def->getValueAsString("storageNamespace");
+}
+
+bool AttrOrTypeDef::genStorageClass() const {
+ return def->getValueAsBit("genStorageClass");
+}
+
+bool AttrOrTypeDef::hasStorageCustomConstructor() const {
+ return def->getValueAsBit("hasStorageCustomConstructor");
+}
+
+void AttrOrTypeDef::getParameters(
+ SmallVectorImpl<AttrOrTypeParameter> ¶meters) const {
+ if (auto *parametersDag = def->getValueAsDag("parameters")) {
+ for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i)
+ parameters.push_back(AttrOrTypeParameter(parametersDag, i));
+ }
+}
+
+unsigned AttrOrTypeDef::getNumParameters() const {
+ auto *parametersDag = def->getValueAsDag("parameters");
+ return parametersDag ? parametersDag->getNumArgs() : 0;
+}
+
+Optional<StringRef> AttrOrTypeDef::getMnemonic() const {
+ return def->getValueAsOptionalString("mnemonic");
+}
+
+Optional<StringRef> AttrOrTypeDef::getPrinterCode() const {
+ return def->getValueAsOptionalString("printer");
+}
+
+Optional<StringRef> AttrOrTypeDef::getParserCode() const {
+ return def->getValueAsOptionalString("parser");
+}
+
+bool AttrOrTypeDef::genAccessors() const {
+ return def->getValueAsBit("genAccessors");
+}
+
+bool AttrOrTypeDef::genVerifyDecl() const {
+ return def->getValueAsBit("genVerifyDecl");
+}
+
+Optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
+ auto value = def->getValueAsString("extraClassDeclaration");
+ return value.empty() ? Optional<StringRef>() : value;
+}
+
+ArrayRef<llvm::SMLoc> AttrOrTypeDef::getLoc() const { return def->getLoc(); }
+
+bool AttrOrTypeDef::skipDefaultBuilders() const {
+ return def->getValueAsBit("skipDefaultBuilders");
+}
+
+bool AttrOrTypeDef::operator==(const AttrOrTypeDef &other) const {
+ return def == other.def;
+}
+
+bool AttrOrTypeDef::operator<(const AttrOrTypeDef &other) const {
+ return getName() < other.getName();
+}
+
+//===----------------------------------------------------------------------===//
+// AttrOrTypeParameter
+//===----------------------------------------------------------------------===//
+
+StringRef AttrOrTypeParameter::getName() const {
+ return def->getArgName(index)->getValue();
+}
+
+Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
+ llvm::Init *parameterType = def->getArg(index);
+ if (isa<llvm::StringInit>(parameterType))
+ return Optional<StringRef>();
+
+ if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
+ llvm::RecordVal *code = param->getDef()->getValue("allocator");
+ if (!code)
+ return Optional<StringRef>();
+ if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(code->getValue()))
+ return ci->getValue();
+ if (isa<llvm::UnsetInit>(code->getValue()))
+ return Optional<StringRef>();
+
+ llvm::PrintFatalError(
+ param->getDef()->getLoc(),
+ "Record `" + def->getArgName(index)->getValue() +
+ "', field `printer' does not have a code initializer!");
+ }
+
+ llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
+ "defs which inherit from AttrOrTypeParameter\n");
+}
+
+StringRef AttrOrTypeParameter::getCppType() const {
+ auto *parameterType = def->getArg(index);
+ if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+ return stringType->getValue();
+ if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
+ return param->getDef()->getValueAsString("cppType");
+ llvm::PrintFatalError(
+ "Parameters DAG arguments must be either strings or defs "
+ "which inherit from AttrOrTypeParameter\n");
+}
+
+Optional<StringRef> AttrOrTypeParameter::getSummary() const {
+ auto *parameterType = def->getArg(index);
+ if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
+ const auto *desc = param->getDef()->getValue("summary");
+ if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
+ return ci->getValue();
+ }
+ return Optional<StringRef>();
+}
+
+StringRef AttrOrTypeParameter::getSyntax() const {
+ auto *parameterType = def->getArg(index);
+ if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+ return stringType->getValue();
+ if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
+ const auto *syntax = param->getDef()->getValue("syntax");
+ if (syntax && isa<llvm::StringInit>(syntax->getValue()))
+ return cast<llvm::StringInit>(syntax->getValue())->getValue();
+ return getCppType();
+ }
+ llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
+ "defs which inherit from AttrOrTypeParameter");
+}
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index fa52dde27a40..557caf1c9c1a 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -11,6 +11,7 @@
llvm_add_library(MLIRTableGen STATIC
Argument.cpp
Attribute.cpp
+ AttrOrTypeDef.cpp
Builder.cpp
Constraint.cpp
Dialect.cpp
@@ -26,7 +27,6 @@ llvm_add_library(MLIRTableGen STATIC
SideEffects.cpp
Successor.cpp
Type.cpp
- TypeDef.cpp
DISABLE_LLVM_LINK_LLVM_DYLIB
diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp
deleted file mode 100644
index d76748d85835..000000000000
--- a/mlir/lib/TableGen/TypeDef.cpp
+++ /dev/null
@@ -1,212 +0,0 @@
-//===- TypeDef.cpp - TypeDef wrapper class --------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// TypeDef wrapper to simplify using TableGen Record defining a MLIR dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/TableGen/TypeDef.h"
-#include "mlir/TableGen/Dialect.h"
-#include "llvm/ADT/StringExtras.h"
-#include "llvm/TableGen/Error.h"
-#include "llvm/TableGen/Record.h"
-
-using namespace mlir;
-using namespace mlir::tblgen;
-
-//===----------------------------------------------------------------------===//
-// TypeBuilder
-//===----------------------------------------------------------------------===//
-
-/// Returns true if this builder is able to infer the MLIRContext parameter.
-bool TypeBuilder::hasInferredContextParameter() const {
- return def->getValueAsBit("hasInferredContextParam");
-}
-
-//===----------------------------------------------------------------------===//
-// TypeDef
-//===----------------------------------------------------------------------===//
-
-Dialect TypeDef::getDialect() const {
- auto *dialectDef =
- dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
- if (dialectDef == nullptr)
- return Dialect(nullptr);
- return Dialect(dialectDef->getDef());
-}
-
-StringRef TypeDef::getName() const { return def->getName(); }
-StringRef TypeDef::getCppClassName() const {
- return def->getValueAsString("cppClassName");
-}
-
-StringRef TypeDef::getCppBaseClassName() const {
- return def->getValueAsString("cppBaseClassName");
-}
-
-bool TypeDef::hasDescription() const {
- const llvm::RecordVal *s = def->getValue("description");
- return s != nullptr && isa<llvm::StringInit>(s->getValue());
-}
-
-StringRef TypeDef::getDescription() const {
- return def->getValueAsString("description");
-}
-
-bool TypeDef::hasSummary() const {
- const llvm::RecordVal *s = def->getValue("summary");
- return s != nullptr && isa<llvm::StringInit>(s->getValue());
-}
-
-StringRef TypeDef::getSummary() const {
- return def->getValueAsString("summary");
-}
-
-StringRef TypeDef::getStorageClassName() const {
- return def->getValueAsString("storageClass");
-}
-StringRef TypeDef::getStorageNamespace() const {
- return def->getValueAsString("storageNamespace");
-}
-
-bool TypeDef::genStorageClass() const {
- return def->getValueAsBit("genStorageClass");
-}
-bool TypeDef::hasStorageCustomConstructor() const {
- return def->getValueAsBit("hasStorageCustomConstructor");
-}
-void TypeDef::getParameters(SmallVectorImpl<TypeParameter> ¶meters) const {
- auto *parametersDag = def->getValueAsDag("parameters");
- if (parametersDag != nullptr) {
- size_t numParams = parametersDag->getNumArgs();
- for (unsigned i = 0; i < numParams; i++)
- parameters.push_back(TypeParameter(parametersDag, i));
- }
-}
-unsigned TypeDef::getNumParameters() const {
- auto *parametersDag = def->getValueAsDag("parameters");
- return parametersDag ? parametersDag->getNumArgs() : 0;
-}
-llvm::Optional<StringRef> TypeDef::getMnemonic() const {
- return def->getValueAsOptionalString("mnemonic");
-}
-llvm::Optional<StringRef> TypeDef::getPrinterCode() const {
- return def->getValueAsOptionalString("printer");
-}
-llvm::Optional<StringRef> TypeDef::getParserCode() const {
- return def->getValueAsOptionalString("parser");
-}
-bool TypeDef::genAccessors() const {
- return def->getValueAsBit("genAccessors");
-}
-bool TypeDef::genVerifyDecl() const {
- return def->getValueAsBit("genVerifyDecl");
-}
-llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
- auto value = def->getValueAsString("extraClassDeclaration");
- return value.empty() ? llvm::Optional<StringRef>() : value;
-}
-llvm::ArrayRef<llvm::SMLoc> TypeDef::getLoc() const { return def->getLoc(); }
-
-bool TypeDef::skipDefaultBuilders() const {
- return def->getValueAsBit("skipDefaultBuilders");
-}
-
-bool TypeDef::operator==(const TypeDef &other) const {
- return def == other.def;
-}
-
-bool TypeDef::operator<(const TypeDef &other) const {
- return getName() < other.getName();
-}
-
-//===----------------------------------------------------------------------===//
-// TypeParameter
-//===----------------------------------------------------------------------===//
-
-TypeDef::TypeDef(const llvm::Record *def) : def(def) {
- // Populate the builders.
- auto *builderList =
- dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
- if (builderList && !builderList->empty()) {
- for (llvm::Init *init : builderList->getValues()) {
- TypeBuilder builder(cast<llvm::DefInit>(init)->getDef(), def->getLoc());
-
- // Ensure that all parameters have names.
- for (const TypeBuilder::Parameter ¶m : builder.getParameters()) {
- if (!param.getName())
- PrintFatalError(def->getLoc(),
- "type builder parameters must have a name");
- }
- builders.emplace_back(builder);
- }
- } else if (skipDefaultBuilders()) {
- PrintFatalError(
- def->getLoc(),
- "default builders are skipped and no custom builders provided");
- }
-}
-
-StringRef TypeParameter::getName() const {
- return def->getArgName(num)->getValue();
-}
-Optional<StringRef> TypeParameter::getAllocator() const {
- llvm::Init *parameterType = def->getArg(num);
- if (isa<llvm::StringInit>(parameterType))
- return llvm::Optional<StringRef>();
-
- if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
- llvm::RecordVal *code = typeParameter->getDef()->getValue("allocator");
- if (!code)
- return llvm::Optional<StringRef>();
- if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(code->getValue()))
- return ci->getValue();
- if (isa<llvm::UnsetInit>(code->getValue()))
- return llvm::Optional<StringRef>();
-
- llvm::PrintFatalError(
- typeParameter->getDef()->getLoc(),
- "Record `" + def->getArgName(num)->getValue() +
- "', field `printer' does not have a code initializer!");
- }
-
- llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
- "defs which inherit from TypeParameter\n");
-}
-StringRef TypeParameter::getCppType() const {
- auto *parameterType = def->getArg(num);
- if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
- return stringType->getValue();
- if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType))
- return typeParameter->getDef()->getValueAsString("cppType");
- llvm::PrintFatalError(
- "Parameters DAG arguments must be either strings or defs "
- "which inherit from TypeParameter\n");
-}
-Optional<StringRef> TypeParameter::getSummary() const {
- auto *parameterType = def->getArg(num);
- if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
- const auto *desc = typeParameter->getDef()->getValue("summary");
- if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
- return ci->getValue();
- }
- return Optional<StringRef>();
-}
-StringRef TypeParameter::getSyntax() const {
- auto *parameterType = def->getArg(num);
- if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
- return stringType->getValue();
- if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
- const auto *syntax = typeParameter->getDef()->getValue("syntax");
- if (syntax && isa<llvm::StringInit>(syntax->getValue()))
- return dyn_cast<llvm::StringInit>(syntax->getValue())->getValue();
- return getCppType();
- }
- llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
- "defs which inherit from TypeParameter");
-}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 895b7755029c..a3aa9f33c673 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -11,10 +11,15 @@ mlir_tablegen(TestOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(TestOpInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRTestInterfaceIncGen)
+set(LLVM_TARGET_DEFINITIONS TestAttrDefs.td)
+mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRTestAttrDefIncGen)
+
set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls)
mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs)
-add_public_tablegen_target(MLIRTestDefIncGen)
+add_public_tablegen_target(MLIRTestTypeDefIncGen)
set(LLVM_TARGET_DEFINITIONS TestOps.td)
@@ -30,6 +35,7 @@ add_public_tablegen_target(MLIRTestOpsIncGen)
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestDialect
+ TestAttributes.cpp
TestDialect.cpp
TestInterfaces.cpp
TestPatterns.cpp
@@ -39,8 +45,9 @@ add_mlir_library(MLIRTestDialect
EXCLUDE_FROM_LIBMLIR
DEPENDS
+ MLIRTestAttrDefIncGen
MLIRTestInterfaceIncGen
- MLIRTestDefIncGen
+ MLIRTestTypeDefIncGen
MLIRTestOpsIncGen
LINK_LIBS PUBLIC
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
new file mode 100644
index 000000000000..8b3ebaa6e632
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -0,0 +1,44 @@
+//===-- TestAttrDefs.td - Test dialect attr definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TableGen data attribute definitions for Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_ATTRDEFS
+#define TEST_ATTRDEFS
+
+// To get the test dialect definition.
+include "TestOps.td"
+
+// All of the attributes will extend this class.
+class Test_Attr<string name> : AttrDef<Test_Dialect, name>;
+
+def SimpleAttrA : Test_Attr<"SimpleA"> {
+ let mnemonic = "smpla";
+}
+
+// A more complex parameterized attribute.
+def CompoundAttrA : Test_Attr<"CompoundA"> {
+ let mnemonic = "cmpnd_a";
+
+ // List of type parameters.
+ let parameters = (
+ ins
+ "int":$widthOfSomething,
+ "::mlir::Type":$oneType,
+ // This is special syntax since ArrayRefs require allocation in the
+ // constructor.
+ ArrayRefParameter<
+ "int", // The parameter C++ type.
+ "An example of an array of ints" // Parameter description.
+ >: $arrayOfInts
+ );
+}
+
+#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
new file mode 100644
index 000000000000..39328b6c1d10
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -0,0 +1,82 @@
+//===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains attributes defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestAttributes.h"
+#include "TestDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::test;
+
+Attribute CompoundAAttr::parse(MLIRContext *context, DialectAsmParser &parser,
+ Type type) {
+ int widthOfSomething;
+ Type oneType;
+ SmallVector<int, 4> arrayOfInts;
+ if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
+ parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
+ parser.parseLSquare())
+ return Attribute();
+
+ int intVal;
+ while (!*parser.parseOptionalInteger(intVal)) {
+ arrayOfInts.push_back(intVal);
+ if (parser.parseOptionalComma())
+ break;
+ }
+
+ if (parser.parseRSquare() || parser.parseGreater())
+ return Attribute();
+ return get(context, widthOfSomething, oneType, arrayOfInts);
+}
+
+void CompoundAAttr::print(DialectAsmPrinter &printer) const {
+ printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
+ << ", [";
+ llvm::interleaveComma(getArrayOfInts(), printer);
+ printer << "]>";
+}
+
+//===----------------------------------------------------------------------===//
+// Tablegen Generated Definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_ATTRDEF_CLASSES
+#include "TestAttrDefs.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// TestDialect
+//===----------------------------------------------------------------------===//
+
+Attribute TestDialect::parseAttribute(DialectAsmParser &parser,
+ Type type) const {
+ StringRef attrTag;
+ if (failed(parser.parseKeyword(&attrTag)))
+ return Attribute();
+ if (auto attr = generatedAttributeParser(getContext(), parser, attrTag, type))
+ return attr;
+
+ parser.emitError(parser.getNameLoc(), "unknown test attribute");
+ return Attribute();
+}
+
+void TestDialect::printAttribute(Attribute attr,
+ DialectAsmPrinter &printer) const {
+ if (succeeded(generatedAttributePrinter(attr, printer)))
+ return;
+}
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
new file mode 100644
index 000000000000..0eaa78eae590
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -0,0 +1,27 @@
+//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTATTRIBUTES_H
+#define MLIR_TESTATTRIBUTES_H
+
+#include <tuple>
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+
+#define GET_ATTRDEF_CLASSES
+#include "TestAttrDefs.h.inc"
+
+#endif // MLIR_TESTATTRIBUTES_H
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index b6af6c978b37..143db6fdb89a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestAttributes.h"
#include "TestTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
@@ -168,6 +169,10 @@ void TestDialect::initialize() {
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "TestAttrDefs.cpp.inc"
+ >();
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface>();
addTypes<TestType, TestRecursiveType,
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 828c31d82e11..458347d36ad9 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -27,6 +27,13 @@ def Test_Dialect : Dialect {
let hasOperationAttrVerify = 1;
let hasRegionArgAttrVerify = 1;
let hasRegionResultAttrVerify = 1;
+
+ let extraClassDeclaration = [{
+ Attribute parseAttribute(DialectAsmParser &parser,
+ Type type) const override;
+ void printAttribute(Attribute attr,
+ DialectAsmPrinter &printer) const override;
+ }];
}
class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
new file mode 100644
index 000000000000..36ea2cb46ece
--- /dev/null
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -0,0 +1,96 @@
+// RUN: mlir-tblgen -gen-attrdef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+
+include "mlir/IR/OpBase.td"
+
+// DECL: #ifdef GET_ATTRDEF_CLASSES
+// DECL: #undef GET_ATTRDEF_CLASSES
+
+// DECL: namespace mlir {
+// DECL: class DialectAsmParser;
+// DECL: class DialectAsmPrinter;
+// DECL: } // namespace mlir
+
+// DEF: #ifdef GET_ATTRDEF_LIST
+// DEF: #undef GET_ATTRDEF_LIST
+// DEF: ::mlir::test::SimpleAAttr,
+// DEF: ::mlir::test::CompoundAAttr,
+// DEF: ::mlir::test::IndexAttr,
+// DEF: ::mlir::test::SingleParameterAttr
+
+// DEF-LABEL: ::mlir::Attribute generatedAttributeParser(::mlir::MLIRContext *context,
+// DEF-NEXT: ::mlir::DialectAsmParser &parser,
+// DEF-NEXT: ::llvm::StringRef mnemonic, ::mlir::Type type) {
+// DEF: if (mnemonic == ::mlir::test::CompoundAAttr::getMnemonic()) return ::mlir::test::CompoundAAttr::parse(context, parser, type);
+// DEF-NEXT: if (mnemonic == ::mlir::test::IndexAttr::getMnemonic()) return ::mlir::test::IndexAttr::parse(context, parser, type);
+// DEF-NEXT: return ::mlir::Attribute();
+
+def Test_Dialect: Dialect {
+// DECL-NOT: TestDialect
+// DEF-NOT: TestDialect
+ let name = "TestDialect";
+ let cppNamespace = "::mlir::test";
+}
+
+class TestAttr<string name> : AttrDef<Test_Dialect, name> { }
+
+def A_SimpleAttrA : TestAttr<"SimpleA"> {
+// DECL: class SimpleAAttr : public ::mlir::Attribute
+}
+
+// A more complex parameterized type
+def B_CompoundAttrA : TestAttr<"CompoundA"> {
+ let summary = "A more complex parameterized attribute";
+ let description = "This attribute is to test a reasonably complex attribute";
+ let mnemonic = "cmpnd_a";
+ let parameters = (
+ ins
+ "int":$widthOfSomething,
+ "::mlir::test::SimpleTypeA": $exampleTdType,
+ "SomeCppStruct": $exampleCppType,
+ ArrayRefParameter<"int", "Matrix dimensions">:$dims,
+ "::mlir::Type":$inner
+ );
+
+ let genVerifyDecl = 1;
+
+// DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute
+// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
+// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
+// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
+// DECL: return ::llvm::StringLiteral("cmpnd_a");
+// DECL: }
+// DECL: static ::mlir::Attribute parse(::mlir::MLIRContext *context,
+// DECL-NEXT: ::mlir::DialectAsmParser &parser, ::mlir::Type type);
+// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
+// DECL: int getWidthOfSomething() const;
+// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
+// DECL: SomeCppStruct getExampleCppType() const;
+}
+
+def C_IndexAttr : TestAttr<"Index"> {
+ let mnemonic = "index";
+
+ let parameters = (
+ ins
+ StringRefParameter<"Label for index">:$label
+ );
+
+// DECL-LABEL: class IndexAttr : public ::mlir::Attribute
+// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
+// DECL: return ::llvm::StringLiteral("index");
+// DECL: }
+// DECL: static ::mlir::Attribute parse(::mlir::MLIRContext *context,
+// DECL-NEXT: ::mlir::DialectAsmParser &parser, ::mlir::Type type);
+// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
+}
+
+def D_SingleParameterAttr : TestAttr<"SingleParameter"> {
+ let parameters = (
+ ins
+ "int": $num
+ );
+// DECL-LABEL: struct SingleParameterAttrStorage;
+// DECL-LABEL: class SingleParameterAttr
+// DECL-NEXT: detail::SingleParameterAttrStorage
+}
diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
new file mode 100644
index 000000000000..8c167ffc2854
--- /dev/null
+++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
@@ -0,0 +1,5 @@
+// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func private @compoundA()
+// CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]>
+func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [5, 6]>}
diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index a895c5cfd872..57103ccc7db9 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -19,9 +19,11 @@ include "mlir/IR/OpBase.td"
// DEF: ::mlir::test::SingleParameterType,
// DEF: ::mlir::test::IntegerType
-// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &parser, ::llvm::StringRef mnemonic)
+// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context,
+// DEF-NEXT: ::mlir::DialectAsmParser &parser,
+// DEF-NEXT: ::llvm::StringRef mnemonic) {
// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(context, parser);
-// DEF return ::mlir::Type();
+// DEF: return ::mlir::Type();
def Test_Dialect: Dialect {
// DECL-NOT: TestDialect
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
new file mode 100644
index 000000000000..3df9fd9ac90d
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -0,0 +1,849 @@
+//===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/TableGen/AttrOrTypeDef.h"
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+#define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
+/// specified and can only find one dialect's defs, use that.
+static void collectAllDefs(StringRef selectedDialect,
+ std::vector<llvm::Record *> records,
+ SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
+ auto defs = llvm::map_range(
+ records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
+ if (defs.empty())
+ return;
+
+ StringRef dialectName;
+ if (selectedDialect.empty()) {
+ if (defs.empty())
+ return;
+
+ Dialect dialect(nullptr);
+ for (const AttrOrTypeDef &typeDef : defs) {
+ if (!dialect) {
+ dialect = typeDef.getDialect();
+ } else if (dialect != typeDef.getDialect()) {
+ llvm::PrintFatalError("defs belonging to more than one dialect. Must "
+ "select one via '--(attr|type)defs-dialect'");
+ }
+ }
+
+ dialectName = dialect.getName();
+ } else {
+ dialectName = selectedDialect;
+ }
+
+ for (const AttrOrTypeDef &def : defs)
+ if (def.getDialect().getName().equals(dialectName))
+ resultDefs.push_back(def);
+}
+
+//===----------------------------------------------------------------------===//
+// ParamCommaFormatter
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Pass an instance of this class to llvm::formatv() to emit a comma separated
+/// list of parameters in the format by 'EmitFormat'.
+class ParamCommaFormatter : public llvm::detail::format_adapter {
+public:
+ /// Choose the output format
+ enum EmitFormat {
+ /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
+ /// [...]".
+ TypeNamePairs,
+
+ /// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
+ TypeNameInitializer,
+
+ /// Emit "param1Name, param2Name, [...]".
+ JustParams,
+ };
+
+ ParamCommaFormatter(EmitFormat emitFormat,
+ ArrayRef<AttrOrTypeParameter> params,
+ bool prependComma = true)
+ : emitFormat(emitFormat), params(params), prependComma(prependComma) {}
+
+ /// llvm::formatv will call this function when using an instance as a
+ /// replacement value.
+ void format(raw_ostream &os, StringRef options) override {
+ if (!params.empty() && prependComma)
+ os << ", ";
+
+ switch (emitFormat) {
+ case EmitFormat::TypeNamePairs:
+ interleaveComma(params, os, [&](const AttrOrTypeParameter &p) {
+ emitTypeNamePair(p, os);
+ });
+ break;
+ case EmitFormat::TypeNameInitializer:
+ interleaveComma(params, os, [&](const AttrOrTypeParameter &p) {
+ emitTypeNameInitializer(p, os);
+ });
+ break;
+ case EmitFormat::JustParams:
+ interleaveComma(params, os,
+ [&](const AttrOrTypeParameter &p) { os << p.getName(); });
+ break;
+ }
+ }
+
+private:
+ // Emit "paramType paramName".
+ static void emitTypeNamePair(const AttrOrTypeParameter ¶m,
+ raw_ostream &os) {
+ os << param.getCppType() << " " << param.getName();
+ }
+ // Emit "paramName(paramName)"
+ void emitTypeNameInitializer(const AttrOrTypeParameter ¶m,
+ raw_ostream &os) {
+ os << param.getName() << "(" << param.getName() << ")";
+ }
+
+ EmitFormat emitFormat;
+ ArrayRef<AttrOrTypeParameter> params;
+ bool prependComma;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// DefGenerator
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This struct is the base generator used when processing tablegen interfaces.
+class DefGenerator {
+public:
+ bool emitDecls(StringRef selectedDialect);
+ bool emitDefs(StringRef selectedDialect);
+
+protected:
+ DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os)
+ : defRecords(std::move(defs)), os(os), isAttrGenerator(false) {}
+
+ /// Emit the declaration of a single def.
+ void emitDefDecl(const AttrOrTypeDef &def);
+ /// Emit the list of def type names.
+ void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
+ /// Emit the code to dispatch between
diff erent defs during parsing/printing.
+ void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
+ /// Emit the definition of a single def.
+ void emitDefDef(const AttrOrTypeDef &def);
+ /// Emit the storage class for the given def.
+ void emitStorageClass(const AttrOrTypeDef &def);
+ /// Emit the parser/printer for the given def.
+ void emitParsePrint(const AttrOrTypeDef &def);
+
+ /// The set of def records to emit.
+ std::vector<llvm::Record *> defRecords;
+ /// The stream to emit to.
+ raw_ostream &os;
+ /// The prefix of the tablegen def name, e.g. Attr or Type.
+ StringRef defTypePrefix;
+ /// The C++ base value type of the def, e.g. Attribute or Type.
+ StringRef valueType;
+ /// Flag indicating if this generator is for Attributes. False if the
+ /// generator is for types.
+ bool isAttrGenerator;
+};
+
+/// A specialized generator for AttrDefs.
+struct AttrDefGenerator : public DefGenerator {
+ AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
+ : DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os) {
+ isAttrGenerator = true;
+ defTypePrefix = "Attr";
+ valueType = "Attribute";
+ }
+};
+/// A specialized generator for TypeDefs.
+struct TypeDefGenerator : public DefGenerator {
+ TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
+ : DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os) {
+ defTypePrefix = "Type";
+ valueType = "Type";
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// GEN: Declarations
+//===----------------------------------------------------------------------===//
+
+/// Print this above all the other declarations. Contains type declarations used
+/// later on.
+static const char *const typeDefDeclHeader = R"(
+namespace mlir {
+class DialectAsmParser;
+class DialectAsmPrinter;
+} // namespace mlir
+)";
+
+/// The code block for the start of a typeDef class declaration -- singleton
+/// case.
+///
+/// {0}: The name of the def class.
+/// {1}: The name of the type base class.
+/// {2}: The name of the base value type, e.g. Attribute or Type.
+/// {3}: The tablegen record type prefix, e.g. Attr or Type.
+static const char *const defDeclSingletonBeginStr = R"(
+ class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage> {{
+ public:
+ /// Inherit some necessary constructors from '{3}Base'.
+ using Base::Base;
+)";
+
+/// The code block for the start of a typeDef class declaration -- parametric
+/// case.
+///
+/// {0}: The name of the typeDef class.
+/// {1}: The name of the type base class.
+/// {2}: The typeDef storage class namespace.
+/// {3}: The storage class name.
+/// {4}: The name of the base value type, e.g. Attribute or Type.
+/// {5}: The tablegen record type prefix, e.g. Attr or Type.
+static const char *const defDeclParametricBeginStr = R"(
+ namespace {2} {
+ struct {3};
+ } // end namespace {2}
+ class {0} : public ::mlir::{4}::{5}Base<{0}, {1},
+ {2}::{3}> {{
+ public:
+ /// Inherit some necessary constructors from '{5}Base'.
+ using Base::Base;
+
+)";
+
+/// The code snippet for print/parse of an Attribute/Type.
+///
+/// {0}: The name of the base value type, e.g. Attribute or Type.
+/// {1}: Extra parser parameters.
+static const char *const defDeclParsePrintStr = R"(
+ static ::mlir::{0} parse(::mlir::MLIRContext *context,
+ ::mlir::DialectAsmParser &parser{1});
+ void print(::mlir::DialectAsmPrinter &printer) const;
+)";
+
+/// The code block for the verify method declaration.
+///
+/// {0}: List of parameters, parameters style.
+static const char *const defDeclVerifyStr = R"(
+ using Base::getChecked;
+ static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0});
+)";
+
+/// Emit the builders for the given def.
+static void emitBuilderDecls(const AttrOrTypeDef &def, raw_ostream &os,
+ ParamCommaFormatter ¶mTypes) {
+ StringRef typeClass = def.getCppClassName();
+ bool genCheckedMethods = def.genVerifyDecl();
+ if (!def.skipDefaultBuilders()) {
+ os << llvm::formatv(
+ " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
+ paramTypes);
+ if (genCheckedMethods) {
+ os << llvm::formatv(" static {0} "
+ "getChecked(llvm::function_ref<::mlir::"
+ "InFlightDiagnostic()> emitError, "
+ "::mlir::MLIRContext *context{1});\n",
+ typeClass, paramTypes);
+ }
+ }
+
+ // Generate the builders specified by the user.
+ for (const AttrOrTypeBuilder &builder : def.getBuilders()) {
+ std::string paramStr;
+ llvm::raw_string_ostream paramOS(paramStr);
+ llvm::interleaveComma(
+ builder.getParameters(), paramOS,
+ [&](const AttrOrTypeBuilder::Parameter ¶m) {
+ // Note: AttrOrTypeBuilder parameters are guaranteed to have names.
+ paramOS << param.getCppType() << " " << *param.getName();
+ if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
+ paramOS << " = " << *defaultParamValue;
+ });
+ paramOS.flush();
+
+ // Generate the `get` variant of the builder.
+ os << " static " << typeClass << " get(";
+ if (!builder.hasInferredContextParameter()) {
+ os << "::mlir::MLIRContext *context";
+ if (!paramStr.empty())
+ os << ", ";
+ }
+ os << paramStr << ");\n";
+
+ // Generate the `getChecked` variant of the builder.
+ if (genCheckedMethods) {
+ os << " static " << typeClass
+ << " getChecked(llvm::function_ref<mlir::InFlightDiagnostic()> "
+ "emitError";
+ if (!builder.hasInferredContextParameter())
+ os << ", ::mlir::MLIRContext *context";
+ if (!paramStr.empty())
+ os << ", ";
+ os << paramStr << ");\n";
+ }
+ }
+}
+
+void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
+ SmallVector<AttrOrTypeParameter, 4> params;
+ def.getParameters(params);
+
+ // Emit the beginning string template: either the singleton or parametric
+ // template.
+ if (def.getNumParameters() == 0) {
+ os << formatv(defDeclSingletonBeginStr, def.getCppClassName(),
+ def.getCppBaseClassName(), valueType, defTypePrefix);
+ } else {
+ os << formatv(defDeclParametricBeginStr, def.getCppClassName(),
+ def.getCppBaseClassName(), def.getStorageNamespace(),
+ def.getStorageClassName(), valueType, defTypePrefix);
+ }
+
+ // Emit the extra declarations first in case there's a definition in there.
+ if (Optional<StringRef> extraDecl = def.getExtraDecls())
+ os << *extraDecl << "\n";
+
+ ParamCommaFormatter emitTypeNamePairsAfterComma(
+ ParamCommaFormatter::EmitFormat::TypeNamePairs, params);
+ if (!params.empty()) {
+ emitBuilderDecls(def, os, emitTypeNamePairsAfterComma);
+
+ // Emit the verify invariants declaration.
+ if (def.genVerifyDecl())
+ os << llvm::formatv(defDeclVerifyStr, emitTypeNamePairsAfterComma);
+ }
+
+ // Emit the mnenomic, if specified.
+ if (auto mnenomic = def.getMnemonic()) {
+ os << " static constexpr ::llvm::StringLiteral getMnemonic() {\n"
+ << " return ::llvm::StringLiteral(\"" << mnenomic << "\");\n"
+ << " }\n";
+
+ // If mnemonic specified, emit print/parse declarations.
+ if (def.getParserCode() || def.getPrinterCode() || !params.empty()) {
+ os << llvm::formatv(defDeclParsePrintStr, valueType,
+ isAttrGenerator ? ", ::mlir::Type type" : "");
+ }
+ }
+
+ if (def.genAccessors()) {
+ SmallVector<AttrOrTypeParameter, 4> parameters;
+ def.getParameters(parameters);
+
+ for (AttrOrTypeParameter ¶meter : parameters) {
+ SmallString<16> name = parameter.getName();
+ name[0] = llvm::toUpper(name[0]);
+ os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name);
+ }
+ }
+
+ // End the decl.
+ os << " };\n";
+}
+
+bool DefGenerator::emitDecls(StringRef selectedDialect) {
+ emitSourceFileHeader((defTypePrefix + "Def Declarations").str(), os);
+ IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os);
+
+ // Output the common "header".
+ os << typeDefDeclHeader;
+
+ SmallVector<AttrOrTypeDef, 16> defs;
+ collectAllDefs(selectedDialect, defRecords, defs);
+ if (defs.empty())
+ return false;
+
+ NamespaceEmitter nsEmitter(os, defs.front().getDialect());
+
+ // Declare all the def classes first (in case they reference each other).
+ for (const AttrOrTypeDef &def : defs)
+ os << " class " << def.getCppClassName() << ";\n";
+
+ // Emit the declarations.
+ for (const AttrOrTypeDef &def : defs)
+ emitDefDecl(def);
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Def List
+//===----------------------------------------------------------------------===//
+
+void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
+ IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_LIST", os);
+ auto interleaveFn = [&](const AttrOrTypeDef &def) {
+ os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
+ };
+ llvm::interleave(defs, os, interleaveFn, ",\n");
+ os << "\n";
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Definitions
+//===----------------------------------------------------------------------===//
+
+/// The code block used to start the auto-generated parser function.
+///
+/// {0}: The name of the base value type, e.g. Attribute or Type.
+/// {1}: Additional parser parameters.
+static const char *const defParserDispatchStartStr = R"(
+static ::mlir::{0} generated{0}Parser(::mlir::MLIRContext *context,
+ ::mlir::DialectAsmParser &parser,
+ ::llvm::StringRef mnemonic{1}) {{
+)";
+
+/// The code block used to start the auto-generated printer function.
+///
+/// {0}: The name of the base value type, e.g. Attribute or Type.
+static const char *const defPrinterDispatchStartStr = R"(
+static ::mlir::LogicalResult generated{0}Printer(
+ ::mlir::{0} def, ::mlir::DialectAsmPrinter &printer) {{
+ return ::llvm::TypeSwitch<::mlir::{0}, ::mlir::LogicalResult>(def)
+)";
+
+/// Beginning of storage class.
+/// {0}: Storage class namespace.
+/// {1}: Storage class c++ name.
+/// {2}: Parameters parameters.
+/// {3}: Parameter initializer string.
+/// {4}: Parameter name list.
+/// {5}: Parameter types.
+/// {6}: The name of the base value type, e.g. Attribute or Type.
+static const char *const defStorageClassBeginStr = R"(
+namespace {0} {{
+ struct {1} : public ::mlir::{6}Storage {{
+ {1} ({2})
+ : {3} {{ }
+
+ /// The hash key is a tuple of the parameter types.
+ using KeyTy = std::tuple<{5}>;
+
+ /// Define the comparison function for the key type.
+ bool operator==(const KeyTy &key) const {{
+ return key == KeyTy({4});
+ }
+)";
+
+/// The storage class' constructor template.
+///
+/// {0}: storage class name.
+/// {1}: The name of the base value type, e.g. Attribute or Type.
+static const char *const defStorageClassConstructorBeginStr = R"(
+ /// Define a construction method for creating a new instance of this
+ /// storage.
+ static {0} *construct(::mlir::{1}StorageAllocator &allocator,
+ const KeyTy &key) {{
+)";
+
+/// The storage class' constructor return template.
+///
+/// {0}: storage class name.
+/// {1}: list of parameters.
+static const char *const defStorageClassConstructorEndStr = R"(
+ return new (allocator.allocate<{0}>())
+ {0}({1});
+ }
+)";
+
+/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
+static void emitStorageParameterAllocation(const AttrOrTypeDef &def,
+ raw_ostream &os) {
+ SmallVector<AttrOrTypeParameter> parameters;
+ def.getParameters(parameters);
+ FmtContext fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
+ for (AttrOrTypeParameter ¶meter : parameters) {
+ if (Optional<StringRef> allocCode = parameter.getAllocator()) {
+ fmtCtxt.withSelf(parameter.getName());
+ fmtCtxt.addSubst("_dst", parameter.getName());
+ os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n";
+ }
+ }
+}
+
+void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
+ SmallVector<AttrOrTypeParameter, 4> parameters;
+ def.getParameters(parameters);
+
+ // Collect the parameter names and types.
+ auto parameterNames =
+ map_range(parameters, [](AttrOrTypeParameter parameter) {
+ return parameter.getName();
+ });
+ auto parameterTypes =
+ map_range(parameters, [](AttrOrTypeParameter parameter) {
+ return parameter.getCppType();
+ });
+ auto parameterList = join(parameterNames, ", ");
+ auto parameterTypeList = join(parameterTypes, ", ");
+
+ // 1) Emit most of the storage class up until the hashKey body.
+ os << formatv(
+ defStorageClassBeginStr, def.getStorageNamespace(),
+ def.getStorageClassName(),
+ ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
+ parameters, /*prependComma=*/false),
+ ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNameInitializer,
+ parameters, /*prependComma=*/false),
+ parameterList, parameterTypeList, valueType);
+
+ // 2) Emit the haskKey method.
+ os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
+
+ // Extract each parameter from the key.
+ os << " return ::llvm::hash_combine(";
+ llvm::interleaveComma(
+ llvm::seq<unsigned>(0, parameters.size()), os,
+ [&](unsigned it) { os << "std::get<" << it << ">(key)"; });
+ os << ");\n }\n";
+
+ // 3) Emit the construct method.
+
+ // If user wants to build the storage constructor themselves, declare it
+ // here and then they can write the definition elsewhere.
+ if (def.hasStorageCustomConstructor()) {
+ os << llvm::formatv(" static {0} *construct(::mlir::{1}StorageAllocator "
+ "&allocator, const KeyTy &key);\n",
+ def.getStorageClassName(), valueType);
+
+ // Otherwise, generate one.
+ } else {
+ // First, unbox the parameters.
+ os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(),
+ valueType);
+ for (unsigned i = 0, e = parameters.size(); i < e; ++i) {
+ os << formatv(" auto {0} = std::get<{1}>(key);\n",
+ parameters[i].getName(), i);
+ }
+
+ // Second, reassign the parameter variables with allocation code, if it's
+ // specified.
+ emitStorageParameterAllocation(def, os);
+
+ // Last, return an allocated copy.
+ os << formatv(defStorageClassConstructorEndStr, def.getStorageClassName(),
+ parameterList);
+ }
+
+ // 4) Emit the parameters as storage class members.
+ for (auto parameter : parameters) {
+ os << " " << parameter.getCppType() << " " << parameter.getName()
+ << ";\n";
+ }
+ os << " };\n";
+
+ os << "} // namespace " << def.getStorageNamespace() << "\n";
+}
+
+void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
+ // Emit the printer code, if specified.
+ if (Optional<StringRef> printerCode = def.getPrinterCode()) {
+ // Both the mnenomic and printerCode must be defined (for parity with
+ // parserCode).
+ os << "void " << def.getCppClassName()
+ << "::print(::mlir::DialectAsmPrinter &printer) const {\n";
+ if (printerCode->empty()) {
+ // If no code specified, emit error.
+ PrintFatalError(def.getLoc(),
+ def.getName() +
+ ": printer (if specified) must have non-empty code");
+ }
+ FmtContext fmtCtxt = FmtContext().addSubst("_printer", "printer");
+ os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
+ }
+
+ // Emit the parser code, if specified.
+ if (Optional<StringRef> parserCode = def.getParserCode()) {
+ FmtContext fmtCtxt;
+ fmtCtxt.addSubst("_parser", "parser").addSubst("_ctxt", "context");
+
+ // The mnenomic must be defined so the dispatcher knows how to dispatch.
+ os << llvm::formatv("::mlir::{0} {1}::parse(::mlir::MLIRContext *context, "
+ "::mlir::DialectAsmParser &parser",
+ valueType, def.getCppClassName());
+ if (isAttrGenerator) {
+ // Attributes also accept a type parameter instead of a context.
+ os << ", ::mlir::Type type";
+ fmtCtxt.addSubst("_type", "type");
+ }
+ os << ") {\n";
+
+ if (parserCode->empty()) {
+ PrintFatalError(def.getLoc(),
+ def.getName() +
+ ": parser (if specified) must have non-empty code");
+ }
+ os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
+ }
+}
+
+/// Replace all instances of 'from' to 'to' in `str` and return the new string.
+static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
+ size_t pos = 0;
+ while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
+ str.replace(pos, from.size(), to.data(), to.size());
+ return str;
+}
+
+/// Emit the builders for the given def.
+static void emitBuilderDefs(const AttrOrTypeDef &def, raw_ostream &os,
+ ArrayRef<AttrOrTypeParameter> params) {
+ bool genCheckedMethods = def.genVerifyDecl();
+ StringRef className = def.getCppClassName();
+ if (!def.skipDefaultBuilders()) {
+ os << llvm::formatv(
+ "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
+ " return Base::get(context{2});\n}\n",
+ className,
+ ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
+ params),
+ ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams,
+ params));
+ if (genCheckedMethods) {
+ os << llvm::formatv(
+ "{0} {0}::getChecked("
+ "llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, "
+ "::mlir::MLIRContext *context{1}) {{\n"
+ " return Base::getChecked(emitError, context{2});\n}\n",
+ className,
+ ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
+ params),
+ ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams,
+ params));
+ }
+ }
+
+ auto builderFmtCtx =
+ FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get");
+ auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get");
+ auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context");
+
+ // Generate the builders specified by the user.
+ for (const AttrOrTypeBuilder &builder : def.getBuilders()) {
+ Optional<StringRef> body = builder.getBody();
+ if (!body)
+ continue;
+ std::string paramStr;
+ llvm::raw_string_ostream paramOS(paramStr);
+ llvm::interleaveComma(builder.getParameters(), paramOS,
+ [&](const AttrOrTypeBuilder::Parameter ¶m) {
+ // Note: AttrOrTypeBuilder parameters are guaranteed
+ // to have names.
+ paramOS << param.getCppType() << " "
+ << *param.getName();
+ });
+ paramOS.flush();
+
+ // Emit the `get` variant of the builder.
+ os << llvm::formatv("{0} {0}::get(", className);
+ if (!builder.hasInferredContextParameter()) {
+ os << "::mlir::MLIRContext *context";
+ if (!paramStr.empty())
+ os << ", ";
+ os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
+ tgfmt(*body, &builderFmtCtx).str());
+ } else {
+ os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
+ tgfmt(*body, &inferredCtxBuilderFmtCtx).str());
+ }
+
+ // Emit the `getChecked` variant of the builder.
+ if (genCheckedMethods) {
+ os << llvm::formatv("{0} "
+ "{0}::getChecked(llvm::function_ref<::mlir::"
+ "InFlightDiagnostic()> emitErrorFn",
+ className);
+ std::string checkedBody =
+ replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, ");
+ if (!builder.hasInferredContextParameter()) {
+ os << ", ::mlir::MLIRContext *context";
+ checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str();
+ }
+ if (!paramStr.empty())
+ os << ", ";
+ os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody);
+ }
+ }
+}
+
+/// Print all the def-specific definition code.
+void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
+ NamespaceEmitter ns(os, def.getDialect());
+
+ SmallVector<AttrOrTypeParameter, 4> parameters;
+ def.getParameters(parameters);
+ if (!parameters.empty()) {
+ // Emit the storage class, if requested and necessary.
+ if (def.genStorageClass())
+ emitStorageClass(def);
+
+ // Emit the builders for this def.
+ emitBuilderDefs(def, os, parameters);
+
+ // Generate accessor definitions only if we also generate the storage class.
+ // Otherwise, let the user define the exact accessor definition.
+ if (def.genAccessors() && def.genStorageClass()) {
+ for (const AttrOrTypeParameter ¶meter : parameters) {
+ SmallString<16> name = parameter.getName();
+ name[0] = llvm::toUpper(name[0]);
+ os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n",
+ parameter.getCppType(), name, parameter.getName(),
+ def.getCppClassName());
+ }
+ }
+ }
+
+ // If mnemonic is specified maybe print definitions for the parser and printer
+ // code, if they're specified.
+ if (def.getMnemonic())
+ emitParsePrint(def);
+}
+
+/// Emit the dialect printer/parser dispatcher. User's code should call these
+/// functions from their dialect's print/parse methods.
+void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
+ if (llvm::none_of(defs, [](const AttrOrTypeDef &def) {
+ return def.getMnemonic().hasValue();
+ })) {
+ return;
+ }
+
+ // The parser dispatch is just a list of if-elses, matching on the mnemonic
+ // and calling the def's parse function.
+ os << llvm::formatv(defParserDispatchStartStr, valueType,
+ isAttrGenerator ? ", ::mlir::Type type" : "");
+ for (const AttrOrTypeDef &def : defs) {
+ if (def.getMnemonic()) {
+ os << formatv(
+ " if (mnemonic == {0}::{1}::getMnemonic()) return {0}::{1}::",
+ def.getDialect().getCppNamespace(), def.getCppClassName());
+
+ // If the def has no parameters and no parser code, just invoke a normal
+ // `get`.
+ if (def.getNumParameters() == 0 && !def.getParserCode()) {
+ os << "get(context);\n";
+ continue;
+ }
+
+ os << "parse(context, parser" << (isAttrGenerator ? ", type" : "")
+ << ");\n";
+ }
+ }
+ os << " return ::mlir::" << valueType << "();\n";
+ os << "}\n\n";
+
+ // The printer dispatch uses llvm::TypeSwitch to find and call the correct
+ // printer.
+ os << llvm::formatv(defPrinterDispatchStartStr, valueType);
+ for (const AttrOrTypeDef &def : defs) {
+ Optional<StringRef> mnemonic = def.getMnemonic();
+ if (!mnemonic)
+ continue;
+
+ StringRef cppNamespace = def.getDialect().getCppNamespace();
+ StringRef cppClassName = def.getCppClassName();
+ os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ",
+ cppNamespace, cppClassName);
+
+ // If the def has no parameters and no printer, just print the mnemonic.
+ if (def.getNumParameters() == 0 && !def.getPrinterCode()) {
+ os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace,
+ cppClassName);
+ } else {
+ os << "t.print(printer);";
+ }
+ os << "\n return ::mlir::success();\n })\n";
+ }
+ os << llvm::formatv(
+ " .Default([](::mlir::{0}) {{ return ::mlir::failure(); });\n}\n\n",
+ valueType);
+}
+
+bool DefGenerator::emitDefs(StringRef selectedDialect) {
+ emitSourceFileHeader((defTypePrefix + "Def Definitions").str(), os);
+
+ SmallVector<AttrOrTypeDef, 16> defs;
+ collectAllDefs(selectedDialect, defRecords, defs);
+ if (defs.empty())
+ return false;
+ emitTypeDefList(defs);
+
+ IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os);
+ emitParsePrintDispatch(defs);
+ for (const AttrOrTypeDef &def : defs)
+ emitDefDef(def);
+
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Registration hooks
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// AttrDef
+
+static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
+static llvm::cl::opt<std::string>
+ attrDialect("attrdefs-dialect",
+ llvm::cl::desc("Generate attributes for this dialect"),
+ llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
+
+static mlir::GenRegistration
+ genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ AttrDefGenerator generator(records, os);
+ return generator.emitDefs(attrDialect);
+ });
+static mlir::GenRegistration
+ genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ AttrDefGenerator generator(records, os);
+ return generator.emitDecls(attrDialect);
+ });
+
+//===----------------------------------------------------------------------===//
+// TypeDef
+
+static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
+static llvm::cl::opt<std::string>
+ typeDialect("typedefs-dialect",
+ llvm::cl::desc("Generate types for this dialect"),
+ llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
+
+static mlir::GenRegistration
+ genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ TypeDefGenerator generator(records, os);
+ return generator.emitDefs(typeDialect);
+ });
+static mlir::GenRegistration
+ genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ TypeDefGenerator generator(records, os);
+ return generator.emitDecls(typeDialect);
+ });
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 32c0d739b517..c45eaee9386c 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -5,6 +5,7 @@ set(LLVM_LINK_COMPONENTS
)
add_tablegen(mlir-tblgen MLIR
+ AttrOrTypeDefGen.cpp
DialectGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp
@@ -22,7 +23,6 @@ add_tablegen(mlir-tblgen MLIR
RewriterGen.cpp
SPIRVUtilsGen.cpp
StructsGen.cpp
- TypeDefGen.cpp
)
set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index c1a25bd506fd..bc130ecd5e52 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -13,9 +13,9 @@
#include "DocGenUtilities.h"
#include "mlir/Support/IndentedOstream.h"
+#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
-#include "mlir/TableGen/TypeDef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -164,7 +164,7 @@ static void emitTypeDoc(const Type &type, raw_ostream &os) {
/// Emit the assembly format of a type.
static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) {
- SmallVector<TypeParameter, 4> parameters;
+ SmallVector<AttrOrTypeParameter, 4> parameters;
td.getParameters(parameters);
if (parameters.size() == 0) {
os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic()
@@ -198,7 +198,7 @@ static void emitTypeDefDoc(TypeDef td, raw_ostream &os) {
}
// Emit attribute documentation.
- SmallVector<TypeParameter, 4> parameters;
+ SmallVector<AttrOrTypeParameter, 4> parameters;
td.getParameters(parameters);
if (!parameters.empty()) {
os << "\n#### Type parameters:\n\n";
diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
deleted file mode 100644
index 689c35ee79c3..000000000000
--- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp
+++ /dev/null
@@ -1,739 +0,0 @@
-//===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// TypeDefGen uses the description of typeDefs to generate C++ definitions.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/TableGen/CodeGenHelpers.h"
-#include "mlir/TableGen/Format.h"
-#include "mlir/TableGen/GenInfo.h"
-#include "mlir/TableGen/TypeDef.h"
-#include "llvm/ADT/SmallSet.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/TableGen/Error.h"
-#include "llvm/TableGen/TableGenBackend.h"
-
-#define DEBUG_TYPE "mlir-tblgen-typedefgen"
-
-using namespace mlir;
-using namespace mlir::tblgen;
-
-static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
-static llvm::cl::opt<std::string>
- selectedDialect("typedefs-dialect",
- llvm::cl::desc("Gen types for this dialect"),
- llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
-
-/// Find all the TypeDefs for the specified dialect. If no dialect specified and
-/// can only find one dialect's types, use that.
-static void findAllTypeDefs(const llvm::RecordKeeper &recordKeeper,
- SmallVectorImpl<TypeDef> &typeDefs) {
- auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
- auto defs = llvm::map_range(
- recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); });
- if (defs.empty())
- return;
-
- StringRef dialectName;
- if (selectedDialect.getNumOccurrences() == 0) {
- if (defs.empty())
- return;
-
- llvm::SmallSet<Dialect, 4> dialects;
- for (const TypeDef typeDef : defs)
- dialects.insert(typeDef.getDialect());
- if (dialects.size() != 1)
- llvm::PrintFatalError("TypeDefs belonging to more than one dialect. Must "
- "select one via '--typedefs-dialect'");
-
- dialectName = (*dialects.begin()).getName();
- } else if (selectedDialect.getNumOccurrences() == 1) {
- dialectName = selectedDialect.getValue();
- } else {
- llvm::PrintFatalError("Cannot select multiple dialects for which to "
- "generate types via '--typedefs-dialect'.");
- }
-
- for (const TypeDef typeDef : defs)
- if (typeDef.getDialect().getName().equals(dialectName))
- typeDefs.push_back(typeDef);
-}
-
-namespace {
-
-/// Pass an instance of this class to llvm::formatv() to emit a comma separated
-/// list of parameters in the format by 'EmitFormat'.
-class TypeParamCommaFormatter : public llvm::detail::format_adapter {
-public:
- /// Choose the output format
- enum EmitFormat {
- /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
- /// [...]".
- TypeNamePairs,
-
- /// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
- TypeNameInitializer,
-
- /// Emit "param1Name, param2Name, [...]".
- JustParams,
- };
-
- TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params,
- bool prependComma = true)
- : emitFormat(emitFormat), params(params), prependComma(prependComma) {}
-
- /// llvm::formatv will call this function when using an instance as a
- /// replacement value.
- void format(raw_ostream &os, StringRef options) override {
- if (!params.empty() && prependComma)
- os << ", ";
-
- switch (emitFormat) {
- case EmitFormat::TypeNamePairs:
- interleaveComma(params, os,
- [&](const TypeParameter &p) { emitTypeNamePair(p, os); });
- break;
- case EmitFormat::TypeNameInitializer:
- interleaveComma(params, os, [&](const TypeParameter &p) {
- emitTypeNameInitializer(p, os);
- });
- break;
- case EmitFormat::JustParams:
- interleaveComma(params, os,
- [&](const TypeParameter &p) { os << p.getName(); });
- break;
- }
- }
-
-private:
- // Emit "paramType paramName".
- static void emitTypeNamePair(const TypeParameter ¶m, raw_ostream &os) {
- os << param.getCppType() << " " << param.getName();
- }
- // Emit "paramName(paramName)"
- void emitTypeNameInitializer(const TypeParameter ¶m, raw_ostream &os) {
- os << param.getName() << "(" << param.getName() << ")";
- }
-
- EmitFormat emitFormat;
- ArrayRef<TypeParameter> params;
- bool prependComma;
-};
-
-} // end anonymous namespace
-
-//===----------------------------------------------------------------------===//
-// GEN: TypeDef declarations
-//===----------------------------------------------------------------------===//
-
-/// Print this above all the other declarations. Contains type declarations used
-/// later on.
-static const char *const typeDefDeclHeader = R"(
-namespace mlir {
-class DialectAsmParser;
-class DialectAsmPrinter;
-} // namespace mlir
-)";
-
-/// The code block for the start of a typeDef class declaration -- singleton
-/// case.
-///
-/// {0}: The name of the typeDef class.
-/// {1}: The name of the type base class.
-static const char *const typeDefDeclSingletonBeginStr = R"(
- class {0} : public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
- public:
- /// Inherit some necessary constructors from 'TypeBase'.
- using Base::Base;
-
-)";
-
-/// The code block for the start of a typeDef class declaration -- parametric
-/// case.
-///
-/// {0}: The name of the typeDef class.
-/// {1}: The name of the type base class.
-/// {2}: The typeDef storage class namespace.
-/// {3}: The storage class name.
-/// {4}: The list of parameters with types.
-static const char *const typeDefDeclParametricBeginStr = R"(
- namespace {2} {
- struct {3};
- } // end namespace {2}
- class {0} : public ::mlir::Type::TypeBase<{0}, {1},
- {2}::{3}> {{
- public:
- /// Inherit some necessary constructors from 'TypeBase'.
- using Base::Base;
-
-)";
-
-/// The snippet for print/parse.
-static const char *const typeDefParsePrint = R"(
- static ::mlir::Type parse(::mlir::MLIRContext *context,
- ::mlir::DialectAsmParser &parser);
- void print(::mlir::DialectAsmPrinter &printer) const;
-)";
-
-/// The code block for the verify method declaration.
-///
-/// {0}: List of parameters, parameters style.
-static const char *const typeDefDeclVerifyStr = R"(
- using Base::getChecked;
- static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0});
-)";
-
-/// Emit the builders for the given type.
-static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
- TypeParamCommaFormatter ¶mTypes) {
- StringRef typeClass = typeDef.getCppClassName();
- bool genCheckedMethods = typeDef.genVerifyDecl();
- if (!typeDef.skipDefaultBuilders()) {
- os << llvm::formatv(
- " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
- paramTypes);
- if (genCheckedMethods) {
- os << llvm::formatv(" static {0} "
- "getChecked(llvm::function_ref<::mlir::"
- "InFlightDiagnostic()> emitError, "
- "::mlir::MLIRContext *context{1});\n",
- typeClass, paramTypes);
- }
- }
-
- // Generate the builders specified by the user.
- for (const TypeBuilder &builder : typeDef.getBuilders()) {
- std::string paramStr;
- llvm::raw_string_ostream paramOS(paramStr);
- llvm::interleaveComma(
- builder.getParameters(), paramOS,
- [&](const TypeBuilder::Parameter ¶m) {
- // Note: TypeBuilder parameters are guaranteed to have names.
- paramOS << param.getCppType() << " " << *param.getName();
- if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
- paramOS << " = " << *defaultParamValue;
- });
- paramOS.flush();
-
- // Generate the `get` variant of the builder.
- os << " static " << typeClass << " get(";
- if (!builder.hasInferredContextParameter()) {
- os << "::mlir::MLIRContext *context";
- if (!paramStr.empty())
- os << ", ";
- }
- os << paramStr << ");\n";
-
- // Generate the `getChecked` variant of the builder.
- if (genCheckedMethods) {
- os << " static " << typeClass
- << " getChecked(llvm::function_ref<mlir::InFlightDiagnostic()> "
- "emitError";
- if (!builder.hasInferredContextParameter())
- os << ", ::mlir::MLIRContext *context";
- if (!paramStr.empty())
- os << ", ";
- os << paramStr << ");\n";
- }
- }
-}
-
-/// Generate the declaration for the given typeDef class.
-static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
- SmallVector<TypeParameter, 4> params;
- typeDef.getParameters(params);
-
- // Emit the beginning string template: either the singleton or parametric
- // template.
- if (typeDef.getNumParameters() == 0)
- os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
- typeDef.getCppBaseClassName());
- else
- os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
- typeDef.getCppBaseClassName(), typeDef.getStorageNamespace(),
- typeDef.getStorageClassName());
-
- // Emit the extra declarations first in case there's a type definition in
- // there.
- if (Optional<StringRef> extraDecl = typeDef.getExtraDecls())
- os << *extraDecl << "\n";
-
- TypeParamCommaFormatter emitTypeNamePairsAfterComma(
- TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
- if (!params.empty()) {
- emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma);
-
- // Emit the verify invariants declaration.
- if (typeDef.genVerifyDecl())
- os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
- }
-
- // Emit the mnenomic, if specified.
- if (auto mnenomic = typeDef.getMnemonic()) {
- os << " static constexpr ::llvm::StringLiteral getMnemonic() {\n"
- << " return ::llvm::StringLiteral(\"" << mnenomic << "\");\n"
- << " }\n";
-
- // If mnemonic specified, emit print/parse declarations.
- if (typeDef.getParserCode() || typeDef.getPrinterCode() || !params.empty())
- os << typeDefParsePrint;
- }
-
- if (typeDef.genAccessors()) {
- SmallVector<TypeParameter, 4> parameters;
- typeDef.getParameters(parameters);
-
- for (TypeParameter ¶meter : parameters) {
- SmallString<16> name = parameter.getName();
- name[0] = llvm::toUpper(name[0]);
- os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name);
- }
- }
-
- // End the typeDef decl.
- os << " };\n";
-}
-
-/// Main entry point for decls.
-static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
- raw_ostream &os) {
- emitSourceFileHeader("TypeDef Declarations", os);
-
- SmallVector<TypeDef, 16> typeDefs;
- findAllTypeDefs(recordKeeper, typeDefs);
-
- IfDefScope scope("GET_TYPEDEF_CLASSES", os);
-
- // Output the common "header".
- os << typeDefDeclHeader;
-
- if (!typeDefs.empty()) {
- NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
-
- // Declare all the type classes first (in case they reference each other).
- for (const TypeDef &typeDef : typeDefs)
- os << " class " << typeDef.getCppClassName() << ";\n";
-
- // Declare all the typedefs.
- for (const TypeDef &typeDef : typeDefs)
- emitTypeDefDecl(typeDef, os);
- }
-
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// GEN: TypeDef list
-//===----------------------------------------------------------------------===//
-
-static void emitTypeDefList(SmallVectorImpl<TypeDef> &typeDefs,
- raw_ostream &os) {
- IfDefScope scope("GET_TYPEDEF_LIST", os);
- for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) {
- os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName();
- if (i < typeDefs.end() - 1)
- os << ",\n";
- else
- os << "\n";
- }
-}
-
-//===----------------------------------------------------------------------===//
-// GEN: TypeDef definitions
-//===----------------------------------------------------------------------===//
-
-/// Beginning of storage class.
-/// {0}: Storage class namespace.
-/// {1}: Storage class c++ name.
-/// {2}: Parameters parameters.
-/// {3}: Parameter initializer string.
-/// {4}: Parameter name list.
-/// {5}: Parameter types.
-static const char *const typeDefStorageClassBegin = R"(
-namespace {0} {{
- struct {1} : public ::mlir::TypeStorage {{
- {1} ({2})
- : {3} {{ }
-
- /// The hash key for this storage is a pair of the integer and type params.
- using KeyTy = std::tuple<{5}>;
-
- /// Define the comparison function for the key type.
- bool operator==(const KeyTy &key) const {{
- return key == KeyTy({4});
- }
-)";
-
-/// The storage class' constructor template.
-/// {0}: storage class name.
-static const char *const typeDefStorageClassConstructorBegin = R"(
- /// Define a construction method for creating a new instance of this storage.
- static {0} *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {{
-)";
-
-/// The storage class' constructor return template.
-/// {0}: storage class name.
-/// {1}: list of parameters.
-static const char *const typeDefStorageClassConstructorReturn = R"(
- return new (allocator.allocate<{0}>())
- {0}({1});
- }
-)";
-
-/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
-static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
- SmallVector<TypeParameter, 4> parameters;
- typeDef.getParameters(parameters);
- auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
- for (TypeParameter ¶meter : parameters) {
- auto allocCode = parameter.getAllocator();
- if (allocCode) {
- fmtCtxt.withSelf(parameter.getName());
- fmtCtxt.addSubst("_dst", parameter.getName());
- os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n";
- }
- }
-}
-
-/// Emit the storage class code for type 'typeDef'.
-/// This includes (in-order):
-/// 1) typeDefStorageClassBegin, which includes:
-/// - The class constructor.
-/// - The KeyTy definition.
-/// - The equality (==) operator.
-/// 2) The hashKey method.
-/// 3) The construct method.
-/// 4) The list of parameters as the storage class member variables.
-static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
- SmallVector<TypeParameter, 4> parameters;
- typeDef.getParameters(parameters);
-
- // Initialize a bunch of variables to be used later on.
- auto parameterNames = map_range(
- parameters, [](TypeParameter parameter) { return parameter.getName(); });
- auto parameterTypes = map_range(parameters, [](TypeParameter parameter) {
- return parameter.getCppType();
- });
- auto parameterList = join(parameterNames, ", ");
- auto parameterTypeList = join(parameterTypes, ", ");
-
- // 1) Emit most of the storage class up until the hashKey body.
- os << formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(),
- typeDef.getStorageClassName(),
- TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
- parameters, /*prependComma=*/false),
- TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::TypeNameInitializer,
- parameters, /*prependComma=*/false),
- parameterList, parameterTypeList);
-
- // 2) Emit the haskKey method.
- os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
- // Extract each parameter from the key.
- for (size_t i = 0, e = parameters.size(); i < e; ++i)
- os << llvm::formatv(" const auto &{0} = std::get<{1}>(key);\n",
- parameters[i].getName(), i);
- // Then combine them all. This requires all the parameters types to have a
- // hash_value defined.
- os << llvm::formatv(
- " return ::llvm::hash_combine({0});\n }\n",
- TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
- parameters, /* prependComma */ false));
-
- // 3) Emit the construct method.
- if (typeDef.hasStorageCustomConstructor()) {
- // If user wants to build the storage constructor themselves, declare it
- // here and then they can write the definition elsewhere.
- os << " static " << typeDef.getStorageClassName()
- << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
- "&key);\n";
- } else {
- // If not, autogenerate one.
-
- // First, unbox the parameters.
- os << formatv(typeDefStorageClassConstructorBegin,
- typeDef.getStorageClassName());
- for (size_t i = 0; i < parameters.size(); ++i) {
- os << formatv(" auto {0} = std::get<{1}>(key);\n",
- parameters[i].getName(), i);
- }
- // Second, reassign the parameter variables with allocation code, if it's
- // specified.
- emitParameterAllocationCode(typeDef, os);
-
- // Last, return an allocated copy.
- os << formatv(typeDefStorageClassConstructorReturn,
- typeDef.getStorageClassName(), parameterList);
- }
-
- // 4) Emit the parameters as storage class members.
- for (auto parameter : parameters) {
- os << " " << parameter.getCppType() << " " << parameter.getName()
- << ";\n";
- }
- os << " };\n";
-
- os << "} // namespace " << typeDef.getStorageNamespace() << "\n";
-}
-
-/// Emit the parser and printer for a particular type, if they're specified.
-void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
- // Emit the printer code, if specified.
- if (auto printerCode = typeDef.getPrinterCode()) {
- // Both the mnenomic and printerCode must be defined (for parity with
- // parserCode).
- os << "void " << typeDef.getCppClassName()
- << "::print(::mlir::DialectAsmPrinter &printer) const {\n";
- if (*printerCode == "") {
- // If no code specified, emit error.
- PrintFatalError(typeDef.getLoc(),
- typeDef.getName() +
- ": printer (if specified) must have non-empty code");
- }
- auto fmtCtxt = FmtContext().addSubst("_printer", "printer");
- os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
- }
-
- // emit a parser, if specified.
- if (auto parserCode = typeDef.getParserCode()) {
- // The mnenomic must be defined so the dispatcher knows how to dispatch.
- os << "::mlir::Type " << typeDef.getCppClassName()
- << "::parse(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &"
- "parser) "
- "{\n";
- if (*parserCode == "") {
- // if no code specified, emit error.
- PrintFatalError(typeDef.getLoc(),
- typeDef.getName() +
- ": parser (if specified) must have non-empty code");
- }
- auto fmtCtxt =
- FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "context");
- os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
- }
-}
-
-/// Replace all instances of 'from' to 'to' in `str` and return the new string.
-static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
- size_t pos = 0;
- while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
- str.replace(pos, from.size(), to.data(), to.size());
- return str;
-}
-
-/// Emit the builders for the given type.
-static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
- ArrayRef<TypeParameter> typeDefParams) {
- bool genCheckedMethods = typeDef.genVerifyDecl();
- StringRef typeClass = typeDef.getCppClassName();
- if (!typeDef.skipDefaultBuilders()) {
- os << llvm::formatv(
- "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
- " return Base::get(context{2});\n}\n",
- typeClass,
- TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::TypeNamePairs, typeDefParams),
- TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
- typeDefParams));
- if (genCheckedMethods) {
- os << llvm::formatv(
- "{0} {0}::getChecked("
- "llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, "
- "::mlir::MLIRContext *context{1}) {{\n"
- " return Base::getChecked(emitError, context{2});\n}\n",
- typeClass,
- TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
- typeDefParams),
- TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::JustParams, typeDefParams));
- }
- }
-
- auto builderFmtCtx =
- FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get");
- auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get");
- auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context");
-
- // Generate the builders specified by the user.
- for (const TypeBuilder &builder : typeDef.getBuilders()) {
- Optional<StringRef> body = builder.getBody();
- if (!body)
- continue;
- std::string paramStr;
- llvm::raw_string_ostream paramOS(paramStr);
- llvm::interleaveComma(builder.getParameters(), paramOS,
- [&](const TypeBuilder::Parameter ¶m) {
- // Note: TypeBuilder parameters are guaranteed to
- // have names.
- paramOS << param.getCppType() << " "
- << *param.getName();
- });
- paramOS.flush();
-
- // Emit the `get` variant of the builder.
- os << llvm::formatv("{0} {0}::get(", typeClass);
- if (!builder.hasInferredContextParameter()) {
- os << "::mlir::MLIRContext *context";
- if (!paramStr.empty())
- os << ", ";
- os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
- tgfmt(*body, &builderFmtCtx).str());
- } else {
- os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
- tgfmt(*body, &inferredCtxBuilderFmtCtx).str());
- }
-
- // Emit the `getChecked` variant of the builder.
- if (genCheckedMethods) {
- os << llvm::formatv("{0} "
- "{0}::getChecked(llvm::function_ref<::mlir::"
- "InFlightDiagnostic()> emitErrorFn",
- typeClass);
- std::string checkedBody =
- replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, ");
- if (!builder.hasInferredContextParameter()) {
- os << ", ::mlir::MLIRContext *context";
- checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str();
- }
- if (!paramStr.empty())
- os << ", ";
- os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody);
- }
- }
-}
-
-/// Print all the typedef-specific definition code.
-static void emitTypeDefDef(const TypeDef &typeDef, raw_ostream &os) {
- NamespaceEmitter ns(os, typeDef.getDialect());
-
- SmallVector<TypeParameter, 4> parameters;
- typeDef.getParameters(parameters);
- if (!parameters.empty()) {
- // Emit the storage class, if requested and necessary.
- if (typeDef.genStorageClass())
- emitStorageClass(typeDef, os);
-
- // Emit the builders for this type.
- emitTypeBuilderDefs(typeDef, os, parameters);
-
- // Generate accessor definitions only if we also generate the storage class.
- // Otherwise, let the user define the exact accessor definition.
- if (typeDef.genAccessors() && typeDef.genStorageClass()) {
- // Emit the parameter accessors.
- for (const TypeParameter ¶meter : parameters) {
- SmallString<16> name = parameter.getName();
- name[0] = llvm::toUpper(name[0]);
- os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
- parameter.getCppType(), name, parameter.getName(),
- typeDef.getCppClassName());
- }
- }
- }
-
- // If mnemonic is specified maybe print definitions for the parser and printer
- // code, if they're specified.
- if (typeDef.getMnemonic())
- emitParserPrinter(typeDef, os);
-}
-
-/// Emit the dialect printer/parser dispatcher. User's code should call these
-/// functions from their dialect's print/parse methods.
-static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
- if (llvm::none_of(types, [](const TypeDef &type) {
- return type.getMnemonic().hasValue();
- })) {
- return;
- }
-
- // The parser dispatch is just a list of if-elses, matching on the
- // mnemonic and calling the class's parse function.
- os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext *"
- "context, ::mlir::DialectAsmParser &parser, "
- "::llvm::StringRef mnemonic) {\n";
- for (const TypeDef &type : types) {
- if (type.getMnemonic()) {
- os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
- "{0}::{1}::",
- type.getDialect().getCppNamespace(),
- type.getCppClassName());
-
- // If the type has no parameters and no parser code, just invoke a normal
- // `get`.
- if (type.getNumParameters() == 0 && !type.getParserCode())
- os << "get(context);\n";
- else
- os << "parse(context, parser);\n";
- }
- }
- os << " return ::mlir::Type();\n";
- os << "}\n\n";
-
- // The printer dispatch uses llvm::TypeSwitch to find and call the correct
- // printer.
- os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
- "type, "
- "::mlir::DialectAsmPrinter &printer) {\n"
- << " return ::llvm::TypeSwitch<::mlir::Type, "
- "::mlir::LogicalResult>(type)\n";
- for (const TypeDef &type : types) {
- if (Optional<StringRef> mnemonic = type.getMnemonic()) {
- StringRef cppNamespace = type.getDialect().getCppNamespace();
- StringRef cppClassName = type.getCppClassName();
- os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ",
- cppNamespace, cppClassName);
-
- // If the type has no parameters and no printer code, just print the
- // mnemonic.
- if (type.getNumParameters() == 0 && !type.getPrinterCode())
- os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace,
- cppClassName);
- else
- os << "t.print(printer);";
- os << "\n return ::mlir::success();\n })\n";
- }
- }
- os << " .Default([](::mlir::Type) { return ::mlir::failure(); });\n"
- << "}\n\n";
-}
-
-/// Entry point for typedef definitions.
-static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
- raw_ostream &os) {
- emitSourceFileHeader("TypeDef Definitions", os);
-
- SmallVector<TypeDef, 16> typeDefs;
- findAllTypeDefs(recordKeeper, typeDefs);
- emitTypeDefList(typeDefs, os);
-
- IfDefScope scope("GET_TYPEDEF_CLASSES", os);
- emitParsePrintDispatch(typeDefs, os);
- for (const TypeDef &typeDef : typeDefs)
- emitTypeDefDef(typeDef, os);
-
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// GEN: TypeDef registration hooks
-//===----------------------------------------------------------------------===//
-
-static mlir::GenRegistration
- genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions",
- [](const llvm::RecordKeeper &records, raw_ostream &os) {
- return emitTypeDefDefs(records, os);
- });
-
-static mlir::GenRegistration
- genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations",
- [](const llvm::RecordKeeper &records, raw_ostream &os) {
- return emitTypeDefDecls(records, os);
- });
More information about the Mlir-commits
mailing list