[Mlir-commits] [mlir] 2696a95 - [mlir][ods] Cleanup of Class Codegen helper
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 12 13:22:05 PST 2021
Author: Mogball
Date: 2021-11-12T21:22:01Z
New Revision: 2696a9529eed0fe60017b46ba8823d8efcddf571
URL: https://github.com/llvm/llvm-project/commit/2696a9529eed0fe60017b46ba8823d8efcddf571
DIFF: https://github.com/llvm/llvm-project/commit/2696a9529eed0fe60017b46ba8823d8efcddf571.diff
LOG: [mlir][ods] Cleanup of Class Codegen helper
Depends on D113331
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D113714
Added:
mlir/include/mlir/TableGen/Class.h
mlir/lib/TableGen/Class.cpp
Modified:
mlir/include/mlir/TableGen/CodeGenHelpers.h
mlir/lib/TableGen/CMakeLists.txt
mlir/tools/mlir-tblgen/DialectGen.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
mlir/include/mlir/TableGen/OpClass.h
mlir/lib/TableGen/OpClass.cpp
################################################################################
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
new file mode 100644
index 0000000000000..9eaf066f7f2c4
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -0,0 +1,412 @@
+//===- Class.h - Helper classes for C++ code emission -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines several classes for Op C++ code emission. They are only
+// expected to be used by MLIR TableGen backends.
+//
+// We emit the op declaration and definition into separate files: *Ops.h.inc
+// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
+// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
+//
+// In order to do this split, we need to track method signature and
+// implementation logic separately. Signature information is used for both
+// declaration and definition, while implementation logic is only for
+// definition. So we have the following classes for C++ code emission.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_CLASS_H_
+#define MLIR_TABLEGEN_CLASS_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <set>
+#include <string>
+
+namespace mlir {
+namespace tblgen {
+class FmtObjectBase;
+
+/// This class contains a single method parameter for a C++ function.
+class MethodParameter {
+public:
+ /// Create a method parameter with a C++ type, parameter name, and an optional
+ /// default value. Marking a parameter as "optional" is a cosmetic effect on
+ /// the generated code.
+ template <typename TypeT, typename NameT, typename DefaultT>
+ MethodParameter(TypeT &&type, NameT &&name, DefaultT &&defaultValue,
+ bool optional = false)
+ : type(stringify(std::forward<TypeT>(type))),
+ name(stringify(std::forward<NameT>(name))),
+ defaultValue(stringify(std::forward<DefaultT>(defaultValue))),
+ optional(optional) {}
+
+ /// Create a method parameter with a C++ type, parameter name, and no default
+ /// value.
+ template <typename TypeT, typename NameT>
+ MethodParameter(TypeT &&type, NameT &&name, bool optional = false)
+ : MethodParameter(std::forward<TypeT>(type), std::forward<NameT>(name),
+ /*defaultValue=*/"", optional) {}
+
+ /// Write the parameter as part of a method declaration.
+ void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
+ /// Write the parameter as part of a method definition.
+ void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
+
+ /// Get the C++ type.
+ const std::string &getType() const { return type; }
+ /// Returns true if the parameter has a default value.
+ bool hasDefaultValue() const { return !defaultValue.empty(); }
+
+private:
+ void writeTo(raw_ostream &os, bool emitDefault) const;
+
+ /// The C++ type.
+ std::string type;
+ /// The variable name.
+ std::string name;
+ /// An optional default value. The default value exists if the string is not
+ /// empty.
+ std::string defaultValue;
+ /// Whether the parameter should be indicated as "optional".
+ bool optional;
+};
+
+/// This class contains a list of method parameters for constructor, class
+/// methods, and method signatures.
+class MethodParameters {
+public:
+ /// Create a list of method parameters.
+ MethodParameters(std::initializer_list<MethodParameter> parameters)
+ : parameters(parameters) {}
+ MethodParameters(SmallVector<MethodParameter> parameters)
+ : parameters(std::move(parameters)) {}
+
+ /// Write the parameters as part of a method declaration.
+ void writeDeclTo(raw_ostream &os) const;
+ /// Write the parameters as part of a method definition.
+ void writeDefTo(raw_ostream &os) const;
+
+ /// Determine whether this list of parameters "subsumes" another, which occurs
+ /// when this parameter list is identical to the other and has zero or more
+ /// additional default-valued parameters.
+ bool subsumes(const MethodParameters &other) const;
+
+ /// Return the number of parameters.
+ unsigned getNumParameters() const { return parameters.size(); }
+
+private:
+ llvm::SmallVector<MethodParameter> parameters;
+};
+
+/// This class contains the signature of a C++ method, including the return
+/// type. method name, and method parameters.
+class MethodSignature {
+public:
+ MethodSignature(StringRef retType, StringRef name,
+ SmallVector<MethodParameter> &¶meters)
+ : returnType(retType), methodName(name),
+ parameters(std::move(parameters)) {}
+ template <typename... Parameters>
+ MethodSignature(StringRef retType, StringRef name, Parameters &&...parameters)
+ : returnType(retType), methodName(name),
+ parameters({std::forward<Parameters>(parameters)...}) {}
+
+ /// Determine whether a method with this signature makes a method with
+ /// `other` signature redundant. This occurs if the signatures have the same
+ /// name and this signature's parameteres subsume the other's.
+ ///
+ /// A method that makes another method redundant with a
diff erent return type
+ /// can replace the other, the assumption being that the subsuming method
+ /// provides a more resolved return type, e.g. IntegerAttr vs. Attribute.
+ bool makesRedundant(const MethodSignature &other) const;
+
+ /// Get the name of the method.
+ StringRef getName() const { return methodName; }
+
+ /// Get the number of parameters.
+ unsigned getNumParameters() const { return parameters.getNumParameters(); }
+
+ /// Write the signature as part of a method declaration.
+ void writeDeclTo(raw_ostream &os) const;
+
+ /// Write the signature as part of a method definition. `namePrefix` is to be
+ /// prepended to the method name (typically namespaces for qualifying the
+ /// method definition).
+ void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
+
+private:
+ /// The method's C++ return type.
+ std::string returnType;
+ /// The method name.
+ std::string methodName;
+ /// The method's parameter list.
+ MethodParameters parameters;
+};
+
+/// Class for holding the body of an op's method for C++ code emission
+class MethodBody {
+public:
+ explicit MethodBody(bool declOnly);
+
+ MethodBody &operator<<(Twine content);
+ MethodBody &operator<<(int content);
+ MethodBody &operator<<(const FmtObjectBase &content);
+
+ void writeTo(raw_ostream &os) const;
+
+private:
+ /// Whether this class should record method body.
+ bool isEffective;
+ /// The body of the method.
+ std::string body;
+};
+
+/// Class for holding an op's method for C++ code emission
+class Method {
+public:
+ /// Properties (qualifiers) of class methods. Bitfield is used here to help
+ /// querying properties.
+ enum Property {
+ MP_None = 0x0,
+ MP_Static = 0x1,
+ MP_Constructor = 0x2,
+ MP_Private = 0x4,
+ MP_Declaration = 0x8,
+ MP_Inline = 0x10,
+ MP_Constexpr = 0x20 | MP_Inline,
+ MP_StaticDeclaration = MP_Static | MP_Declaration,
+ };
+
+ template <typename... Args>
+ Method(StringRef retType, StringRef name, Property property, Args &&...args)
+ : properties(property),
+ methodSignature(retType, name, std::forward<Args>(args)...),
+ methodBody(properties & MP_Declaration) {}
+
+ Method(Method &&) = default;
+ Method &operator=(Method &&) = default;
+
+ virtual ~Method() = default;
+
+ MethodBody &body() { return methodBody; }
+
+ /// Returns true if this is a static method.
+ bool isStatic() const { return properties & MP_Static; }
+
+ /// Returns true if this is a private method.
+ bool isPrivate() const { return properties & MP_Private; }
+
+ /// Returns true if this is an inline method.
+ bool isInline() const { return properties & MP_Inline; }
+
+ /// Returns the name of this method.
+ StringRef getName() const { return methodSignature.getName(); }
+
+ /// Returns if this method makes the `other` method redundant.
+ bool makesRedundant(const Method &other) const {
+ return methodSignature.makesRedundant(other.methodSignature);
+ }
+
+ /// Writes the method as a declaration to the given `os`.
+ virtual void writeDeclTo(raw_ostream &os) const;
+
+ /// Writes the method as a definition to the given `os`. `namePrefix` is the
+ /// prefix to be prepended to the method name (typically namespaces for
+ /// qualifying the method definition).
+ virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
+
+protected:
+ /// A collection of method properties.
+ Property properties;
+ /// The signature of the method.
+ MethodSignature methodSignature;
+ /// The body of the method, if it has one.
+ MethodBody methodBody;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+/// The OR of two method properties should return method properties. Ensure that
+/// this function is visible to `Class`.
+inline constexpr mlir::tblgen::Method::Property
+operator|(mlir::tblgen::Method::Property lhs,
+ mlir::tblgen::Method::Property rhs) {
+ return mlir::tblgen::Method::Property(static_cast<unsigned>(lhs) |
+ static_cast<unsigned>(rhs));
+}
+
+namespace mlir {
+namespace tblgen {
+
+/// Class for holding an op's constructor method for C++ code emission.
+class Constructor : public Method {
+public:
+ template <typename... Parameters>
+ Constructor(StringRef className, Property property,
+ Parameters &&...parameters)
+ : Method("", className, property,
+ std::forward<Parameters>(parameters)...) {}
+
+ /// Add member initializer to constructor initializing `name` with `value`.
+ void addMemberInitializer(StringRef name, StringRef value);
+
+ /// Writes the method as a definition to the given `os`. `namePrefix` is the
+ /// prefix to be prepended to the method name (typically namespaces for
+ /// qualifying the method definition).
+ void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
+
+private:
+ /// Member initializers.
+ std::string memberInitializers;
+};
+
+/// A class used to emit C++ classes from Tablegen. Contains a list of public
+/// methods and a list of private fields to be emitted.
+class Class {
+public:
+ explicit Class(StringRef name);
+
+ /// Add a new constructor to this class and prune and constructors made
+ /// redundant by it. Returns null if the constructor was not added. Else,
+ /// returns a pointer to the new constructor.
+ template <typename... Parameters>
+ Constructor *addConstructorAndPrune(Parameters &&...parameters) {
+ return addConstructorAndPrune(
+ Constructor(getClassName(), Method::MP_Constructor,
+ std::forward<Parameters>(parameters)...));
+ }
+
+ /// Add a new method to this class and prune any methods made redundant by it.
+ /// Returns null if the method was not added (because an existing method would
+ /// make it redundant). Else, returns a pointer to the new method.
+ template <typename... Parameters>
+ Method *addMethod(StringRef retType, StringRef name,
+ Method::Property properties, Parameters &&...parameters) {
+ return addMethodAndPrune(Method(retType, name, properties,
+ std::forward<Parameters>(parameters)...));
+ }
+
+ /// Add a method with statically-known properties.
+ template <Method::Property Properties = Method::MP_None,
+ typename... Parameters>
+ Method *addMethod(StringRef retType, StringRef name,
+ Parameters &&...parameters) {
+ return addMethod(retType, name, Properties,
+ std::forward<Parameters>(parameters)...);
+ }
+
+ /// Add a static method.
+ template <Method::Property Properties = Method::MP_None,
+ typename... Parameters>
+ Method *addStaticMethod(StringRef retType, StringRef name,
+ Parameters &&...parameters) {
+ return addMethod<Properties | Method::MP_Static>(
+ retType, name, std::forward<Parameters>(parameters)...);
+ }
+
+ /// Add an inline static method.
+ template <Method::Property Properties = Method::MP_None,
+ typename... Parameters>
+ Method *addStaticInlineMethod(StringRef retType, StringRef name,
+ Parameters &&...parameters) {
+ return addMethod<Properties | Method::MP_Static | Method::MP_Inline>(
+ retType, name, std::forward<Parameters>(parameters)...);
+ }
+
+ /// Add an inline method.
+ template <Method::Property Properties = Method::MP_None,
+ typename... Parameters>
+ Method *addInlineMethod(StringRef retType, StringRef name,
+ Parameters &&...parameters) {
+ return addMethod<Properties | Method::MP_Inline>(
+ retType, name, std::forward<Parameters>(parameters)...);
+ }
+
+ /// Add a declaration for a method.
+ template <Method::Property Properties = Method::MP_None,
+ typename... Parameters>
+ Method *declareMethod(StringRef retType, StringRef name,
+ Parameters &&...parameters) {
+ return addMethod<Properties | Method::MP_Declaration>(
+ retType, name, std::forward<Parameters>(parameters)...);
+ }
+
+ /// Add a declaration for a static method.
+ template <Method::Property Properties = Method::MP_None,
+ typename... Parameters>
+ Method *declareStaticMethod(StringRef retType, StringRef name,
+ Parameters &&...parameters) {
+ return addMethod<Properties | Method::MP_StaticDeclaration>(
+ retType, name, std::forward<Parameters>(parameters)...);
+ }
+
+ /// Creates a new field in this class.
+ void newField(StringRef type, StringRef name, StringRef defaultValue = "");
+
+ /// Writes this op's class as a declaration to the given `os`.
+ void writeDeclTo(raw_ostream &os) const;
+ /// Writes the method definitions in this op's class to the given `os`.
+ void writeDefTo(raw_ostream &os) const;
+
+ /// Returns the C++ class name of the op.
+ StringRef getClassName() const { return className; }
+
+protected:
+ /// Get a list of all the methods to emit, filtering out hidden ones.
+ void forAllMethods(llvm::function_ref<void(const Method &)> func) const {
+ llvm::for_each(constructors, [&](auto &ctor) { func(ctor); });
+ llvm::for_each(methods, [&](auto &method) { func(method); });
+ }
+
+ /// Add a new constructor if it is not made redundant by any existing
+ /// constructors and prune and existing constructors made redundant.
+ Constructor *addConstructorAndPrune(Constructor &&newCtor);
+ /// Add a new method if it is not made redundant by any existing methods and
+ /// prune and existing methods made redundant.
+ Method *addMethodAndPrune(Method &&newMethod);
+
+ /// The C++ class name.
+ std::string className;
+ /// The list of constructors.
+ std::vector<Constructor> constructors;
+ /// The list of class methods.
+ std::vector<Method> methods;
+ /// The list of class members.
+ SmallVector<std::string, 4> fields;
+};
+
+// Class for holding an op for C++ code emission
+class OpClass : public Class {
+public:
+ explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
+
+ /// Adds an op trait.
+ void addTrait(Twine trait);
+
+ /// Writes this op's class as a declaration to the given `os`. Redefines
+ /// Class::writeDeclTo to also emit traits and extra class declarations.
+ void writeDeclTo(raw_ostream &os) const;
+
+private:
+ StringRef extraClassDeclaration;
+ llvm::SetVector<std::string, SmallVector<std::string>, StringSet<>> traits;
+};
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_CLASS_H_
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index 14af7d3380e48..c913e3514a8ad 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -216,9 +216,28 @@ class StaticVerifierFunctionEmitter {
ConstraintMap regionConstraints;
};
-// Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar.
+/// Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar.
std::string escapeString(StringRef value);
+namespace detail {
+template <typename> struct stringifier {
+ template <typename T> static std::string apply(T &&t) {
+ return std::string(std::forward<T>(t));
+ }
+};
+template <> struct stringifier<Twine> {
+ static std::string apply(const Twine &twine) {
+ return twine.str();
+ }
+};
+} // end namespace detail
+
+/// Generically convert a value to a std::string.
+template <typename T> std::string stringify(T &&t) {
+ return detail::stringifier<std::remove_reference_t<std::remove_const_t<T>>>::
+ apply(std::forward<T>(t));
+}
+
} // namespace tblgen
} // namespace mlir
diff --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h
deleted file mode 100644
index 243e7fa876a97..0000000000000
--- a/mlir/include/mlir/TableGen/OpClass.h
+++ /dev/null
@@ -1,442 +0,0 @@
-//===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines several classes for Op C++ code emission. They are only
-// expected to be used by MLIR TableGen backends.
-//
-// We emit the op declaration and definition into separate files: *Ops.h.inc
-// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
-// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
-//
-// In order to do this split, we need to track method signature and
-// implementation logic separately. Signature information is used for both
-// declaration and definition, while implementation logic is only for
-// definition. So we have the following classes for C++ code emission.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TABLEGEN_OPCLASS_H_
-#define MLIR_TABLEGEN_OPCLASS_H_
-
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/StringSet.h"
-#include "llvm/Support/raw_ostream.h"
-
-#include <set>
-#include <string>
-
-namespace mlir {
-namespace tblgen {
-class FmtObjectBase;
-
-// Class for holding a single parameter of an op's method for C++ code emission.
-class OpMethodParameter {
-public:
- // Properties (qualifiers) for the parameter.
- enum Property {
- PP_None = 0x0,
- PP_Optional = 0x1,
- };
-
- OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "",
- Property properties = PP_None)
- : type(type), name(name), defaultValue(defaultValue),
- properties(properties) {}
-
- OpMethodParameter(StringRef type, StringRef name, Property property)
- : OpMethodParameter(type, name, "", property) {}
-
- // Writes the parameter as a part of a method declaration to `os`.
- void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
-
- // Writes the parameter as a part of a method definition to `os`
- void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
-
- const std::string &getType() const { return type; }
- bool hasDefaultValue() const { return !defaultValue.empty(); }
-
-private:
- void writeTo(raw_ostream &os, bool emitDefault) const;
-
- std::string type;
- std::string name;
- std::string defaultValue;
- Property properties;
-};
-
-// Base class for holding parameters of an op's method for C++ code emission.
-class OpMethodParameters {
-public:
- // Discriminator for LLVM-style RTTI.
- enum ParamsKind {
- // Separate type and name for each parameter is not known.
- PK_Unresolved,
- // Each parameter is resolved to a type and name.
- PK_Resolved,
- };
-
- OpMethodParameters(ParamsKind kind) : kind(kind) {}
- virtual ~OpMethodParameters() {}
-
- // LLVM-style RTTI support.
- ParamsKind getKind() const { return kind; }
-
- // Writes the parameters as a part of a method declaration to `os`.
- virtual void writeDeclTo(raw_ostream &os) const = 0;
-
- // Writes the parameters as a part of a method definition to `os`
- virtual void writeDefTo(raw_ostream &os) const = 0;
-
- // Factory methods to create the correct type of `OpMethodParameters`
- // object based on the arguments.
- static std::unique_ptr<OpMethodParameters> create();
-
- static std::unique_ptr<OpMethodParameters> create(StringRef params);
-
- static std::unique_ptr<OpMethodParameters>
- create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms);
-
- static std::unique_ptr<OpMethodParameters>
- create(StringRef type, StringRef name, StringRef defaultValue = "");
-
-private:
- const ParamsKind kind;
-};
-
-// Class for holding unresolved parameters.
-class OpMethodUnresolvedParameters : public OpMethodParameters {
-public:
- OpMethodUnresolvedParameters(StringRef params)
- : OpMethodParameters(PK_Unresolved), parameters(params) {}
-
- // write the parameters as a part of a method declaration to the given `os`.
- void writeDeclTo(raw_ostream &os) const override;
-
- // write the parameters as a part of a method definition to the given `os`
- void writeDefTo(raw_ostream &os) const override;
-
- // LLVM-style RTTI support.
- static bool classof(const OpMethodParameters *params) {
- return params->getKind() == PK_Unresolved;
- }
-
-private:
- std::string parameters;
-};
-
-// Class for holding resolved parameters.
-class OpMethodResolvedParameters : public OpMethodParameters {
-public:
- OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {}
-
- OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> &¶ms)
- : OpMethodParameters(PK_Resolved) {
- for (OpMethodParameter ¶m : params)
- parameters.emplace_back(std::move(param));
- }
-
- OpMethodResolvedParameters(StringRef type, StringRef name,
- StringRef defaultValue)
- : OpMethodParameters(PK_Resolved) {
- parameters.emplace_back(type, name, defaultValue);
- }
-
- // Returns the number of parameters.
- size_t getNumParameters() const { return parameters.size(); }
-
- // Returns if this method makes the `other` method redundant. Note that this
- // is more than just finding conflicting methods. This method determines if
- // the 2 set of parameters are conflicting and if so, returns true if this
- // method has a more general set of parameters that can replace all possible
- // calls to the `other` method.
- bool makesRedundant(const OpMethodResolvedParameters &other) const;
-
- // write the parameters as a part of a method declaration to the given `os`.
- void writeDeclTo(raw_ostream &os) const override;
-
- // write the parameters as a part of a method definition to the given `os`
- void writeDefTo(raw_ostream &os) const override;
-
- // LLVM-style RTTI support.
- static bool classof(const OpMethodParameters *params) {
- return params->getKind() == PK_Resolved;
- }
-
-private:
- llvm::SmallVector<OpMethodParameter, 4> parameters;
-};
-
-// Class for holding the signature of an op's method for C++ code emission
-class OpMethodSignature {
-public:
- template <typename... Args>
- OpMethodSignature(StringRef retType, StringRef name, Args &&...args)
- : returnType(retType), methodName(name),
- parameters(OpMethodParameters::create(std::forward<Args>(args)...)) {}
- OpMethodSignature(OpMethodSignature &&) = default;
-
- // Returns if a method with this signature makes a method with `other`
- // signature redundant. Only supports resolved parameters.
- bool makesRedundant(const OpMethodSignature &other) const;
-
- // Returns the number of parameters (for resolved parameters).
- size_t getNumParameters() const {
- return cast<OpMethodResolvedParameters>(parameters.get())
- ->getNumParameters();
- }
-
- // Returns the name of the method.
- StringRef getName() const { return methodName; }
-
- // Writes the signature as a method declaration to the given `os`.
- void writeDeclTo(raw_ostream &os) const;
-
- // Writes the signature as the start of a method definition to the given `os`.
- // `namePrefix` is the prefix to be prepended to the method name (typically
- // namespaces for qualifying the method definition).
- void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
-
-private:
- std::string returnType;
- std::string methodName;
- std::unique_ptr<OpMethodParameters> parameters;
-};
-
-// Class for holding the body of an op's method for C++ code emission
-class OpMethodBody {
-public:
- explicit OpMethodBody(bool declOnly);
-
- OpMethodBody &operator<<(Twine content);
- OpMethodBody &operator<<(int content);
- OpMethodBody &operator<<(const FmtObjectBase &content);
-
- void writeTo(raw_ostream &os) const;
-
-private:
- // Whether this class should record method body.
- bool isEffective;
- std::string body;
-};
-
-// Class for holding an op's method for C++ code emission
-class OpMethod {
-public:
- // Properties (qualifiers) of class methods. Bitfield is used here to help
- // querying properties.
- enum Property {
- MP_None = 0x0,
- MP_Static = 0x1,
- MP_Constructor = 0x2,
- MP_Private = 0x4,
- MP_Declaration = 0x8,
- MP_Inline = 0x10,
- MP_Constexpr = 0x20 | MP_Inline,
- MP_StaticDeclaration = MP_Static | MP_Declaration,
- };
-
- template <typename... Args>
- OpMethod(StringRef retType, StringRef name, Property property, unsigned id,
- Args &&...args)
- : properties(property),
- methodSignature(retType, name, std::forward<Args>(args)...),
- methodBody(properties & MP_Declaration), id(id) {}
-
- OpMethod(OpMethod &&) = default;
-
- virtual ~OpMethod() = default;
-
- OpMethodBody &body() { return methodBody; }
-
- // Returns true if this is a static method.
- bool isStatic() const { return properties & MP_Static; }
-
- // Returns true if this is a private method.
- bool isPrivate() const { return properties & MP_Private; }
-
- // Returns true if this is an inline method.
- bool isInline() const { return properties & MP_Inline; }
-
- // Returns the name of this method.
- StringRef getName() const { return methodSignature.getName(); }
-
- // Returns the ID for this method
- unsigned getID() const { return id; }
-
- // Returns if this method makes the `other` method redundant.
- bool makesRedundant(const OpMethod &other) const {
- return methodSignature.makesRedundant(other.methodSignature);
- }
-
- // Writes the method as a declaration to the given `os`.
- virtual void writeDeclTo(raw_ostream &os) const;
-
- // Writes the method as a definition to the given `os`. `namePrefix` is the
- // prefix to be prepended to the method name (typically namespaces for
- // qualifying the method definition).
- virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
-
-protected:
- Property properties;
- OpMethodSignature methodSignature;
- OpMethodBody methodBody;
- const unsigned id;
-};
-
-// Class for holding an op's constructor method for C++ code emission.
-class OpConstructor : public OpMethod {
-public:
- template <typename... Args>
- OpConstructor(StringRef className, Property property, unsigned id,
- Args &&...args)
- : OpMethod("", className, property, id, std::forward<Args>(args)...) {}
-
- // Add member initializer to constructor initializing `name` with `value`.
- void addMemberInitializer(StringRef name, StringRef value);
-
- // Writes the method as a definition to the given `os`. `namePrefix` is the
- // prefix to be prepended to the method name (typically namespaces for
- // qualifying the method definition).
- void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
-
-private:
- // Member initializers.
- std::string memberInitializers;
-};
-
-// A class used to emit C++ classes from Tablegen. Contains a list of public
-// methods and a list of private fields to be emitted.
-class Class {
-public:
- explicit Class(StringRef name);
-
- // Adds a new method to this class and prune redundant methods. Returns null
- // if the method was not added (because an existing method would make it
- // redundant), else returns a pointer to the added method. Note that this call
- // may also delete existing methods that are made redundant by a method to the
- // class.
- template <typename... Args>
- OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
- OpMethod::Property properties, Args &&...args) {
- auto newMethod = std::make_unique<OpMethod>(
- retType, name, properties, nextMethodID++, std::forward<Args>(args)...);
- return addMethodAndPrune(methods, std::move(newMethod));
- }
-
- template <typename... Args>
- OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
- Args &&...args) {
- return addMethodAndPrune(retType, name, OpMethod::MP_None,
- std::forward<Args>(args)...);
- }
-
- template <typename... Args>
- OpConstructor *addConstructorAndPrune(Args &&...args) {
- auto newConstructor = std::make_unique<OpConstructor>(
- getClassName(), OpMethod::MP_Constructor, nextMethodID++,
- std::forward<Args>(args)...);
- return addMethodAndPrune(constructors, std::move(newConstructor));
- }
-
- // Creates a new field in this class.
- void newField(StringRef type, StringRef name, StringRef defaultValue = "");
-
- // Writes this op's class as a declaration to the given `os`.
- void writeDeclTo(raw_ostream &os) const;
- // Writes the method definitions in this op's class to the given `os`.
- void writeDefTo(raw_ostream &os) const;
-
- // Returns the C++ class name of the op.
- StringRef getClassName() const { return className; }
-
-protected:
- // Get a list of all the methods to emit, filtering out hidden ones.
- void forAllMethods(llvm::function_ref<void(const OpMethod &)> func) const {
- using ConsRef = const std::unique_ptr<OpConstructor> &;
- using MethodRef = const std::unique_ptr<OpMethod> &;
- llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); });
- llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); });
- }
-
- // For deterministic code generation, keep methods sorted in the order in
- // which they were generated.
- template <typename MethodTy>
- struct MethodCompare {
- bool operator()(const std::unique_ptr<MethodTy> &x,
- const std::unique_ptr<MethodTy> &y) const {
- return x->getID() < y->getID();
- }
- };
-
- template <typename MethodTy>
- using MethodSet =
- std::set<std::unique_ptr<MethodTy>, MethodCompare<MethodTy>>;
-
- template <typename MethodTy>
- MethodTy *addMethodAndPrune(MethodSet<MethodTy> &set,
- std::unique_ptr<MethodTy> &&newMethod) {
- // Check if the new method will be made redundant by existing methods.
- for (auto &method : set)
- if (method->makesRedundant(*newMethod))
- return nullptr;
-
- // We can add this a method to the set. Prune any existing methods that will
- // be made redundant by adding this new method. Note that the redundant
- // check between two methods is more than a conflict check. makesRedundant()
- // below will check if the new method conflicts with an existing method and
- // if so, returns true if the new method makes the existing method redundant
- // because all calls to the existing method can be subsumed by the new
- // method. So makesRedundant() does a combined job of finding conflicts and
- // deciding which of the 2 conflicting methods survive.
- //
- // Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing
- // it manually here.
- for (auto it = set.begin(), end = set.end(); it != end;) {
- if (newMethod->makesRedundant(*(it->get())))
- it = set.erase(it);
- else
- ++it;
- }
-
- MethodTy *ret = newMethod.get();
- set.insert(std::move(newMethod));
- return ret;
- }
-
- std::string className;
- MethodSet<OpConstructor> constructors;
- MethodSet<OpMethod> methods;
- unsigned nextMethodID = 0;
- SmallVector<std::string, 4> fields;
-};
-
-// Class for holding an op for C++ code emission
-class OpClass : public Class {
-public:
- explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
-
- // Adds an op trait.
- void addTrait(Twine trait);
-
- // Writes this op's class as a declaration to the given `os`. Redefines
- // Class::writeDeclTo to also emit traits and extra class declarations.
- void writeDeclTo(raw_ostream &os) const;
-
-private:
- StringRef extraClassDeclaration;
- SmallVector<std::string, 4> traitsVec;
- StringSet<> traitsSet;
-};
-
-} // namespace tblgen
-} // namespace mlir
-
-#endif // MLIR_TABLEGEN_OPCLASS_H_
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index a97419b193216..bb522d7d03f48 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -13,12 +13,12 @@ llvm_add_library(MLIRTableGen STATIC
Attribute.cpp
AttrOrTypeDef.cpp
Builder.cpp
+ Class.cpp
Constraint.cpp
Dialect.cpp
Format.cpp
Interfaces.cpp
Operator.cpp
- OpClass.cpp
Pass.cpp
Pattern.cpp
Predicate.cpp
diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/Class.cpp
similarity index 50%
rename from mlir/lib/TableGen/OpClass.cpp
rename to mlir/lib/TableGen/Class.cpp
index d1453bacdd05f..3fdba8c858b6a 100644
--- a/mlir/lib/TableGen/OpClass.cpp
+++ b/mlir/lib/TableGen/Class.cpp
@@ -1,4 +1,4 @@
-//===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
+//===- Class.cpp - Helper classes for Op C++ code emission --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/TableGen/OpClass.h"
+#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/Format.h"
#include "llvm/ADT/Sequence.h"
@@ -20,173 +20,102 @@
using namespace mlir;
using namespace mlir::tblgen;
-namespace {
-
// Returns space to be emitted after the given C++ `type`. return "" if the
// ends with '&' or '*', or is empty, else returns " ".
-StringRef getSpaceAfterType(StringRef type) {
+static StringRef getSpaceAfterType(StringRef type) {
return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
}
-} // namespace
-
//===----------------------------------------------------------------------===//
-// OpMethodParameter definitions
+// MethodParameter definitions
//===----------------------------------------------------------------------===//
-void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
- if (properties & PP_Optional)
+void MethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
+ if (optional)
os << "/*optional*/";
os << type << getSpaceAfterType(type) << name;
- if (emitDefault && !defaultValue.empty())
+ if (emitDefault && hasDefaultValue())
os << " = " << defaultValue;
}
//===----------------------------------------------------------------------===//
-// OpMethodParameters definitions
+// MethodParameters definitions
//===----------------------------------------------------------------------===//
-// Factory methods to construct the correct type of `OpMethodParameters`
-// object based on the arguments.
-std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
- return std::make_unique<OpMethodResolvedParameters>();
+void MethodParameters::writeDeclTo(raw_ostream &os) const {
+ llvm::interleaveComma(parameters, os,
+ [&os](auto ¶m) { param.writeDeclTo(os); });
}
-
-std::unique_ptr<OpMethodParameters>
-OpMethodParameters::create(StringRef params) {
- return std::make_unique<OpMethodUnresolvedParameters>(params);
-}
-
-std::unique_ptr<OpMethodParameters>
-OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms) {
- return std::make_unique<OpMethodResolvedParameters>(std::move(params));
+void MethodParameters::writeDefTo(raw_ostream &os) const {
+ llvm::interleaveComma(parameters, os,
+ [&os](auto ¶m) { param.writeDefTo(os); });
}
-std::unique_ptr<OpMethodParameters>
-OpMethodParameters::create(StringRef type, StringRef name,
- StringRef defaultValue) {
- return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
-}
-
-//===----------------------------------------------------------------------===//
-// OpMethodUnresolvedParameters definitions
-//===----------------------------------------------------------------------===//
-void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
- os << parameters;
-}
-
-void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
- // We need to remove the default values for parameters in method definition.
- // TODO: We are using '=' and ',' as delimiters for parameter
- // initializers. This is incorrect for initializer list with more than one
- // element. Change to a more robust approach.
- llvm::SmallVector<StringRef, 4> tokens;
- StringRef params = parameters;
- while (!params.empty()) {
- std::pair<StringRef, StringRef> parts = params.split("=");
- tokens.push_back(parts.first);
- params = parts.second.split(',').second;
- }
- llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
-}
-
-//===----------------------------------------------------------------------===//
-// OpMethodResolvedParameters definitions
-//===----------------------------------------------------------------------===//
-
-// Returns true if a method with these parameters makes a method with parameters
-// `other` redundant. This should return true only if all possible calls to the
-// other method can be replaced by calls to this method.
-bool OpMethodResolvedParameters::makesRedundant(
- const OpMethodResolvedParameters &other) const {
- const size_t otherNumParams = other.getNumParameters();
- const size_t thisNumParams = getNumParameters();
-
- // All calls to the other method can be replaced this method only if this
- // method has the same or more arguments number of arguments as the other, and
- // the common arguments have the same type.
- if (thisNumParams < otherNumParams)
+bool MethodParameters::subsumes(const MethodParameters &other) const {
+ // These parameters do not subsume the others if there are fewer parameters
+ // or their types do not match.
+ if (parameters.size() < other.parameters.size())
+ return false;
+ if (!std::equal(
+ other.parameters.begin(), other.parameters.end(), parameters.begin(),
+ [](auto &lhs, auto &rhs) { return lhs.getType() == rhs.getType(); }))
return false;
- for (int idx : llvm::seq<int>(0, otherNumParams))
- if (parameters[idx].getType() != other.parameters[idx].getType())
- return false;
-
- // If all the common arguments have the same type, we can elide the other
- // method if this method has the same number of arguments as other or the
- // first argument after the common ones has a default value (and by C++
- // requirement, all the later ones will also have a default value).
- return thisNumParams == otherNumParams ||
- parameters[otherNumParams].hasDefaultValue();
-}
-void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
- llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
- param.writeDeclTo(os);
- });
-}
-
-void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
- llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
- param.writeDefTo(os);
- });
+ // If all the common parameters have the same type, we can elide the other
+ // method if this method has the same number of parameters as other or if the
+ // first paramater after the common parameters has a default value (and, as
+ // required by C++, subsequent parameters will have default values too).
+ return parameters.size() == other.parameters.size() ||
+ parameters[other.parameters.size()].hasDefaultValue();
}
//===----------------------------------------------------------------------===//
-// OpMethodSignature definitions
+// MethodSignature definitions
//===----------------------------------------------------------------------===//
-// Returns if a method with this signature makes a method with `other` signature
-// redundant. Only supports resolved parameters.
-bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
- if (methodName != other.methodName)
- return false;
- auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
- auto *resolvedOther =
- dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
- if (resolvedThis && resolvedOther)
- return resolvedThis->makesRedundant(*resolvedOther);
- return false;
+bool MethodSignature::makesRedundant(const MethodSignature &other) const {
+ return methodName == other.methodName &&
+ parameters.subsumes(other.parameters);
}
-void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
+void MethodSignature::writeDeclTo(raw_ostream &os) const {
os << returnType << getSpaceAfterType(returnType) << methodName << "(";
- parameters->writeDeclTo(os);
+ parameters.writeDeclTo(os);
os << ")";
}
-void OpMethodSignature::writeDefTo(raw_ostream &os,
- StringRef namePrefix) const {
+void MethodSignature::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
os << returnType << getSpaceAfterType(returnType) << namePrefix
<< (namePrefix.empty() ? "" : "::") << methodName << "(";
- parameters->writeDefTo(os);
+ parameters.writeDefTo(os);
os << ")";
}
//===----------------------------------------------------------------------===//
-// OpMethodBody definitions
+// MethodBody definitions
//===----------------------------------------------------------------------===//
-OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
+MethodBody::MethodBody(bool declOnly) : isEffective(!declOnly) {}
-OpMethodBody &OpMethodBody::operator<<(Twine content) {
+MethodBody &MethodBody::operator<<(Twine content) {
if (isEffective)
body.append(content.str());
return *this;
}
-OpMethodBody &OpMethodBody::operator<<(int content) {
+MethodBody &MethodBody::operator<<(int content) {
if (isEffective)
body.append(std::to_string(content));
return *this;
}
-OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
+MethodBody &MethodBody::operator<<(const FmtObjectBase &content) {
if (isEffective)
body.append(content.str());
return *this;
}
-void OpMethodBody::writeTo(raw_ostream &os) const {
+void MethodBody::writeTo(raw_ostream &os) const {
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
os << bodyRef;
if (bodyRef.empty() || bodyRef.back() != '\n')
@@ -194,10 +123,10 @@ void OpMethodBody::writeTo(raw_ostream &os) const {
}
//===----------------------------------------------------------------------===//
-// OpMethod definitions
+// Method definitions
//===----------------------------------------------------------------------===//
-void OpMethod::writeDeclTo(raw_ostream &os) const {
+void Method::writeDeclTo(raw_ostream &os) const {
os.indent(2);
if (isStatic())
os << "static ";
@@ -213,7 +142,7 @@ void OpMethod::writeDeclTo(raw_ostream &os) const {
}
}
-void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
+void Method::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
// Do not write definition if the method is decl only.
if (properties & MP_Declaration)
return;
@@ -227,15 +156,15 @@ void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
}
//===----------------------------------------------------------------------===//
-// OpConstructor definitions
+// Constructor definitions
//===----------------------------------------------------------------------===//
-void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
+void Constructor::addMemberInitializer(StringRef name, StringRef value) {
memberInitializers.append(std::string(llvm::formatv(
"{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
}
-void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
+void Constructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
// Do not write definition if the method is decl only.
if (properties & MP_Declaration)
return;
@@ -243,7 +172,7 @@ void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
methodSignature.writeDefTo(os, namePrefix);
os << " " << memberInitializers << " {\n";
methodBody.writeTo(os);
- os << "}";
+ os << "}\n";
}
//===----------------------------------------------------------------------===//
@@ -259,12 +188,13 @@ void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
: formatv("{0} = {1}", varName, defaultValue).str();
fields.push_back(std::move(field));
}
+
void Class::writeDeclTo(raw_ostream &os) const {
bool hasPrivateMethod = false;
os << "class " << className << " {\n";
os << "public:\n";
- forAllMethods([&](const OpMethod &method) {
+ forAllMethods([&](const Method &method) {
if (!method.isPrivate()) {
method.writeDeclTo(os);
os << '\n';
@@ -276,7 +206,7 @@ void Class::writeDeclTo(raw_ostream &os) const {
os << '\n';
os << "private:\n";
if (hasPrivateMethod) {
- forAllMethods([&](const OpMethod &method) {
+ forAllMethods([&](const Method &method) {
if (method.isPrivate()) {
method.writeDeclTo(os);
os << '\n';
@@ -291,12 +221,35 @@ void Class::writeDeclTo(raw_ostream &os) const {
}
void Class::writeDefTo(raw_ostream &os) const {
- forAllMethods([&](const OpMethod &method) {
+ forAllMethods([&](const Method &method) {
method.writeDefTo(os, className);
- os << "\n\n";
+ os << "\n";
});
}
+// Insert a new method into a list of methods, if it would not be pruned, and
+// prune and existing methods.
+template <typename ContainerT, typename MethodT>
+MethodT *insertAndPrune(ContainerT &methods, MethodT newMethod) {
+ if (llvm::any_of(methods, [&](auto &method) {
+ return method.makesRedundant(newMethod);
+ }))
+ return nullptr;
+
+ llvm::erase_if(
+ methods, [&](auto &method) { return newMethod.makesRedundant(method); });
+ methods.push_back(std::move(newMethod));
+ return &methods.back();
+}
+
+Method *Class::addMethodAndPrune(Method &&newMethod) {
+ return insertAndPrune(methods, std::move(newMethod));
+}
+
+Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) {
+ return insertAndPrune(constructors, std::move(newCtor));
+}
+
//===----------------------------------------------------------------------===//
// OpClass definitions
//===----------------------------------------------------------------------===//
@@ -304,15 +257,11 @@ void Class::writeDefTo(raw_ostream &os) const {
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
: Class(name), extraClassDeclaration(extraClassDeclaration) {}
-void OpClass::addTrait(Twine trait) {
- auto traitStr = trait.str();
- if (traitsSet.insert(traitStr).second)
- traitsVec.push_back(std::move(traitStr));
-}
+void OpClass::addTrait(Twine trait) { traits.insert(trait.str()); }
void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " : public ::mlir::Op<" << className;
- for (const auto &trait : traitsVec)
+ for (const auto &trait : traits)
os << ", " << trait;
os << "> {\npublic:\n"
<< " using Op::Op;\n"
@@ -320,7 +269,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
<< " using Adaptor = " << className << "Adaptor;\n";
bool hasPrivateMethod = false;
- forAllMethods([&](const OpMethod &method) {
+ forAllMethods([&](const Method &method) {
if (!method.isPrivate()) {
method.writeDeclTo(os);
os << "\n";
@@ -335,7 +284,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
if (hasPrivateMethod) {
os << "\nprivate:\n";
- forAllMethods([&](const OpMethod &method) {
+ forAllMethods([&](const Method &method) {
if (method.isPrivate()) {
method.writeDeclTo(os);
os << "\n";
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 7e3a96e5980e2..33c15fb81fbb5 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -10,11 +10,11 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
-#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/Optional.h"
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index b60d8a2c98dd8..5e1cb842f1bc0 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -13,11 +13,11 @@
#include "OpFormatGen.h"
#include "OpGenHelpers.h"
+#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
-#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/SideEffects.h"
#include "mlir/TableGen/Trait.h"
@@ -361,7 +361,7 @@ class OpEmitter {
// types. `inferredAttributes` is populated with any attributes that are
// elided from the build list. The given `typeParamKind` and `attrParamKind`
// controls how result types and attributes are placed in the parameter list.
- void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> ¶mList,
+ void buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
llvm::StringSet<> &inferredAttributes,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
@@ -369,7 +369,7 @@ class OpEmitter {
// Adds op arguments and regions into operation state for build() methods.
void
- genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
+ genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
llvm::StringSet<> &inferredAttributes,
bool isRawValueAttr = false);
@@ -390,17 +390,16 @@ class OpEmitter {
// Generates verify statements for operands and results in the operation.
// The generated code will be attached to `body`.
- void genOperandResultVerifier(OpMethodBody &body,
- Operator::value_range values,
+ void genOperandResultVerifier(MethodBody &body, Operator::value_range values,
StringRef valueKind);
// Generates verify statements for regions in the operation.
// The generated code will be attached to `body`.
- void genRegionVerifier(OpMethodBody &body);
+ void genRegionVerifier(MethodBody &body);
// Generates verify statements for successors in the operation.
// The generated code will be attached to `body`.
- void genSuccessorVerifier(OpMethodBody &body);
+ void genSuccessorVerifier(MethodBody &body);
// Generates the traits used by the object.
void genTraits();
@@ -413,8 +412,8 @@ class OpEmitter {
// Generate op interface method for the given interface method. If
// 'declaration' is true, generates a declaration, else a definition.
- OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
- bool declaration = true);
+ Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
+ bool declaration = true);
// Generate the side effect interface methods.
void genSideEffectInterfaceMethods();
@@ -470,7 +469,7 @@ static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
// Generate attribute verification. If an op instance is not available, then
// attribute checks that require one will not be emitted.
static void genAttributeVerifier(
- const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, OpMethodBody &body,
+ const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
// Check that a required attribute exists.
//
@@ -602,7 +601,7 @@ void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
-static void errorIfPruned(size_t line, OpMethod *m, const Twine &methodName,
+static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
const Operator &op) {
if (m)
return;
@@ -627,18 +626,15 @@ void OpEmitter::genAttrNameGetters() {
for (const NamedAttribute &namedAttr : op.getAttributes())
addAttrName(namedAttr.name);
// Include key attributes from several traits as implicitly registered.
- std::string operandSizes = "operand_segment_sizes";
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
- addAttrName(operandSizes);
- std::string attrSizes = "result_segment_sizes";
+ addAttrName("operand_segment_sizes");
if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
- addAttrName(attrSizes);
+ addAttrName("result_segment_sizes");
// Emit the getAttributeNames method.
{
- auto *method = opClass.addMethodAndPrune(
- "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames",
- OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline));
+ auto *method = opClass.addStaticInlineMethod(
+ "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
ERROR_IF_PRUNED(method, "getAttributeNames", op);
auto &body = method->body();
if (attributeNames.empty()) {
@@ -658,20 +654,18 @@ void OpEmitter::genAttrNameGetters() {
// Emit the getAttributeNameForIndex methods.
{
- auto *method = opClass.addMethodAndPrune(
+ auto *method = opClass.addInlineMethod<Method::MP_Private>(
"::mlir::Identifier", "getAttributeNameForIndex",
- OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline),
- "unsigned", "index");
+ MethodParameter("unsigned", "index"));
ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
method->body()
<< " return getAttributeNameForIndex((*this)->getName(), index);";
}
{
- auto *method = opClass.addMethodAndPrune(
+ auto *method = opClass.addStaticInlineMethod<Method::MP_Private>(
"::mlir::Identifier", "getAttributeNameForIndex",
- OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline |
- OpMethod::MP_Static),
- "::mlir::OperationName name, unsigned index");
+ MethodParameter("::mlir::OperationName", "name"),
+ MethodParameter("unsigned", "index"));
ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
method->body() << "assert(index < " << attributeNames.size()
<< " && \"invalid attribute index\");\n"
@@ -689,8 +683,7 @@ void OpEmitter::genAttrNameGetters() {
// Generate the non-static variant.
{
auto *method =
- opClass.addMethodAndPrune("::mlir::Identifier", methodName,
- OpMethod::Property(OpMethod::MP_Inline));
+ opClass.addInlineMethod("::mlir::Identifier", methodName);
ERROR_IF_PRUNED(method, methodName, op);
method->body()
<< llvm::formatv(attrNameMethodBody, attrIt.second).str();
@@ -698,10 +691,9 @@ void OpEmitter::genAttrNameGetters() {
// Generate the static variant.
{
- auto *method = opClass.addMethodAndPrune(
+ auto *method = opClass.addStaticInlineMethod(
"::mlir::Identifier", methodName,
- OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static),
- "::mlir::OperationName", "name");
+ MethodParameter("::mlir::OperationName", "name"));
ERROR_IF_PRUNED(method, methodName, op);
method->body() << llvm::formatv(attrNameMethodBody,
"name, " + Twine(attrIt.second))
@@ -717,13 +709,13 @@ void OpEmitter::genAttrGetters() {
// Emit the derived attribute body.
auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
- if (auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name))
+ if (auto *method = opClass.addMethod(attr.getReturnType(), name))
method->body() << " " << attr.getDerivedCodeBody() << "\n";
};
// Emit with return type specified.
auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
- auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
+ auto *method = opClass.addMethod(attr.getReturnType(), name);
ERROR_IF_PRUNED(method, name, op);
auto &body = method->body();
body << " auto attr = " << name << "Attr();\n";
@@ -748,7 +740,7 @@ void OpEmitter::genAttrGetters() {
// use the string interface for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
auto *method =
- opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
+ opClass.addMethod(attr.getStorageType(), (name + "Attr").str());
if (!method)
return;
method->body() << formatv(
@@ -773,68 +765,69 @@ void OpEmitter::genAttrGetters() {
[](const NamedAttribute &namedAttr) {
return namedAttr.attr.isDerivedAttr();
});
- if (!derivedAttrs.empty()) {
- opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
- // Generate helper method to query whether a named attribute is a derived
- // attribute. This enables, for example, avoiding adding an attribute that
- // overlaps with a derived attribute.
- {
- auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
- OpMethod::MP_Static,
- "::llvm::StringRef", "name");
- ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
- auto &body = method->body();
- for (auto namedAttr : derivedAttrs)
- body << " if (name == \"" << namedAttr.name << "\") return true;\n";
- body << " return false;";
- }
- // Generate method to materialize derived attributes as a DictionaryAttr.
- {
- auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
- "materializeDerivedAttributes");
- ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
- auto &body = method->body();
-
- auto nonMaterializable =
- make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
- return namedAttr.attr.getConvertFromStorageCall().empty();
- });
- if (!nonMaterializable.empty()) {
- std::string attrs;
- llvm::raw_string_ostream os(attrs);
- interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
- os << op.getGetterName(attr.name);
- });
- PrintWarning(
- op.getLoc(),
- formatv(
- "op has non-materializable derived attributes '{0}', skipping",
- os.str()));
- body << formatv(" emitOpError(\"op has non-materializable derived "
- "attributes '{0}'\");\n",
- attrs);
- body << " return nullptr;";
- return;
- }
+ if (derivedAttrs.empty())
+ return;
- body << " ::mlir::MLIRContext* ctx = getContext();\n";
- body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
- body << " return ::mlir::DictionaryAttr::get(";
- body << " ctx, {\n";
- interleave(
- derivedAttrs, body,
- [&](const NamedAttribute &namedAttr) {
- auto tmpl = namedAttr.attr.getConvertFromStorageCall();
- std::string name = op.getGetterName(namedAttr.name);
- body << " {" << name << "AttrName(),\n"
- << tgfmt(tmpl, &fctx.withSelf(name + "()")
- .withBuilder("odsBuilder")
- .addSubst("_ctx", "ctx"))
- << "}";
- },
- ",\n");
- body << "});";
+ opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
+ // Generate helper method to query whether a named attribute is a derived
+ // attribute. This enables, for example, avoiding adding an attribute that
+ // overlaps with a derived attribute.
+ {
+ auto *method =
+ opClass.addStaticMethod("bool", "isDerivedAttribute",
+ MethodParameter("::llvm::StringRef", "name"));
+ ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
+ auto &body = method->body();
+ for (auto namedAttr : derivedAttrs)
+ body << " if (name == \"" << namedAttr.name << "\") return true;\n";
+ body << " return false;";
+ }
+ // Generate method to materialize derived attributes as a DictionaryAttr.
+ {
+ auto *method = opClass.addMethod("::mlir::DictionaryAttr",
+ "materializeDerivedAttributes");
+ ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
+ auto &body = method->body();
+
+ auto nonMaterializable =
+ make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
+ return namedAttr.attr.getConvertFromStorageCall().empty();
+ });
+ if (!nonMaterializable.empty()) {
+ std::string attrs;
+ llvm::raw_string_ostream os(attrs);
+ interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
+ os << op.getGetterName(attr.name);
+ });
+ PrintWarning(
+ op.getLoc(),
+ formatv(
+ "op has non-materializable derived attributes '{0}', skipping",
+ os.str()));
+ body << formatv(" emitOpError(\"op has non-materializable derived "
+ "attributes '{0}'\");\n",
+ attrs);
+ body << " return nullptr;";
+ return;
}
+
+ body << " ::mlir::MLIRContext* ctx = getContext();\n";
+ body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
+ body << " return ::mlir::DictionaryAttr::get(";
+ body << " ctx, {\n";
+ interleave(
+ derivedAttrs, body,
+ [&](const NamedAttribute &namedAttr) {
+ auto tmpl = namedAttr.attr.getConvertFromStorageCall();
+ std::string name = op.getGetterName(namedAttr.name);
+ body << " {" << name << "AttrName(),\n"
+ << tgfmt(tmpl, &fctx.withSelf(name + "()")
+ .withBuilder("odsBuilder")
+ .addSubst("_ctx", "ctx"))
+ << "}";
+ },
+ ",\n");
+ body << "});";
}
}
@@ -844,19 +837,21 @@ void OpEmitter::genAttrSetters() {
// for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
Attribute attr) {
- auto *method = opClass.addMethodAndPrune(
- "void", (setterName + "Attr").str(), attr.getStorageType(), "attr");
+ auto *method =
+ opClass.addMethod("void", (setterName + "Attr").str(),
+ MethodParameter(attr.getStorageType(), "attr"));
if (method)
method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);",
getterName);
};
for (const NamedAttribute &namedAttr : op.getAttributes()) {
- if (!namedAttr.attr.isDerivedAttr())
- for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
- op.getGetterNames(namedAttr.name)))
- emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
- namedAttr.attr);
+ if (namedAttr.attr.isDerivedAttr())
+ continue;
+ for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
+ op.getGetterNames(namedAttr.name)))
+ emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
+ namedAttr.attr);
}
}
@@ -866,7 +861,7 @@ void OpEmitter::genOptionalAttrRemovers() {
auto emitRemoveAttr = [&](StringRef name) {
auto upperInitial = name.take_front().upper();
auto suffix = name.drop_front();
- auto *method = opClass.addMethodAndPrune(
+ auto *method = opClass.addMethod(
"::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
if (!method)
return;
@@ -887,8 +882,8 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
int numVariadic, int numNonVariadic,
StringRef rangeSizeCall, bool hasAttrSegmentSize,
StringRef sizeAttrInit, RangeT &&odsValues) {
- auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
- methodName, "unsigned", "index");
+ auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
+ MethodParameter("unsigned", "index"));
if (!method)
return;
auto &body = method->body();
@@ -900,7 +895,7 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
// Because the op can have arbitrarily interleaved variadic and non-variadic
// operands, we need to embed a list in the "sink" getter method for
// calculation at run-time.
- llvm::SmallVector<StringRef, 4> isVariadic;
+ SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(llvm::size(odsValues));
for (auto &it : odsValues)
isVariadic.push_back(it.isVariableLength() ? "true" : "false");
@@ -959,8 +954,8 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
rangeSizeCall, attrSizedOperands, sizeAttrInit,
const_cast<Operator &>(op).getOperands());
- auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
- "index");
+ auto *m = opClass.addMethod(rangeType, "getODSOperands",
+ MethodParameter("unsigned", "index"));
ERROR_IF_PRUNED(m, "getODSOperands", op);
auto &body = m->body();
body << formatv(valueRangeReturnCode, rangeBeginCall,
@@ -974,7 +969,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
continue;
for (StringRef name : op.getGetterNames(operand.name)) {
if (operand.isOptional()) {
- m = opClass.addMethodAndPrune("::mlir::Value", name);
+ m = opClass.addMethod("::mlir::Value", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " auto operands = getODSOperands(" << i << ");\n"
<< " return operands.empty() ? ::mlir::Value() : "
@@ -983,24 +978,24 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
std::string segmentAttr = op.getGetterName(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
if (isAdaptor) {
- m = opClass.addMethodAndPrune(
- "::llvm::SmallVector<::mlir::ValueRange>", name);
+ m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>",
+ name);
ERROR_IF_PRUNED(m, name, op);
m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
segmentAttr, i);
continue;
}
- m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", name);
+ m = opClass.addMethod("::mlir::OperandRangeRange", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSOperands(" << i << ").split("
<< segmentAttr << "Attr());";
} else if (operand.isVariadic()) {
- m = opClass.addMethodAndPrune(rangeType, name);
+ m = opClass.addMethod(rangeType, name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSOperands(" << i << ");";
} else {
- m = opClass.addMethodAndPrune("::mlir::Value", name);
+ m = opClass.addMethod("::mlir::Value", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSOperands(" << i << ").begin();";
}
@@ -1035,10 +1030,10 @@ void OpEmitter::genNamedOperandSetters() {
if (operand.name.empty())
continue;
for (StringRef name : op.getGetterNames(operand.name)) {
- auto *m = opClass.addMethodAndPrune(
- operand.isVariadicOfVariadic() ? "::mlir::MutableOperandRangeRange"
- : "::mlir::MutableOperandRange",
- (name + "Mutable").str());
+ auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
+ ? "::mlir::MutableOperandRangeRange"
+ : "::mlir::MutableOperandRange",
+ (name + "Mutable").str());
ERROR_IF_PRUNED(m, name, op);
auto &body = m->body();
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
@@ -1110,8 +1105,9 @@ void OpEmitter::genNamedResultGetters() {
numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
attrSizeInitCode, op.getResults());
- auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
- "getODSResults", "unsigned", "index");
+ auto *m =
+ opClass.addMethod("::mlir::Operation::result_range", "getODSResults",
+ MethodParameter("unsigned", "index"));
ERROR_IF_PRUNED(m, "getODSResults", op);
m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
"getODSResultIndexAndLength(index)");
@@ -1122,17 +1118,17 @@ void OpEmitter::genNamedResultGetters() {
continue;
for (StringRef name : op.getGetterNames(result.name)) {
if (result.isOptional()) {
- m = opClass.addMethodAndPrune("::mlir::Value", name);
+ m = opClass.addMethod("::mlir::Value", name);
ERROR_IF_PRUNED(m, name, op);
m->body()
<< " auto results = getODSResults(" << i << ");\n"
<< " return results.empty() ? ::mlir::Value() : *results.begin();";
} else if (result.isVariadic()) {
- m = opClass.addMethodAndPrune("::mlir::Operation::result_range", name);
+ m = opClass.addMethod("::mlir::Operation::result_range", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSResults(" << i << ");";
} else {
- m = opClass.addMethodAndPrune("::mlir::Value", name);
+ m = opClass.addMethod("::mlir::Value", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSResults(" << i << ").begin();";
}
@@ -1150,15 +1146,15 @@ void OpEmitter::genNamedRegionGetters() {
for (StringRef name : op.getGetterNames(region.name)) {
// Generate the accessors for a variadic region.
if (region.isVariadic()) {
- auto *m = opClass.addMethodAndPrune(
- "::mlir::MutableArrayRef<::mlir::Region>", name);
+ auto *m =
+ opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
i);
continue;
}
- auto *m = opClass.addMethodAndPrune("::mlir::Region &", name);
+ auto *m = opClass.addMethod("::mlir::Region &", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getRegion({0});", i);
}
@@ -1175,7 +1171,7 @@ void OpEmitter::genNamedSuccessorGetters() {
for (StringRef name : op.getGetterNames(successor.name)) {
// Generate the accessors for a variadic successor list.
if (successor.isVariadic()) {
- auto *m = opClass.addMethodAndPrune("::mlir::SuccessorRange", name);
+ auto *m = opClass.addMethod("::mlir::SuccessorRange", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(
" return {std::next((*this)->successor_begin(), {0}), "
@@ -1184,7 +1180,7 @@ void OpEmitter::genNamedSuccessorGetters() {
continue;
}
- auto *m = opClass.addMethodAndPrune("::mlir::Block *", name);
+ auto *m = opClass.addMethod("::mlir::Block *", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getSuccessor({0});", i);
}
@@ -1227,14 +1223,13 @@ void OpEmitter::genSeparateArgParamBuilder() {
// inferring result type.
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
bool inferType) {
- llvm::SmallVector<OpMethodParameter, 4> paramList;
- llvm::SmallVector<std::string, 4> resultNames;
+ SmallVector<MethodParameter> paramList;
+ SmallVector<std::string, 4> resultNames;
llvm::StringSet<> inferredAttributes;
buildParamList(paramList, inferredAttributes, resultNames, paramKind,
attrType);
- auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
- std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method.
if (!m)
return;
@@ -1308,7 +1303,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
int numResults = op.getNumResults();
// Signature
- llvm::SmallVector<OpMethodParameter, 4> paramList;
+ SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
@@ -1319,8 +1314,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
- auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
- std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
@@ -1348,14 +1342,13 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
// TODO: Expand to support regions.
- SmallVector<OpMethodParameter, 4> paramList;
+ SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", "{}");
- auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
- std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
@@ -1407,14 +1400,13 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
}
void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
- llvm::SmallVector<OpMethodParameter, 4> paramList;
- llvm::SmallVector<std::string, 4> resultNames;
+ SmallVector<MethodParameter> paramList;
+ SmallVector<std::string, 4> resultNames;
llvm::StringSet<> inferredAttributes;
buildParamList(paramList, inferredAttributes, resultNames,
TypeParamKind::None);
- auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
- std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
@@ -1436,14 +1428,13 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
}
void OpEmitter::genUseAttrAsResultTypeBuilder() {
- SmallVector<OpMethodParameter, 4> paramList;
+ SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", "{}");
- auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
- std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
@@ -1480,16 +1471,15 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
/// Returns a signature of the builder. Updates the context `fctx` to enable
/// replacement of $_builder and $_state in the body.
-static std::string getBuilderSignature(const Builder &builder) {
+static SmallVector<MethodParameter>
+getBuilderSignature(const Builder &builder) {
ArrayRef<Builder::Parameter> params(builder.getParameters());
// Inject builder and state arguments.
- llvm::SmallVector<std::string, 8> arguments;
+ SmallVector<MethodParameter> arguments;
arguments.reserve(params.size() + 2);
- arguments.push_back(
- llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str());
- arguments.push_back(
- llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
+ arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
+ arguments.emplace_back("::mlir::OperationState &", builderOpState);
for (unsigned i = 0, e = params.size(); i < e; ++i) {
// If no name is provided, generate one.
@@ -1497,27 +1487,27 @@ static std::string getBuilderSignature(const Builder &builder) {
std::string name =
paramName ? paramName->str() : "odsArg" + std::to_string(i);
- std::string defaultValue;
+ StringRef defaultValue;
if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
- defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str();
- arguments.push_back(
- llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue)
- .str());
+ defaultValue = *defaultParamValue;
+
+ arguments.emplace_back(params[i].getCppType(), std::move(name),
+ defaultValue);
}
- return llvm::join(arguments, ", ");
+ return arguments;
}
void OpEmitter::genBuilder() {
// Handle custom builders if provided.
for (const Builder &builder : op.getBuilders()) {
- std::string paramStr = getBuilderSignature(builder);
+ SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
Optional<StringRef> body = builder.getBody();
- OpMethod::Property properties =
- body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
+ Method::Property properties =
+ body ? Method::MP_Static : Method::MP_StaticDeclaration;
auto *method =
- opClass.addMethodAndPrune("void", "build", properties, paramStr);
+ opClass.addMethod("void", "build", properties, std::move(arguments));
if (body)
ERROR_IF_PRUNED(method, "build", op);
@@ -1561,7 +1551,7 @@ void OpEmitter::genCollectiveParamBuilder() {
int numVariadicOperands = op.getNumVariableLengthOperands();
int numNonVariadicOperands = numOperands - numVariadicOperands;
- SmallVector<OpMethodParameter, 4> paramList;
+ SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::TypeRange", "resultTypes");
@@ -1573,8 +1563,7 @@ void OpEmitter::genCollectiveParamBuilder() {
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
- auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
- std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
@@ -1612,7 +1601,7 @@ void OpEmitter::genCollectiveParamBuilder() {
genInferredTypeCollectiveParamBuilder();
}
-void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
+void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
llvm::StringSet<> &inferredAttributes,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
@@ -1637,11 +1626,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
StringRef type =
result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
- OpMethodParameter::Property properties = OpMethodParameter::PP_None;
- if (result.isOptional())
- properties = OpMethodParameter::PP_Optional;
- paramList.emplace_back(type, resultName, properties);
+ paramList.emplace_back(type, resultName, result.isOptional());
resultTypeNames.emplace_back(std::move(resultName));
}
} break;
@@ -1699,11 +1685,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
else
type = "::mlir::Value";
- OpMethodParameter::Property properties = OpMethodParameter::PP_None;
- if (operand->isOptional())
- properties = OpMethodParameter::PP_Optional;
paramList.emplace_back(type, getArgumentName(op, numOperands++),
- properties);
+ operand->isOptional());
continue;
}
const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
@@ -1713,10 +1696,6 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
if (inferredAttributes.contains(namedAttr.name))
continue;
- OpMethodParameter::Property properties = OpMethodParameter::PP_None;
- if (attr.isOptional())
- properties = OpMethodParameter::PP_Optional;
-
StringRef type;
switch (attrParamKind) {
case AttrParamKind::WrappedAttr:
@@ -1736,7 +1715,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
i >= defaultValuedAttrStartIndex) {
defaultValue += attr.getDefaultValue();
}
- paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
+ paramList.emplace_back(type, namedAttr.name, defaultValue,
+ attr.isOptional());
}
/// Insert parameters for each successor.
@@ -1754,7 +1734,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
}
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
- OpMethodBody &body, llvm::StringSet<> &inferredAttributes,
+ MethodBody &body, llvm::StringSet<> &inferredAttributes,
bool isRawValueAttr) {
// Push all operands to the result.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
@@ -1871,12 +1851,11 @@ void OpEmitter::genCanonicalizerDecls() {
if (hasCanonicalizeMethod) {
// static LogicResult FooOp::
// canonicalize(FooOp op, PatternRewriter &rewriter);
- SmallVector<OpMethodParameter, 2> paramList;
+ SmallVector<MethodParameter> paramList;
paramList.emplace_back(op.getCppClassName(), "op");
paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
- auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize",
- OpMethod::MP_StaticDeclaration,
- std::move(paramList));
+ auto *m = opClass.declareStaticMethod("::mlir::LogicalResult",
+ "canonicalize", std::move(paramList));
ERROR_IF_PRUNED(m, "canonicalize", op);
}
@@ -1892,12 +1871,12 @@ void OpEmitter::genCanonicalizerDecls() {
// Add a signature for getCanonicalizationPatterns if implemented by the
// dialect or if synthesized to call 'canonicalize'.
- SmallVector<OpMethodParameter, 2> paramList;
+ SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::RewritePatternSet &", "results");
paramList.emplace_back("::mlir::MLIRContext *", "context");
- auto kind = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
- auto *method = opClass.addMethodAndPrune(
- "void", "getCanonicalizationPatterns", kind, std::move(paramList));
+ auto kind = hasBody ? Method::MP_Static : Method::MP_StaticDeclaration;
+ auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind,
+ std::move(paramList));
// If synthesizing the method, fill it it.
if (hasBody) {
@@ -1912,18 +1891,17 @@ void OpEmitter::genFolderDecls() {
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
- auto *m = opClass.addMethodAndPrune(
- "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
- "::llvm::ArrayRef<::mlir::Attribute>", "operands");
+ auto *m = opClass.declareMethod(
+ "::mlir::OpFoldResult", "fold",
+ MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands"));
ERROR_IF_PRUNED(m, "operands", op);
} else {
- SmallVector<OpMethodParameter, 2> paramList;
+ SmallVector<MethodParameter> paramList;
paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
"results");
- auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
- OpMethod::MP_Declaration,
- std::move(paramList));
+ auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold",
+ std::move(paramList));
ERROR_IF_PRUNED(m, "fold", op);
}
}
@@ -1953,18 +1931,18 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
}
}
-OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
- bool declaration) {
- SmallVector<OpMethodParameter, 4> paramList;
+Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
+ bool declaration) {
+ SmallVector<MethodParameter> paramList;
for (const InterfaceMethod::Argument &arg : method.getArguments())
paramList.emplace_back(arg.type, arg.name);
- auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
+ auto properties = method.isStatic() ? Method::MP_Static : Method::MP_None;
if (declaration)
properties =
- static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
- return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
- properties, std::move(paramList));
+ static_cast<Method::Property>(properties | Method::MP_Declaration);
+ return opClass.addMethod(method.getReturnType(), method.getName(), properties,
+ std::move(paramList));
}
void OpEmitter::genOpInterfaceMethods() {
@@ -2039,8 +2017,8 @@ void OpEmitter::genSideEffectInterfaceMethods() {
"SideEffects::EffectInstance<{0}>> &",
it.first())
.str();
- auto *getEffects =
- opClass.addMethodAndPrune("void", "getEffects", type, "effects");
+ auto *getEffects = opClass.addMethod("void", "getEffects",
+ MethodParameter(type, "effects"));
ERROR_IF_PRUNED(getEffects, "getEffects", op);
auto &body = getEffects->body();
@@ -2082,7 +2060,7 @@ void OpEmitter::genTypeInterfaceMethods() {
const auto *trait = dyn_cast<InterfaceTrait>(
op.getTrait("::mlir::InferTypeOpInterface::Trait"));
Interface interface = trait->getInterface();
- OpMethod *method = [&]() -> OpMethod * {
+ Method *method = [&]() -> Method * {
for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
if (interfaceMethod.getName() == "inferReturnTypes") {
return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
@@ -2099,8 +2077,7 @@ void OpEmitter::genTypeInterfaceMethods() {
fctx.withBuilder("odsBuilder");
body << " ::mlir::Builder odsBuilder(context);\n";
- auto emitType =
- [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
+ auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
if (!type.isArg())
return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
auto argIndex = type.getArg();
@@ -2129,12 +2106,11 @@ void OpEmitter::genParser() {
hasStringAttribute(def, "assemblyFormat"))
return;
- SmallVector<OpMethodParameter, 2> paramList;
+ SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpAsmParser &", "parser");
paramList.emplace_back("::mlir::OperationState &", "result");
- auto *method =
- opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
- OpMethod::MP_Static, std::move(paramList));
+ auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
+ std::move(paramList));
ERROR_IF_PRUNED(method, "parse", op);
FmtContext fctx;
@@ -2152,8 +2128,8 @@ void OpEmitter::genPrinter() {
if (!stringInit)
return;
- auto *method =
- opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
+ auto *method = opClass.addMethod(
+ "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
ERROR_IF_PRUNED(method, "print", op);
FmtContext fctx;
fctx.addSubst("cppClass", opClass.getClassName());
@@ -2162,7 +2138,7 @@ void OpEmitter::genPrinter() {
}
/// Generate verification on native traits requiring attributes.
-static void genNativeTraitAttrVerifier(OpMethodBody &body,
+static void genNativeTraitAttrVerifier(MethodBody &body,
const OpOrAdaptorHelper &emitHelper) {
// Check that the variadic segment sizes attribute exists and contains the
// expected number of elements.
@@ -2209,7 +2185,7 @@ static void genNativeTraitAttrVerifier(OpMethodBody &body,
}
void OpEmitter::genVerifier() {
- auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
+ auto *method = opClass.addMethod("::mlir::LogicalResult", "verify");
ERROR_IF_PRUNED(method, "verify", op);
auto &body = method->body();
@@ -2247,7 +2223,7 @@ void OpEmitter::genVerifier() {
}
}
-void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
+void OpEmitter::genOperandResultVerifier(MethodBody &body,
Operator::value_range values,
StringRef valueKind) {
// Check that an optional value is at most 1 element.
@@ -2321,7 +2297,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
body << " }\n";
}
-void OpEmitter::genRegionVerifier(OpMethodBody &body) {
+void OpEmitter::genRegionVerifier(MethodBody &body) {
/// Code to verify a region.
///
/// {0}: Getter for the regions.
@@ -2363,7 +2339,7 @@ void OpEmitter::genRegionVerifier(OpMethodBody &body) {
body << " }\n";
}
-void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
+void OpEmitter::genSuccessorVerifier(MethodBody &body) {
const char *const verifySuccessor = R"(
for (auto *successor : {0})
if (::mlir::failed({1}(*this, successor, "{2}", index++)))
@@ -2485,9 +2461,8 @@ void OpEmitter::genTraits() {
}
void OpEmitter::genOpNameGetter() {
- auto *method = opClass.addMethodAndPrune(
- "::llvm::StringLiteral", "getOperationName",
- OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
+ auto *method = opClass.addStaticMethod<Method::MP_Constexpr>(
+ "::llvm::StringLiteral", "getOperationName");
ERROR_IF_PRUNED(method, "getOperationName", op);
method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
<< "\");";
@@ -2514,8 +2489,9 @@ void OpEmitter::genOpAsmInterface() {
opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
// Generate the right accessor for the number of results.
- auto *method = opClass.addMethodAndPrune(
- "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
+ auto *method = opClass.addMethod(
+ "void", "getAsmResultNames",
+ MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
ERROR_IF_PRUNED(method, "getAsmResultNames", op);
auto &body = method->body();
for (int i = 0; i != numResults; ++i) {
@@ -2567,7 +2543,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
const auto *attrSizedOperands =
op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
{
- SmallVector<OpMethodParameter, 2> paramList;
+ SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::ValueRange", "values");
paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
attrSizedOperands ? "" : "nullptr");
@@ -2581,14 +2557,14 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
{
auto *constructor = adaptor.addConstructorAndPrune(
- llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
+ MethodParameter(op.getCppClassName() + " &", "op"));
constructor->addMemberInitializer("odsOperands", "op->getOperands()");
constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
constructor->addMemberInitializer("odsRegions", "op->getRegions()");
}
{
- auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands");
+ auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands");
ERROR_IF_PRUNED(m, "getOperands", op);
m->body() << " return odsOperands;";
}
@@ -2605,7 +2581,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
auto emitAttr = [&](StringRef name, StringRef emitName, Attribute attr) {
- auto *method = adaptor.addMethodAndPrune(attr.getStorageType(), emitName);
+ auto *method = adaptor.addMethod(attr.getStorageType(), emitName);
ERROR_IF_PRUNED(method, "Adaptor::" + emitName, op);
auto &body = method->body();
body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
@@ -2629,8 +2605,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
};
{
- auto *m =
- adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes");
+ auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes");
ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
m->body() << " return odsAttrs;";
}
@@ -2645,7 +2620,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
unsigned numRegions = op.getNumRegions();
if (numRegions > 0) {
- auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions");
+ auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions");
ERROR_IF_PRUNED(m, "Adaptor::getRegions", op);
m->body() << " return odsRegions;";
}
@@ -2657,13 +2632,13 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
// Generate the accessors for a variadic region.
for (StringRef name : op.getGetterNames(region.name)) {
if (region.isVariadic()) {
- auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", name);
+ auto *m = adaptor.addMethod("::mlir::RegionRange", name);
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
m->body() << formatv(" return odsRegions.drop_front({0});", i);
continue;
}
- auto *m = adaptor.addMethodAndPrune("::mlir::Region &", name);
+ auto *m = adaptor.addMethod("::mlir::Region &", name);
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
m->body() << formatv(" return *odsRegions[{0}];", i);
}
@@ -2674,8 +2649,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
}
void OpOperandAdaptorEmitter::addVerification() {
- auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
- "::mlir::Location", "loc");
+ auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify",
+ MethodParameter("::mlir::Location", "loc"));
ERROR_IF_PRUNED(method, "verify", op);
auto &body = method->body();
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 19dd6fa7c1016..59d76f0a8e4b9 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -9,10 +9,10 @@
#include "OpFormatGen.h"
#include "FormatGen.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
-#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/MapVector.h"
@@ -140,8 +140,7 @@ using SuccessorVariable =
namespace {
/// This class implements single kind directives.
-template <Element::Kind type>
-class DirectiveElement : public Element {
+template <Element::Kind type> class DirectiveElement : public Element {
public:
DirectiveElement() : Element(type){};
static bool classof(const Element *ele) { return ele->getKind() == type; }
@@ -422,23 +421,23 @@ struct OperationFormat {
/// Generate the operation parser from this format.
void genParser(Operator &op, OpClass &opClass);
/// Generate the parser code for a specific format element.
- void genElementParser(Element *element, OpMethodBody &body,
+ void genElementParser(Element *element, MethodBody &body,
FmtContext &attrTypeCtx);
/// Generate the c++ to resolve the types of operands and results during
/// parsing.
- void genParserTypeResolution(Operator &op, OpMethodBody &body);
+ void genParserTypeResolution(Operator &op, MethodBody &body);
/// Generate the c++ to resolve regions during parsing.
- void genParserRegionResolution(Operator &op, OpMethodBody &body);
+ void genParserRegionResolution(Operator &op, MethodBody &body);
/// Generate the c++ to resolve successors during parsing.
- void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
+ void genParserSuccessorResolution(Operator &op, MethodBody &body);
/// Generate the c++ to handling variadic segment size traits.
- void genParserVariadicSegmentResolution(Operator &op, OpMethodBody &body);
+ void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
/// Generate the operation printer from this format.
void genPrinter(Operator &op, OpClass &opClass);
/// Generate the printer code for a specific format element.
- void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
+ void genElementPrinter(Element *element, MethodBody &body, Operator &op,
bool &shouldEmitSpace, bool &lastWasPunctuation);
/// The various elements in this format.
@@ -813,7 +812,7 @@ static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
}
/// Generate the parser for a literal value.
-static void genLiteralParser(StringRef value, OpMethodBody &body) {
+static void genLiteralParser(StringRef value, MethodBody &body) {
// Handle the case of a keyword/identifier.
if (value.front() == '_' || isalpha(value.front())) {
body << "Keyword(\"" << value << "\")";
@@ -839,7 +838,7 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
/// Generate the storage code required for parsing the given element.
static void genElementParserStorage(Element *element, const Operator &op,
- OpMethodBody &body) {
+ MethodBody &body) {
if (auto *optional = dyn_cast<OptionalElement>(element)) {
auto elements = optional->getThenElements();
@@ -937,7 +936,7 @@ static void genElementParserStorage(Element *element, const Operator &op,
}
/// Generate the parser for a parameter to a custom directive.
-static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
+static void genCustomParameterParser(Element ¶m, MethodBody &body) {
if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
body << attr->getVar()->name << "Attr";
} else if (isa<AttrDictDirective>(¶m)) {
@@ -988,7 +987,7 @@ static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
}
/// Generate the parser for a custom directive.
-static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
+static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
body << " {\n";
// Preprocess the directive variables.
@@ -1098,7 +1097,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
}
/// Generate the parser for a enum attribute.
-static void genEnumAttrParser(const NamedAttribute *var, OpMethodBody &body,
+static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
FmtContext &attrTypeCtx) {
Attribute baseAttr = var->attr.getBaseAttr();
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
@@ -1141,13 +1140,12 @@ static void genEnumAttrParser(const NamedAttribute *var, OpMethodBody &body,
}
void OperationFormat::genParser(Operator &op, OpClass &opClass) {
- llvm::SmallVector<OpMethodParameter, 4> paramList;
+ SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpAsmParser &", "parser");
paramList.emplace_back("::mlir::OperationState &", "result");
- auto *method =
- opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
- OpMethod::MP_Static, std::move(paramList));
+ auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
+ std::move(paramList));
auto &body = method->body();
// Generate variables to store the operands and type within the format. This
@@ -1174,7 +1172,7 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
body << " return ::mlir::success();\n";
}
-void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
+void OperationFormat::genElementParser(Element *element, MethodBody &body,
FmtContext &attrTypeCtx) {
/// Optional Group.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
@@ -1353,8 +1351,7 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
}
}
-void OperationFormat::genParserTypeResolution(Operator &op,
- OpMethodBody &body) {
+void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
// If any of type resolutions use transformed variables, make sure that the
// types of those variables are resolved.
SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
@@ -1528,7 +1525,7 @@ void OperationFormat::genParserTypeResolution(Operator &op,
}
void OperationFormat::genParserRegionResolution(Operator &op,
- OpMethodBody &body) {
+ MethodBody &body) {
// Check for the case where all regions were parsed.
bool hasAllRegions = llvm::any_of(
elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
@@ -1547,7 +1544,7 @@ void OperationFormat::genParserRegionResolution(Operator &op,
}
void OperationFormat::genParserSuccessorResolution(Operator &op,
- OpMethodBody &body) {
+ MethodBody &body) {
// Check for the case where all successors were parsed.
bool hasAllSuccessors = llvm::any_of(
elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
@@ -1566,7 +1563,7 @@ void OperationFormat::genParserSuccessorResolution(Operator &op,
}
void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
- OpMethodBody &body) {
+ MethodBody &body) {
if (!allOperands) {
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
body << " result.addAttribute(\"operand_segment_sizes\", "
@@ -1641,7 +1638,7 @@ const char *enumAttrBeginPrinterCode = R"(
/// Generate the printer for the 'attr-dict' directive.
static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
- OpMethodBody &body, bool withKeyword) {
+ MethodBody &body, bool withKeyword) {
body << " _odsPrinter.printOptionalAttrDict"
<< (withKeyword ? "WithKeyword" : "")
<< "((*this)->getAttrs(), /*elidedAttrs=*/{";
@@ -1665,7 +1662,7 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
/// Generate the printer for a literal value. `shouldEmitSpace` is true if a
/// space should be emitted before this element. `lastWasPunctuation` is true if
/// the previous element was a punctuation literal.
-static void genLiteralPrinter(StringRef value, OpMethodBody &body,
+static void genLiteralPrinter(StringRef value, MethodBody &body,
bool &shouldEmitSpace, bool &lastWasPunctuation) {
body << " _odsPrinter";
@@ -1682,8 +1679,8 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
/// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
/// are set to false.
-static void genSpacePrinter(bool value, OpMethodBody &body,
- bool &shouldEmitSpace, bool &lastWasPunctuation) {
+static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
+ bool &lastWasPunctuation) {
if (value) {
body << " _odsPrinter << ' ';\n";
lastWasPunctuation = false;
@@ -1696,7 +1693,7 @@ static void genSpacePrinter(bool value, OpMethodBody &body,
/// Generate the printer for a custom directive parameter.
static void genCustomDirectiveParameterPrinter(Element *element,
const Operator &op,
- OpMethodBody &body) {
+ MethodBody &body) {
if (auto *attr = dyn_cast<AttributeVariable>(element)) {
body << op.getGetterName(attr->getVar()->name) << "Attr()";
@@ -1734,7 +1731,7 @@ static void genCustomDirectiveParameterPrinter(Element *element,
/// Generate the printer for a custom directive.
static void genCustomDirectivePrinter(CustomDirective *customDir,
- const Operator &op, OpMethodBody &body) {
+ const Operator &op, MethodBody &body) {
body << " print" << customDir->getName() << "(_odsPrinter, *this";
for (Element ¶m : customDir->getArguments()) {
body << ", ";
@@ -1744,7 +1741,7 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
}
/// Generate the printer for a region with the given variable name.
-static void genRegionPrinter(const Twine ®ionName, OpMethodBody &body,
+static void genRegionPrinter(const Twine ®ionName, MethodBody &body,
bool hasImplicitTermTrait) {
if (hasImplicitTermTrait)
body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
@@ -1753,7 +1750,7 @@ static void genRegionPrinter(const Twine ®ionName, OpMethodBody &body,
body << " _odsPrinter.printRegion(" << regionName << ");\n";
}
static void genVariadicRegionPrinter(const Twine ®ionListName,
- OpMethodBody &body,
+ MethodBody &body,
bool hasImplicitTermTrait) {
body << " llvm::interleaveComma(" << regionListName
<< ", _odsPrinter, [&](::mlir::Region ®ion) {\n ";
@@ -1762,8 +1759,8 @@ static void genVariadicRegionPrinter(const Twine ®ionListName,
}
/// Generate the C++ for an operand to a (*-)type directive.
-static OpMethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
- OpMethodBody &body) {
+static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
+ MethodBody &body) {
if (isa<OperandsDirective>(arg))
return body << "getOperation()->getOperandTypes()";
if (isa<ResultsDirective>(arg))
@@ -1786,7 +1783,7 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
/// Generate the printer for an enum attribute.
static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
- OpMethodBody &body) {
+ MethodBody &body) {
Attribute baseAttr = var->attr.getBaseAttr();
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
@@ -1864,7 +1861,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
/// Generate the check for the anchor of an optional group.
static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
- OpMethodBody &body) {
+ MethodBody &body) {
TypeSwitch<Element *>(anchor)
.Case<OperandVariable, ResultVariable>([&](auto *element) {
const NamedTypeConstraint *var = element->getVar();
@@ -1892,7 +1889,7 @@ static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
});
}
-void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
+void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
Operator &op, bool &shouldEmitSpace,
bool &lastWasPunctuation) {
if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
@@ -2047,8 +2044,9 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
}
void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
- auto *method = opClass.addMethodAndPrune("void", "print",
- "::mlir::OpAsmPrinter &_odsPrinter");
+ auto *method = opClass.addMethod(
+ "void", "print",
+ MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
auto &body = method->body();
// Flags for if we should emit a space, and if the last element was
@@ -2065,8 +2063,7 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
/// Function to find an element within the given range that has the same name as
/// 'name'.
-template <typename RangeT>
-static auto findArg(RangeT &&range, StringRef name) {
+template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
return it != range.end() ? &*it : nullptr;
}
More information about the Mlir-commits
mailing list