[Mlir-commits] [mlir] ca6bd9c - [mlir][ods] AttrOrTypeGen uses Class
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 1 08:53:27 PST 2021
Author: Mogball
Date: 2021-12-01T16:53:23Z
New Revision: ca6bd9cd4320e675026c559cc3f8bf810a89d4ce
URL: https://github.com/llvm/llvm-project/commit/ca6bd9cd4320e675026c559cc3f8bf810a89d4ce
DIFF: https://github.com/llvm/llvm-project/commit/ca6bd9cd4320e675026c559cc3f8bf810a89d4ce.diff
LOG: [mlir][ods] AttrOrTypeGen uses Class
AttrOrType def generator uses `Class` code gen helper,
instead of naked raw_ostream.
Depends on D113714 and D114807
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D113715
Added:
mlir/tools/mlir-tblgen/OpClass.cpp
mlir/tools/mlir-tblgen/OpClass.h
Modified:
mlir/cmake/modules/AddMLIR.cmake
mlir/include/mlir/Support/IndentedOstream.h
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/include/mlir/TableGen/Class.h
mlir/include/mlir/TableGen/CodeGenHelpers.h
mlir/include/mlir/TableGen/Format.h
mlir/lib/Support/IndentedOstream.cpp
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/lib/TableGen/Class.cpp
mlir/lib/TableGen/Format.cpp
mlir/test/mlir-tblgen/attr-or-type-format.td
mlir/test/mlir-tblgen/attrdefs.td
mlir/test/mlir-tblgen/default-type-attr-print-parser.td
mlir/test/mlir-tblgen/typedefs.td
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
mlir/tools/mlir-tblgen/CMakeLists.txt
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpDocGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
mlir/unittests/Support/IndentedOstreamTest.cpp
Removed:
################################################################################
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index 6f8c1c65ffa31..770d2aa06d2de 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -12,8 +12,8 @@ function(add_mlir_dialect dialect dialect_namespace)
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
- mlir_tablegen(${dialect}Types.h.inc -gen-typedef-decls)
- mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs)
+ mlir_tablegen(${dialect}Types.h.inc -gen-typedef-decls -typedefs-dialect=${dialect_namespace})
+ mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs -typedefs-dialect=${dialect_namespace})
mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
mlir_tablegen(${dialect}Dialect.cpp.inc -gen-dialect-defs -dialect=${dialect_namespace})
add_public_tablegen_target(MLIR${dialect}IncGen)
diff --git a/mlir/include/mlir/Support/IndentedOstream.h b/mlir/include/mlir/Support/IndentedOstream.h
index 9a755bc7ebb08..79c66995c67ae 100644
--- a/mlir/include/mlir/Support/IndentedOstream.h
+++ b/mlir/include/mlir/Support/IndentedOstream.h
@@ -29,34 +29,38 @@ class raw_indented_ostream : public raw_ostream {
/// Simple RAII struct to use to indentation around entering/exiting region.
struct DelimitedScope {
explicit DelimitedScope(raw_indented_ostream &os, StringRef open = "",
- StringRef close = "")
- : os(os), open(open), close(close) {
+ StringRef close = "", bool indent = true)
+ : os(os), open(open), close(close), indent(indent) {
os << open;
- os.indent();
+ if (indent)
+ os.indent();
}
~DelimitedScope() {
- os.unindent();
+ if (indent)
+ os.unindent();
os << close;
}
raw_indented_ostream &os;
private:
- llvm::StringRef open, close;
+ StringRef open, close;
+ bool indent;
};
/// Returns the underlying (unindented) raw_ostream.
raw_ostream &getOStream() const { return os; }
/// Returns DelimitedScope.
- DelimitedScope scope(StringRef open = "", StringRef close = "") {
- return DelimitedScope(*this, open, close);
+ DelimitedScope scope(StringRef open = "", StringRef close = "",
+ bool indent = true) {
+ return DelimitedScope(*this, open, close, indent);
}
- /// Re-indents by removing the leading whitespace from the first non-empty
- /// line from every line of the string, skipping over empty lines at the
- /// start.
- raw_indented_ostream &reindent(StringRef str);
+ /// Prints a string re-indented to the current indent. Re-indents by removing
+ /// the leading whitespace from the first non-empty line from every line of
+ /// the string, skipping over empty lines at the start.
+ raw_indented_ostream &printReindented(StringRef str);
/// Increases the indent and returning this raw_indented_ostream.
raw_indented_ostream &indent() {
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 09294c2fa8081..303edfdf143d1 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -27,7 +27,6 @@ class SMLoc;
namespace mlir {
namespace tblgen {
class Dialect;
-class AttrOrTypeParameter;
//===----------------------------------------------------------------------===//
// AttrOrTypeBuilder
@@ -42,6 +41,76 @@ class AttrOrTypeBuilder : public Builder {
bool hasInferredContextParameter() const;
};
+//===----------------------------------------------------------------------===//
+// AttrOrTypeParameter
+//===----------------------------------------------------------------------===//
+
+// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to
+// AttrOrTypeDefs to parameterize them.
+class AttrOrTypeParameter {
+public:
+ explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index)
+ : def(def), index(index) {}
+
+ // Get the parameter name.
+ StringRef getName() const;
+
+ // If specified, get the custom allocator code for this parameter.
+ Optional<StringRef> getAllocator() const;
+
+ // If specified, get the custom comparator code for this parameter.
+ Optional<StringRef> getComparator() const;
+
+ // Get the C++ type of this parameter.
+ StringRef getCppType() const;
+
+ // Get the C++ accessor type of this parameter.
+ StringRef getCppAccessorType() const;
+
+ // Get the C++ storage type of this parameter.
+ StringRef getCppStorageType() const;
+
+ // Get an optional C++ parameter parser.
+ Optional<StringRef> getParser() const;
+
+ // Get an optional C++ parameter printer.
+ Optional<StringRef> getPrinter() const;
+
+ // Get a description of this parameter for documentation purposes.
+ Optional<StringRef> getSummary() const;
+
+ // Get the assembly syntax documentation.
+ StringRef getSyntax() const;
+
+ // Return the underlying def of this parameter.
+ const llvm::Init *getDef() const;
+
+ // The parameter is pointer-comparable.
+ bool operator==(const AttrOrTypeParameter &other) const {
+ return def == other.def && index == other.index;
+ }
+ bool operator!=(const AttrOrTypeParameter &other) const {
+ return !(*this == other);
+ }
+
+private:
+ /// The underlying tablegen parameter list this parameter is a part of.
+ const llvm::DagInit *def;
+ /// The index of the parameter within the parameter list (`def`).
+ unsigned index;
+};
+
+//===----------------------------------------------------------------------===//
+// AttributeSelfTypeParameter
+//===----------------------------------------------------------------------===//
+
+// A wrapper class for the AttributeSelfTypeParameter tblgen class. This
+// represents a parameter of mlir::Type that is the value type of an AttrDef.
+class AttributeSelfTypeParameter : public AttrOrTypeParameter {
+public:
+ static bool classof(const AttrOrTypeParameter *param);
+};
+
//===----------------------------------------------------------------------===//
// AttrOrTypeDef
//===----------------------------------------------------------------------===//
@@ -82,9 +151,8 @@ class AttrOrTypeDef {
// Indicates whether or not to generate the storage class constructor.
bool hasStorageCustomConstructor() const;
- // Fill a list with this def's parameters. See AttrOrTypeDef in OpBase.td for
- // documentation of parameter usage.
- void getParameters(SmallVectorImpl<AttrOrTypeParameter> &) const;
+ /// Get the parameters of this attribute or type.
+ ArrayRef<AttrOrTypeParameter> getParameters() const { return parameters; }
// Return the number of parameters
unsigned getNumParameters() const;
@@ -104,6 +172,19 @@ class AttrOrTypeDef {
// Returns the custom assembly format, if one was specified.
Optional<StringRef> getAssemblyFormat() const;
+ // An attribute or type with parameters needs a parser.
+ bool needsParserPrinter() const { return getNumParameters() != 0; }
+
+ // Returns true if this attribute or type has a generated parser.
+ bool hasGeneratedParser() const {
+ return getParserCode() || getAssemblyFormat();
+ }
+
+ // Returns true if this attribute or type has a generated printer.
+ bool hasGeneratedPrinter() const {
+ return getPrinterCode() || getAssemblyFormat();
+ }
+
// Returns true if the accessors based on the parameters should be generated.
bool genAccessors() const;
@@ -148,6 +229,9 @@ class AttrOrTypeDef {
// The traits of this definition.
SmallVector<Trait> traits;
+
+ /// The parameters of this attribute or type.
+ SmallVector<AttrOrTypeParameter> parameters;
};
//===----------------------------------------------------------------------===//
@@ -176,68 +260,6 @@ class TypeDef : public AttrOrTypeDef {
using AttrOrTypeDef::AttrOrTypeDef;
};
-//===----------------------------------------------------------------------===//
-// AttrOrTypeParameter
-//===----------------------------------------------------------------------===//
-
-// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to
-// AttrOrTypeDefs to parameterize them.
-class AttrOrTypeParameter {
-public:
- explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index)
- : def(def), index(index) {}
-
- // Get the parameter name.
- StringRef getName() const;
-
- // If specified, get the custom allocator code for this parameter.
- Optional<StringRef> getAllocator() const;
-
- // If specified, get the custom comparator code for this parameter.
- Optional<StringRef> getComparator() const;
-
- // Get the C++ type of this parameter.
- StringRef getCppType() const;
-
- // Get the C++ accessor type of this parameter.
- StringRef getCppAccessorType() const;
-
- // Get the C++ storage type of this parameter.
- StringRef getCppStorageType() const;
-
- // Get an optional C++ parameter parser.
- Optional<StringRef> getParser() const;
-
- // Get an optional C++ parameter printer.
- Optional<StringRef> getPrinter() const;
-
- // Get a description of this parameter for documentation purposes.
- Optional<StringRef> getSummary() const;
-
- // Get the assembly syntax documentation.
- StringRef getSyntax() const;
-
- // Return the underlying def of this parameter.
- const llvm::Init *getDef() const;
-
-private:
- /// The underlying tablegen parameter list this parameter is a part of.
- const llvm::DagInit *def;
- /// The index of the parameter within the parameter list (`def`).
- unsigned index;
-};
-
-//===----------------------------------------------------------------------===//
-// AttributeSelfTypeParameter
-//===----------------------------------------------------------------------===//
-
-// A wrapper class for the AttributeSelfTypeParameter tblgen class. This
-// represents a parameter of mlir::Type that is the value type of an AttrDef.
-class AttributeSelfTypeParameter : public AttrOrTypeParameter {
-public:
- static bool classof(const AttrOrTypeParameter *param);
-};
-
} // end namespace tblgen
} // end namespace mlir
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 9eaf066f7f2c4..ea56158fe3319 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
//
-// This file defines several classes for Op C++ code emission. They are only
+// This file defines several classes for C++ code emission. They are only
// expected to be used by MLIR TableGen backends.
//
-// We emit the op declaration and definition into separate files: *Ops.h.inc
-// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
-// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
+// We emit the declarations and definitions into separate files: *.h.inc and
+// *.cpp.inc. The former is to be included in the dialect *.h and the latter for
+// dialect *.cpp. This way provides a cleaner interface.
//
// In order to do this split, we need to track method signature and
// implementation logic separately. Signature information is used for both
@@ -23,6 +23,7 @@
#ifndef MLIR_TABLEGEN_CLASS_H_
#define MLIR_TABLEGEN_CLASS_H_
+#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "llvm/ADT/SetVector.h"
@@ -30,7 +31,6 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/Twine.h"
-#include "llvm/Support/raw_ostream.h"
#include <set>
#include <string>
@@ -61,18 +61,16 @@ class MethodParameter {
/*defaultValue=*/"", optional) {}
/// Write the parameter as part of a method declaration.
- void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
+ void writeDeclTo(raw_indented_ostream &os) const;
/// Write the parameter as part of a method definition.
- void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
+ void writeDefTo(raw_indented_ostream &os) const;
/// Get the C++ type.
- const std::string &getType() const { return type; }
+ StringRef getType() const { return type; }
/// Returns true if the parameter has a default value.
bool hasDefaultValue() const { return !defaultValue.empty(); }
private:
- void writeTo(raw_ostream &os, bool emitDefault) const;
-
/// The C++ type.
std::string type;
/// The variable name.
@@ -95,9 +93,9 @@ class MethodParameters {
: parameters(std::move(parameters)) {}
/// Write the parameters as part of a method declaration.
- void writeDeclTo(raw_ostream &os) const;
+ void writeDeclTo(raw_indented_ostream &os) const;
/// Write the parameters as part of a method definition.
- void writeDefTo(raw_ostream &os) const;
+ void writeDefTo(raw_indented_ostream &os) const;
/// Determine whether this list of parameters "subsumes" another, which occurs
/// when this parameter list is identical to the other and has zero or more
@@ -108,21 +106,39 @@ class MethodParameters {
unsigned getNumParameters() const { return parameters.size(); }
private:
- llvm::SmallVector<MethodParameter> parameters;
+ /// The list of parameters.
+ SmallVector<MethodParameter> parameters;
};
/// This class contains the signature of a C++ method, including the return
/// type. method name, and method parameters.
class MethodSignature {
public:
- MethodSignature(StringRef retType, StringRef name,
+ /// Create a method signature with a return type, a method name, and a list of
+ /// parameters. Take ownership of the list.
+ template <typename RetTypeT, typename NameT>
+ MethodSignature(RetTypeT &&retType, NameT &&name,
SmallVector<MethodParameter> &¶meters)
- : returnType(retType), methodName(name),
+ : returnType(stringify(std::forward<RetTypeT>(retType))),
+ methodName(stringify(std::forward<NameT>(name))),
parameters(std::move(parameters)) {}
- template <typename... Parameters>
- MethodSignature(StringRef retType, StringRef name, Parameters &&...parameters)
- : returnType(retType), methodName(name),
- parameters({std::forward<Parameters>(parameters)...}) {}
+ /// Create a method signature with a return type, a method name, and a list of
+ /// parameters.
+ template <typename RetTypeT, typename NameT>
+ MethodSignature(RetTypeT &&retType, NameT &&name,
+ ArrayRef<MethodParameter> parameters)
+ : MethodSignature(std::forward<RetTypeT>(retType),
+ std::forward<NameT>(name),
+ SmallVector<MethodParameter>(parameters.begin(),
+ parameters.end())) {}
+ /// Create a method signature with a return type, a method name, and a
+ /// variadic list of parameters.
+ template <typename RetTypeT, typename NameT, typename... Parameters>
+ MethodSignature(RetTypeT &&retType, NameT &&name, Parameters &&...parameters)
+ : MethodSignature(std::forward<RetTypeT>(retType),
+ std::forward<NameT>(name),
+ ArrayRef<MethodParameter>(
+ {std::forward<Parameters>(parameters)...})) {}
/// Determine whether a method with this signature makes a method with
/// `other` signature redundant. This occurs if the signatures have the same
@@ -140,12 +156,12 @@ class MethodSignature {
unsigned getNumParameters() const { return parameters.getNumParameters(); }
/// Write the signature as part of a method declaration.
- void writeDeclTo(raw_ostream &os) const;
+ void writeDeclTo(raw_indented_ostream &os) const;
/// Write the signature as part of a method definition. `namePrefix` is to be
/// prepended to the method name (typically namespaces for qualifying the
/// method definition).
- void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
+ void writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const;
private:
/// The method's C++ return type.
@@ -156,61 +172,174 @@ class MethodSignature {
MethodParameters parameters;
};
-/// Class for holding the body of an op's method for C++ code emission
+/// This class contains the body of a C++ method.
class MethodBody {
public:
- explicit MethodBody(bool declOnly);
+ /// Create a method body, indicating whether it should be elided for methods
+ /// that are declaration-only.
+ MethodBody(bool declOnly);
+
+ /// Define a move constructor to correctly initialize the streams.
+ MethodBody(MethodBody &&other)
+ : declOnly(other.declOnly), body(std::move(other.body)), stringOs(body),
+ os(stringOs) {}
+ /// Define a move assignment operator. `raw_ostream` has deleted assignment
+ /// operators, so reinitialize the whole object.
+ MethodBody &operator=(MethodBody &&body) {
+ this->~MethodBody();
+ new (this) MethodBody(std::move(body));
+ return *this;
+ }
+
+ /// Write a value to the method body.
+ template <typename ValueT>
+ MethodBody &operator<<(ValueT &&value) {
+ if (!declOnly) {
+ os << std::forward<ValueT>(value);
+ os.flush();
+ }
+ return *this;
+ }
- MethodBody &operator<<(Twine content);
- MethodBody &operator<<(int content);
- MethodBody &operator<<(const FmtObjectBase &content);
+ /// Write the method body to the output stream. The body can be written as
+ /// part of the declaration of an inline method or just in the definition.
+ void writeTo(raw_indented_ostream &os) const;
- void writeTo(raw_ostream &os) const;
+ /// Indent the output stream.
+ MethodBody &indent() {
+ os.indent();
+ return *this;
+ }
+ /// Unindent the output stream.
+ MethodBody &unindent() {
+ os.unindent();
+ return *this;
+ }
+ /// Create a delimited scope: immediately print `open`, indent if `indent` is
+ /// true, and print `close` on object destruction.
+ raw_indented_ostream::DelimitedScope
+ scope(StringRef open = "", StringRef close = "", bool indent = false) {
+ return os.scope(open, close, indent);
+ }
+
+ /// Get the underlying indented output stream.
+ raw_indented_ostream &getStream() { return os; }
private:
- /// Whether this class should record method body.
- bool isEffective;
- /// The body of the method.
+ /// Whether the body should be elided.
+ bool declOnly;
+ /// The body data.
std::string body;
+ /// The string output stream.
+ llvm::raw_string_ostream stringOs;
+ /// An indented output stream for formatting input.
+ raw_indented_ostream os;
+};
+
+/// A class declaration is a class element that appears as part of its
+/// declaration.
+class ClassDeclaration {
+public:
+ virtual ~ClassDeclaration() = default;
+
+ /// Kinds for LLVM-style RTTI.
+ enum Kind {
+ Method,
+ UsingDeclaration,
+ VisibilityDeclaration,
+ Field,
+ ExtraClassDeclaration
+ };
+ /// Create a class declaration with a given kind.
+ ClassDeclaration(Kind kind) : kind(kind) {}
+
+ /// Get the class declaration kind.
+ Kind getKind() const { return kind; }
+
+ /// Write the declaration.
+ virtual void writeDeclTo(raw_indented_ostream &os) const = 0;
+
+ /// Write the definition, if any. `namePrefix` is the namespace prefix, which
+ /// may contains a class name.
+ virtual void writeDefTo(raw_indented_ostream &os,
+ StringRef namePrefix) const {}
+
+private:
+ /// The class declaration kind.
+ Kind kind;
+};
+
+/// Base class for class declarations.
+template <ClassDeclaration::Kind DeclKind>
+class ClassDeclarationBase : public ClassDeclaration {
+public:
+ using Base = ClassDeclarationBase<DeclKind>;
+ ClassDeclarationBase() : ClassDeclaration(DeclKind) {}
+
+ static bool classof(const ClassDeclaration *other) {
+ return other->getKind() == DeclKind;
+ }
};
/// Class for holding an op's method for C++ code emission
-class Method {
+class Method : public ClassDeclarationBase<ClassDeclaration::Method> {
public:
/// Properties (qualifiers) of class methods. Bitfield is used here to help
/// querying properties.
- enum Property {
- MP_None = 0x0,
- MP_Static = 0x1,
- MP_Constructor = 0x2,
- MP_Private = 0x4,
- MP_Declaration = 0x8,
- MP_Inline = 0x10,
- MP_Constexpr = 0x20 | MP_Inline,
- MP_StaticDeclaration = MP_Static | MP_Declaration,
+ enum Properties {
+ None = 0x0,
+ Static = 0x1,
+ Constructor = 0x2,
+ Private = 0x4,
+ Declaration = 0x8,
+ Inline = 0x10,
+ ConstexprValue = 0x20,
+ Const = 0x40,
+
+ Constexpr = ConstexprValue | Inline,
+ StaticDeclaration = Static | Declaration,
+ StaticInline = Static | Inline,
+ ConstInline = Const | Inline,
+ ConstDeclaration = Const | Declaration
};
- template <typename... Args>
- Method(StringRef retType, StringRef name, Property property, Args &&...args)
- : properties(property),
- methodSignature(retType, name, std::forward<Args>(args)...),
- methodBody(properties & MP_Declaration) {}
-
+ /// Create a method with a return type, a name, method properties, and a some
+ /// parameters. The parameteres may be passed as a list or as a variadic pack.
+ template <typename RetTypeT, typename NameT, typename... Args>
+ Method(RetTypeT &&retType, NameT &&name, Properties properties,
+ Args &&...args)
+ : properties(properties),
+ methodSignature(std::forward<RetTypeT>(retType),
+ std::forward<NameT>(name), std::forward<Args>(args)...),
+ methodBody(properties & Declaration) {}
+ /// Create a method with a return type, a name, method properties, and a list
+ /// of parameters.
+ Method(StringRef retType, StringRef name, Properties properties,
+ std::initializer_list<MethodParameter> params)
+ : properties(properties), methodSignature(retType, name, params),
+ methodBody(properties & Declaration) {}
+
+ // Define move constructor and assignment operator to prevent copying.
Method(Method &&) = default;
Method &operator=(Method &&) = default;
- virtual ~Method() = default;
-
+ /// Get the method body.
MethodBody &body() { return methodBody; }
/// Returns true if this is a static method.
- bool isStatic() const { return properties & MP_Static; }
+ bool isStatic() const { return properties & Static; }
/// Returns true if this is a private method.
- bool isPrivate() const { return properties & MP_Private; }
+ bool isPrivate() const { return properties & Private; }
/// Returns true if this is an inline method.
- bool isInline() const { return properties & MP_Inline; }
+ bool isInline() const { return properties & Inline; }
+
+ /// Returns true if this is a constructor.
+ bool isConstructor() const { return properties & Constructor; }
+
+ /// Returns true if this class method is const.
+ bool isConst() const { return properties & Const; }
/// Returns the name of this method.
StringRef getName() const { return methodSignature.getName(); }
@@ -220,158 +349,369 @@ class Method {
return methodSignature.makesRedundant(other.methodSignature);
}
- /// Writes the method as a declaration to the given `os`.
- virtual void writeDeclTo(raw_ostream &os) const;
+ /// Write the method declaration, including the definition if inline.
+ void writeDeclTo(raw_indented_ostream &os) const override;
- /// Writes the method as a definition to the given `os`. `namePrefix` is the
- /// prefix to be prepended to the method name (typically namespaces for
- /// qualifying the method definition).
- virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
+ /// Write the method definition. This is a no-op for inline methods.
+ void writeDefTo(raw_indented_ostream &os,
+ StringRef namePrefix) const override;
protected:
/// A collection of method properties.
- Property properties;
+ Properties properties;
/// The signature of the method.
MethodSignature methodSignature;
/// The body of the method, if it has one.
MethodBody methodBody;
};
+/// This enum describes C++ inheritance visibility.
+enum class Visibility { Public, Protected, Private };
+
+/// Write "public", "protected", or "private".
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ mlir::tblgen::Visibility visibility);
+
+// Class for holding an op's constructor method for C++ code emission.
+class Constructor : public Method {
+public:
+ /// Create a constructor for a given class, with method properties, and
+ /// parameters specified either as a list of a variadic pack.
+ template <typename NameT, typename... Args>
+ Constructor(NameT &&className, Properties properties, Args &&...args)
+ : Method("", std::forward<NameT>(className), properties,
+ std::forward<Args>(args)...) {}
+
+ /// Add member initializer to constructor initializing `name` with `value`.
+ template <typename NameT, typename ValueT>
+ void addMemberInitializer(NameT &&name, ValueT &&value) {
+ initializers.emplace_back(stringify(std::forward<NameT>(name)),
+ stringify(std::forward<ValueT>(value)));
+ }
+
+ /// Write the declaration of the constructor, and its definition if inline.
+ void writeDeclTo(raw_indented_ostream &os) const override;
+
+ /// Write the definition of the constructor if it is not inline.
+ void writeDefTo(raw_indented_ostream &os,
+ StringRef namePrefix) const override;
+
+ /// Return true if a method is a constructor.
+ static bool classof(const ClassDeclaration *other) {
+ return isa<Method>(other) && cast<Method>(other)->isConstructor();
+ }
+
+ /// Initialization of a class field in a constructor.
+ class MemberInitializer {
+ public:
+ /// Create a member initializer in a constructor that initializes the class
+ /// field `name` with `value`.
+ MemberInitializer(std::string name, std::string value)
+ : name(std::move(name)), value(std::move(value)) {}
+
+ /// Write the member initializer.
+ void writeTo(raw_indented_ostream &os) const;
+
+ private:
+ /// The name of the class field.
+ std::string name;
+ /// The value with which to initialize it.
+ std::string value;
+ };
+
+private:
+ /// The list of member initializers.
+ SmallVector<MemberInitializer> initializers;
+};
+
} // end namespace tblgen
} // end namespace mlir
/// The OR of two method properties should return method properties. Ensure that
/// this function is visible to `Class`.
-inline constexpr mlir::tblgen::Method::Property
-operator|(mlir::tblgen::Method::Property lhs,
- mlir::tblgen::Method::Property rhs) {
- return mlir::tblgen::Method::Property(static_cast<unsigned>(lhs) |
- static_cast<unsigned>(rhs));
+inline constexpr mlir::tblgen::Method::Properties
+operator|(mlir::tblgen::Method::Properties lhs,
+ mlir::tblgen::Method::Properties rhs) {
+ return mlir::tblgen::Method::Properties(static_cast<unsigned>(lhs) |
+ static_cast<unsigned>(rhs));
}
namespace mlir {
namespace tblgen {
-/// Class for holding an op's constructor method for C++ code emission.
-class Constructor : public Method {
+/// This class describes a C++ parent class declaration.
+class ParentClass {
public:
- template <typename... Parameters>
- Constructor(StringRef className, Property property,
- Parameters &&...parameters)
- : Method("", className, property,
- std::forward<Parameters>(parameters)...) {}
+ /// Create a parent class with a class name and visibility.
+ template <typename NameT>
+ ParentClass(NameT &&name, Visibility visibility = Visibility::Public)
+ : name(stringify(std::forward<NameT>(name))), visibility(visibility) {}
+
+ /// Add a template parameter.
+ template <typename ParamT>
+ void addTemplateParam(ParamT param) {
+ templateParams.insert(stringify(param));
+ }
+ /// Add a list of template parameters.
+ template <typename ContainerT>
+ void addTemplateParams(ContainerT &&container) {
+ templateParams.insert(std::begin(container), std::end(container));
+ }
- /// Add member initializer to constructor initializing `name` with `value`.
- void addMemberInitializer(StringRef name, StringRef value);
+ /// Write the parent class declaration.
+ void writeTo(raw_indented_ostream &os) const;
+
+private:
+ /// The fully resolved C++ name of the parent class.
+ std::string name;
+ /// The visibility of the parent class.
+ Visibility visibility;
+ /// An optional list of class template parameters.
+ SetVector<std::string, SmallVector<std::string>, StringSet<>> templateParams;
+};
- /// Writes the method as a definition to the given `os`. `namePrefix` is the
- /// prefix to be prepended to the method name (typically namespaces for
- /// qualifying the method definition).
- void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
+/// This class describes a using-declaration for a class. E.g.
+///
+/// using Op::Op;
+/// using Adaptor = OpAdaptor;
+///
+class UsingDeclaration
+ : public ClassDeclarationBase<ClassDeclaration::UsingDeclaration> {
+public:
+ /// Create a using declaration that either aliases `name` to `value` or
+ /// inherits the parent methods `name.
+ template <typename NameT, typename ValueT = std::string>
+ UsingDeclaration(NameT &&name, ValueT &&value = "")
+ : name(stringify(std::forward<NameT>(name))),
+ value(stringify(std::forward<ValueT>(value))) {}
+
+ /// Write the using declaration.
+ void writeDeclTo(raw_indented_ostream &os) const override;
+
+private:
+ /// The name of the declaration, or a resolved name to an inherited function.
+ std::string name;
+ /// The type that is being aliased. Leave empty for inheriting functions.
+ std::string value;
+};
+
+/// This class describes a class field.
+class Field : public ClassDeclarationBase<ClassDeclaration::Field> {
+public:
+ /// Create a class field with a type and variable name.
+ template <typename TypeT, typename NameT>
+ Field(TypeT &&type, NameT &&name)
+ : type(stringify(std::forward<TypeT>(type))),
+ name(stringify(std::forward<NameT>(name))) {}
+
+ /// Write the declaration of the field.
+ void writeDeclTo(raw_indented_ostream &os) const override;
+
+private:
+ /// The C++ type of the field.
+ std::string type;
+ /// The variable name of the class whether.
+ std::string name;
+};
+
+/// A declaration for the visibility of subsequent declarations.
+class VisibilityDeclaration
+ : public ClassDeclarationBase<ClassDeclaration::VisibilityDeclaration> {
+public:
+ /// Create a declaration for the given visibility.
+ VisibilityDeclaration(Visibility visibility) : visibility(visibility) {}
+
+ /// Get the visibility.
+ Visibility getVisibility() const { return visibility; }
+
+ /// Write the visibility declaration.
+ void writeDeclTo(raw_indented_ostream &os) const override;
private:
- /// Member initializers.
- std::string memberInitializers;
+ /// The visibility of subsequent class declarations.
+ Visibility visibility;
+};
+
+/// Unstructured extra class declarations, from TableGen definitions. The
+/// default visibility of extra class declarations is up to the owning class.
+class ExtraClassDeclaration
+ : public ClassDeclarationBase<ClassDeclaration::ExtraClassDeclaration> {
+public:
+ /// Create an extra class declaration.
+ ExtraClassDeclaration(StringRef extraClassDeclaration)
+ : extraClassDeclaration(extraClassDeclaration) {}
+
+ /// Write the extra class declarations.
+ void writeDeclTo(raw_indented_ostream &os) const override;
+
+private:
+ /// The string of the extra class declarations. It is re-indented before
+ /// printed.
+ StringRef extraClassDeclaration;
};
/// A class used to emit C++ classes from Tablegen. Contains a list of public
/// methods and a list of private fields to be emitted.
class Class {
public:
- explicit Class(StringRef name);
+ virtual ~Class() = default;
+
+ /// Create a class with a name, and whether it should be declared as a `class`
+ /// or `struct`.
+ template <typename NameT>
+ Class(NameT &&name, bool isStruct = false)
+ : className(stringify(std::forward<NameT>(name))), isStruct(isStruct) {}
/// Add a new constructor to this class and prune and constructors made
/// redundant by it. Returns null if the constructor was not added. Else,
/// returns a pointer to the new constructor.
- template <typename... Parameters>
- Constructor *addConstructorAndPrune(Parameters &&...parameters) {
- return addConstructorAndPrune(
- Constructor(getClassName(), Method::MP_Constructor,
- std::forward<Parameters>(parameters)...));
+ template <Method::Properties Properties = Method::None, typename... Args>
+ Constructor *addConstructor(Args &&...args) {
+ return addConstructorAndPrune(Constructor(getClassName(),
+ Properties | Method::Constructor,
+ std::forward<Args>(args)...));
}
/// Add a new method to this class and prune any methods made redundant by it.
/// Returns null if the method was not added (because an existing method would
/// make it redundant). Else, returns a pointer to the new method.
- template <typename... Parameters>
- Method *addMethod(StringRef retType, StringRef name,
- Method::Property properties, Parameters &&...parameters) {
- return addMethodAndPrune(Method(retType, name, properties,
- std::forward<Parameters>(parameters)...));
+ template <Method::Properties Properties = Method::None, typename RetTypeT,
+ typename NameT, typename... Args>
+ Method *addMethod(RetTypeT &&retType, NameT &&name,
+ Method::Properties properties, Args &&...args) {
+ return addMethodAndPrune(
+ Method(std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+ Properties | properties, std::forward<Args>(args)...));
}
/// Add a method with statically-known properties.
- template <Method::Property Properties = Method::MP_None,
- typename... Parameters>
- Method *addMethod(StringRef retType, StringRef name,
- Parameters &&...parameters) {
- return addMethod(retType, name, Properties,
- std::forward<Parameters>(parameters)...);
+ template <Method::Properties Properties = Method::None, typename RetTypeT,
+ typename NameT, typename... Args>
+ Method *addMethod(RetTypeT &&retType, NameT &&name, Args &&...args) {
+ return addMethod(std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+ Properties, std::forward<Args>(args)...);
}
/// Add a static method.
- template <Method::Property Properties = Method::MP_None,
- typename... Parameters>
- Method *addStaticMethod(StringRef retType, StringRef name,
- Parameters &&...parameters) {
- return addMethod<Properties | Method::MP_Static>(
- retType, name, std::forward<Parameters>(parameters)...);
+ template <Method::Properties Properties = Method::None, typename RetTypeT,
+ typename NameT, typename... Args>
+ Method *addStaticMethod(RetTypeT &&retType, NameT &&name, Args &&...args) {
+ return addMethod<Properties | Method::Static>(
+ std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+ std::forward<Args>(args)...);
}
/// Add an inline static method.
- template <Method::Property Properties = Method::MP_None,
- typename... Parameters>
- Method *addStaticInlineMethod(StringRef retType, StringRef name,
- Parameters &&...parameters) {
- return addMethod<Properties | Method::MP_Static | Method::MP_Inline>(
- retType, name, std::forward<Parameters>(parameters)...);
+ template <Method::Properties Properties = Method::None, typename RetTypeT,
+ typename NameT, typename... Args>
+ Method *addStaticInlineMethod(RetTypeT &&retType, NameT &&name,
+ Args &&...args) {
+ return addMethod<Properties | Method::StaticInline>(
+ std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+ std::forward<Args>(args)...);
}
/// Add an inline method.
- template <Method::Property Properties = Method::MP_None,
- typename... Parameters>
- Method *addInlineMethod(StringRef retType, StringRef name,
- Parameters &&...parameters) {
- return addMethod<Properties | Method::MP_Inline>(
- retType, name, std::forward<Parameters>(parameters)...);
+ template <Method::Properties Properties = Method::None, typename RetTypeT,
+ typename NameT, typename... Args>
+ Method *addInlineMethod(RetTypeT &&retType, NameT &&name, Args &&...args) {
+ return addMethod<Properties | Method::Inline>(
+ std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+ std::forward<Args>(args)...);
+ }
+
+ /// Add a const method.
+ template <Method::Properties Properties = Method::None, typename RetTypeT,
+ typename NameT, typename... Args>
+ Method *addConstMethod(RetTypeT &&retType, NameT &&name, Args &&...args) {
+ return addMethod<Properties | Method::Const>(
+ std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+ std::forward<Args>(args)...);
}
/// Add a declaration for a method.
- template <Method::Property Properties = Method::MP_None,
- typename... Parameters>
- Method *declareMethod(StringRef retType, StringRef name,
- Parameters &&...parameters) {
- return addMethod<Properties | Method::MP_Declaration>(
- retType, name, std::forward<Parameters>(parameters)...);
+ template <Method::Properties Properties = Method::None, typename RetTypeT,
+ typename NameT, typename... Args>
+ Method *declareMethod(RetTypeT &&retType, NameT &&name, Args &&...args) {
+ return addMethod<Properties | Method::Declaration>(
+ std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+ std::forward<Args>(args)...);
}
/// Add a declaration for a static method.
- template <Method::Property Properties = Method::MP_None,
- typename... Parameters>
- Method *declareStaticMethod(StringRef retType, StringRef name,
- Parameters &&...parameters) {
- return addMethod<Properties | Method::MP_StaticDeclaration>(
- retType, name, std::forward<Parameters>(parameters)...);
+ template <Method::Properties Properties = Method::None, typename RetTypeT,
+ typename NameT, typename... Args>
+ Method *declareStaticMethod(RetTypeT &&retType, NameT &&name,
+ Args &&...args) {
+ return addMethod<Properties | Method::StaticDeclaration>(
+ std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+ std::forward<Args>(args)...);
}
- /// Creates a new field in this class.
- void newField(StringRef type, StringRef name, StringRef defaultValue = "");
+ /// Add a new field to the class. Class fields added this way are always
+ /// private.
+ template <typename TypeT, typename NameT>
+ void addField(TypeT &&type, NameT &&name) {
+ fields.emplace_back(std::forward<TypeT>(type), std::forward<NameT>(name));
+ }
- /// Writes this op's class as a declaration to the given `os`.
- void writeDeclTo(raw_ostream &os) const;
- /// Writes the method definitions in this op's class to the given `os`.
- void writeDefTo(raw_ostream &os) const;
+ /// Add a parent class.
+ ParentClass &addParent(ParentClass parent);
- /// Returns the C++ class name of the op.
+ /// Return the C++ name of the class.
StringRef getClassName() const { return className; }
-protected:
- /// Get a list of all the methods to emit, filtering out hidden ones.
- void forAllMethods(llvm::function_ref<void(const Method &)> func) const {
- llvm::for_each(constructors, [&](auto &ctor) { func(ctor); });
- llvm::for_each(methods, [&](auto &method) { func(method); });
+ /// Write the declaration of this class, all declarations, and definitions of
+ /// inline functions. Wrap the output stream in an indented stream.
+ void writeDeclTo(raw_ostream &rawOs) const {
+ raw_indented_ostream os(rawOs);
+ writeDeclTo(os);
+ }
+ /// Write the definitions of thiss class's out-of-line constructors and
+ /// methods. Wrap the output stream in an indented stream.
+ void writeDefTo(raw_ostream &rawOs) const {
+ raw_indented_ostream os(rawOs);
+ writeDefTo(os);
+ }
+
+ /// Write the declaration of this class, all declarations, and definitions of
+ /// inline functions.
+ void writeDeclTo(raw_indented_ostream &os) const;
+ /// Write the definitions of thiss class's out-of-line constructors and
+ /// methods.
+ void writeDefTo(raw_indented_ostream &os) const;
+
+ /// Add a declaration. The declaration is appended directly to the list of
+ /// class declarations.
+ template <typename DeclT, typename... Args>
+ DeclT *declare(Args &&...args) {
+ auto decl = std::make_unique<DeclT>(std::forward<Args>(args)...);
+ auto *ret = decl.get();
+ declarations.push_back(std::move(decl));
+ return ret;
}
+ /// The declaration of a class needs to be "finalized".
+ ///
+ /// Class constructors, methods, and fields can be added in any order,
+ /// regardless of whether they are public or private. These are stored in
+ /// lists separate from list of declarations `declarations`.
+ ///
+ /// So that the generated C++ code is somewhat organised, public methods are
+ /// declared together, and so are private methods and class fields. This
+ /// function iterates through all the added methods and fields and organises
+ /// them into the list of declarations, adding visibility declarations as
+ /// needed, as follows:
+ ///
+ /// 1. public methods and constructors
+ /// 2. private methods and constructors
+ /// 3. class fields -- all are private
+ ///
+ /// `Class::finalize` clears the lists of pending methods and fields, and can
+ /// be called multiple times.
+ virtual void finalize();
+
+protected:
/// Add a new constructor if it is not made redundant by any existing
/// constructors and prune and existing constructors made redundant.
Constructor *addConstructorAndPrune(Constructor &&newCtor);
@@ -379,31 +719,22 @@ class Class {
/// prune and existing methods made redundant.
Method *addMethodAndPrune(Method &&newMethod);
+ /// Get the last visibility declaration.
+ Visibility getLastVisibilityDecl() const;
+
/// The C++ class name.
std::string className;
- /// The list of constructors.
- std::vector<Constructor> constructors;
- /// The list of class methods.
- std::vector<Method> methods;
- /// The list of class members.
- SmallVector<std::string, 4> fields;
-};
-
-// Class for holding an op for C++ code emission
-class OpClass : public Class {
-public:
- explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
-
- /// Adds an op trait.
- void addTrait(Twine trait);
-
- /// Writes this op's class as a declaration to the given `os`. Redefines
- /// Class::writeDeclTo to also emit traits and extra class declarations.
- void writeDeclTo(raw_ostream &os) const;
-
-private:
- StringRef extraClassDeclaration;
- llvm::SetVector<std::string, SmallVector<std::string>, StringSet<>> traits;
+ /// The list of parent classes.
+ SmallVector<ParentClass> parents;
+ /// The pending list of methods and constructors.
+ std::vector<std::unique_ptr<Method>> methods;
+ /// The pending list of private class fields.
+ SmallVector<Field> fields;
+ /// Whether this is a `class` or a `struct`.
+ bool isStruct;
+
+ /// A list of declarations in the class, emitted in order.
+ std::vector<std::unique_ptr<ClassDeclaration>> declarations;
};
} // namespace tblgen
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index c913e3514a8ad..b82f6460ae153 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -29,6 +29,12 @@ namespace tblgen {
class Constraint;
class DagLeaf;
+// Format into a std::string
+template <typename... Parameters>
+std::string strfmt(const char *fmt, Parameters &&...parameters) {
+ return llvm::formatv(fmt, std::forward<Parameters>(parameters)...).str();
+}
+
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
public:
@@ -58,7 +64,7 @@ class NamespaceEmitter {
~NamespaceEmitter() {
for (StringRef ns : llvm::reverse(namespaces))
- os << "} // namespace " << ns << "\n";
+ os << "} // end namespace " << ns << "\n";
}
private:
@@ -230,6 +236,13 @@ template <> struct stringifier<Twine> {
return twine.str();
}
};
+template <typename OptionalT>
+struct stringifier<Optional<OptionalT>> {
+ static std::string apply(Optional<OptionalT> optional) {
+ return optional.hasValue() ? stringifier<OptionalT>::apply(*optional)
+ : std::string();
+ }
+};
} // end namespace detail
/// Generically convert a value to a std::string.
diff --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h
index 3120f6ef5766c..6879dbf98fe76 100644
--- a/mlir/include/mlir/TableGen/Format.h
+++ b/mlir/include/mlir/TableGen/Format.h
@@ -50,8 +50,11 @@ class FmtContext {
FmtContext() = default;
+ // Create a format context with a list of substitutions.
+ FmtContext(ArrayRef<std::pair<StringRef, StringRef>> subs);
+
// Setter for custom placeholders
- FmtContext &addSubst(StringRef placeholder, Twine subst);
+ FmtContext &addSubst(StringRef placeholder, const Twine &subst);
// Setters for builtin placeholders
FmtContext &withBuilder(Twine subst);
diff --git a/mlir/lib/Support/IndentedOstream.cpp b/mlir/lib/Support/IndentedOstream.cpp
index bb3feef6c4458..470147cd5c526 100644
--- a/mlir/lib/Support/IndentedOstream.cpp
+++ b/mlir/lib/Support/IndentedOstream.cpp
@@ -15,20 +15,31 @@
using namespace mlir;
-raw_indented_ostream &mlir::raw_indented_ostream::reindent(StringRef str) {
- StringRef remaining = str;
- // Find leading whitespace indent.
- while (!remaining.empty()) {
- auto split = remaining.split('\n');
+raw_indented_ostream &
+mlir::raw_indented_ostream::printReindented(StringRef str) {
+ StringRef output = str;
+ // Skip empty lines.
+ while (!output.empty()) {
+ auto split = output.split('\n');
size_t indent = split.first.find_first_not_of(" \t");
if (indent != StringRef::npos) {
+ // Set an initial value.
leadingWs = indent;
break;
}
+ output = split.second;
+ }
+ // Determine the maximum indent.
+ StringRef remaining = output;
+ while (!remaining.empty()) {
+ auto split = remaining.split('\n');
+ size_t indent = split.first.find_first_not_of(" \t");
+ if (indent != StringRef::npos)
+ leadingWs = std::min(leadingWs, static_cast<int>(indent));
remaining = split.second;
}
// Print, skipping the empty lines.
- *this << remaining;
+ *this << output;
leadingWs = 0;
return *this;
}
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index f43949c30a222..e874c8648d41c 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -56,6 +56,12 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
if (traitSet.insert(traitInit).second)
traits.push_back(Trait::create(traitInit));
}
+
+ // Populate the parameters.
+ if (auto *parametersDag = def->getValueAsDag("parameters")) {
+ for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i)
+ parameters.push_back(AttrOrTypeParameter(parametersDag, i));
+ }
}
Dialect AttrOrTypeDef::getDialect() const {
@@ -107,14 +113,6 @@ bool AttrOrTypeDef::hasStorageCustomConstructor() const {
return def->getValueAsBit("hasStorageCustomConstructor");
}
-void AttrOrTypeDef::getParameters(
- SmallVectorImpl<AttrOrTypeParameter> ¶meters) const {
- if (auto *parametersDag = def->getValueAsDag("parameters")) {
- for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i)
- parameters.push_back(AttrOrTypeParameter(parametersDag, i));
- }
-}
-
unsigned AttrOrTypeDef::getNumParameters() const {
auto *parametersDag = def->getValueAsDag("parameters");
return parametersDag ? parametersDag->getNumArgs() : 0;
diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp
index 3fdba8c858b6a..0c38fa5b0067e 100644
--- a/mlir/lib/TableGen/Class.cpp
+++ b/mlir/lib/TableGen/Class.cpp
@@ -7,21 +7,16 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Class.h"
-
#include "mlir/TableGen/Format.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Support/raw_ostream.h"
-#include <unordered_set>
-
-#define DEBUG_TYPE "mlir-tblgen-opclass"
using namespace mlir;
using namespace mlir::tblgen;
-// Returns space to be emitted after the given C++ `type`. return "" if the
-// ends with '&' or '*', or is empty, else returns " ".
+/// Returns space to be emitted after the given C++ `type`. return "" if the
+/// ends with '&' or '*', or is empty, else returns " ".
static StringRef getSpaceAfterType(StringRef type) {
return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
}
@@ -30,23 +25,29 @@ static StringRef getSpaceAfterType(StringRef type) {
// MethodParameter definitions
//===----------------------------------------------------------------------===//
-void MethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
+void MethodParameter::writeDeclTo(raw_indented_ostream &os) const {
if (optional)
os << "/*optional*/";
os << type << getSpaceAfterType(type) << name;
- if (emitDefault && hasDefaultValue())
+ if (hasDefaultValue())
os << " = " << defaultValue;
}
+void MethodParameter::writeDefTo(raw_indented_ostream &os) const {
+ if (optional)
+ os << "/*optional*/";
+ os << type << getSpaceAfterType(type) << name;
+}
+
//===----------------------------------------------------------------------===//
// MethodParameters definitions
//===----------------------------------------------------------------------===//
-void MethodParameters::writeDeclTo(raw_ostream &os) const {
+void MethodParameters::writeDeclTo(raw_indented_ostream &os) const {
llvm::interleaveComma(parameters, os,
[&os](auto ¶m) { param.writeDeclTo(os); });
}
-void MethodParameters::writeDefTo(raw_ostream &os) const {
+void MethodParameters::writeDefTo(raw_indented_ostream &os) const {
llvm::interleaveComma(parameters, os,
[&os](auto ¶m) { param.writeDefTo(os); });
}
@@ -78,13 +79,14 @@ bool MethodSignature::makesRedundant(const MethodSignature &other) const {
parameters.subsumes(other.parameters);
}
-void MethodSignature::writeDeclTo(raw_ostream &os) const {
+void MethodSignature::writeDeclTo(raw_indented_ostream &os) const {
os << returnType << getSpaceAfterType(returnType) << methodName << "(";
parameters.writeDeclTo(os);
os << ")";
}
-void MethodSignature::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
+void MethodSignature::writeDefTo(raw_indented_ostream &os,
+ StringRef namePrefix) const {
os << returnType << getSpaceAfterType(returnType) << namePrefix
<< (namePrefix.empty() ? "" : "::") << methodName << "(";
parameters.writeDefTo(os);
@@ -95,30 +97,15 @@ void MethodSignature::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
// MethodBody definitions
//===----------------------------------------------------------------------===//
-MethodBody::MethodBody(bool declOnly) : isEffective(!declOnly) {}
-
-MethodBody &MethodBody::operator<<(Twine content) {
- if (isEffective)
- body.append(content.str());
- return *this;
-}
-
-MethodBody &MethodBody::operator<<(int content) {
- if (isEffective)
- body.append(std::to_string(content));
- return *this;
-}
-
-MethodBody &MethodBody::operator<<(const FmtObjectBase &content) {
- if (isEffective)
- body.append(content.str());
- return *this;
-}
+MethodBody::MethodBody(bool declOnly)
+ : declOnly(declOnly), stringOs(body), os(stringOs) {}
-void MethodBody::writeTo(raw_ostream &os) const {
+void MethodBody::writeTo(raw_indented_ostream &os) const {
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
os << bodyRef;
- if (bodyRef.empty() || bodyRef.back() != '\n')
+ if (bodyRef.empty())
+ return;
+ if (bodyRef.back() != '\n')
os << "\n";
}
@@ -126,171 +113,252 @@ void MethodBody::writeTo(raw_ostream &os) const {
// Method definitions
//===----------------------------------------------------------------------===//
-void Method::writeDeclTo(raw_ostream &os) const {
- os.indent(2);
+void Method::writeDeclTo(raw_indented_ostream &os) const {
if (isStatic())
os << "static ";
- if ((properties & MP_Constexpr) == MP_Constexpr)
+ if (properties & ConstexprValue)
os << "constexpr ";
methodSignature.writeDeclTo(os);
+ if (isConst())
+ os << " const";
if (!isInline()) {
- os << ";";
- } else {
- os << " {\n";
- methodBody.writeTo(os.indent(2));
- os.indent(2) << "}";
+ os << ";\n";
+ return;
}
+ os << " {\n";
+ methodBody.writeTo(os);
+ os << "}\n\n";
}
-void Method::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
- // Do not write definition if the method is decl only.
- if (properties & MP_Declaration)
- return;
- // Do not generate separate definition for inline method
- if (isInline())
+void Method::writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const {
+ // The method has no definition to write if it is declaration only or inline.
+ if (properties & Declaration || isInline())
return;
+
methodSignature.writeDefTo(os, namePrefix);
+ if (isConst())
+ os << " const";
os << " {\n";
methodBody.writeTo(os);
- os << "}";
+ os << "}\n\n";
}
//===----------------------------------------------------------------------===//
// Constructor definitions
//===----------------------------------------------------------------------===//
-void Constructor::addMemberInitializer(StringRef name, StringRef value) {
- memberInitializers.append(std::string(llvm::formatv(
- "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
+void Constructor::writeDeclTo(raw_indented_ostream &os) const {
+ if (properties & ConstexprValue)
+ os << "constexpr ";
+ methodSignature.writeDeclTo(os);
+ if (!isInline()) {
+ os << ";\n\n";
+ return;
+ }
+ os << ' ';
+ if (!initializers.empty())
+ os << ": ";
+ llvm::interleaveComma(initializers, os,
+ [&](auto &initializer) { initializer.writeTo(os); });
+ if (!initializers.empty())
+ os << ' ';
+ os << "{";
+ methodBody.writeTo(os);
+ os << "}\n\n";
}
-void Constructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
- // Do not write definition if the method is decl only.
- if (properties & MP_Declaration)
+void Constructor::writeDefTo(raw_indented_ostream &os,
+ StringRef namePrefix) const {
+ // The method has no definition to write if it is declaration only or inline.
+ if (properties & Declaration || isInline())
return;
methodSignature.writeDefTo(os, namePrefix);
- os << " " << memberInitializers << " {\n";
+ os << ' ';
+ if (!initializers.empty())
+ os << ": ";
+ llvm::interleaveComma(initializers, os,
+ [&](auto &initializer) { initializer.writeTo(os); });
+ if (!initializers.empty())
+ os << ' ';
+ os << "{";
methodBody.writeTo(os);
- os << "}\n";
+ os << "}\n\n";
+}
+
+void Constructor::MemberInitializer::writeTo(raw_indented_ostream &os) const {
+ os << name << '(' << value << ')';
}
//===----------------------------------------------------------------------===//
-// Class definitions
+// Visibility definitions
//===----------------------------------------------------------------------===//
-Class::Class(StringRef name) : className(name) {}
+namespace mlir {
+namespace tblgen {
+raw_ostream &operator<<(raw_ostream &os, Visibility visibility) {
+ switch (visibility) {
+ case Visibility::Public:
+ return os << "public";
+ case Visibility::Protected:
+ return os << "protected";
+ case Visibility::Private:
+ return os << "private";
+ }
+ return os;
+}
+} // end namespace tblgen
+} // end namespace mlir
-void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
- std::string varName = formatv("{0} {1}", type, name).str();
- std::string field = defaultValue.empty()
- ? varName
- : formatv("{0} = {1}", varName, defaultValue).str();
- fields.push_back(std::move(field));
+//===----------------------------------------------------------------------===//
+// ParentClass definitions
+//===----------------------------------------------------------------------===//
+
+void ParentClass::writeTo(raw_indented_ostream &os) const {
+ os << visibility << ' ' << name;
+ if (!templateParams.empty()) {
+ auto scope = os.scope("<", ">", /*indent=*/false);
+ llvm::interleaveComma(templateParams, os,
+ [&](auto ¶m) { os << param; });
+ }
}
-void Class::writeDeclTo(raw_ostream &os) const {
- bool hasPrivateMethod = false;
- os << "class " << className << " {\n";
- os << "public:\n";
-
- forAllMethods([&](const Method &method) {
- if (!method.isPrivate()) {
- method.writeDeclTo(os);
- os << '\n';
- } else {
- hasPrivateMethod = true;
- }
- });
+//===----------------------------------------------------------------------===//
+// UsingDeclaration definitions
+//===----------------------------------------------------------------------===//
+
+void UsingDeclaration::writeDeclTo(raw_indented_ostream &os) const {
+ os << "using " << name;
+ if (!value.empty())
+ os << " = " << value;
+ os << ";\n";
+}
- os << '\n';
- os << "private:\n";
- if (hasPrivateMethod) {
- forAllMethods([&](const Method &method) {
- if (method.isPrivate()) {
- method.writeDeclTo(os);
- os << '\n';
- }
- });
- os << '\n';
+//===----------------------------------------------------------------------===//
+// Field definitions
+//===----------------------------------------------------------------------===//
+
+void Field::writeDeclTo(raw_indented_ostream &os) const {
+ os << type << ' ' << name << ";\n";
+}
+
+//===----------------------------------------------------------------------===//
+// VisibilityDeclaration definitions
+//===----------------------------------------------------------------------===//
+
+void VisibilityDeclaration::writeDeclTo(raw_indented_ostream &os) const {
+ os.unindent();
+ os << visibility << ":\n";
+ os.indent();
+}
+
+//===----------------------------------------------------------------------===//
+// ExtraClassDeclaration definitions
+//===----------------------------------------------------------------------===//
+
+void ExtraClassDeclaration::writeDeclTo(raw_indented_ostream &os) const {
+ os.printReindented(extraClassDeclaration);
+}
+
+//===----------------------------------------------------------------------===//
+// Class definitions
+//===----------------------------------------------------------------------===//
+
+ParentClass &Class::addParent(ParentClass parent) {
+ parents.push_back(std::move(parent));
+ return parents.back();
+}
+
+void Class::writeDeclTo(raw_indented_ostream &os) const {
+ // Declare the class.
+ os << (isStruct ? "struct" : "class") << ' ' << className << ' ';
+
+ // Declare the parent classes, if any.
+ if (!parents.empty()) {
+ os << ": ";
+ llvm::interleaveComma(parents, os,
+ [&](auto &parent) { parent.writeTo(os); });
+ os << ' ';
}
+ auto classScope = os.scope("{\n", "};\n", /*indent=*/true);
- for (const auto &field : fields)
- os.indent(2) << field << ";\n";
- os << "};\n";
+ // Print all the class declarations.
+ for (auto &decl : declarations)
+ decl->writeDeclTo(os);
}
-void Class::writeDefTo(raw_ostream &os) const {
- forAllMethods([&](const Method &method) {
- method.writeDefTo(os, className);
- os << "\n";
+void Class::writeDefTo(raw_indented_ostream &os) const {
+ // Print all the definitions.
+ for (auto &decl : declarations)
+ decl->writeDefTo(os, className);
+}
+
+void Class::finalize() {
+ // Sort the methods by public and private. Remove them from the pending list
+ // of methods.
+ SmallVector<std::unique_ptr<Method>> publicMethods, privateMethods;
+ for (auto &method : methods) {
+ if (method->isPrivate())
+ privateMethods.push_back(std::move(method));
+ else
+ publicMethods.push_back(std::move(method));
+ }
+ methods.clear();
+
+ // If the last visibility declaration wasn't `public`, add one that is. Then,
+ // declare the public methods.
+ if (!publicMethods.empty() && getLastVisibilityDecl() != Visibility::Public)
+ declare<VisibilityDeclaration>(Visibility::Public);
+ for (auto &method : publicMethods)
+ declarations.push_back(std::move(method));
+
+ // If the last visibility declaration wasn't `private`, add one that is. Then,
+ // declare the private methods.
+ if (!privateMethods.empty() && getLastVisibilityDecl() != Visibility::Private)
+ declare<VisibilityDeclaration>(Visibility::Private);
+ for (auto &method : privateMethods)
+ declarations.push_back(std::move(method));
+
+ // All fields added to the pending list are private and declared at the bottom
+ // of the class. If the last visibility declaration wasn't `private`, add one
+ // that is, then declare the fields.
+ if (!fields.empty() && getLastVisibilityDecl() != Visibility::Private)
+ declare<VisibilityDeclaration>(Visibility::Private);
+ for (auto &field : fields)
+ declare<Field>(std::move(field));
+ fields.clear();
+}
+
+Visibility Class::getLastVisibilityDecl() const {
+ auto reverseDecls = llvm::reverse(declarations);
+ auto it = llvm::find_if(reverseDecls, [](auto &decl) {
+ return isa<VisibilityDeclaration>(decl);
});
+ return it == reverseDecls.end()
+ ? (isStruct ? Visibility::Public : Visibility::Private)
+ : cast<VisibilityDeclaration>(*it).getVisibility();
}
-// Insert a new method into a list of methods, if it would not be pruned, and
-// prune and existing methods.
-template <typename ContainerT, typename MethodT>
-MethodT *insertAndPrune(ContainerT &methods, MethodT newMethod) {
+Method *insertAndPruneMethods(std::vector<std::unique_ptr<Method>> &methods,
+ std::unique_ptr<Method> newMethod) {
if (llvm::any_of(methods, [&](auto &method) {
- return method.makesRedundant(newMethod);
+ return method->makesRedundant(*newMethod);
}))
return nullptr;
- llvm::erase_if(
- methods, [&](auto &method) { return newMethod.makesRedundant(method); });
+ llvm::erase_if(methods, [&](auto &method) {
+ return newMethod->makesRedundant(*method);
+ });
methods.push_back(std::move(newMethod));
- return &methods.back();
+ return methods.back().get();
}
Method *Class::addMethodAndPrune(Method &&newMethod) {
- return insertAndPrune(methods, std::move(newMethod));
+ return insertAndPruneMethods(methods,
+ std::make_unique<Method>(std::move(newMethod)));
}
Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) {
- return insertAndPrune(constructors, std::move(newCtor));
-}
-
-//===----------------------------------------------------------------------===//
-// OpClass definitions
-//===----------------------------------------------------------------------===//
-
-OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
- : Class(name), extraClassDeclaration(extraClassDeclaration) {}
-
-void OpClass::addTrait(Twine trait) { traits.insert(trait.str()); }
-
-void OpClass::writeDeclTo(raw_ostream &os) const {
- os << "class " << className << " : public ::mlir::Op<" << className;
- for (const auto &trait : traits)
- os << ", " << trait;
- os << "> {\npublic:\n"
- << " using Op::Op;\n"
- << " using Op::print;\n"
- << " using Adaptor = " << className << "Adaptor;\n";
-
- bool hasPrivateMethod = false;
- forAllMethods([&](const Method &method) {
- if (!method.isPrivate()) {
- method.writeDeclTo(os);
- os << "\n";
- } else {
- hasPrivateMethod = true;
- }
- });
-
- // TODO: Add line control markers to make errors easier to debug.
- if (!extraClassDeclaration.empty())
- os << extraClassDeclaration << "\n";
-
- if (hasPrivateMethod) {
- os << "\nprivate:\n";
- forAllMethods([&](const Method &method) {
- if (method.isPrivate()) {
- method.writeDeclTo(os);
- os << "\n";
- }
- });
- }
-
- os << "};\n";
+ return dyn_cast_or_null<Constructor>(insertAndPruneMethods(
+ methods, std::make_unique<Constructor>(std::move(newCtor))));
}
diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp
index 4a0bbdf7f346c..917d1f5b50fff 100644
--- a/mlir/lib/TableGen/Format.cpp
+++ b/mlir/lib/TableGen/Format.cpp
@@ -21,7 +21,12 @@ using namespace mlir::tblgen;
// Marker to indicate an error happened when replacing a placeholder.
const char *const kMarkerForNoSubst = "<no-subst-found>";
-FmtContext &FmtContext::addSubst(StringRef placeholder, Twine subst) {
+FmtContext::FmtContext(ArrayRef<std::pair<StringRef, StringRef>> subs) {
+ for (auto &sub : subs)
+ addSubst(sub.first, sub.second);
+}
+
+FmtContext &FmtContext::addSubst(StringRef placeholder, const Twine &subst) {
customSubstMap[placeholder] = subst.str();
return *this;
}
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 96ec74141910d..888b9856da761 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -35,7 +35,7 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
/// Check simple attribute parser and printer are generated correctly.
// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::AsmParser &parser,
-// ATTR: ::mlir::Type attrType) {
+// ATTR: ::mlir::Type type) {
// ATTR: FailureOr<IntegerAttr> _result_value;
// ATTR: FailureOr<TestParamA> _result_complex;
// ATTR: if (parser.parseKeyword("hello"))
@@ -47,7 +47,7 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
// ATTR: return {};
// ATTR: if (parser.parseComma())
// ATTR: return {};
-// ATTR: _result_complex = ::parseAttrParamA(parser, attrType);
+// ATTR: _result_complex = ::parseAttrParamA(parser, type);
// ATTR: if (failed(_result_complex))
// ATTR: return {};
// ATTR: if (parser.parseRParen())
@@ -81,7 +81,7 @@ def AttrA : TestAttr<"TestA"> {
/// Test simple struct parser and printer are generated correctly.
// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::AsmParser &parser,
-// ATTR: ::mlir::Type attrType) {
+// ATTR: ::mlir::Type type) {
// ATTR: bool _seen_v0 = false;
// ATTR: bool _seen_v1 = false;
// ATTR: for (unsigned _index = 0; _index < 2; ++_index) {
@@ -92,12 +92,12 @@ def AttrA : TestAttr<"TestA"> {
// ATTR: return {};
// ATTR: if (!_seen_v0 && _paramKey == "v0") {
// ATTR: _seen_v0 = true;
-// ATTR: _result_v0 = ::parseAttrParamA(parser, attrType);
+// ATTR: _result_v0 = ::parseAttrParamA(parser, type);
// ATTR: if (failed(_result_v0))
// ATTR: return {};
// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
// ATTR: _seen_v1 = true;
-// ATTR: _result_v1 = attrType ? ::parseAttrWithType(parser, attrType) : ::parseAttrWithout(parser);
+// ATTR: _result_v1 = type ? ::parseAttrWithType(parser, type) : ::parseAttrWithout(parser);
// ATTR: if (failed(_result_v1))
// ATTR: return {};
// ATTR: } else {
@@ -136,7 +136,7 @@ def AttrB : TestAttr<"TestB"> {
/// Test attribute with capture-all params has correct parser and printer.
// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::AsmParser &parser,
-// ATTR: ::mlir::Type attrType) {
+// ATTR: ::mlir::Type type) {
// ATTR: ::mlir::FailureOr<int> _result_v0;
// ATTR: ::mlir::FailureOr<int> _result_v1;
// ATTR: _result_v0 = ::mlir::FieldParser<int>::parse(parser);
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 64edaad23e4b0..ee455f5359b78 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -7,9 +7,9 @@ include "mlir/IR/OpBase.td"
// DECL: #undef GET_ATTRDEF_CLASSES
// DECL: namespace mlir {
-// DECL: class DialectAsmParser;
-// DECL: class DialectAsmPrinter;
-// DECL: } // namespace mlir
+// DECL: class AsmParser;
+// DECL: class AsmPrinter;
+// DECL: } // end namespace mlir
// DEF: #ifdef GET_ATTRDEF_LIST
// DEF: #undef GET_ATTRDEF_LIST
@@ -19,9 +19,9 @@ include "mlir/IR/OpBase.td"
// DEF: ::test::SingleParameterAttr
// DEF-LABEL: ::mlir::OptionalParseResult generatedAttributeParser(
-// DEF-NEXT: ::mlir::AsmParser &parser,
-// DEF-NEXT: ::llvm::StringRef mnemonic, ::mlir::Type type,
-// DEF-NEXT: ::mlir::Attribute &value) {
+// DEF-SAME: ::mlir::AsmParser &parser,
+// DEF-SAME: ::llvm::StringRef mnemonic, ::mlir::Type type,
+// DEF-SAME: ::mlir::Attribute &value) {
// DEF: if (mnemonic == ::test::CompoundAAttr::getMnemonic()) {
// DEF-NEXT: value = ::test::CompoundAAttr::parse(parser, type);
// DEF-NEXT: return ::mlir::success(!!value);
@@ -61,10 +61,10 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
let genVerifyDecl = 1;
// DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute
-// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
+// DECL: static CompoundAAttr getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
-// DECL: return ::llvm::StringLiteral("cmpnd_a");
+// DECL: return "cmpnd_a";
// DECL: }
// DECL: static ::mlir::Attribute parse(
// DECL-SAME: ::mlir::AsmParser &parser, ::mlir::Type type);
@@ -75,27 +75,23 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
// Check that AttributeSelfTypeParameter is handled properly.
// DEF-LABEL: struct CompoundAAttrStorage
-// DEF: CompoundAAttrStorage (
-// DEF-NEXT: : ::mlir::AttributeStorage(inner),
+// DEF: CompoundAAttrStorage(
+// DEF-SAME: : ::mlir::AttributeStorage(inner),
// DEF: bool operator==(const KeyTy &tblgenKey) const {
-// DEF-NEXT: if (!(widthOfSomething == std::get<0>(tblgenKey)))
-// DEF-NEXT: return false;
-// DEF-NEXT: if (!(exampleTdType == std::get<1>(tblgenKey)))
-// DEF-NEXT: return false;
-// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(tblgenKey))))
-// DEF-NEXT: return false;
-// DEF-NEXT: if (!(dims == std::get<3>(tblgenKey)))
-// DEF-NEXT: return false;
-// DEF-NEXT: if (!(getType() == std::get<4>(tblgenKey)))
-// DEF-NEXT: return false;
-// DEF-NEXT: return true;
+// DEF-NEXT: return
+// DEF-SAME: (widthOfSomething == std::get<0>(tblgenKey)) &&
+// DEF-SAME: (exampleTdType == std::get<1>(tblgenKey)) &&
+// DEF-SAME: (apFloat.bitwiseIsEqual(std::get<2>(tblgenKey))) &&
+// DEF-SAME: (dims == std::get<3>(tblgenKey)) &&
+// DEF-SAME: (getType() == std::get<4>(tblgenKey));
// DEF: static CompoundAAttrStorage *construct
// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
-// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);
+// DEF-SAME: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);
-// DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType().cast<::mlir::Type>(); }
+// DEF: ::mlir::Type CompoundAAttr::getInner() const {
+// DEF-NEXT: return getImpl()->getType().cast<::mlir::Type>();
}
def C_IndexAttr : TestAttr<"Index"> {
@@ -108,7 +104,7 @@ def C_IndexAttr : TestAttr<"Index"> {
// DECL-LABEL: class IndexAttr : public ::mlir::Attribute
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
-// DECL: return ::llvm::StringLiteral("index");
+// DECL: return "index";
// DECL: }
// DECL: static ::mlir::Attribute parse(
// DECL-SAME: ::mlir::AsmParser &parser, ::mlir::Type type);
@@ -122,7 +118,7 @@ def D_SingleParameterAttr : TestAttr<"SingleParameter"> {
);
// DECL-LABEL: struct SingleParameterAttrStorage;
// DECL-LABEL: class SingleParameterAttr
-// DECL-NEXT: detail::SingleParameterAttrStorage
+// DECL-SAME: detail::SingleParameterAttrStorage
}
// An attribute testing AttributeSelfTypeParameter.
@@ -133,8 +129,8 @@ def E_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> {
}
// DEF-LABEL: struct AttrWithTypeBuilderAttrStorage
-// DEF: AttrWithTypeBuilderAttrStorage (::mlir::IntegerAttr attr)
-// DEF-NEXT: : ::mlir::AttributeStorage(attr.getType()), attr(attr)
+// DEF: AttrWithTypeBuilderAttrStorage(::mlir::IntegerAttr attr)
+// DEF-SAME: : ::mlir::AttributeStorage(attr.getType()), attr(attr)
def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> {
let parameters = (ins AttrParameter<"std::string", "", "StringRef">:$param);
@@ -143,6 +139,5 @@ def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> {
// DECL-LABEL: class ParamWithAccessorTypeAttr
// DECL: StringRef getParam()
// DEF: ParamWithAccessorTypeAttrStorage
-// DEF-NEXT: ParamWithAccessorTypeAttrStorage (std::string param)
+// DEF: ParamWithAccessorTypeAttrStorage(std::string param)
// DEF: StringRef ParamWithAccessorTypeAttr::getParam()
-
diff --git a/mlir/test/mlir-tblgen/default-type-attr-print-parser.td b/mlir/test/mlir-tblgen/default-type-attr-print-parser.td
index 2ec3f65188368..4b6a783725003 100644
--- a/mlir/test/mlir-tblgen/default-type-attr-print-parser.td
+++ b/mlir/test/mlir-tblgen/default-type-attr-print-parser.td
@@ -45,7 +45,7 @@ def AttrA : TestAttr<"AttrA"> {
// ATTR: return;
// ATTR: }
-// ATTR: } // namespace test
+// ATTR: } // end namespace test
def TypeA : TestType<"TypeA"> {
let mnemonic = "type_a";
@@ -73,4 +73,4 @@ def TypeA : TestType<"TypeA"> {
// TYPE: return;
// TYPE: }
-// TYPE: } // namespace test
+// TYPE: } // end namespace test
diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 2c03695409d1b..66c51d6e719b6 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -9,7 +9,7 @@ include "mlir/IR/OpBase.td"
// DECL: namespace mlir {
// DECL: class AsmParser;
// DECL: class AsmPrinter;
-// DECL: } // namespace mlir
+// DECL: } // end namespace mlir
// DEF: #ifdef GET_TYPEDEF_LIST
// DEF: #undef GET_TYPEDEF_LIST
@@ -20,9 +20,9 @@ include "mlir/IR/OpBase.td"
// DEF: ::test::IntegerType
// DEF-LABEL: ::mlir::OptionalParseResult generatedTypeParser(
-// DEF-NEXT: ::mlir::AsmParser &parser,
-// DEF-NEXT: ::llvm::StringRef mnemonic,
-// DEF-NEXT: ::mlir::Type &value) {
+// DEF-SAME: ::mlir::AsmParser &parser,
+// DEF-SAME: ::llvm::StringRef mnemonic,
+// DEF-SAME: ::mlir::Type &value) {
// DEF: if (mnemonic == ::test::CompoundAType::getMnemonic()) {
// DEF-NEXT: value = ::test::CompoundAType::parse(parser);
// DEF-NEXT: return ::mlir::success(!!value);
@@ -66,10 +66,10 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
let genVerifyDecl = 1;
// DECL-LABEL: class CompoundAType : public ::mlir::Type
-// DECL: static CompoundAType getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
+// DECL: static CompoundAType getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
-// DECL: return ::llvm::StringLiteral("cmpnd_a");
+// DECL: return "cmpnd_a";
// DECL: }
// DECL: static ::mlir::Type parse(::mlir::AsmParser &parser);
// DECL: void print(::mlir::AsmPrinter &printer) const;
@@ -88,7 +88,7 @@ def C_IndexType : TestType<"Index"> {
// DECL-LABEL: class IndexType : public ::mlir::Type
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
-// DECL: return ::llvm::StringLiteral("index");
+// DECL: return "index";
// DECL: }
// DECL: static ::mlir::Type parse(::mlir::AsmParser &parser);
// DECL: void print(::mlir::AsmPrinter &printer) const;
@@ -101,7 +101,7 @@ def D_SingleParameterType : TestType<"SingleParameter"> {
);
// DECL-LABEL: struct SingleParameterTypeStorage;
// DECL-LABEL: class SingleParameterType
-// DECL-NEXT: detail::SingleParameterTypeStorage
+// DECL-SAME: detail::SingleParameterTypeStorage
}
def E_IntegerType : TestType<"Integer"> {
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 3ae476123336f..8402877d9cb24 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -9,11 +9,13 @@
#include "AttrOrTypeFormatGen.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
+#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
@@ -43,104 +45,556 @@ static void collectAllDefs(StringRef selectedDialect,
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
auto defs = llvm::map_range(
records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
- if (defs.empty())
- return;
-
- StringRef dialectName;
if (selectedDialect.empty()) {
- if (defs.empty())
- return;
-
- Dialect dialect(nullptr);
- for (const AttrOrTypeDef &typeDef : defs) {
- if (!dialect) {
- dialect = typeDef.getDialect();
- } else if (dialect != typeDef.getDialect()) {
- llvm::PrintFatalError("defs belonging to more than one dialect. Must "
- "select one via '--(attr|type)defs-dialect'");
- }
+ if (!llvm::is_splat(
+ llvm::map_range(defs, [](auto def) { return def.getDialect(); }))) {
+ llvm::PrintFatalError("defs belonging to more than one dialect. Must "
+ "select one via '--(attr|type)defs-dialect'");
}
-
- dialectName = dialect.getName();
+ resultDefs.assign(defs.begin(), defs.end());
} else {
- dialectName = selectedDialect;
+ auto dialectDefs = llvm::make_filter_range(defs, [&](auto def) {
+ return def.getDialect().getName().equals(selectedDialect);
+ });
+ resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
}
-
- for (const AttrOrTypeDef &def : defs)
- if (def.getDialect().getName().equals(dialectName))
- resultDefs.push_back(def);
}
//===----------------------------------------------------------------------===//
-// ParamCommaFormatter
+// DefGen
//===----------------------------------------------------------------------===//
namespace {
-
-/// Pass an instance of this class to llvm::formatv() to emit a comma separated
-/// list of parameters in the format by 'EmitFormat'.
-class ParamCommaFormatter : public llvm::detail::format_adapter {
+class DefGen {
public:
- /// Choose the output format
- enum EmitFormat {
- /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
- /// [...]".
- TypeNamePairs,
+ /// Create the attribute or type class.
+ DefGen(const AttrOrTypeDef &def);
- /// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
- TypeNameInitializer,
-
- /// Emit "param1Name, param2Name, [...]".
- JustParams,
- };
-
- ParamCommaFormatter(EmitFormat emitFormat,
- ArrayRef<AttrOrTypeParameter> params,
- bool prependComma = true)
- : emitFormat(emitFormat), params(params), prependComma(prependComma) {}
-
- /// llvm::formatv will call this function when using an instance as a
- /// replacement value.
- void format(raw_ostream &os, StringRef options) override {
- if (!params.empty() && prependComma)
- os << ", ";
-
- switch (emitFormat) {
- case EmitFormat::TypeNamePairs:
- interleaveComma(params, os, [&](const AttrOrTypeParameter &p) {
- emitTypeNamePair(p, os);
- });
- break;
- case EmitFormat::TypeNameInitializer:
- interleaveComma(params, os, [&](const AttrOrTypeParameter &p) {
- emitTypeNameInitializer(p, os);
- });
- break;
- case EmitFormat::JustParams:
- interleaveComma(params, os,
- [&](const AttrOrTypeParameter &p) { os << p.getName(); });
- break;
+ void emitDecl(raw_ostream &os) const {
+ if (storageCls) {
+ NamespaceEmitter ns(os, def.getStorageNamespace());
+ os << "struct " << def.getStorageClassName() << ";\n";
+ }
+ defCls.writeDeclTo(os);
+ }
+ void emitDef(raw_ostream &os) const {
+ if (storageCls && def.genStorageClass()) {
+ NamespaceEmitter ns(os, def.getStorageNamespace());
+ storageCls->writeDeclTo(os); // everything is inline
}
+ defCls.writeDefTo(os);
}
private:
- // Emit "paramType paramName".
- static void emitTypeNamePair(const AttrOrTypeParameter ¶m,
- raw_ostream &os) {
- os << param.getCppType() << " " << param.getName();
+ /// Add traits from the TableGen definition to the class.
+ void createParentWithTraits();
+ /// Emit top-level declarations: using declarations and any extra class
+ /// declarations.
+ void emitTopLevelDeclarations();
+ /// Emit attribute or type builders.
+ void emitBuilders();
+ /// Emit a verifier for the def.
+ void emitVerifier();
+ /// Emit parsers and printers.
+ void emitParserPrinter();
+ /// Emit parameter accessors, if required.
+ void emitAccessors();
+ /// Emit interface methods.
+ void emitInterfaceMethods();
+
+ //===--------------------------------------------------------------------===//
+ // Builder Emission
+
+ /// Emit the default builder `Attribute::get`
+ void emitDefaultBuilder();
+ /// Emit the checked builder `Attribute::getChecked`
+ void emitCheckedBuilder();
+ /// Emit a custom builder.
+ void emitCustomBuilder(const AttrOrTypeBuilder &builder);
+ /// Emit a checked custom builder.
+ void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
+
+ //===--------------------------------------------------------------------===//
+ // Parser and Printer Emission
+ void emitParserPrinterBody(MethodBody &parser, MethodBody &printer);
+
+ //===--------------------------------------------------------------------===//
+ // Interface Method Emission
+
+ /// Emit methods for a trait.
+ void emitTraitMethods(const InterfaceTrait &trait);
+ /// Emit a trait method.
+ void emitTraitMethod(const InterfaceMethod &method);
+
+ //===--------------------------------------------------------------------===//
+ // Storage Class Emission
+ void emitStorageClass();
+ /// Generate the storage class constructor.
+ void emitStorageConstructor();
+ /// Emit the key type `KeyTy`.
+ void emitKeyType();
+ /// Emit the equality comparison operator.
+ void emitEquals();
+ /// Emit the key hash function.
+ void emitHashKey();
+ /// Emit the function to construct the storage class.
+ void emitConstruct();
+
+ //===--------------------------------------------------------------------===//
+ // Utility Function Declarations
+
+ /// Get the method parameters for a def builder, where the first several
+ /// parameters may be
diff erent.
+ SmallVector<MethodParameter>
+ getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
+
+ //===--------------------------------------------------------------------===//
+ // Class fields
+
+ /// The attribute or type definition.
+ const AttrOrTypeDef &def;
+ /// The list of attribute or type parameters.
+ ArrayRef<AttrOrTypeParameter> params;
+ /// The attribute or type class.
+ Class defCls;
+ /// An optional attribute or type storage class. The storage class will
+ /// exist if and only if the def has more than zero parameters.
+ Optional<Class> storageCls;
+
+ /// The C++ base value of the def, either "Attribute" or "Type".
+ StringRef valueType;
+ /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
+ StringRef defType;
+};
+} // end anonymous namespace
+
+DefGen::DefGen(const AttrOrTypeDef &def)
+ : def(def), params(def.getParameters()), defCls(def.getCppClassName()),
+ valueType(isa<AttrDef>(def) ? "Attribute" : "Type"),
+ defType(isa<AttrDef>(def) ? "Attr" : "Type") {
+ // If a storage class is needed, create one.
+ if (def.getNumParameters() > 0)
+ storageCls.emplace(def.getStorageClassName(), /*isStruct=*/true);
+
+ // Create the parent class with any indicated traits.
+ createParentWithTraits();
+ // Emit top-level declarations.
+ emitTopLevelDeclarations();
+ // Emit builders for defs with parameters
+ if (storageCls)
+ emitBuilders();
+ // Emit the verifier.
+ if (storageCls && def.genVerifyDecl())
+ emitVerifier();
+ // Emit the mnemonic, if there is one, and any associated parser and printer.
+ if (def.getMnemonic())
+ emitParserPrinter();
+ // Emit accessors
+ if (def.genAccessors())
+ emitAccessors();
+ // Emit trait interface methods
+ emitInterfaceMethods();
+ defCls.finalize();
+ // Emit a storage class if one is needed
+ if (storageCls && def.genStorageClass())
+ emitStorageClass();
+}
+
+void DefGen::createParentWithTraits() {
+ ParentClass defParent(strfmt("::mlir::{0}::{1}Base", valueType, defType));
+ defParent.addTemplateParam(def.getCppClassName());
+ defParent.addTemplateParam(def.getCppBaseClassName());
+ defParent.addTemplateParam(storageCls
+ ? strfmt("{0}::{1}", def.getStorageNamespace(),
+ def.getStorageClassName())
+ : strfmt("::mlir::{0}Storage", valueType));
+ for (auto &trait : def.getTraits()) {
+ defParent.addTemplateParam(
+ isa<NativeTrait>(&trait)
+ ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
+ : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
+ }
+ defCls.addParent(std::move(defParent));
+}
+
+void DefGen::emitTopLevelDeclarations() {
+ // Inherit constructors from the attribute or type class.
+ defCls.declare<VisibilityDeclaration>(Visibility::Public);
+ defCls.declare<UsingDeclaration>("Base::Base");
+
+ // Emit the extra declarations first in case there's a definition in there.
+ if (Optional<StringRef> extraDecl = def.getExtraDecls())
+ defCls.declare<ExtraClassDeclaration>(*extraDecl);
+}
+
+void DefGen::emitBuilders() {
+ if (!def.skipDefaultBuilders()) {
+ emitDefaultBuilder();
+ if (def.genVerifyDecl())
+ emitCheckedBuilder();
}
- // Emit "paramName(paramName)"
- void emitTypeNameInitializer(const AttrOrTypeParameter ¶m,
- raw_ostream &os) {
- os << param.getName() << "(" << param.getName() << ")";
+ for (auto &builder : def.getBuilders()) {
+ emitCustomBuilder(builder);
+ if (def.genVerifyDecl())
+ emitCheckedCustomBuilder(builder);
}
+}
- EmitFormat emitFormat;
- ArrayRef<AttrOrTypeParameter> params;
- bool prependComma;
-};
+void DefGen::emitVerifier() {
+ defCls.declare<UsingDeclaration>("Base::getChecked");
+ defCls.declareStaticMethod(
+ "::mlir::LogicalResult", "verify",
+ getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
+ "emitError"}}));
+}
-} // end anonymous namespace
+void DefGen::emitParserPrinter() {
+ auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
+ "::llvm::StringLiteral", "getMnemonic");
+ mnemonic->body().indent() << strfmt("return \"{0}\";", *def.getMnemonic());
+ // Declare the parser and printer, if needed.
+ if (!def.needsParserPrinter() && !def.hasGeneratedParser() &&
+ !def.hasGeneratedPrinter())
+ return;
+
+ // Declare the parser.
+ SmallVector<MethodParameter> parserParams;
+ parserParams.emplace_back("::mlir::AsmParser &", "parser");
+ if (isa<AttrDef>(&def))
+ parserParams.emplace_back("::mlir::Type", "type");
+ auto *parser = defCls.addMethod(
+ strfmt("::mlir::{0}", valueType), "parse",
+ def.hasGeneratedParser() ? Method::Static : Method::StaticDeclaration,
+ std::move(parserParams));
+ // Declare the printer.
+ auto props =
+ def.hasGeneratedPrinter() ? Method::Const : Method::ConstDeclaration;
+ Method *printer =
+ defCls.addMethod("void", "print", props,
+ MethodParameter("::mlir::AsmPrinter &", "printer"));
+ // Emit the bodies.
+ emitParserPrinterBody(parser->body(), printer->body());
+}
+
+void DefGen::emitAccessors() {
+ for (auto ¶m : params) {
+ Method *m = defCls.addMethod(
+ param.getCppAccessorType(), getParameterAccessorName(param.getName()),
+ def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
+ // Generate accessor definitions only if we also generate the storage
+ // class. Otherwise, let the user define the exact accessor definition.
+ if (!def.genStorageClass())
+ continue;
+ auto scope = m->body().indent().scope("return getImpl()->", ";");
+ if (isa<AttributeSelfTypeParameter>(param))
+ m->body() << formatv("getType().cast<{0}>()", param.getCppType());
+ else
+ m->body() << param.getName();
+ }
+}
+
+void DefGen::emitInterfaceMethods() {
+ for (auto &traitDef : def.getTraits())
+ if (auto *trait = dyn_cast<InterfaceTrait>(&traitDef))
+ if (trait->shouldDeclareMethods())
+ emitTraitMethods(*trait);
+}
+
+//===----------------------------------------------------------------------===//
+// Builder Emission
+
+SmallVector<MethodParameter>
+DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
+ SmallVector<MethodParameter> builderParams;
+ builderParams.append(prefix.begin(), prefix.end());
+ for (auto ¶m : params)
+ builderParams.emplace_back(param.getCppType(), param.getName());
+ return builderParams;
+}
+
+void DefGen::emitDefaultBuilder() {
+ Method *m = defCls.addStaticMethod(
+ def.getCppClassName(), "get",
+ getBuilderParams({{"::mlir::MLIRContext *", "context"}}));
+ MethodBody &body = m->body().indent();
+ auto scope = body.scope("return Base::get(context", ");");
+ llvm::for_each(params, [&](auto ¶m) { body << ", " << param.getName(); });
+}
+
+void DefGen::emitCheckedBuilder() {
+ Method *m = defCls.addStaticMethod(
+ def.getCppClassName(), "getChecked",
+ getBuilderParams(
+ {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"},
+ {"::mlir::MLIRContext *", "context"}}));
+ MethodBody &body = m->body().indent();
+ auto scope = body.scope("return Base::getChecked(emitError, context", ");");
+ llvm::for_each(params, [&](auto ¶m) { body << ", " << param.getName(); });
+}
+
+static SmallVector<MethodParameter>
+getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
+ const AttrOrTypeBuilder &builder) {
+ auto params = builder.getParameters();
+ SmallVector<MethodParameter> builderParams;
+ builderParams.append(prefix.begin(), prefix.end());
+ if (!builder.hasInferredContextParameter())
+ builderParams.emplace_back("::mlir::MLIRContext *", "context");
+ for (auto ¶m : params) {
+ builderParams.emplace_back(param.getCppType(), *param.getName(),
+ param.getDefaultValue());
+ }
+ return builderParams;
+}
+
+void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
+ // Don't emit a body if there isn't one.
+ auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
+ Method *m = defCls.addMethod(def.getCppClassName(), "get", props,
+ getCustomBuilderParams({}, builder));
+ if (!builder.getBody())
+ return;
+
+ // Format the body and emit it.
+ FmtContext ctx;
+ ctx.addSubst("_get", "Base::get");
+ if (!builder.hasInferredContextParameter())
+ ctx.addSubst("_ctxt", "context");
+ std::string bodyStr = tgfmt(*builder.getBody(), &ctx);
+ m->body().indent().getStream().printReindented(bodyStr);
+}
+
+/// Replace all instances of 'from' to 'to' in `str` and return the new string.
+static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
+ size_t pos = 0;
+ while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
+ str.replace(pos, from.size(), to.data(), to.size());
+ return str;
+}
+
+void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
+ // Don't emit a body if there isn't one.
+ auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
+ Method *m = defCls.addMethod(
+ def.getCppClassName(), "getChecked", props,
+ getCustomBuilderParams(
+ {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
+ builder));
+ if (!builder.getBody())
+ return;
+
+ // Format the body and emit it. Replace $_get(...) with
+ // Base::getChecked(emitError, ...)
+ FmtContext ctx;
+ if (!builder.hasInferredContextParameter())
+ ctx.addSubst("_ctxt", "context");
+ std::string bodyStr = replaceInStr(builder.getBody()->str(), "$_get(",
+ "Base::getChecked(emitError, ");
+ bodyStr = tgfmt(bodyStr, &ctx);
+ m->body().indent().getStream().printReindented(bodyStr);
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and Printer Emission
+
+void DefGen::emitParserPrinterBody(MethodBody &parser, MethodBody &printer) {
+ Optional<StringRef> parserCode = def.getParserCode();
+ Optional<StringRef> printerCode = def.getPrinterCode();
+ Optional<StringRef> asmFormat = def.getAssemblyFormat();
+ // Verify the parser-printer specification first.
+ if (asmFormat && (parserCode || printerCode)) {
+ PrintFatalError(def.getLoc(),
+ def.getName() + ": assembly format cannot be specified at "
+ "the same time as printer or parser code");
+ }
+ // Specified code cannot be empty.
+ if (parserCode && parserCode->empty())
+ PrintFatalError(def.getLoc(), def.getName() + ": parser cannot be empty");
+ if (printerCode && printerCode->empty())
+ PrintFatalError(def.getLoc(), def.getName() + ": printer cannot be empty");
+ // Assembly format requires accessors to be generated.
+ if (asmFormat && !def.genAccessors()) {
+ PrintFatalError(def.getLoc(),
+ def.getName() +
+ ": the generated printer from 'assemblyFormat' "
+ "requires 'genAccessors' to be true");
+ }
+
+ // Generate the parser and printer bodies.
+ if (asmFormat)
+ return generateAttrOrTypeFormat(def, parser, printer);
+
+ FmtContext ctx = FmtContext(
+ {{"_parser", "parser"}, {"_printer", "printer"}, {"_type", "type"}});
+ if (parserCode) {
+ ctx.addSubst("_ctxt", "parser.getContext()");
+ parser.indent().getStream().printReindented(tgfmt(*parserCode, &ctx).str());
+ }
+ if (printerCode) {
+ ctx.addSubst("_ctxt", "printer.getContext()");
+ printer.indent().getStream().printReindented(
+ tgfmt(*printerCode, &ctx).str());
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Interface Method Emission
+
+void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
+ // Get the set of methods that should always be declared.
+ auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods();
+ StringSet<> alwaysDeclared;
+ alwaysDeclared.insert(alwaysDeclaredMethods.begin(),
+ alwaysDeclaredMethods.end());
+
+ Interface iface = trait.getInterface(); // causes strange bugs if elided
+ for (auto &method : iface.getMethods()) {
+ // Don't declare if the method has a body. Or if the method has a default
+ // implementation and the def didn't request that it always be declared.
+ if (method.getBody() || (method.getDefaultImplementation() &&
+ !alwaysDeclared.count(method.getName())))
+ continue;
+ emitTraitMethod(method);
+ }
+}
+
+void DefGen::emitTraitMethod(const InterfaceMethod &method) {
+ // All interface methods are declaration-only.
+ auto props =
+ method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration;
+ SmallVector<MethodParameter> params;
+ for (auto ¶m : method.getArguments())
+ params.emplace_back(param.type, param.name);
+ defCls.addMethod(method.getReturnType(), method.getName(), props,
+ std::move(params));
+}
+
+//===----------------------------------------------------------------------===//
+// Storage Class Emission
+
+void DefGen::emitStorageConstructor() {
+ Constructor *ctor =
+ storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
+ if (auto *attrDef = dyn_cast<AttrDef>(&def)) {
+ // For attributes, a parameter marked with AttributeSelfTypeParameter is
+ // the type initializer that must be passed to the parent constructor.
+ const auto isSelfType = [](const AttrOrTypeParameter ¶m) {
+ return isa<AttributeSelfTypeParameter>(param);
+ };
+ auto *selfTypeParam = llvm::find_if(params, isSelfType);
+ if (std::count_if(selfTypeParam, params.end(), isSelfType) > 1) {
+ PrintFatalError(def.getLoc(),
+ "Only one attribute parameter can be marked as "
+ "AttributeSelfTypeParameter");
+ }
+ // Alternatively, if a type builder was specified, use that instead.
+ std::string attrStorageInit =
+ selfTypeParam == params.end() ? "" : selfTypeParam->getName().str();
+ if (attrDef->getTypeBuilder()) {
+ FmtContext ctx;
+ for (auto ¶m : params)
+ ctx.addSubst(strfmt("_{0}", param.getName()), param.getName());
+ attrStorageInit = tgfmt(*attrDef->getTypeBuilder(), &ctx);
+ }
+ ctor->addMemberInitializer("::mlir::AttributeStorage",
+ std::move(attrStorageInit));
+ // Initialize members that aren't the attribute's type.
+ for (auto ¶m : params)
+ if (selfTypeParam == params.end() || *selfTypeParam != param)
+ ctor->addMemberInitializer(param.getName(), param.getName());
+ } else {
+ for (auto ¶m : params)
+ ctor->addMemberInitializer(param.getName(), param.getName());
+ }
+}
+
+void DefGen::emitKeyType() {
+ std::string keyType("std::tuple<");
+ llvm::raw_string_ostream os(keyType);
+ llvm::interleaveComma(params, os,
+ [&](auto ¶m) { os << param.getCppType(); });
+ os << '>';
+ storageCls->declare<UsingDeclaration>("KeyTy", std::move(os.str()));
+}
+
+void DefGen::emitEquals() {
+ Method *eq = storageCls->addConstMethod<Method::Inline>(
+ "bool", "operator==", MethodParameter("const KeyTy &", "tblgenKey"));
+ auto &body = eq->body().indent();
+ auto scope = body.scope("return (", ");");
+ const auto eachFn = [&](auto it) {
+ FmtContext ctx({{"_lhs", isa<AttributeSelfTypeParameter>(it.value())
+ ? "getType()"
+ : it.value().getName()},
+ {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
+ Optional<StringRef> comparator = it.value().getComparator();
+ body << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &ctx);
+ };
+ llvm::interleave(llvm::enumerate(params), body, eachFn, ") && (");
+}
+
+void DefGen::emitHashKey() {
+ Method *hash = storageCls->addStaticInlineMethod(
+ "::llvm::hash_code", "hashKey",
+ MethodParameter("const KeyTy &", "tblgenKey"));
+ auto &body = hash->body().indent();
+ auto scope = body.scope("return ::llvm::hash_combine(", ");");
+ llvm::interleaveComma(llvm::enumerate(params), body, [&](auto it) {
+ body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index());
+ });
+}
+
+void DefGen::emitConstruct() {
+ Method *construct = storageCls->addMethod<Method::Inline>(
+ strfmt("{0} *", def.getStorageClassName()), "construct",
+ def.hasStorageCustomConstructor() ? Method::StaticDeclaration
+ : Method::Static,
+ MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
+ "allocator"),
+ MethodParameter("const KeyTy &", "tblgenKey"));
+ if (!def.hasStorageCustomConstructor()) {
+ auto &body = construct->body().indent();
+ for (auto it : llvm::enumerate(params)) {
+ body << formatv("auto {0} = std::get<{1}>(tblgenKey);\n",
+ it.value().getName(), it.index());
+ }
+ // Use the parameters' custom allocator code, if provided.
+ FmtContext ctx = FmtContext().addSubst("_allocator", "allocator");
+ for (auto ¶m : params) {
+ if (Optional<StringRef> allocCode = param.getAllocator()) {
+ ctx.withSelf(param.getName()).addSubst("_dst", param.getName());
+ body << tgfmt(*allocCode, &ctx) << '\n';
+ }
+ }
+ auto scope =
+ body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
+ def.getStorageClassName()),
+ ");");
+ llvm::interleaveComma(params, body,
+ [&](auto ¶m) { body << param.getName(); });
+ }
+}
+
+void DefGen::emitStorageClass() {
+ // Add the appropriate parent class.
+ storageCls->addParent(strfmt("::mlir::{0}Storage", valueType));
+ // Add the constructor.
+ emitStorageConstructor();
+ // Declare the key type.
+ emitKeyType();
+ // Add the comparison method.
+ emitEquals();
+ // Emit the key hash method.
+ emitHashKey();
+ // Emit the storage constructor. Just declare it if the user wants to define
+ // it themself.
+ emitConstruct();
+ // Emit the storage class members as public, at the very end of the struct.
+ storageCls->finalize();
+ for (auto ¶m : params)
+ if (!isa<AttributeSelfTypeParameter>(param))
+ storageCls->declare<Field>(param.getCppType(), param.getName());
+}
//===----------------------------------------------------------------------===//
// DefGenerator
@@ -154,28 +608,23 @@ class DefGenerator {
bool emitDefs(StringRef selectedDialect);
protected:
- DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os)
- : defRecords(std::move(defs)), os(os), isAttrGenerator(false) {}
+ DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os,
+ StringRef defType, StringRef valueType, bool isAttrGenerator)
+ : defRecords(std::move(defs)), os(os), defType(defType),
+ valueType(valueType), isAttrGenerator(isAttrGenerator) {}
- /// Emit the declaration of a single def.
- void emitDefDecl(const AttrOrTypeDef &def);
/// Emit the list of def type names.
void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
/// Emit the code to dispatch between
diff erent defs during parsing/printing.
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
- /// Emit the definition of a single def.
- void emitDefDef(const AttrOrTypeDef &def);
- /// Emit the storage class for the given def.
- void emitStorageClass(const AttrOrTypeDef &def);
- /// Emit the parser/printer for the given def.
- void emitParsePrint(const AttrOrTypeDef &def);
/// The set of def records to emit.
std::vector<llvm::Record *> defRecords;
+ /// The attribute or type class to emit.
/// The stream to emit to.
raw_ostream &os;
/// The prefix of the tablegen def name, e.g. Attr or Type.
- StringRef defTypePrefix;
+ StringRef defType;
/// The C++ base value type of the def, e.g. Attribute or Type.
StringRef valueType;
/// Flag indicating if this generator is for Attributes. False if the
@@ -186,19 +635,14 @@ class DefGenerator {
/// A specialized generator for AttrDefs.
struct AttrDefGenerator : public DefGenerator {
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
- : DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os) {
- isAttrGenerator = true;
- defTypePrefix = "Attr";
- valueType = "Attribute";
- }
+ : DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os, "Attr",
+ "Attribute", /*isAttrGenerator=*/true) {}
};
/// A specialized generator for TypeDefs.
struct TypeDefGenerator : public DefGenerator {
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
- : DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os) {
- defTypePrefix = "Type";
- valueType = "Type";
- }
+ : DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os, "Type",
+ "Type", /*isAttrGenerator=*/false) {}
};
} // end anonymous namespace
@@ -211,240 +655,13 @@ struct TypeDefGenerator : public DefGenerator {
static const char *const typeDefDeclHeader = R"(
namespace mlir {
class AsmParser;
-class DialectAsmParser;
class AsmPrinter;
-class DialectAsmPrinter;
-} // namespace mlir
-)";
-
-/// The code block for the start of a typeDef class declaration -- singleton
-/// case.
-///
-/// {0}: The name of the def class.
-/// {1}: The name of the type base class.
-/// {2}: The name of the base value type, e.g. Attribute or Type.
-/// {3}: The tablegen record type prefix, e.g. Attr or Type.
-/// {4}: The traits of the def class.
-static const char *const defDeclSingletonBeginStr = R"(
- class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage{4}> {{
- public:
- /// Inherit some necessary constructors from '{3}Base'.
- using Base::Base;
+} // end namespace mlir
)";
-/// The code block for the start of a class declaration -- parametric case.
-///
-/// {0}: The name of the def class.
-/// {1}: The name of the base class.
-/// {2}: The def storage class namespace.
-/// {3}: The storage class name.
-/// {4}: The name of the base value type, e.g. Attribute or Type.
-/// {5}: The tablegen record type prefix, e.g. Attr or Type.
-/// {6}: The traits of the def class.
-static const char *const defDeclParametricBeginStr = R"(
- namespace {2} {
- struct {3};
- } // end namespace {2}
- class {0} : public ::mlir::{4}::{5}Base<{0}, {1},
- {2}::{3}{6}> {{
- public:
- /// Inherit some necessary constructors from '{5}Base'.
- using Base::Base;
-
-)";
-
-/// The code snippet for print/parse of an Attribute/Type.
-///
-/// {0}: The name of the base value type, e.g. Attribute or Type.
-/// {1}: Extra parser parameters.
-static const char *const defDeclParsePrintStr = R"(
- static ::mlir::{0} parse(::mlir::AsmParser &parser{1});
- void print(::mlir::AsmPrinter &printer) const;
-)";
-
-/// The code block for the verify method declaration.
-///
-/// {0}: List of parameters, parameters style.
-static const char *const defDeclVerifyStr = R"(
- using Base::getChecked;
- static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0});
-)";
-
-/// Emit the builders for the given def.
-static void emitBuilderDecls(const AttrOrTypeDef &def, raw_ostream &os,
- ParamCommaFormatter ¶mTypes) {
- StringRef typeClass = def.getCppClassName();
- bool genCheckedMethods = def.genVerifyDecl();
- if (!def.skipDefaultBuilders()) {
- os << llvm::formatv(
- " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
- paramTypes);
- if (genCheckedMethods) {
- os << llvm::formatv(" static {0} "
- "getChecked(llvm::function_ref<::mlir::"
- "InFlightDiagnostic()> emitError, "
- "::mlir::MLIRContext *context{1});\n",
- typeClass, paramTypes);
- }
- }
-
- // Generate the builders specified by the user.
- for (const AttrOrTypeBuilder &builder : def.getBuilders()) {
- std::string paramStr;
- llvm::raw_string_ostream paramOS(paramStr);
- llvm::interleaveComma(
- builder.getParameters(), paramOS,
- [&](const AttrOrTypeBuilder::Parameter ¶m) {
- // Note: AttrOrTypeBuilder parameters are guaranteed to have names.
- paramOS << param.getCppType() << " " << *param.getName();
- if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
- paramOS << " = " << *defaultParamValue;
- });
- paramOS.flush();
-
- // Generate the `get` variant of the builder.
- os << " static " << typeClass << " get(";
- if (!builder.hasInferredContextParameter()) {
- os << "::mlir::MLIRContext *context";
- if (!paramStr.empty())
- os << ", ";
- }
- os << paramStr << ");\n";
-
- // Generate the `getChecked` variant of the builder.
- if (genCheckedMethods) {
- os << " static " << typeClass
- << " getChecked(llvm::function_ref<mlir::InFlightDiagnostic()> "
- "emitError";
- if (!builder.hasInferredContextParameter())
- os << ", ::mlir::MLIRContext *context";
- if (!paramStr.empty())
- os << ", ";
- os << paramStr << ");\n";
- }
- }
-}
-
-static void emitInterfaceMethodDecls(const InterfaceTrait *trait,
- raw_ostream &os) {
- Interface interface = trait->getInterface();
-
- // Get the set of methods that should always be declared.
- auto alwaysDeclaredMethodsVec = trait->getAlwaysDeclaredMethods();
- llvm::StringSet<> alwaysDeclaredMethods;
- alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
- alwaysDeclaredMethodsVec.end());
-
- for (const InterfaceMethod &method : interface.getMethods()) {
- // Don't declare if the method has a body.
- if (method.getBody())
- continue;
- // Don't declare if the method has a default implementation and the def
- // didn't request that it always be declared.
- if (method.getDefaultImplementation() &&
- !alwaysDeclaredMethods.count(method.getName()))
- continue;
-
- // Emit the method declaration.
- os << " " << (method.isStatic() ? "static " : "")
- << method.getReturnType() << " " << method.getName() << "(";
- llvm::interleaveComma(method.getArguments(), os,
- [&](const InterfaceMethod::Argument &arg) {
- os << arg.type << " " << arg.name;
- });
- os << ")" << (method.isStatic() ? "" : " const") << ";\n";
- }
-}
-
-void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
- SmallVector<AttrOrTypeParameter, 4> params;
- def.getParameters(params);
-
- // Build the trait list for this def.
- std::vector<std::string> traitList;
- StringSet<> traitSet;
- for (const Trait &baseTrait : def.getTraits()) {
- std::string traitStr;
- if (const auto *trait = dyn_cast<NativeTrait>(&baseTrait))
- traitStr = trait->getFullyQualifiedTraitName();
- else if (const auto *trait = dyn_cast<InterfaceTrait>(&baseTrait))
- traitStr = trait->getFullyQualifiedTraitName();
- else
- llvm_unreachable("unexpected Attribute/Type trait type");
-
- if (traitSet.insert(traitStr).second)
- traitList.emplace_back(std::move(traitStr));
- }
- std::string traitStr;
- if (!traitList.empty())
- traitStr = ", " + llvm::join(traitList, ", ");
-
- // Emit the beginning string template: either the singleton or parametric
- // template.
- if (def.getNumParameters() == 0) {
- os << formatv(defDeclSingletonBeginStr, def.getCppClassName(),
- def.getCppBaseClassName(), valueType, defTypePrefix,
- traitStr);
- } else {
- os << formatv(defDeclParametricBeginStr, def.getCppClassName(),
- def.getCppBaseClassName(), def.getStorageNamespace(),
- def.getStorageClassName(), valueType, defTypePrefix,
- traitStr);
- }
-
- // Emit the extra declarations first in case there's a definition in there.
- if (Optional<StringRef> extraDecl = def.getExtraDecls())
- os << *extraDecl << "\n";
-
- ParamCommaFormatter emitTypeNamePairsAfterComma(
- ParamCommaFormatter::EmitFormat::TypeNamePairs, params);
- if (!params.empty()) {
- emitBuilderDecls(def, os, emitTypeNamePairsAfterComma);
-
- // Emit the verify invariants declaration.
- if (def.genVerifyDecl())
- os << llvm::formatv(defDeclVerifyStr, emitTypeNamePairsAfterComma);
- }
-
- // Emit the mnenomic, if specified.
- if (auto mnenomic = def.getMnemonic()) {
- os << " static constexpr ::llvm::StringLiteral getMnemonic() {\n"
- << " return ::llvm::StringLiteral(\"" << mnenomic << "\");\n"
- << " }\n";
-
- // If mnemonic specified, emit print/parse declarations.
- if (def.getParserCode() || def.getPrinterCode() ||
- def.getAssemblyFormat() || !params.empty()) {
- os << llvm::formatv(defDeclParsePrintStr, valueType,
- isAttrGenerator ? ", ::mlir::Type type" : "");
- }
- }
-
- if (def.genAccessors()) {
- SmallVector<AttrOrTypeParameter, 4> parameters;
- def.getParameters(parameters);
-
- for (AttrOrTypeParameter ¶meter : parameters) {
- os << formatv(" {0} {1}() const;\n", parameter.getCppAccessorType(),
- getParameterAccessorName(parameter.getName()));
- }
- }
-
- // Emit any interface method declarations.
- for (const Trait &trait : def.getTraits()) {
- if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait)) {
- if (traitDef->shouldDeclareMethods())
- emitInterfaceMethodDecls(traitDef, os);
- }
- }
-
- // End the decl.
- os << " };\n";
-}
-
bool DefGenerator::emitDecls(StringRef selectedDialect) {
- emitSourceFileHeader((defTypePrefix + "Def Declarations").str(), os);
- IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os);
+ emitSourceFileHeader((defType + "Def Declarations").str(), os);
+ IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
// Output the common "header".
os << typeDefDeclHeader;
@@ -458,11 +675,11 @@ bool DefGenerator::emitDecls(StringRef selectedDialect) {
// Declare all the def classes first (in case they reference each other).
for (const AttrOrTypeDef &def : defs)
- os << " class " << def.getCppClassName() << ";\n";
+ os << "class " << def.getCppClassName() << ";\n";
// Emit the declarations.
for (const AttrOrTypeDef &def : defs)
- emitDefDecl(def);
+ DefGen(def).emitDecl(os);
}
// Emit the TypeID explicit specializations to have a single definition for
// each of these.
@@ -479,7 +696,7 @@ bool DefGenerator::emitDecls(StringRef selectedDialect) {
//===----------------------------------------------------------------------===//
void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
- IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_LIST", os);
+ IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os);
auto interleaveFn = [&](const AttrOrTypeDef &def) {
os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
};
@@ -491,17 +708,6 @@ void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
// GEN: Definitions
//===----------------------------------------------------------------------===//
-/// The code block used to start the auto-generated parser function.
-///
-/// {0}: The name of the base value type, e.g. Attribute or Type.
-/// {1}: Additional parser parameters.
-static const char *const defParserDispatchStartStr = R"(
-static ::mlir::OptionalParseResult generated{0}Parser(
- ::mlir::AsmParser &parser,
- ::llvm::StringRef mnemonic{1},
- ::mlir::{0} &value) {{
-)";
-
/// The code block for default attribute parser/printer dispatch boilerplate.
/// {0}: the dialect fully qualified class name.
static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
@@ -555,412 +761,6 @@ void {0}::printType(::mlir::Type type,
}
)";
-/// The code block used to start the auto-generated printer function.
-///
-/// {0}: The name of the base value type, e.g. Attribute or Type.
-static const char *const defPrinterDispatchStartStr = R"(
-static ::mlir::LogicalResult generated{0}Printer(
- ::mlir::{0} def, ::mlir::AsmPrinter &printer) {{
- return ::llvm::TypeSwitch<::mlir::{0}, ::mlir::LogicalResult>(def)
-)";
-
-/// Beginning of storage class.
-/// {0}: Storage class namespace.
-/// {1}: Storage class c++ name.
-/// {2}: Parameters parameters.
-/// {3}: Parameter initializer string.
-/// {4}: Parameter types.
-/// {5}: The name of the base value type, e.g. Attribute or Type.
-static const char *const defStorageClassBeginStr = R"(
-namespace {0} {{
- struct {1} : public ::mlir::{5}Storage {{
- {1} ({2})
- : {3} {{ }
-
- /// The hash key is a tuple of the parameter types.
- using KeyTy = std::tuple<{4}>;
-)";
-
-/// The storage class' constructor template.
-///
-/// {0}: storage class name.
-/// {1}: The name of the base value type, e.g. Attribute or Type.
-static const char *const defStorageClassConstructorBeginStr = R"(
- /// Define a construction method for creating a new instance of this
- /// storage.
- static {0} *construct(::mlir::{1}StorageAllocator &allocator,
- const KeyTy &tblgenKey) {{
-)";
-
-/// The storage class' constructor return template.
-///
-/// {0}: storage class name.
-/// {1}: list of parameters.
-static const char *const defStorageClassConstructorEndStr = R"(
- return new (allocator.allocate<{0}>())
- {0}({1});
- }
-)";
-
-/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
-static void emitStorageParameterAllocation(const AttrOrTypeDef &def,
- raw_ostream &os) {
- SmallVector<AttrOrTypeParameter> parameters;
- def.getParameters(parameters);
- FmtContext fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
- for (AttrOrTypeParameter ¶meter : parameters) {
- if (Optional<StringRef> allocCode = parameter.getAllocator()) {
- fmtCtxt.withSelf(parameter.getName());
- fmtCtxt.addSubst("_dst", parameter.getName());
- os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n";
- }
- }
-}
-
-/// Builds a code block that initializes the attribute storage of 'def'.
-/// Attribute initialization is separated from Type initialization given that
-/// the Attribute also needs to initialize its self-type, which has multiple
-/// means of initialization.
-static std::string buildAttributeStorageParamInitializer(
- const AttrOrTypeDef &def, ArrayRef<AttrOrTypeParameter> parameters) {
- std::string paramInitializer;
- llvm::raw_string_ostream paramOS(paramInitializer);
- paramOS << "::mlir::AttributeStorage(";
-
- // If this is an attribute, we need to check for value type initialization.
- Optional<size_t> selfParamIndex;
- for (auto it : llvm::enumerate(parameters)) {
- const auto *selfParam = dyn_cast<AttributeSelfTypeParameter>(&it.value());
- if (!selfParam)
- continue;
- if (selfParamIndex) {
- llvm::PrintFatalError(def.getLoc(),
- "Only one attribute parameter can be marked as "
- "AttributeSelfTypeParameter");
- }
- paramOS << selfParam->getName();
- selfParamIndex = it.index();
- }
-
- // If we didn't find a self param, but the def has a type builder we use that
- // to construct the type.
- if (!selfParamIndex) {
- const AttrDef &attrDef = cast<AttrDef>(def);
- if (Optional<StringRef> typeBuilder = attrDef.getTypeBuilder()) {
- FmtContext fmtContext;
- for (const AttrOrTypeParameter ¶m : parameters)
- fmtContext.addSubst(("_" + param.getName()).str(), param.getName());
- paramOS << tgfmt(*typeBuilder, &fmtContext);
- }
- }
- paramOS << ")";
-
- // Append the parameters to the initializer.
- for (auto it : llvm::enumerate(parameters))
- if (it.index() != selfParamIndex)
- paramOS << llvm::formatv(", {0}({0})", it.value().getName());
-
- return paramOS.str();
-}
-
-void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
- SmallVector<AttrOrTypeParameter, 4> params;
- def.getParameters(params);
-
- // Collect the parameter types.
- auto parameterTypes =
- llvm::map_range(params, [](const AttrOrTypeParameter ¶meter) {
- return parameter.getCppType();
- });
- std::string parameterTypeList = llvm::join(parameterTypes, ", ");
-
- // Collect the parameter initializer.
- std::string paramInitializer;
- if (isAttrGenerator) {
- paramInitializer = buildAttributeStorageParamInitializer(def, params);
-
- } else {
- llvm::raw_string_ostream initOS(paramInitializer);
- llvm::interleaveComma(params, initOS, [&](const AttrOrTypeParameter &it) {
- initOS << llvm::formatv("{0}({0})", it.getName());
- });
- }
-
- // * Emit most of the storage class up until the hashKey body.
- os << formatv(
- defStorageClassBeginStr, def.getStorageNamespace(),
- def.getStorageClassName(),
- ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
- params, /*prependComma=*/false),
- paramInitializer, parameterTypeList, valueType);
-
- // * Emit the comparison method.
- os << " bool operator==(const KeyTy &tblgenKey) const {\n";
- for (auto it : llvm::enumerate(params)) {
- os << " if (!(";
-
- // Build the comparator context.
- bool isSelfType = isa<AttributeSelfTypeParameter>(it.value());
- FmtContext context;
- context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName())
- .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(tblgenKey)");
-
- // Use the parameter specified comparator if possible, otherwise default to
- // operator==.
- Optional<StringRef> comparator = it.value().getComparator();
- os << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &context);
- os << "))\n return false;\n";
- }
- os << " return true;\n }\n";
-
- // * Emit the haskKey method.
- os << " static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {\n";
-
- // Extract each parameter from the key.
- os << " return ::llvm::hash_combine(";
- llvm::interleaveComma(
- llvm::seq<unsigned>(0, params.size()), os,
- [&](unsigned it) { os << "std::get<" << it << ">(tblgenKey)"; });
- os << ");\n }\n";
-
- // * Emit the construct method.
-
- // If user wants to build the storage constructor themselves, declare it
- // here and then they can write the definition elsewhere.
- if (def.hasStorageCustomConstructor()) {
- os << llvm::formatv(" static {0} *construct(::mlir::{1}StorageAllocator "
- "&allocator, const KeyTy &tblgenKey);\n",
- def.getStorageClassName(), valueType);
-
- // Otherwise, generate one.
- } else {
- // First, unbox the parameters.
- os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(),
- valueType);
- for (unsigned i = 0, e = params.size(); i < e; ++i) {
- os << formatv(" auto {0} = std::get<{1}>(tblgenKey);\n",
- params[i].getName(), i);
- }
-
- // Second, reassign the parameter variables with allocation code, if it's
- // specified.
- emitStorageParameterAllocation(def, os);
-
- // Last, return an allocated copy.
- auto parameterNames = llvm::map_range(
- params, [](const auto ¶m) { return param.getName(); });
- os << formatv(defStorageClassConstructorEndStr, def.getStorageClassName(),
- llvm::join(parameterNames, ", "));
- }
-
- // * Emit the parameters as storage class members.
- for (const AttrOrTypeParameter ¶meter : params) {
- // Attribute value types are not stored as fields in the storage.
- if (!isa<AttributeSelfTypeParameter>(parameter))
- os << " " << parameter.getCppType() << " " << parameter.getName()
- << ";\n";
- }
- os << " };\n";
-
- os << "} // namespace " << def.getStorageNamespace() << "\n";
-}
-
-void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
- auto printerCode = def.getPrinterCode();
- auto parserCode = def.getParserCode();
- auto assemblyFormat = def.getAssemblyFormat();
- if (assemblyFormat && (printerCode || parserCode)) {
- // Custom assembly format cannot be specified at the same time as either
- // custom printer or parser code.
- PrintFatalError(def.getLoc(),
- def.getName() + ": assembly format cannot be specified at "
- "the same time as printer or parser code");
- }
-
- // Generate a parser and printer based on the assembly format, if specified.
- if (assemblyFormat) {
- // A custom assembly format requires accessors to be generated for the
- // generated printer.
- if (!def.genAccessors()) {
- PrintFatalError(def.getLoc(),
- def.getName() +
- ": the generated printer from 'assemblyFormat' "
- "requires 'genAccessors' to be true");
- }
- return generateAttrOrTypeFormat(def, os);
- }
-
- // Emit the printer code, if specified.
- if (printerCode) {
- // Both the mnenomic and printerCode must be defined (for parity with
- // parserCode).
- os << "void " << def.getCppClassName()
- << "::print(::mlir::AsmPrinter &printer) const {\n";
- if (printerCode->empty()) {
- // If no code specified, emit error.
- PrintFatalError(def.getLoc(),
- def.getName() +
- ": printer (if specified) must have non-empty code");
- }
- FmtContext fmtCtxt = FmtContext().addSubst("_printer", "printer");
- os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
- }
-
- // Emit the parser code, if specified.
- if (parserCode) {
- FmtContext fmtCtxt;
- fmtCtxt.addSubst("_parser", "parser")
- .addSubst("_ctxt", "parser.getContext()");
-
- // The mnenomic must be defined so the dispatcher knows how to dispatch.
- os << llvm::formatv("::mlir::{0} {1}::parse("
- "::mlir::AsmParser &parser",
- valueType, def.getCppClassName());
- if (isAttrGenerator) {
- // Attributes also accept a type parameter instead of a context.
- os << ", ::mlir::Type type";
- fmtCtxt.addSubst("_type", "type");
- }
- os << ") {\n";
-
- if (parserCode->empty()) {
- PrintFatalError(def.getLoc(),
- def.getName() +
- ": parser (if specified) must have non-empty code");
- }
- os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
- }
-}
-
-/// Replace all instances of 'from' to 'to' in `str` and return the new string.
-static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
- size_t pos = 0;
- while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
- str.replace(pos, from.size(), to.data(), to.size());
- return str;
-}
-
-/// Emit the builders for the given def.
-static void emitBuilderDefs(const AttrOrTypeDef &def, raw_ostream &os,
- ArrayRef<AttrOrTypeParameter> params) {
- bool genCheckedMethods = def.genVerifyDecl();
- StringRef className = def.getCppClassName();
- if (!def.skipDefaultBuilders()) {
- os << llvm::formatv(
- "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
- " return Base::get(context{2});\n}\n",
- className,
- ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
- params),
- ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams,
- params));
- if (genCheckedMethods) {
- os << llvm::formatv(
- "{0} {0}::getChecked("
- "llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, "
- "::mlir::MLIRContext *context{1}) {{\n"
- " return Base::getChecked(emitError, context{2});\n}\n",
- className,
- ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
- params),
- ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams,
- params));
- }
- }
-
- auto builderFmtCtx =
- FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get");
- auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get");
- auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context");
-
- // Generate the builders specified by the user.
- for (const AttrOrTypeBuilder &builder : def.getBuilders()) {
- Optional<StringRef> body = builder.getBody();
- if (!body)
- continue;
- std::string paramStr;
- llvm::raw_string_ostream paramOS(paramStr);
- llvm::interleaveComma(builder.getParameters(), paramOS,
- [&](const AttrOrTypeBuilder::Parameter ¶m) {
- // Note: AttrOrTypeBuilder parameters are guaranteed
- // to have names.
- paramOS << param.getCppType() << " "
- << *param.getName();
- });
- paramOS.flush();
-
- // Emit the `get` variant of the builder.
- os << llvm::formatv("{0} {0}::get(", className);
- if (!builder.hasInferredContextParameter()) {
- os << "::mlir::MLIRContext *context";
- if (!paramStr.empty())
- os << ", ";
- os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
- tgfmt(*body, &builderFmtCtx).str());
- } else {
- os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
- tgfmt(*body, &inferredCtxBuilderFmtCtx).str());
- }
-
- // Emit the `getChecked` variant of the builder.
- if (genCheckedMethods) {
- os << llvm::formatv("{0} "
- "{0}::getChecked(llvm::function_ref<::mlir::"
- "InFlightDiagnostic()> emitErrorFn",
- className);
- std::string checkedBody =
- replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, ");
- if (!builder.hasInferredContextParameter()) {
- os << ", ::mlir::MLIRContext *context";
- checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str();
- }
- if (!paramStr.empty())
- os << ", ";
- os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody);
- }
- }
-}
-
-/// Print all the def-specific definition code.
-void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
- NamespaceEmitter ns(os, def.getDialect());
-
- SmallVector<AttrOrTypeParameter, 4> parameters;
- def.getParameters(parameters);
- if (!parameters.empty()) {
- // Emit the storage class, if requested and necessary.
- if (def.genStorageClass())
- emitStorageClass(def);
-
- // Emit the builders for this def.
- emitBuilderDefs(def, os, parameters);
-
- // Generate accessor definitions only if we also generate the storage class.
- // Otherwise, let the user define the exact accessor definition.
- if (def.genAccessors() && def.genStorageClass()) {
- for (const AttrOrTypeParameter ¶m : parameters) {
- SmallString<32> paramStorageName;
- if (isa<AttributeSelfTypeParameter>(param)) {
- Twine("getType().cast<" + param.getCppType() + ">()")
- .toVector(paramStorageName);
- } else {
- paramStorageName = param.getName();
- }
-
- os << formatv("{0} {3}::{1}() const {{ return getImpl()->{2}; }\n",
- param.getCppAccessorType(),
- getParameterAccessorName(param.getName()),
- paramStorageName, def.getCppClassName());
- }
- }
- }
-
- // If mnemonic is specified maybe print definitions for the parser and printer
- // code, if they're specified.
- if (def.getMnemonic())
- emitParsePrint(def);
-}
-
/// Emit the dialect printer/parser dispatcher. User's code should call these
/// functions from their dialect's print/parse methods.
void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
@@ -969,59 +769,66 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
})) {
return;
}
+ // Declare the parser.
+ SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
+ {"::llvm::StringRef", "mnemonic"}};
+ if (isAttrGenerator)
+ params.emplace_back("::mlir::Type", "type");
+ params.emplace_back(strfmt("::mlir::{0} &", valueType), "value");
+ Method parse("::mlir::OptionalParseResult",
+ strfmt("generated{0}Parser", valueType), Method::StaticInline,
+ std::move(params));
+ // Declare the printer.
+ Method printer("::mlir::LogicalResult",
+ strfmt("generated{0}Printer", valueType), Method::StaticInline,
+ {{strfmt("::mlir::{0}", valueType), "def"},
+ {"::mlir::AsmPrinter &", "printer"}});
// The parser dispatch is just a list of if-elses, matching on the mnemonic
// and calling the def's parse function.
- os << llvm::formatv(defParserDispatchStartStr, valueType,
- isAttrGenerator ? ", ::mlir::Type type" : "");
- for (const AttrOrTypeDef &def : defs) {
- if (def.getMnemonic()) {
- os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) { \n"
- " value = {0}::{1}::",
- def.getDialect().getCppNamespace(), def.getCppClassName());
-
- // If the def has no parameters and no parser code, just invoke a normal
- // `get`.
- if (def.getNumParameters() == 0 && !def.getParserCode()) {
- os << "get(parser.getContext());\n";
- os << " return ::mlir::success(!!value);\n }\n";
- continue;
- }
-
- os << "parse(parser" << (isAttrGenerator ? ", type" : "")
- << ");\n return ::mlir::success(!!value);\n }\n";
- }
+ const char *const getValueForMnemonic =
+ R"( if (mnemonic == {0}::getMnemonic()) {{
+ value = {0}::{1};
+ return ::mlir::success(!!value);
}
- os << " return {};\n";
- os << "}\n\n";
-
+)";
// The printer dispatch uses llvm::TypeSwitch to find and call the correct
// printer.
- os << llvm::formatv(defPrinterDispatchStartStr, valueType);
- for (const AttrOrTypeDef &def : defs) {
- Optional<StringRef> mnemonic = def.getMnemonic();
- if (!mnemonic)
+ printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType
+ << ", ::mlir::LogicalResult>(def)";
+ const char *const printValue = R"( .Case<{0}>([&](auto t) {{
+ printer << {0}::getMnemonic();{1}
+ return ::mlir::success();
+ })
+)";
+ for (auto &def : defs) {
+ if (!def.getMnemonic())
continue;
+ std::string defClass = strfmt(
+ "{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName());
+ // If the def has no parameters or parser code, invoke a normal `get`.
+ std::string parseOrGet =
+ def.needsParserPrinter() || def.hasGeneratedParser()
+ ? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "")
+ : "get(parser.getContext())";
+ parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet);
- StringRef cppNamespace = def.getDialect().getCppNamespace();
- StringRef cppClassName = def.getCppClassName();
- os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ",
- cppNamespace, cppClassName);
-
- os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace,
- cppClassName);
// If the def has no parameters and no printer, just print the mnemonic.
- if (def.getNumParameters() != 0 || def.getPrinterCode())
- os << "t.print(printer);";
- os << "\n return ::mlir::success();\n })\n";
+ StringRef printDef = "";
+ if (def.needsParserPrinter() || def.hasGeneratedPrinter())
+ printDef = "\nt.print(printer);";
+ printer.body() << llvm::formatv(printValue, defClass, printDef);
}
- os << llvm::formatv(
- " .Default([](::mlir::{0}) {{ return ::mlir::failure(); });\n}\n\n",
- valueType);
+ parse.body() << " return {};";
+ printer.body() << " .Default([](auto) { return ::mlir::failure(); });";
+
+ raw_indented_ostream indentedOs(os);
+ parse.writeDeclTo(indentedOs);
+ printer.writeDeclTo(indentedOs);
}
bool DefGenerator::emitDefs(StringRef selectedDialect) {
- emitSourceFileHeader((defTypePrefix + "Def Definitions").str(), os);
+ emitSourceFileHeader((defType + "Def Definitions").str(), os);
SmallVector<AttrOrTypeDef, 16> defs;
collectAllDefs(selectedDialect, defRecords, defs);
@@ -1029,10 +836,14 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
return false;
emitTypeDefList(defs);
- IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os);
+ IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
emitParsePrintDispatch(defs);
for (const AttrOrTypeDef &def : defs) {
- emitDefDef(def);
+ {
+ NamespaceEmitter ns(os, def.getDialect());
+ DefGen gen(def);
+ gen.emitDef(os);
+ }
// Emit the TypeID explicit specializations to have a single symbol def.
if (!def.getDialect().getCppNamespace().empty())
os << "DEFINE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace()
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 339fba7bedccd..3c6035b7baf5b 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -158,21 +158,6 @@ class StructDirective
// Format Strings
//===----------------------------------------------------------------------===//
-/// Format for defining an attribute parser.
-///
-/// $0: The attribute C++ class name.
-static const char *const attrParserDefn = R"(
-::mlir::Attribute $0::parse(::mlir::AsmParser &$_parser,
- ::mlir::Type $_type) {
-)";
-
-/// Format for defining a type parser.
-///
-/// $0: The type C++ class name.
-static const char *const typeParserDefn = R"(
-::mlir::Type $0::parse(::mlir::AsmParser &$_parser) {
-)";
-
/// Default parser for attribute or type parameters.
static const char *const defaultParameterParser =
"::mlir::FieldParser<$0>::parse($_parser)";
@@ -186,13 +171,6 @@ static const char *const defaultParameterPrinter = "$_printer << $_self";
static const char *const parseErrorStr =
"$_parser.emitError($_parser.getCurrentLocation(), ";
-/// Format for defining an attribute or type printer.
-///
-/// $0: The attribute or type C++ class name.
-static const char *const attrOrTypePrinterDefn = R"(
-void $0::print(::mlir::AsmPrinter &$_printer) const {
-)";
-
/// Loop declaration for struct parser.
///
/// $0: Number of expected parameters.
@@ -212,12 +190,12 @@ static const char *const structParseLoopStart = R"(
/// {0}: Code template for printing an error.
/// {1}: Number of elements in the struct.
static const char *const structParseLoopEnd = R"({{
- {0}"duplicate or unknown struct parameter name: ") << _paramKey;
- return {{};
- }
- if ((_index != {1} - 1) && parser.parseComma())
- return {{};
+ {0}"duplicate or unknown struct parameter name: ") << _paramKey;
+ return {{};
}
+ if ((_index != {1} - 1) && parser.parseComma())
+ return {{};
+}
)";
/// Code format to parse a variable. Separate by lines because variable parsers
@@ -228,26 +206,14 @@ static const char *const structParseLoopEnd = R"({{
/// {2}: Code template for printing an error.
/// {3}: Name of the attribute or type.
/// {4}: C++ class of the parameter.
-static const char *const variableParser[] = {
- " // Parse variable '{0}'",
- " _result_{0} = {1};",
- " if (failed(_result_{0})) {{",
- " {2}\"failed to parse {3} parameter '{0}' which is to be a `{4}`\");",
- " return {{};",
- " }",
-};
-
-//===----------------------------------------------------------------------===//
-// Utility Functions
-//===----------------------------------------------------------------------===//
-
-/// Get a list of an attribute's or type's parameters. These can be wrapper
-/// objects around `AttrOrTypeParameter` or string inits.
-static auto getParameters(const AttrOrTypeDef &def) {
- SmallVector<AttrOrTypeParameter> params;
- def.getParameters(params);
- return params;
+static const char *const variableParser = R"(
+// Parse variable '{0}'
+_result_{0} = {1};
+if (failed(_result_{0})) {{
+ {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
+ return {{};
}
+)";
//===----------------------------------------------------------------------===//
// AttrOrTypeFormat
@@ -261,35 +227,34 @@ class AttrOrTypeFormat {
: def(def), elements(std::move(elements)) {}
/// Generate the attribute or type parser.
- void genParser(raw_ostream &os);
+ void genParser(MethodBody &os);
/// Generate the attribute or type printer.
- void genPrinter(raw_ostream &os);
+ void genPrinter(MethodBody &os);
private:
/// Generate the parser code for a specific format element.
- void genElementParser(Element *el, FmtContext &ctx, raw_ostream &os);
+ void genElementParser(Element *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a literal.
- void genLiteralParser(StringRef value, FmtContext &ctx, raw_ostream &os,
- unsigned indent = 0);
+ void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a variable.
void genVariableParser(const AttrOrTypeParameter ¶m, FmtContext &ctx,
- raw_ostream &os, unsigned indent = 0);
+ MethodBody &os);
/// Generate the parser code for a `params` directive.
- void genParamsParser(ParamsDirective *el, FmtContext &ctx, raw_ostream &os);
+ void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `struct` directive.
- void genStructParser(StructDirective *el, FmtContext &ctx, raw_ostream &os);
+ void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a specific format element.
- void genElementPrinter(Element *el, FmtContext &ctx, raw_ostream &os);
+ void genElementPrinter(Element *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a literal.
- void genLiteralPrinter(StringRef value, FmtContext &ctx, raw_ostream &os);
+ void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a variable.
void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx,
- raw_ostream &os);
+ MethodBody &os);
/// Generate the printer code for a `params` directive.
- void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, raw_ostream &os);
+ void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `struct` directive.
- void genStructPrinter(StructDirective *el, FmtContext &ctx, raw_ostream &os);
+ void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
/// The ODS definition of the attribute or type whose format is being used to
/// generate a parser and printer.
@@ -308,23 +273,18 @@ class AttrOrTypeFormat {
// ParserGen
//===----------------------------------------------------------------------===//
-void AttrOrTypeFormat::genParser(raw_ostream &os) {
+void AttrOrTypeFormat::genParser(MethodBody &os) {
FmtContext ctx;
ctx.addSubst("_parser", "parser");
-
- /// Generate the definition.
- if (isa<AttrDef>(def)) {
- ctx.addSubst("_type", "attrType");
- os << tgfmt(attrParserDefn, &ctx, def.getCppClassName());
- } else {
- os << tgfmt(typeParserDefn, &ctx, def.getCppClassName());
- }
+ if (isa<AttrDef>(def))
+ ctx.addSubst("_type", "type");
+ os.indent();
/// Declare variables to store all of the parameters. Allocated parameters
/// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
/// FailureOr<T> to defer type construction for parameters that are parsed in
/// a loop (parsers return FailureOr anyways).
- SmallVector<AttrOrTypeParameter> params = getParameters(def);
+ ArrayRef<AttrOrTypeParameter> params = def.getParameters();
for (const AttrOrTypeParameter ¶m : params) {
os << formatv(" ::mlir::FailureOr<{0}> _result_{1};\n",
param.getCppStorageType(), param.getName());
@@ -332,8 +292,8 @@ void AttrOrTypeFormat::genParser(raw_ostream &os) {
/// Store the initial location of the parser.
ctx.addSubst("_loc", "loc");
- os << tgfmt(" ::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
- " (void) $_loc;\n",
+ os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
+ "(void) $_loc;\n",
&ctx);
/// Generate call to each parameter parser.
@@ -343,19 +303,19 @@ void AttrOrTypeFormat::genParser(raw_ostream &os) {
/// Generate call to the attribute or type builder. Use the checked getter
/// if one was generated.
if (def.genVerifyDecl()) {
- os << tgfmt(" return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
+ os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
&ctx, def.getCppClassName());
} else {
- os << tgfmt(" return $0::get($_parser.getContext()", &ctx,
+ os << tgfmt("return $0::get($_parser.getContext()", &ctx,
def.getCppClassName());
}
for (const AttrOrTypeParameter ¶m : params)
os << formatv(",\n _result_{0}.getValue()", param.getName());
- os << ");\n}\n\n";
+ os << ");";
}
void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx,
- raw_ostream &os) {
+ MethodBody &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralParser(literal->getSpelling(), ctx, os);
if (auto *var = dyn_cast<VariableElement>(el))
@@ -369,9 +329,9 @@ void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx,
}
void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx,
- raw_ostream &os, unsigned indent) {
- os.indent(indent) << " // Parse literal '" << value << "'\n";
- os.indent(indent) << tgfmt(" if ($_parser.parse", &ctx);
+ MethodBody &os) {
+ os << "// Parse literal '" << value << "'\n";
+ os << tgfmt("if ($_parser.parse", &ctx);
if (value.front() == '_' || isalpha(value.front())) {
os << "Keyword(\"" << value << "\")";
} else {
@@ -395,28 +355,23 @@ void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx,
}
os << ")\n";
// Parser will emit an error
- os.indent(indent) << " return {};\n";
+ os << " return {};\n";
}
void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter ¶m,
- FmtContext &ctx, raw_ostream &os,
- unsigned indent) {
+ FmtContext &ctx, MethodBody &os) {
/// Check for a custom parser. Use the default attribute parser otherwise.
auto customParser = param.getParser();
auto parser =
customParser ? *customParser : StringRef(defaultParameterParser);
- for (const char *line : variableParser) {
- os.indent(indent) << formatv(line, param.getName(),
- tgfmt(parser, &ctx, param.getCppStorageType()),
- tgfmt(parseErrorStr, &ctx), def.getName(),
- param.getCppType())
- << "\n";
- }
+ os << formatv(variableParser, param.getName(),
+ tgfmt(parser, &ctx, param.getCppStorageType()),
+ tgfmt(parseErrorStr, &ctx), def.getName(), param.getCppType());
}
void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
- raw_ostream &os) {
- os << " // Parse parameter list\n";
+ MethodBody &os) {
+ os << "// Parse parameter list\n";
llvm::interleave(
el->getParams(),
[&](auto param) { this->genVariableParser(param, ctx, os); },
@@ -424,28 +379,30 @@ void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
}
void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx,
- raw_ostream &os) {
- os << " // Parse parameter struct\n";
+ MethodBody &os) {
+ os << "// Parse parameter struct\n";
/// Declare a "seen" variable for each key.
for (const AttrOrTypeParameter ¶m : el->getParams())
- os << formatv(" bool _seen_{0} = false;\n", param.getName());
+ os << formatv("bool _seen_{0} = false;\n", param.getName());
/// Generate the parsing loop.
- os << tgfmt(structParseLoopStart, &ctx, el->getNumParams());
- genLiteralParser("=", ctx, os, 2);
- os << " ";
+ os.getStream().printReindented(
+ tgfmt(structParseLoopStart, &ctx, el->getNumParams()).str());
+ os.indent();
+ genLiteralParser("=", ctx, os);
for (const AttrOrTypeParameter ¶m : el->getParams()) {
os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
- " _seen_{0} = true;\n",
+ " _seen_{0} = true;\n",
param.getName());
- genVariableParser(param, ctx, os, 4);
- os << " } else ";
+ genVariableParser(param, ctx, os.indent());
+ os.unindent() << "} else ";
}
+ os.unindent();
/// Duplicate or unknown parameter.
- os << formatv(structParseLoopEnd, tgfmt(parseErrorStr, &ctx),
- el->getNumParams());
+ os.getStream().printReindented(strfmt(
+ structParseLoopEnd, tgfmt(parseErrorStr, &ctx), el->getNumParams()));
/// Because the loop loops N times and each non-failing iteration sets 1 of
/// N flags, successfully exiting the loop means that all parameters have been
@@ -457,24 +414,19 @@ void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx,
// PrinterGen
//===----------------------------------------------------------------------===//
-void AttrOrTypeFormat::genPrinter(raw_ostream &os) {
+void AttrOrTypeFormat::genPrinter(MethodBody &os) {
FmtContext ctx;
ctx.addSubst("_printer", "printer");
- /// Generate the definition.
- os << tgfmt(attrOrTypePrinterDefn, &ctx, def.getCppClassName());
-
/// Generate printers.
shouldEmitSpace = true;
lastWasPunctuation = false;
for (auto &el : elements)
genElementPrinter(el.get(), ctx, os);
-
- os << "}\n\n";
}
void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx,
- raw_ostream &os) {
+ MethodBody &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralPrinter(literal->getSpelling(), ctx, os);
if (auto *params = dyn_cast<ParamsDirective>(el))
@@ -488,7 +440,7 @@ void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx,
}
void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
- raw_ostream &os) {
+ MethodBody &os) {
/// Don't insert a space before certain punctuation.
bool needSpace =
shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
@@ -502,7 +454,7 @@ void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
}
void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m,
- FmtContext &ctx, raw_ostream &os) {
+ FmtContext &ctx, MethodBody &os) {
/// Insert a space before the next parameter, if necessary.
if (shouldEmitSpace || !lastWasPunctuation)
os << tgfmt(" $_printer << ' ';\n", &ctx);
@@ -518,7 +470,7 @@ void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m,
}
void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
- raw_ostream &os) {
+ MethodBody &os) {
llvm::interleave(
el->getParams(),
[&](auto param) { this->genVariablePrinter(param, ctx, os); },
@@ -526,13 +478,12 @@ void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
}
void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
- raw_ostream &os) {
+ MethodBody &os) {
llvm::interleave(
el->getParams(),
[&](auto param) {
this->genLiteralPrinter(param.getName(), ctx, os);
this->genLiteralPrinter("=", ctx, os);
- os << tgfmt(" $_printer << ' ';\n", &ctx);
this->genVariablePrinter(param, ctx, os);
},
[&]() { this->genLiteralPrinter(",", ctx, os); });
@@ -624,8 +575,7 @@ FailureOr<AttrOrTypeFormat> FormatParser::parse() {
}
/// Check that all parameters have been seen.
- SmallVector<AttrOrTypeParameter> params = getParameters(def);
- for (auto it : llvm::enumerate(params)) {
+ for (auto &it : llvm::enumerate(def.getParameters())) {
if (!seenParams.test(it.index())) {
return emitError("format is missing reference to parameter: " +
it.value().getName());
@@ -669,7 +619,7 @@ FormatParser::parseVariable(ParserContext ctx) {
auto name = curToken.getSpelling().drop_front();
/// Lookup the parameter.
- SmallVector<AttrOrTypeParameter> params = getParameters(def);
+ ArrayRef<AttrOrTypeParameter> params = def.getParameters();
auto *it = llvm::find_if(
params, [&](auto ¶m) { return param.getName() == name; });
@@ -705,10 +655,9 @@ FormatParser::parseDirective(ParserContext ctx) {
FailureOr<std::unique_ptr<Element>> FormatParser::parseParamsDirective() {
consumeToken();
/// Collect all of the attribute's or type's parameters.
- SmallVector<AttrOrTypeParameter> params = getParameters(def);
SmallVector<std::unique_ptr<Element>> vars;
/// Ensure that none of the parameters have already been captured.
- for (auto it : llvm::enumerate(params)) {
+ for (auto it : llvm::enumerate(def.getParameters())) {
if (seenParams.test(it.index())) {
return emitError("`params` captures duplicate parameter: " +
it.value().getName());
@@ -759,15 +708,16 @@ FailureOr<std::unique_ptr<Element>> FormatParser::parseStructDirective() {
//===----------------------------------------------------------------------===//
void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
- raw_ostream &os) {
+ MethodBody &parser,
+ MethodBody &printer) {
llvm::SourceMgr mgr;
mgr.AddNewSourceBuffer(
llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()),
llvm::SMLoc());
/// Parse the custom assembly format>
- FormatParser parser(mgr, def);
- FailureOr<AttrOrTypeFormat> format = parser.parse();
+ FormatParser fmtParser(mgr, def);
+ FailureOr<AttrOrTypeFormat> format = fmtParser.parse();
if (failed(format)) {
if (formatErrorIsFatal)
PrintFatalError(def.getLoc(), "failed to parse assembly format");
@@ -775,6 +725,6 @@ void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
}
/// Generate the parser and printer.
- format->genParser(os);
- format->genPrinter(os);
+ format->genParser(parser);
+ format->genPrinter(printer);
}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
index 2a10a157dfc90..6f24de5ea4f62 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
@@ -9,9 +9,7 @@
#ifndef MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
#define MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
-#include "llvm/Support/raw_ostream.h"
-
-#include <string>
+#include "mlir/TableGen/Class.h"
namespace mlir {
namespace tblgen {
@@ -19,7 +17,8 @@ class AttrOrTypeDef;
/// Generate a parser and printer based on a custom assembly format for an
/// attribute or type.
-void generateAttrOrTypeFormat(const AttrOrTypeDef &def, llvm::raw_ostream &os);
+void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser,
+ MethodBody &printer);
/// From the parameter name, get the name of the accessor function in camelcase.
/// The first letter of the parameter is upper-cased and prefixed with "get".
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index a937a9d89a1d3..86f2ea0335933 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -15,6 +15,7 @@ add_tablegen(mlir-tblgen MLIR
LLVMIRConversionGen.cpp
LLVMIRIntrinsicGen.cpp
mlir-tblgen.cpp
+ OpClass.cpp
OpDefinitionsGen.cpp
OpDocGen.cpp
OpFormatGen.cpp
diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
new file mode 100644
index 0000000000000..9524dc9210b82
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp
@@ -0,0 +1,34 @@
+//===- OpClass.cpp - Implementation of an Op Class ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "OpClass.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+//===----------------------------------------------------------------------===//
+// OpClass definitions
+//===----------------------------------------------------------------------===//
+
+OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
+ : Class(name.str()), extraClassDeclaration(extraClassDeclaration),
+ parent(addParent("::mlir::Op")) {
+ parent.addTemplateParam(getClassName().str());
+ declare<VisibilityDeclaration>(Visibility::Public);
+ /// Inherit functions from Op.
+ declare<UsingDeclaration>("Op::Op");
+ declare<UsingDeclaration>("Op::print");
+ /// Type alias for the adaptor class.
+ declare<UsingDeclaration>("Adaptor", className + "Adaptor");
+}
+
+void OpClass::finalize() {
+ Class::finalize();
+ declare<VisibilityDeclaration>(Visibility::Public);
+ declare<ExtraClassDeclaration>(extraClassDeclaration);
+}
diff --git a/mlir/tools/mlir-tblgen/OpClass.h b/mlir/tools/mlir-tblgen/OpClass.h
new file mode 100644
index 0000000000000..5258bd19e302f
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/OpClass.h
@@ -0,0 +1,49 @@
+//===- OpClass.h - Implementation of an Op Class --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRTBLGEN_OPCLASS_H_
+#define MLIR_TOOLS_MLIRTBLGEN_OPCLASS_H_
+
+#include "mlir/TableGen/Class.h"
+
+namespace mlir {
+namespace tblgen {
+
+/// Class for holding an op for C++ code emission. The class is specialized to
+/// add Op-specific declarations to the class.
+class OpClass : public Class {
+public:
+ /// Create an operation class with extra class declarations, whose default
+ /// visibility is public. Also declares at the top of the class:
+ ///
+ /// - inheritance of constructors from `Op`
+ /// - inheritance of `print`
+ /// - a type alias for the associated adaptor class
+ ///
+ OpClass(StringRef name, StringRef extraClassDeclaration);
+
+ /// Add an op trait.
+ void addTrait(Twine trait) { parent.addTemplateParam(trait.str()); }
+
+ /// The operation class is finalized by calling `Class::finalize` to delcare
+ /// all pending private and public methods (ops don't have custom constructors
+ /// or fields). Then, the extra class declarations are appended to the end of
+ /// the class declaration.
+ void finalize() override;
+
+private:
+ /// Hand-written extra class declarations.
+ StringRef extraClassDeclaration;
+ /// The parent class, which also contains the traits to be inherited.
+ ParentClass &parent;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TOOLS_MLIRTBLGEN_OPCLASS_H_
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index e822b7819a532..f845845f8cb24 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "OpClass.h"
#include "OpFormatGen.h"
#include "OpGenHelpers.h"
#include "mlir/TableGen/Class.h"
@@ -42,21 +43,21 @@ static const char *const generatedArgName = "odsArg";
static const char *const odsBuilder = "odsBuilder";
static const char *const builderOpState = "odsState";
-// Code for an Op to lookup an attribute. Uses cached identifiers.
-//
-// {0}: The attribute's getter name.
+/// Code for an Op to lookup an attribute. Uses cached identifiers.
+///
+/// {0}: The attribute's getter name.
static const char *const opGetAttr = "(*this)->getAttr({0}AttrName())";
-// The logic to calculate the actual value range for a declared operand/result
-// of an op with variadic operands/results. Note that this logic is not for
-// general use; it assumes all variadic operands/results must have the same
-// number of values.
-//
-// {0}: The list of whether each declared operand/result is variadic.
-// {1}: The total number of non-variadic operands/results.
-// {2}: The total number of variadic operands/results.
-// {3}: The total number of actual values.
-// {4}: "operand" or "result".
+/// The logic to calculate the actual value range for a declared operand/result
+/// of an op with variadic operands/results. Note that this logic is not for
+/// general use; it assumes all variadic operands/results must have the same
+/// number of values.
+///
+/// {0}: The list of whether each declared operand/result is variadic.
+/// {1}: The total number of non-variadic operands/results.
+/// {2}: The total number of variadic operands/results.
+/// {3}: The total number of actual values.
+/// {4}: "operand" or "result".
static const char *const sameVariadicSizeValueRangeCalcCode = R"(
bool isVariadic[] = {{{0}};
int prevVariadicCount = 0;
@@ -75,12 +76,12 @@ static const char *const sameVariadicSizeValueRangeCalcCode = R"(
return {{start, size};
)";
-// The logic to calculate the actual value range for a declared operand/result
-// of an op with variadic operands/results. Note that this logic is assumes
-// the op has an attribute specifying the size of each operand/result segment
-// (variadic or not).
-//
-// {0}: The name of the attribute specifying the segment sizes.
+/// The logic to calculate the actual value range for a declared operand/result
+/// of an op with variadic operands/results. Note that this logic is assumes
+/// the op has an attribute specifying the size of each operand/result segment
+/// (variadic or not).
+///
+/// {0}: The name of the attribute specifying the segment sizes.
static const char *const adapterSegmentSizeAttrInitCode = R"(
assert(odsAttrs && "missing segment size attribute for op");
auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
@@ -99,11 +100,12 @@ static const char *const attrSizedSegmentValueRangeCalcCode = R"(
start += sizeAttrValueIt[i];
return {start, sizeAttrValueIt[index]};
)";
-// The logic to calculate the actual value range for a declared operand
-// of an op with variadic of variadic operands within the OpAdaptor.
-//
-// {0}: The name of the segment attribute.
-// {1}: The index of the main operand.
+
+/// The logic to calculate the actual value range for a declared operand
+/// of an op with variadic of variadic operands within the OpAdaptor.
+///
+/// {0}: The name of the segment attribute.
+/// {1}: The index of the main operand.
static const char *const variadicOfVariadicAdaptorCalcCode = R"(
auto tblgenTmpOperands = getODSOperands({1});
auto sizeAttrValues = {0}().getValues<uint32_t>();
@@ -117,16 +119,20 @@ static const char *const variadicOfVariadicAdaptorCalcCode = R"(
return tblgenTmpOperandGroups;
)";
-// The logic to build a range of either operand or result values.
-//
-// {0}: The begin iterator of the actual values.
-// {1}: The call to generate the start and length of the value range.
+/// The logic to build a range of either operand or result values.
+///
+/// {0}: The begin iterator of the actual values.
+/// {1}: The call to generate the start and length of the value range.
static const char *const valueRangeReturnCode = R"(
auto valueRange = {1};
return {{std::next({0}, valueRange.first),
std::next({0}, valueRange.first + valueRange.second)};
)";
+/// A header for indicating code sections.
+///
+/// {0}: Some text, or a class name.
+/// {1}: Some text.
static const char *const opCommentHeader = R"(
//===----------------------------------------------------------------------===//
// {0} {1}
@@ -597,9 +603,15 @@ void OpEmitter::emitDef(
OpEmitter(op, staticVerifierEmitter).emitDef(os);
}
-void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
+void OpEmitter::emitDecl(raw_ostream &os) {
+ opClass.finalize();
+ opClass.writeDeclTo(os);
+}
-void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
+void OpEmitter::emitDef(raw_ostream &os) {
+ opClass.finalize();
+ opClass.writeDefTo(os);
+}
static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
const Operator &op) {
@@ -654,7 +666,7 @@ void OpEmitter::genAttrNameGetters() {
// Emit the getAttributeNameForIndex methods.
{
- auto *method = opClass.addInlineMethod<Method::MP_Private>(
+ auto *method = opClass.addInlineMethod<Method::Private>(
"::mlir::StringAttr", "getAttributeNameForIndex",
MethodParameter("unsigned", "index"));
ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
@@ -662,15 +674,17 @@ void OpEmitter::genAttrNameGetters() {
<< " return getAttributeNameForIndex((*this)->getName(), index);";
}
{
- auto *method = opClass.addStaticInlineMethod<Method::MP_Private>(
+ auto *method = opClass.addStaticInlineMethod<Method::Private>(
"::mlir::StringAttr", "getAttributeNameForIndex",
MethodParameter("::mlir::OperationName", "name"),
MethodParameter("unsigned", "index"));
ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
- method->body() << "assert(index < " << attributeNames.size()
- << " && \"invalid attribute index\");\n"
- " return name.getRegisteredInfo()"
- "->getAttributeNames()[index];";
+
+ const char *const getAttrName = R"(
+ assert(index < {0} && "invalid attribute index");
+ return name.getRegisteredInfo()->getAttributeNames()[index];
+)";
+ method->body() << formatv(getAttrName, attributeNames.size());
}
// Generate the <attr>AttrName methods, that expose the attribute names to
@@ -685,8 +699,7 @@ void OpEmitter::genAttrNameGetters() {
auto *method =
opClass.addInlineMethod("::mlir::StringAttr", methodName);
ERROR_IF_PRUNED(method, methodName, op);
- method->body()
- << llvm::formatv(attrNameMethodBody, attrIt.second).str();
+ method->body() << llvm::formatv(attrNameMethodBody, attrIt.second);
}
// Generate the static variant.
@@ -696,8 +709,7 @@ void OpEmitter::genAttrNameGetters() {
MethodParameter("::mlir::OperationName", "name"));
ERROR_IF_PRUNED(method, methodName, op);
method->body() << llvm::formatv(attrNameMethodBody,
- "name, " + Twine(attrIt.second))
- .str();
+ "name, " + Twine(attrIt.second));
}
}
}
@@ -739,8 +751,7 @@ void OpEmitter::genAttrGetters() {
// that allows referring to the attributes via accessors instead of having to
// use the string interface for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
- auto *method =
- opClass.addMethod(attr.getStorageType(), (name + "Attr").str());
+ auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr");
if (!method)
return;
method->body() << formatv(
@@ -838,7 +849,7 @@ void OpEmitter::genAttrSetters() {
auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
Attribute attr) {
auto *method =
- opClass.addMethod("void", (setterName + "Attr").str(),
+ opClass.addMethod("void", setterName + "Attr",
MethodParameter(attr.getStorageType(), "attr"));
if (method)
method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);",
@@ -861,8 +872,8 @@ void OpEmitter::genOptionalAttrRemovers() {
auto emitRemoveAttr = [&](StringRef name) {
auto upperInitial = name.take_front().upper();
auto suffix = name.drop_front();
- auto *method = opClass.addMethod(
- "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
+ auto *method = opClass.addMethod("::mlir::Attribute",
+ "remove" + upperInitial + suffix + "Attr");
if (!method)
return;
method->body() << formatv(" return (*this)->removeAttr({0}AttrName());",
@@ -1504,8 +1515,7 @@ void OpEmitter::genBuilder() {
SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
Optional<StringRef> body = builder.getBody();
- Method::Property properties =
- body ? Method::MP_Static : Method::MP_StaticDeclaration;
+ auto properties = body ? Method::Static : Method::StaticDeclaration;
auto *method =
opClass.addMethod("void", "build", properties, std::move(arguments));
if (body)
@@ -1715,7 +1725,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
i >= defaultValuedAttrStartIndex) {
defaultValue += attr.getDefaultValue();
}
- paramList.emplace_back(type, namedAttr.name, defaultValue,
+ paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue),
attr.isOptional());
}
@@ -1874,7 +1884,7 @@ void OpEmitter::genCanonicalizerDecls() {
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::RewritePatternSet &", "results");
paramList.emplace_back("::mlir::MLIRContext *", "context");
- auto kind = hasBody ? Method::MP_Static : Method::MP_StaticDeclaration;
+ auto kind = hasBody ? Method::Static : Method::StaticDeclaration;
auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind,
std::move(paramList));
@@ -1937,11 +1947,9 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
for (const InterfaceMethod::Argument &arg : method.getArguments())
paramList.emplace_back(arg.type, arg.name);
- auto properties = method.isStatic() ? Method::MP_Static : Method::MP_None;
- if (declaration)
- properties =
- static_cast<Method::Property>(properties | Method::MP_Declaration);
- return opClass.addMethod(method.getReturnType(), method.getName(), properties,
+ auto props = (method.isStatic() ? Method::Static : Method::None) |
+ (declaration ? Method::Declaration : Method::None);
+ return opClass.addMethod(method.getReturnType(), method.getName(), props,
std::move(paramList));
}
@@ -1960,10 +1968,10 @@ void OpEmitter::genSideEffectInterfaceMethods() {
SideEffect effect;
/// The index if the kind is not static.
- unsigned index : 30;
+ unsigned index;
/// The kind of the location.
- unsigned kind : 2;
+ unsigned kind;
};
StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
@@ -2360,7 +2368,7 @@ void OpEmitter::genSuccessorVerifier(MethodBody &body) {
body << " {\n unsigned index = 0; (void)index;\n";
- for (auto it : llvm::enumerate(successors)) {
+ for (auto &it : llvm::enumerate(successors)) {
const auto &successor = it.value();
if (canSkip(successor))
continue;
@@ -2461,7 +2469,7 @@ void OpEmitter::genTraits() {
}
void OpEmitter::genOpNameGetter() {
- auto *method = opClass.addStaticMethod<Method::MP_Constexpr>(
+ auto *method = opClass.addStaticMethod<Method::Constexpr>(
"::llvm::StringLiteral", "getOperationName");
ERROR_IF_PRUNED(method, "getOperationName", op);
method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
@@ -2537,18 +2545,18 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter)
: op(op), adaptor(op.getAdaptorName()),
staticVerifierEmitter(staticVerifierEmitter) {
- adaptor.newField("::mlir::ValueRange", "odsOperands");
- adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
- adaptor.newField("::mlir::RegionRange", "odsRegions");
+ adaptor.addField("::mlir::ValueRange", "odsOperands");
+ adaptor.addField("::mlir::DictionaryAttr", "odsAttrs");
+ adaptor.addField("::mlir::RegionRange", "odsRegions");
const auto *attrSizedOperands =
- op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
+ op.getTrait("::m::OpTrait::AttrSizedOperandSegments");
{
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::ValueRange", "values");
paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
attrSizedOperands ? "" : "nullptr");
paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
- auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
+ auto *constructor = adaptor.addConstructor(std::move(paramList));
constructor->addMemberInitializer("odsOperands", "values");
constructor->addMemberInitializer("odsAttrs", "attrs");
@@ -2556,7 +2564,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
}
{
- auto *constructor = adaptor.addConstructorAndPrune(
+ auto *constructor = adaptor.addConstructor(
MethodParameter(op.getCppClassName() + " &", "op"));
constructor->addMemberInitializer("odsOperands", "op->getOperands()");
constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
@@ -2646,6 +2654,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
// Add verification function.
addVerification();
+ adaptor.finalize();
}
void OpOperandAdaptorEmitter::addVerification() {
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index 52d2feb0d6cd0..fc2f814176303 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -45,7 +45,7 @@ extern llvm::cl::opt<std::string> selectedDialect;
// nested in the op definition.
void mlir::tblgen::emitDescription(StringRef description, raw_ostream &os) {
raw_indented_ostream ros(os);
- ros.reindent(description.rtrim(" \t"));
+ ros.printReindented(description.rtrim(" \t"));
}
// Emits `str` with trailing newline if not empty.
@@ -226,8 +226,7 @@ static void emitTypeDoc(const Type &type, raw_ostream &os) {
static void emitAttrOrTypeDefAssemblyFormat(const AttrOrTypeDef &def,
raw_ostream &os) {
- SmallVector<AttrOrTypeParameter, 4> parameters;
- def.getParameters(parameters);
+ ArrayRef<AttrOrTypeParameter> parameters = def.getParameters();
if (parameters.empty()) {
os << "\nSyntax: `!" << def.getDialect().getName() << "."
<< def.getMnemonic() << "`\n";
@@ -265,8 +264,7 @@ static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
}
// Emit parameter documentation.
- SmallVector<AttrOrTypeParameter, 4> parameters;
- def.getParameters(parameters);
+ ArrayRef<AttrOrTypeParameter> parameters = def.getParameters();
if (!parameters.empty()) {
os << "\n#### Parameters:\n\n";
os << "| Parameter | C++ type | Description |\n"
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0fa0cd78d4816..dfe671c9b7a93 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -8,6 +8,7 @@
#include "OpFormatGen.h"
#include "FormatGen.h"
+#include "OpClass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/Format.h"
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index b217298690c8d..76f925072d0b0 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -953,7 +953,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Emit matchAndRewrite() function.
{
auto classScope = os.scope();
- os.reindent(R"(
+ os.printReindented(R"(
::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
::mlir::PatternRewriter &rewriter) const override {)")
<< '\n';
diff --git a/mlir/unittests/Support/IndentedOstreamTest.cpp b/mlir/unittests/Support/IndentedOstreamTest.cpp
index 0271eb73e8897..11b6e573df680 100644
--- a/mlir/unittests/Support/IndentedOstreamTest.cpp
+++ b/mlir/unittests/Support/IndentedOstreamTest.cpp
@@ -98,7 +98,7 @@ TEST(FormatTest, Reindent) {
)";
- ros.reindent(desc);
+ ros.printReindented(desc);
ros.flush();
const auto *expected =
R"(First line
More information about the Mlir-commits
mailing list