[Mlir-commits] [mlir] [NFC][mlir][tblgen] Refactor TableGen into a Cpp exposed lib (PR #189689)
Fabian Mora
llvmlistbot at llvm.org
Fri Apr 17 04:17:33 PDT 2026
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/189689
>From 50251118a9b314ef6f530ecc09650b3ea537c849 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Tue, 31 Mar 2026 14:49:48 +0000
Subject: [PATCH] [mlir][tblgen] Refactor TableGen into a Cpp exposed lib
Co-Authored-By: Claude Sonnet 4.6 <noreply at anthropic.com>
Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
.../TableGen/Generators/AttrOrTypeDefGen.h | 233 ++++++++
.../TableGen/Generators/AttrOrTypeFormatGen.h | 242 ++++++++
.../TableGen/Generators/BytecodeDialectGen.h | 30 +
.../TableGen/Generators}/CppGenUtilities.h | 20 +-
.../mlir/TableGen/Generators/DialectGen.h | 67 +++
.../Generators/DialectInterfacesGen.h | 55 ++
.../TableGen/Generators}/DocGenUtilities.h | 28 +-
.../Generators/EnumPythonBindingGen.h | 30 +
.../mlir/TableGen/Generators/EnumsGen.h | 42 ++
.../mlir/TableGen/Generators}/FormatGen.h | 84 +--
.../TableGen/Generators/OpAdaptorHelper.h | 225 ++++++++
.../mlir/TableGen/Generators}/OpClass.h | 10 +-
.../TableGen/Generators/OpDefinitionsGen.h | 305 ++++++++++
.../mlir/TableGen/Generators/OpDocGen.h | 102 ++++
.../mlir/TableGen/Generators/OpFormatGen.h | 180 ++++++
.../mlir/TableGen/Generators}/OpGenHelpers.h | 26 +-
.../TableGen/Generators/OpInterfacesGen.h | 96 ++++
.../TableGen/Generators/OpPythonBindingGen.h | 33 ++
.../mlir/TableGen/Generators/PassCAPIGen.h | 40 ++
.../mlir/TableGen/Generators/PassDocGen.h | 31 +
.../mlir/TableGen/Generators/PassGen.h | 68 +++
.../mlir/TableGen/Generators/RewriterGen.h | 31 +
mlir/lib/TableGen/CMakeLists.txt | 2 +
.../TableGen/Generators}/AttrOrTypeDefGen.cpp | 409 ++++---------
.../Generators}/AttrOrTypeFormatGen.cpp | 335 +++--------
.../Generators}/BytecodeDialectGen.cpp | 27 +-
mlir/lib/TableGen/Generators/CMakeLists.txt | 39 ++
.../TableGen/Generators}/CppGenUtilities.cpp | 13 +-
.../TableGen/Generators}/DialectGen.cpp | 243 ++++----
.../Generators}/DialectInterfacesGen.cpp | 60 +-
.../TableGen/Generators/DocGenUtilities.cpp | 54 ++
.../Generators}/EnumPythonBindingGen.cpp | 26 +-
.../TableGen/Generators}/EnumsGen.cpp | 48 +-
.../TableGen/Generators}/FormatGen.cpp | 93 ++-
.../TableGen/Generators}/OpClass.cpp | 2 +-
.../TableGen/Generators}/OpDefinitionsGen.cpp | 540 +++---------------
.../TableGen/Generators}/OpDocGen.cpp | 239 ++------
.../TableGen/Generators}/OpFormatGen.cpp | 152 +----
.../TableGen/Generators}/OpGenHelpers.cpp | 49 +-
.../TableGen/Generators}/OpInterfacesGen.cpp | 245 ++------
.../Generators}/OpPythonBindingGen.cpp | 51 +-
.../TableGen/Generators}/PassCAPIGen.cpp | 78 +--
.../TableGen/Generators}/PassDocGen.cpp | 31 +-
.../TableGen/Generators}/PassGen.cpp | 275 ++++-----
.../TableGen/Generators}/RewriterGen.cpp | 357 ++++++------
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h | 26 -
mlir/tools/mlir-tblgen/CMakeLists.txt | 22 +-
mlir/tools/mlir-tblgen/DialectGenUtilities.h | 24 -
mlir/tools/mlir-tblgen/Generators.cpp | 429 ++++++++++++++
mlir/tools/mlir-tblgen/OpFormatGen.h | 29 -
50 files changed, 3413 insertions(+), 2463 deletions(-)
create mode 100644 mlir/include/mlir/TableGen/Generators/AttrOrTypeDefGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/AttrOrTypeFormatGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/BytecodeDialectGen.h
rename mlir/{tools/mlir-tblgen => include/mlir/TableGen/Generators}/CppGenUtilities.h (60%)
create mode 100644 mlir/include/mlir/TableGen/Generators/DialectGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/DialectInterfacesGen.h
rename mlir/{tools/mlir-tblgen => include/mlir/TableGen/Generators}/DocGenUtilities.h (55%)
create mode 100644 mlir/include/mlir/TableGen/Generators/EnumPythonBindingGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/EnumsGen.h
rename mlir/{tools/mlir-tblgen => include/mlir/TableGen/Generators}/FormatGen.h (88%)
create mode 100644 mlir/include/mlir/TableGen/Generators/OpAdaptorHelper.h
rename mlir/{tools/mlir-tblgen => include/mlir/TableGen/Generators}/OpClass.h (91%)
create mode 100644 mlir/include/mlir/TableGen/Generators/OpDefinitionsGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/OpDocGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/OpFormatGen.h
rename mlir/{tools/mlir-tblgen => include/mlir/TableGen/Generators}/OpGenHelpers.h (53%)
create mode 100644 mlir/include/mlir/TableGen/Generators/OpInterfacesGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/OpPythonBindingGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/PassCAPIGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/PassDocGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/PassGen.h
create mode 100644 mlir/include/mlir/TableGen/Generators/RewriterGen.h
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/AttrOrTypeDefGen.cpp (76%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/AttrOrTypeFormatGen.cpp (77%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/BytecodeDialectGen.cpp (95%)
create mode 100644 mlir/lib/TableGen/Generators/CMakeLists.txt
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/CppGenUtilities.cpp (74%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/DialectGen.cpp (73%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/DialectInterfacesGen.cpp (72%)
create mode 100644 mlir/lib/TableGen/Generators/DocGenUtilities.cpp
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/EnumPythonBindingGen.cpp (89%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/EnumsGen.cpp (96%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/FormatGen.cpp (87%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/OpClass.cpp (97%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/OpDefinitionsGen.cpp (90%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/OpDocGen.cpp (69%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/OpFormatGen.cpp (96%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/OpGenHelpers.cpp (61%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/OpInterfacesGen.cpp (70%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/OpPythonBindingGen.cpp (97%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/PassCAPIGen.cpp (50%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/PassDocGen.cpp (75%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/PassGen.cpp (79%)
rename mlir/{tools/mlir-tblgen => lib/TableGen/Generators}/RewriterGen.cpp (86%)
delete mode 100644 mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
delete mode 100644 mlir/tools/mlir-tblgen/DialectGenUtilities.h
create mode 100644 mlir/tools/mlir-tblgen/Generators.cpp
delete mode 100644 mlir/tools/mlir-tblgen/OpFormatGen.h
diff --git a/mlir/include/mlir/TableGen/Generators/AttrOrTypeDefGen.h b/mlir/include/mlir/TableGen/Generators/AttrOrTypeDefGen.h
new file mode 100644
index 0000000000000..87f1d6420ef38
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/AttrOrTypeDefGen.h
@@ -0,0 +1,233 @@
+//===- AttrOrTypeDefGen.h - AttrDef/TypeDef code generator ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares classes and functions for generating C++ definitions and
+// declarations for MLIR attribute and type definitions from TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_ATTRORTYPEDEFGEN_H
+#define MLIR_TABLEGEN_GENERATORS_ATTRORTYPEDEFGEN_H
+
+#include "mlir/TableGen/AttrOrTypeDef.h"
+#include "mlir/TableGen/Class.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/raw_ostream.h"
+#include <optional>
+#include <vector>
+
+namespace llvm {
+class Record;
+class RecordKeeper;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+//===----------------------------------------------------------------------===//
+// AttrOrTypeDefEmitter
+//===----------------------------------------------------------------------===//
+
+/// Generates C++ class declarations and definitions for a single
+/// attribute or type definition derived from an AttrOrTypeDef TableGen record.
+class AttrOrTypeDefEmitter {
+public:
+ /// Create the attribute or type class. If fatalOnError is true, assembly
+ /// format parse failures are reported as fatal errors.
+ AttrOrTypeDefEmitter(const AttrOrTypeDef &def, bool fatalOnError = true);
+
+ virtual ~AttrOrTypeDefEmitter() = default;
+
+ void emitDecl(llvm::raw_ostream &os) const;
+ void emitDef(llvm::raw_ostream &os) const;
+
+protected:
+ /// Add traits from the TableGen definition to the class.
+ virtual void createParentWithTraits();
+ /// Emit top-level declarations: using declarations and any extra class
+ /// declarations.
+ virtual void emitTopLevelDeclarations();
+ /// Emit the function that returns the type or attribute name.
+ virtual void emitName();
+ /// Emit the dialect name as a static member variable.
+ virtual void emitDialectName();
+ /// Emit attribute or type builders.
+ virtual void emitBuilders();
+ /// Emit a verifier declaration for custom verification (impl. provided by
+ /// the users).
+ virtual void emitVerifierDecl();
+ /// Emit a verifier that checks type constraints.
+ virtual void emitInvariantsVerifierImpl();
+ /// Emit an entry point for verification that calls the invariants and
+ /// custom verifier.
+ virtual void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
+ /// Emit parsers and printers.
+ virtual void emitParserPrinter();
+ /// Emit parameter accessors, if required.
+ virtual void emitAccessors();
+ /// Emit interface methods.
+ virtual void emitInterfaceMethods();
+
+ //===--------------------------------------------------------------------===//
+ // Builder Emission
+
+ /// Emit the default builder `Attribute::get`.
+ virtual void emitDefaultBuilder();
+ /// Emit the checked builder `Attribute::getChecked`.
+ virtual void emitCheckedBuilder();
+ /// Emit a custom builder.
+ virtual void emitCustomBuilder(const AttrOrTypeBuilder &builder);
+ /// Emit a checked custom builder.
+ virtual void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
+
+ //===--------------------------------------------------------------------===//
+ // Interface Method Emission
+
+ /// Emit methods for a trait.
+ virtual void emitTraitMethods(const InterfaceTrait &trait);
+ /// Emit a trait method.
+ virtual void emitTraitMethod(const InterfaceMethod &method);
+ /// Generate a using declaration for a trait method.
+ virtual void genTraitMethodUsingDecl(const InterfaceTrait &trait,
+ const InterfaceMethod &method);
+
+ //===--------------------------------------------------------------------===//
+ // OpAsm{Type,Attr}Interface Default Method Emission
+
+ /// Emit 'getAlias' method using mnemonic as alias.
+ virtual void emitMnemonicAliasMethod();
+
+ //===--------------------------------------------------------------------===//
+ // Storage Class Emission
+ virtual void emitStorageClass();
+ /// Generate the storage class constructor.
+ virtual void emitStorageConstructor();
+ /// Emit the key type `KeyTy`.
+ virtual void emitKeyType();
+ /// Emit the equality comparison operator.
+ virtual void emitEquals();
+ /// Emit the key hash function.
+ virtual void emitHashKey();
+ /// Emit the function to construct the storage class.
+ virtual void emitConstruct();
+
+ //===--------------------------------------------------------------------===//
+ // Utility Function Declarations
+
+ /// Get the method parameters for a def builder, where the first several
+ /// parameters may be different.
+ SmallVector<MethodParameter>
+ getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
+
+ //===--------------------------------------------------------------------===//
+ // Class fields
+
+ /// The attribute or type definition.
+ const AttrOrTypeDef &def;
+ /// The list of attribute or type parameters.
+ ArrayRef<AttrOrTypeParameter> params;
+ /// The attribute or type class.
+ Class defCls;
+ /// An optional attribute or type storage class. The storage class will
+ /// exist if and only if the def has more than zero parameters.
+ std::optional<Class> storageCls;
+
+ /// The C++ base value of the def, either "Attribute" or "Type".
+ StringRef valueType;
+ /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
+ StringRef defType;
+
+ /// The set of using declarations for trait methods.
+ llvm::StringSet<> interfaceUsingNames;
+
+ /// Whether assembly format parse failures are fatal errors.
+ bool fatalOnError;
+};
+
+//===----------------------------------------------------------------------===//
+// AttrTypeDefGenerator
+//===----------------------------------------------------------------------===//
+
+/// Base generator for processing TableGen attr/type definitions.
+class AttrTypeDefGenerator {
+public:
+ virtual ~AttrTypeDefGenerator() = default;
+
+ virtual bool emitDecls(llvm::StringRef selectedDialect);
+ virtual bool emitDefs(llvm::StringRef selectedDialect);
+
+protected:
+ AttrTypeDefGenerator(llvm::ArrayRef<const llvm::Record *> defs,
+ llvm::raw_ostream &os, llvm::StringRef defType,
+ llvm::StringRef valueType, bool isAttrGenerator,
+ bool fatalOnError = true);
+
+ /// Emit the list of def type names.
+ virtual void emitTypeDefList(llvm::ArrayRef<AttrOrTypeDef> defs);
+ /// Emit the code to dispatch between different defs during parsing/printing.
+ virtual void emitParsePrintDispatch(llvm::ArrayRef<AttrOrTypeDef> defs);
+
+ /// The set of def records to emit.
+ std::vector<const llvm::Record *> defRecords;
+ /// The stream to emit to.
+ llvm::raw_ostream &os;
+ /// The prefix of the tablegen def name, e.g. Attr or Type.
+ llvm::StringRef defType;
+ /// The C++ base value type of the def, e.g. Attribute or Type.
+ llvm::StringRef valueType;
+ /// Flag indicating if this generator is for Attributes. False if the
+ /// generator is for types.
+ bool isAttrGenerator;
+ /// Whether assembly format parse failures are fatal errors.
+ bool fatalOnError;
+};
+
+/// A specialized generator for AttrDefs.
+struct AttrDefGenerator : public AttrTypeDefGenerator {
+ AttrDefGenerator(const llvm::RecordKeeper &records, llvm::raw_ostream &os,
+ bool fatalOnError = true);
+};
+
+/// A specialized generator for TypeDefs.
+struct TypeDefGenerator : public AttrTypeDefGenerator {
+ TypeDefGenerator(const llvm::RecordKeeper &records, llvm::raw_ostream &os,
+ bool fatalOnError = true);
+};
+
+//===----------------------------------------------------------------------===//
+// Constraint Functions
+//===----------------------------------------------------------------------===//
+
+/// Emit declarations for all type constraints in records that have a C++
+/// function name set.
+void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
+ llvm::raw_ostream &os);
+
+/// Emit declarations for all attribute constraints in records that have a
+/// C++ function name set.
+void emitAttrConstraintDecls(const llvm::RecordKeeper &records,
+ llvm::raw_ostream &os);
+
+/// Emit definitions for all type constraints in records that have a C++
+/// function name set.
+void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
+ llvm::raw_ostream &os);
+
+/// Emit definitions for all attribute constraints in records that have a
+/// C++ function name set.
+void emitAttrConstraintDefs(const llvm::RecordKeeper &records,
+ llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_ATTRORTYPEDEFGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/AttrOrTypeFormatGen.h b/mlir/include/mlir/TableGen/Generators/AttrOrTypeFormatGen.h
new file mode 100644
index 0000000000000..c2592a3cd9e96
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/AttrOrTypeFormatGen.h
@@ -0,0 +1,242 @@
+//===- AttrOrTypeFormatGen.h - Attr/type format generator -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_ATTRORTYPEFORMATGEN_H
+#define MLIR_TABLEGEN_GENERATORS_ATTRORTYPEFORMATGEN_H
+
+#include "mlir/TableGen/AttrOrTypeDef.h"
+#include "mlir/TableGen/Class.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/Generators/FormatGen.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/SourceMgr.h"
+#include <vector>
+
+namespace mlir {
+namespace tblgen {
+
+//===----------------------------------------------------------------------===//
+// ParameterElement
+//===----------------------------------------------------------------------===//
+
+/// Represents a variable element referring to an attribute or type parameter.
+class ParameterElement
+ : public VariableElementBase<VariableElement::Parameter> {
+public:
+ ParameterElement(AttrOrTypeParameter param) : param(param) {}
+
+ /// Get the parameter in the element.
+ const AttrOrTypeParameter &getParam() const { return param; }
+
+ /// Indicate if this variable is printed "qualified" (that is it is
+ /// prefixed with the `#dialect.mnemonic`).
+ bool shouldBeQualified() { return shouldBeQualifiedFlag; }
+ void setShouldBeQualified(bool qualified = true) {
+ shouldBeQualifiedFlag = qualified;
+ }
+
+ /// Returns true if the element contains an optional parameter.
+ bool isOptional() const { return param.isOptional(); }
+
+ /// Returns the name of the parameter.
+ llvm::StringRef getName() const { return param.getName(); }
+
+ /// Return the code to check whether the parameter is present.
+ auto genIsPresent(FmtContext &ctx, const llvm::Twine &self) const {
+ assert(isOptional() && "cannot guard on a mandatory parameter");
+ std::string valueStr = tgfmt(*param.getDefaultValue(), &ctx).str();
+ ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
+ return tgfmt(getParam().getComparator(), &ctx);
+ }
+
+ /// Generate the code to check whether the parameter should be printed.
+ MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const;
+
+private:
+ bool shouldBeQualifiedFlag = false;
+ AttrOrTypeParameter param;
+};
+
+//===----------------------------------------------------------------------===//
+// ParamsDirective
+//===----------------------------------------------------------------------===//
+
+/// Represents a `params` directive that refers to all parameters of an
+/// attribute or type.
+class ParamsDirective
+ : public VectorDirectiveBase<DirectiveElement::Params, ParameterElement *> {
+public:
+ using Base::Base;
+
+ /// Returns true if there are optional parameters present.
+ bool hasOptionalElements() const;
+};
+
+//===----------------------------------------------------------------------===//
+// StructDirective
+//===----------------------------------------------------------------------===//
+
+/// Represents a `struct` directive that generates a struct format.
+class StructDirective
+ : public VectorDirectiveBase<DirectiveElement::Struct, FormatElement *> {
+public:
+ using Base::Base;
+
+ /// Returns true if there are optional format elements present.
+ bool hasOptionalElements() const;
+};
+
+//===----------------------------------------------------------------------===//
+// AttrTypeDefFormat
+//===----------------------------------------------------------------------===//
+
+/// Holds the parsed assembly format for an attribute or type and generates
+/// parser and printer code.
+class AttrTypeDefFormat {
+public:
+ AttrTypeDefFormat(const AttrOrTypeDef &def,
+ std::vector<FormatElement *> &&elements)
+ : def(def), elements(std::move(elements)) {}
+
+ virtual ~AttrTypeDefFormat() = default;
+
+ /// Generate the attribute or type parser.
+ virtual void genParser(MethodBody &os);
+ /// Generate the attribute or type printer.
+ virtual void genPrinter(MethodBody &os);
+
+protected:
+ /// Generate the parser code for a specific format element.
+ void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
+ /// Generate the parser code for a literal.
+ void genLiteralParser(llvm::StringRef value, FmtContext &ctx, MethodBody &os,
+ bool isOptional = false);
+ /// Generate the parser code for a variable.
+ void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
+ /// Generate the parser code for a `params` directive.
+ virtual void genParamsParser(ParamsDirective *el, FmtContext &ctx,
+ MethodBody &os);
+ /// Generate the parser code for a `struct` directive.
+ virtual void genStructParser(StructDirective *el, FmtContext &ctx,
+ MethodBody &os);
+ /// Generate the parser code for a `custom` directive.
+ virtual void genCustomParser(CustomDirective *el, FmtContext &ctx,
+ MethodBody &os, bool isOptional = false);
+ /// Generate the parser code for an optional group.
+ void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
+ MethodBody &os);
+
+ /// Generate the printer code for a specific format element.
+ void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
+ /// Generate the printer code for a literal.
+ void genLiteralPrinter(llvm::StringRef value, FmtContext &ctx,
+ MethodBody &os);
+ /// Generate the printer code for a variable.
+ void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
+ bool skipGuard = false);
+ /// Generate a printer for comma-separated format elements.
+ void genCommaSeparatedPrinter(
+ llvm::ArrayRef<FormatElement *> params, FmtContext &ctx, MethodBody &os,
+ llvm::function_ref<void(FormatElement *)> extra,
+ llvm::function_ref<void(FormatElement *)> extraPost = nullptr);
+ /// Generate the printer code for a `params` directive.
+ virtual void genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
+ MethodBody &os);
+ /// Generate the printer code for a `struct` directive.
+ virtual void genStructPrinter(StructDirective *el, FmtContext &ctx,
+ MethodBody &os);
+ /// Generate the printer code for a `custom` directive.
+ virtual void genCustomPrinter(CustomDirective *el, FmtContext &ctx,
+ MethodBody &os);
+ /// Generate the printer code for an optional group.
+ void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
+ MethodBody &os);
+ /// Generate a printer (or space eraser) for a whitespace element.
+ void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
+ MethodBody &os);
+
+ /// The ODS definition of the attribute or type whose format is being used to
+ /// generate a parser and printer.
+ const AttrOrTypeDef &def;
+ /// The list of top-level format elements returned by the assembly format
+ /// parser.
+ std::vector<FormatElement *> elements;
+
+ /// Flags for printing spaces.
+ bool shouldEmitSpace = false;
+ bool lastWasPunctuation = false;
+};
+
+//===----------------------------------------------------------------------===//
+// AttrTypeDefFormatParser
+//===----------------------------------------------------------------------===//
+
+/// Parser for attribute and type assembly formats.
+class AttrTypeDefFormatParser : public FormatParser {
+public:
+ AttrTypeDefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
+ : FormatParser(mgr, def.getLoc()[0]), def(def),
+ seenParams(def.getNumParameters()) {}
+
+ /// Parse the attribute or type format and create the format elements.
+ FailureOr<AttrTypeDefFormat> parse();
+
+protected:
+ /// Verify the parsed elements.
+ LogicalResult verify(SMLoc loc,
+ llvm::ArrayRef<FormatElement *> elements) override;
+ /// Verify the elements of a custom directive.
+ LogicalResult verifyCustomDirectiveArguments(
+ SMLoc loc, llvm::ArrayRef<FormatElement *> arguments) override;
+ /// Verify the elements of an optional group.
+ LogicalResult
+ verifyOptionalGroupElements(SMLoc loc,
+ llvm::ArrayRef<FormatElement *> elements,
+ FormatElement *anchor) override;
+ /// Verify the arguments to a struct directive.
+ LogicalResult
+ verifyStructArguments(SMLoc loc, llvm::ArrayRef<FormatElement *> arguments);
+
+ LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
+
+ /// Parse an attribute or type variable.
+ FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, llvm::StringRef name,
+ Context ctx) override;
+ /// Parse an attribute or type format directive.
+ FailureOr<FormatElement *>
+ parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
+
+private:
+ /// Parse a `params` directive.
+ FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
+ /// Parse a `struct` directive.
+ FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
+
+ /// Attribute or type tablegen def.
+ const AttrOrTypeDef &def;
+
+ /// Seen attribute or type parameters.
+ llvm::BitVector seenParams;
+};
+
+//===----------------------------------------------------------------------===//
+// Interface
+//===----------------------------------------------------------------------===//
+
+/// Generate a parser and printer based on a custom assembly format for an
+/// attribute or type. If fatalOnError is true, a parse failure is a fatal
+/// error; otherwise it is silently ignored.
+void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser,
+ MethodBody &printer, bool fatalOnError = true);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_ATTRORTYPEFORMATGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/BytecodeDialectGen.h b/mlir/include/mlir/TableGen/Generators/BytecodeDialectGen.h
new file mode 100644
index 0000000000000..495f76a08d56d
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/BytecodeDialectGen.h
@@ -0,0 +1,30 @@
+//===- BytecodeDialectGen.h - Dialect bytecode read/writer gen -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_BYTECODEDIALECTGEN_H
+#define MLIR_TABLEGEN_GENERATORS_BYTECODEDIALECTGEN_H
+
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class RecordKeeper;
+class raw_ostream;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Emit bytecode dialect readers/writers. If dialectName is non-empty,
+/// only emit code for that dialect.
+bool emitBytecodeDialect(const llvm::RecordKeeper &records,
+ llvm::StringRef dialectName, llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_BYTECODEDIALECTGEN_H
diff --git a/mlir/tools/mlir-tblgen/CppGenUtilities.h b/mlir/include/mlir/TableGen/Generators/CppGenUtilities.h
similarity index 60%
rename from mlir/tools/mlir-tblgen/CppGenUtilities.h
rename to mlir/include/mlir/TableGen/Generators/CppGenUtilities.h
index 69d8cd85ecf70..22b12e3a06638 100644
--- a/mlir/tools/mlir-tblgen/CppGenUtilities.h
+++ b/mlir/include/mlir/TableGen/Generators/CppGenUtilities.h
@@ -1,4 +1,4 @@
-//===- CppGenUtilities.h - MLIR cpp gen utilities ---------------*- C++ -*-===//
+//===- CppGenUtilities.h - MLIR C++ gen utilities ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,26 +6,30 @@
//
//===----------------------------------------------------------------------===//
//
-// This file defines common utilities for generating cpp files from tablegen
+// This file declares common utilities for generating C++ files from TableGen
// structures.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TOOLS_MLIRTBLGEN_CPPGENUTILITIES_H_
-#define MLIR_TOOLS_MLIRTBLGEN_CPPGENUTILITIES_H_
+#ifndef MLIR_TABLEGEN_GENERATORS_CPPGENUTILITIES_H
+#define MLIR_TABLEGEN_GENERATORS_CPPGENUTILITIES_H
#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/raw_ostream.h"
+
+namespace llvm {
+class raw_ostream;
+} // namespace llvm
namespace mlir {
namespace tblgen {
-// Emit the summary and description as a C++ comment. If `terminateComment` is
-// true, terminates the comment with a `\n`.
+/// Emit the summary and description as a C++ comment. If terminateComment
+/// is true, terminates the comment with a newline.
void emitSummaryAndDescComments(llvm::raw_ostream &os, llvm::StringRef summary,
llvm::StringRef description,
bool terminateComment = true);
+
} // namespace tblgen
} // namespace mlir
-#endif // MLIR_TOOLS_MLIRTBLGEN_CPPGENUTILITIES_H_
+#endif // MLIR_TABLEGEN_GENERATORS_CPPGENUTILITIES_H
diff --git a/mlir/include/mlir/TableGen/Generators/DialectGen.h b/mlir/include/mlir/TableGen/Generators/DialectGen.h
new file mode 100644
index 0000000000000..bbb1bf5c20d62
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/DialectGen.h
@@ -0,0 +1,67 @@
+//===- DialectGen.h - MLIR dialect C++ generation utilities -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares utilities for generating C++ definitions for dialects from
+// TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_DIALECTGEN_H
+#define MLIR_TABLEGEN_GENERATORS_DIALECTGEN_H
+
+#include "mlir/TableGen/Dialect.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/TableGen/Record.h"
+#include <optional>
+#include <string>
+#include <utility>
+
+namespace llvm {
+class RecordKeeper;
+class raw_ostream;
+class DagInit;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Populate discardableAttributes from the given DagInit of discardable
+/// attribute descriptors.
+void populateDiscardableAttributes(
+ Dialect &dialect, const llvm::DagInit *discardableAttrDag,
+ llvm::SmallVectorImpl<std::pair<std::string, std::string>>
+ &discardableAttributes);
+
+/// Find the dialect to generate from dialects. If selectedDialect is
+/// empty, the dialect is auto-detected (succeeds only when exactly one dialect
+/// is present). Returns std::nullopt and prints an error on failure.
+std::optional<Dialect> findDialectToGenerate(llvm::ArrayRef<Dialect> dialects,
+ llvm::StringRef selectedDialect);
+
+/// Emit the C++ class declaration for dialect.
+void emitDialectDecl(Dialect &dialect, llvm::raw_ostream &os);
+
+/// Emit the C++ class declarations for all dialects in records, selecting
+/// the one identified by selectedDialect.
+bool emitDialectDecls(const llvm::RecordKeeper &records,
+ llvm::StringRef selectedDialect, llvm::raw_ostream &os);
+
+/// Emit the C++ constructor and destructor definitions for dialect.
+void emitDialectDef(Dialect &dialect, const llvm::RecordKeeper &records,
+ llvm::raw_ostream &os);
+
+/// Emit the C++ constructor and destructor definitions for all dialects in
+/// records, selecting the one identified by selectedDialect.
+bool emitDialectDefs(const llvm::RecordKeeper &records,
+ llvm::StringRef selectedDialect, llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_DIALECTGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/DialectInterfacesGen.h b/mlir/include/mlir/TableGen/Generators/DialectInterfacesGen.h
new file mode 100644
index 0000000000000..241155986637d
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/DialectInterfacesGen.h
@@ -0,0 +1,55 @@
+//===- DialectInterfacesGen.h - Dialect interface generator -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_DIALECTINTERFACESGEN_H
+#define MLIR_TABLEGEN_GENERATORS_DIALECTINTERFACESGEN_H
+
+#include "mlir/TableGen/Interfaces.h"
+#include "llvm/Support/raw_ostream.h"
+#include <vector>
+
+namespace llvm {
+class Record;
+class RecordKeeper;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Get all DialectInterface definitions from the given records, excluding those
+/// defined outside the top-level file.
+std::vector<const llvm::Record *>
+getAllDialectInterfaceDefinitions(const llvm::RecordKeeper &records);
+
+//===----------------------------------------------------------------------===//
+// DialectInterfaceGenerator
+//===----------------------------------------------------------------------===//
+
+/// Generator for dialect interface declarations from TableGen records.
+class DialectInterfaceGenerator {
+public:
+ DialectInterfaceGenerator(const llvm::RecordKeeper &records,
+ llvm::raw_ostream &os);
+
+ virtual ~DialectInterfaceGenerator() = default;
+
+ virtual bool emitInterfaceDecls();
+
+protected:
+ virtual void emitInterfaceDecl(const DialectInterface &interface);
+
+ /// The set of interface records to emit.
+ std::vector<const llvm::Record *> defs;
+ /// The stream to emit to.
+ llvm::raw_ostream &os;
+};
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_DIALECTINTERFACESGEN_H
diff --git a/mlir/tools/mlir-tblgen/DocGenUtilities.h b/mlir/include/mlir/TableGen/Generators/DocGenUtilities.h
similarity index 55%
rename from mlir/tools/mlir-tblgen/DocGenUtilities.h
rename to mlir/include/mlir/TableGen/Generators/DocGenUtilities.h
index dd1dbbe243911..6dedbce207c45 100644
--- a/mlir/tools/mlir-tblgen/DocGenUtilities.h
+++ b/mlir/include/mlir/TableGen/Generators/DocGenUtilities.h
@@ -6,13 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
-// This file defines common utilities for generating documents from tablegen
-// structures.
+// This file declares common utilities for generating documentation from
+// TableGen structures.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TOOLS_MLIRTBLGEN_DOCGENUTILITIES_H_
-#define MLIR_TOOLS_MLIRTBLGEN_DOCGENUTILITIES_H_
+#ifndef MLIR_TABLEGEN_GENERATORS_DOCGENUTILITIES_H
+#define MLIR_TABLEGEN_GENERATORS_DOCGENUTILITIES_H
#include "llvm/ADT/StringRef.h"
@@ -23,23 +23,23 @@ class raw_ostream;
namespace mlir {
namespace tblgen {
-// Emit the summary. To avoid confusion, the summary is styled differently from
-// the description.
+/// Emit the summary. To avoid confusion, the summary is styled differently
+/// from the description.
void emitSummary(llvm::StringRef summary, llvm::raw_ostream &os);
-// Emit the description by aligning the text to the left per line (e.g.
-// removing the minimum indentation across the block).
-//
-// This expects that the description in the tablegen file is already formatted
-// in a way the user wanted but has some additional indenting due to being
-// nested.
+/// Emit the description by aligning the text to the left per line (e.g.
+/// removing the minimum indentation across the block).
+///
+/// This expects that the description in the tablegen file is already formatted
+/// in a way the user wanted but has some additional indenting due to being
+/// nested.
void emitDescription(llvm::StringRef description, llvm::raw_ostream &os);
-// Emit the description as a C++ comment while realigning it.
+/// Emit the description as a C++ comment while realigning it.
void emitDescriptionComment(llvm::StringRef description, llvm::raw_ostream &os,
llvm::StringRef prefix = "");
} // namespace tblgen
} // namespace mlir
-#endif // MLIR_TOOLS_MLIRTBLGEN_DOCGENUTILITIES_H_
+#endif // MLIR_TABLEGEN_GENERATORS_DOCGENUTILITIES_H
diff --git a/mlir/include/mlir/TableGen/Generators/EnumPythonBindingGen.h b/mlir/include/mlir/TableGen/Generators/EnumPythonBindingGen.h
new file mode 100644
index 0000000000000..29e1bef280d6a
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/EnumPythonBindingGen.h
@@ -0,0 +1,30 @@
+//===- EnumPythonBindingGen.h - Python enum bindings generator --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_ENUMPYTHONBINDINGGEN_H
+#define MLIR_TABLEGEN_GENERATORS_ENUMPYTHONBINDINGGEN_H
+
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class RecordKeeper;
+class raw_ostream;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Emit Python bindings for enum attributes. If dialectName is non-empty,
+/// only emit builders for EnumAttr records belonging to that dialect.
+bool emitPythonEnums(const llvm::RecordKeeper &records,
+ llvm::StringRef dialectName, llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_ENUMPYTHONBINDINGGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/EnumsGen.h b/mlir/include/mlir/TableGen/Generators/EnumsGen.h
new file mode 100644
index 0000000000000..cbe973e507a87
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/EnumsGen.h
@@ -0,0 +1,42 @@
+//===- EnumsGen.h - Enum utility generator ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares functions for generating enum utility declarations and
+// definitions from TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_ENUMSGEN_H
+#define MLIR_TABLEGEN_GENERATORS_ENUMSGEN_H
+
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm {
+class Record;
+class RecordKeeper;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Emit declarations for a single enum defined by enumDef.
+void emitEnumDecl(const llvm::Record &enumDef, llvm::raw_ostream &os);
+
+/// Emit declarations for all enums in records.
+bool emitEnumDecls(const llvm::RecordKeeper &records, llvm::raw_ostream &os);
+
+/// Emit definitions for a single enum defined by enumDef.
+void emitEnumDef(const llvm::Record &enumDef, llvm::raw_ostream &os);
+
+/// Emit definitions for all enums in records.
+bool emitEnumDefs(const llvm::RecordKeeper &records, llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_ENUMSGEN_H
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/include/mlir/TableGen/Generators/FormatGen.h
similarity index 88%
rename from mlir/tools/mlir-tblgen/FormatGen.h
rename to mlir/include/mlir/TableGen/Generators/FormatGen.h
index 8e7d49bb37e71..2bc80d2580b26 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/include/mlir/TableGen/Generators/FormatGen.h
@@ -11,14 +11,14 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
-#define MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
+#ifndef MLIR_TABLEGEN_GENERATORS_FORMATGEN_H
+#define MLIR_TABLEGEN_GENERATORS_FORMATGEN_H
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Allocator.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SMLoc.h"
#include <vector>
@@ -81,24 +81,22 @@ class FormatToken {
string,
};
- FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
+ FormatToken(Kind kind, StringRef spelling);
/// Return the bytes that make up this token.
- StringRef getSpelling() const { return spelling; }
+ StringRef getSpelling() const;
/// Return the kind of this token.
- Kind getKind() const { return kind; }
+ Kind getKind() const;
/// Return a location for this token.
SMLoc getLoc() const;
/// Returns true if the token is of the given kind.
- bool is(Kind kind) { return getKind() == kind; }
+ bool is(Kind kind);
/// Return if this token is a keyword.
- bool isKeyword() const {
- return getKind() > Kind::keyword_start && getKind() < Kind::keyword_end;
- }
+ bool isKeyword() const;
private:
/// Discriminator that indicates the kind of token this is.
@@ -138,9 +136,7 @@ class FormatLexer {
FormatToken lexString(const char *tokStart);
/// Create a token with the current pointer and a start pointer.
- FormatToken formToken(FormatToken::Kind kind, const char *tokStart) {
- return FormatToken(kind, StringRef(tokStart, curPtr - tokStart));
- }
+ FormatToken formToken(FormatToken::Kind kind, const char *tokStart);
/// The source manager containing the format string.
llvm::SourceMgr &mgr;
@@ -415,30 +411,21 @@ class OptionalElement : public FormatElementBase<FormatElement::Optional> {
OptionalElement(std::vector<FormatElement *> &&thenElements,
std::vector<FormatElement *> &&elseElements,
unsigned thenParseStart, unsigned elseParseStart,
- FormatElement *anchor, bool inverted)
- : thenElements(std::move(thenElements)),
- elseElements(std::move(elseElements)), thenParseStart(thenParseStart),
- elseParseStart(elseParseStart), anchor(anchor), inverted(inverted) {}
+ FormatElement *anchor, bool inverted);
/// Return the `then` elements of the optional group. Drops the first
/// `thenParseStart` whitespace elements if `parseable` is true.
- ArrayRef<FormatElement *> getThenElements(bool parseable = false) const {
- return llvm::ArrayRef(thenElements)
- .drop_front(parseable ? thenParseStart : 0);
- }
+ ArrayRef<FormatElement *> getThenElements(bool parseable = false) const;
/// Return the `else` elements of the optional group. Drops the first
/// `elseParseStart` whitespace elements if `parseable` is true.
- ArrayRef<FormatElement *> getElseElements(bool parseable = false) const {
- return llvm::ArrayRef(elseElements)
- .drop_front(parseable ? elseParseStart : 0);
- }
+ ArrayRef<FormatElement *> getElseElements(bool parseable = false) const;
/// Return the anchor of the optional group.
- FormatElement *getAnchor() const { return anchor; }
+ FormatElement *getAnchor() const;
/// Return true if the optional group is inverted.
- bool isInverted() const { return inverted; }
+ bool isInverted() const;
private:
/// The child elements emitted when the anchor is present.
@@ -491,15 +478,11 @@ class FormatParser {
};
/// Create a format parser with the given source manager and a location.
- explicit FormatParser(llvm::SourceMgr &mgr, llvm::SMLoc loc)
- : lexer(mgr, loc), curToken(lexer.lexToken()) {}
+ explicit FormatParser(llvm::SourceMgr &mgr, llvm::SMLoc loc);
/// Allocate and construct a format element.
template <typename FormatElementT, typename... Args>
FormatElementT *create(Args &&...args) {
- // FormatElementT *ptr = allocator.Allocate<FormatElementT>();
- // ::new (ptr) FormatElementT(std::forward<Args>(args)...);
- // return ptr;
auto mem = std::make_unique<FormatElementT>(std::forward<Args>(args)...);
FormatElementT *ptr = mem.get();
allocator.push_back(std::move(mem));
@@ -518,11 +501,12 @@ class FormatParser {
/// Parse a variable.
FailureOr<FormatElement *> parseVariable(Context ctx);
/// Parse a directive.
- FailureOr<FormatElement *> parseDirective(Context ctx);
+ virtual FailureOr<FormatElement *> parseDirective(Context ctx);
/// Parse an optional group.
FailureOr<FormatElement *> parseOptionalGroup(Context ctx);
/// Parse a custom directive.
- FailureOr<FormatElement *> parseCustomDirective(llvm::SMLoc loc, Context ctx);
+ virtual FailureOr<FormatElement *> parseCustomDirective(llvm::SMLoc loc,
+ Context ctx);
/// Parse a ref directive.
FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context context);
/// Parse a qualified directive.
@@ -560,43 +544,26 @@ class FormatParser {
// Lexer Utilities
/// Emit an error at the given location.
- LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
- lexer.emitError(loc, msg);
- return failure();
- }
+ LogicalResult emitError(llvm::SMLoc loc, const Twine &msg);
/// Emit an error and a note at the given notation.
LogicalResult emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
- const Twine ¬e) {
- lexer.emitErrorAndNote(loc, msg, note);
- return failure();
- }
+ const Twine ¬e);
/// Parse a single token of the expected kind.
- FailureOr<FormatToken> parseToken(FormatToken::Kind kind, const Twine &msg) {
- if (!curToken.is(kind))
- return emitError(curToken.getLoc(), msg);
- FormatToken tok = curToken;
- consumeToken();
- return tok;
- }
+ FailureOr<FormatToken> parseToken(FormatToken::Kind kind, const Twine &msg);
/// Advance the lexer to the next token.
- void consumeToken() {
- assert(!curToken.is(FormatToken::eof) && !curToken.is(FormatToken::error) &&
- "shouldn't advance past EOF or errors");
- curToken = lexer.lexToken();
- }
+ void consumeToken();
/// Get the current token.
- FormatToken peekToken() { return curToken; }
+ FormatToken peekToken();
private:
/// The format parser retains ownership of the format elements in a bump
/// pointer allocator.
// FIXME: FormatElement with `std::vector` need to be converted to use
// trailing objects.
- // llvm::BumpPtrAllocator allocator;
std::vector<std::unique_ptr<FormatElement>> allocator;
/// The format lexer to use.
FormatLexer lexer;
@@ -622,10 +589,7 @@ bool canFormatStringAsKeyword(StringRef value,
bool isValidLiteral(StringRef value,
function_ref<void(Twine)> emitError = nullptr);
-/// Whether a failure in parsing the assembly format should be a fatal error.
-extern llvm::cl::opt<bool> formatErrorIsFatal;
-
} // namespace tblgen
} // namespace mlir
-#endif // MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
+#endif // MLIR_TABLEGEN_GENERATORS_FORMATGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/OpAdaptorHelper.h b/mlir/include/mlir/TableGen/Generators/OpAdaptorHelper.h
new file mode 100644
index 0000000000000..daf82243e50be
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/OpAdaptorHelper.h
@@ -0,0 +1,225 @@
+//===- OpAdaptorHelper.h - Helper for Op/OpAdaptor code gen -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares AttributeMetadata and OpOrAdaptorHelper, which are used
+// to share attribute-access code generation between Op and OpAdaptor emitters.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_OPADAPTORHELPER_H
+#define MLIR_TABLEGEN_GENERATORS_OPADAPTORHELPER_H
+
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Property.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
+#include <functional>
+#include <optional>
+#include <string>
+
+namespace mlir {
+namespace tblgen {
+
+//===----------------------------------------------------------------------===//
+// AttributeMetadata
+//===----------------------------------------------------------------------===//
+
+/// Metadata on a registered attribute. Given that attributes are stored in
+/// sorted order on operations, we can use information from ODS to deduce the
+/// number of required attributes less than and greater than each attribute,
+/// allowing us to search only a subrange of the attributes in ODS-generated
+/// getters.
+struct AttributeMetadata {
+ /// The attribute name.
+ llvm::StringRef attrName;
+ /// Whether the attribute is required.
+ bool isRequired;
+ /// The ODS attribute constraint. Not present for implicit attributes.
+ std::optional<Attribute> constraint;
+ /// The number of required attributes less than this attribute.
+ unsigned lowerBound = 0;
+ /// The number of required attributes greater than this attribute.
+ unsigned upperBound = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// OpOrAdaptorHelper
+//===----------------------------------------------------------------------===//
+
+/// Helper class to select between OpAdaptor and Op code templates for
+/// attribute-access code generation.
+class OpOrAdaptorHelper {
+public:
+ OpOrAdaptorHelper(const Operator &op, bool emitForOp)
+ : op(op), emitForOp(emitForOp) {
+ computeAttrMetadata();
+ }
+
+ /// Object that wraps a functor in a stream operator for interop with
+ /// llvm::formatv.
+ class Formatter {
+ public:
+ template <typename Functor>
+ Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
+
+ std::string str() const {
+ std::string result;
+ llvm::raw_string_ostream os(result);
+ os << *this;
+ return os.str();
+ }
+
+ private:
+ std::function<llvm::raw_ostream &(llvm::raw_ostream &)> func;
+
+ friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const Formatter &fmt) {
+ return fmt.func(os);
+ }
+ };
+
+ /// Generate code for getting an attribute. The definition is in the .cpp
+ /// file because it references a file-local format string constant.
+ Formatter getAttr(llvm::StringRef attrName, bool isNamed = false) const;
+
+ /// Generate code for getting the name of an attribute.
+ Formatter getAttrName(llvm::StringRef attrName) const {
+ return [this, attrName](llvm::raw_ostream &os) -> llvm::raw_ostream & {
+ if (emitForOp)
+ return os << op.getGetterName(attrName) << "AttrName()";
+ return os << llvm::formatv("{0}::{1}AttrName(*odsOpName)",
+ op.getCppClassName(),
+ op.getGetterName(attrName));
+ };
+ }
+
+ /// Get the code snippet for getting the named attribute range.
+ llvm::StringRef getAttrRange() const {
+ return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
+ }
+
+ /// Get the prefix code for emitting an error.
+ Formatter emitErrorPrefix() const {
+ return [this](llvm::raw_ostream &os) -> llvm::raw_ostream & {
+ if (emitForOp)
+ return os << "emitOpError(\"";
+ return os << llvm::formatv("emitError(loc, \"'{0}' op ",
+ op.getOperationName());
+ };
+ }
+
+ /// Get the call to get an operand or segment of operands.
+ Formatter getOperand(unsigned index) const {
+ return [this, index](llvm::raw_ostream &os) -> llvm::raw_ostream & {
+ return os << llvm::formatv(op.getOperand(index).isVariadic()
+ ? "this->getODSOperands({0})"
+ : "(*this->getODSOperands({0}).begin())",
+ index);
+ };
+ }
+
+ /// Get the call to get a result or segment of results.
+ Formatter getResult(unsigned index) const {
+ return [this, index](llvm::raw_ostream &os) -> llvm::raw_ostream & {
+ if (!emitForOp)
+ return os << "<no results should be generated>";
+ return os << llvm::formatv(op.getResult(index).isVariadic()
+ ? "this->getODSResults({0})"
+ : "(*this->getODSResults({0}).begin())",
+ index);
+ };
+ }
+
+ /// Return whether an op instance is available.
+ bool isEmittingForOp() const { return emitForOp; }
+
+ /// Return the ODS operation wrapper.
+ const Operator &getOp() const { return op; }
+
+ /// Get the attribute metadata sorted by name.
+ const llvm::MapVector<llvm::StringRef, AttributeMetadata> &
+ getAttrMetadata() const {
+ return attrMetadata;
+ }
+
+ /// Returns whether to emit a Properties struct for this operation or not.
+ bool hasProperties() const {
+ if (!op.getProperties().empty())
+ return true;
+ return true;
+ }
+
+ /// Returns whether the operation will have a non-empty Properties struct.
+ bool hasNonEmptyPropertiesStruct() const {
+ if (!op.getProperties().empty())
+ return true;
+ if (!hasProperties())
+ return false;
+ if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") ||
+ op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
+ return true;
+ return llvm::any_of(
+ getAttrMetadata(),
+ [](const std::pair<llvm::StringRef, AttributeMetadata> &it) {
+ return !it.second.constraint ||
+ !it.second.constraint->isDerivedAttr();
+ });
+ }
+
+ std::optional<NamedProperty> &getOperandSegmentsSize() {
+ return operandSegmentsSize;
+ }
+
+ std::optional<NamedProperty> &getResultSegmentsSize() {
+ return resultSegmentsSize;
+ }
+
+ uint32_t getOperandSegmentSizesLegacyIndex() {
+ return operandSegmentSizesLegacyIndex;
+ }
+
+ uint32_t getResultSegmentSizesLegacyIndex() {
+ return resultSegmentSizesLegacyIndex;
+ }
+
+private:
+ /// Compute the attribute metadata.
+ void computeAttrMetadata();
+
+ /// The operation ODS wrapper.
+ const Operator &op;
+ /// True if code is being generated for an op, false for an adaptor.
+ const bool emitForOp;
+
+ /// The attribute metadata, mapped by name.
+ llvm::MapVector<llvm::StringRef, AttributeMetadata> attrMetadata;
+
+ std::optional<NamedProperty> operandSegmentsSize;
+ std::string operandSegmentsSizeStorage;
+ std::string operandSegmentsSizeParser;
+ std::optional<NamedProperty> resultSegmentsSize;
+ std::string resultSegmentsSizeStorage;
+ std::string resultSegmentsSizeParser;
+
+ /// Indices storing the position in the emission order of the operand/result
+ /// segment sizes attribute if emitted as part of the properties for legacy
+ /// bytecode encodings (versions less than 6).
+ uint32_t operandSegmentSizesLegacyIndex = 0;
+ uint32_t resultSegmentSizesLegacyIndex = 0;
+
+ /// The number of required attributes.
+ unsigned numRequired;
+};
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_OPADAPTORHELPER_H
diff --git a/mlir/tools/mlir-tblgen/OpClass.h b/mlir/include/mlir/TableGen/Generators/OpClass.h
similarity index 91%
rename from mlir/tools/mlir-tblgen/OpClass.h
rename to mlir/include/mlir/TableGen/Generators/OpClass.h
index 20b96baf868c7..5647b7d338521 100644
--- a/mlir/tools/mlir-tblgen/OpClass.h
+++ b/mlir/include/mlir/TableGen/Generators/OpClass.h
@@ -1,4 +1,4 @@
-//===- OpClass.h - Implementation of an Op Class --------------------------===//
+//===- OpClass.h - Implementation of an Op Class ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TOOLS_MLIRTBLGEN_OPCLASS_H_
-#define MLIR_TOOLS_MLIRTBLGEN_OPCLASS_H_
+#ifndef MLIR_TABLEGEN_GENERATORS_OPCLASS_H
+#define MLIR_TABLEGEN_GENERATORS_OPCLASS_H
#include "mlir/TableGen/Class.h"
@@ -31,7 +31,7 @@ class OpClass : public Class {
/// Add an op trait.
void addTrait(Twine trait) { parent.addTemplateParam(trait.str()); }
- /// The operation class is finalized by calling `Class::finalize` to delcare
+ /// The operation class is finalized by calling `Class::finalize` to declare
/// all pending private and public methods (ops don't have custom constructors
/// or fields). Then, the extra class declarations are appended to the end of
/// the class declaration.
@@ -49,4 +49,4 @@ class OpClass : public Class {
} // namespace tblgen
} // namespace mlir
-#endif // MLIR_TOOLS_MLIRTBLGEN_OPCLASS_H_
+#endif // MLIR_TABLEGEN_GENERATORS_OPCLASS_H
diff --git a/mlir/include/mlir/TableGen/Generators/OpDefinitionsGen.h b/mlir/include/mlir/TableGen/Generators/OpDefinitionsGen.h
new file mode 100644
index 0000000000000..c34f8822925d0
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/OpDefinitionsGen.h
@@ -0,0 +1,305 @@
+//===- OpDefinitionsGen.h - Op definitions generator -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_OPDEFINITIONSGEN_H
+#define MLIR_TABLEGEN_GENERATORS_OPDEFINITIONSGEN_H
+
+#include "mlir/TableGen/Class.h"
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/Generators/OpAdaptorHelper.h"
+#include "mlir/TableGen/Generators/OpClass.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Property.h"
+#include "mlir/TableGen/Trait.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm {
+class Record;
+class RecordKeeper;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+//===----------------------------------------------------------------------===//
+// OpEmitter
+//===----------------------------------------------------------------------===//
+
+/// Generates C++ declarations and definitions for a single operation record.
+class OpEmitter {
+public:
+ using ConstArgument =
+ llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
+
+ /// Emit C++ declarations for op.
+ static void
+ emitDecl(const Operator &op, llvm::raw_ostream &os,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool fatalOnError = true);
+ /// Emit C++ definitions for op.
+ static void
+ emitDef(const Operator &op, llvm::raw_ostream &os,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool fatalOnError = true);
+
+ virtual ~OpEmitter() = default;
+
+protected:
+ OpEmitter(const Operator &op,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool fatalOnError = true);
+
+ void emitDecl(llvm::raw_ostream &os);
+ void emitDef(llvm::raw_ostream &os);
+
+ /// Generate methods for accessing the attribute names of this operation.
+ virtual void genAttrNameGetters();
+
+ /// Generate the OpAsmOpInterface for this operation if possible.
+ virtual void genOpAsmInterface();
+
+ /// Generate the getOperationName method for this op.
+ virtual void genOpNameGetter();
+
+ /// Generate code to manage the properties, if any.
+ virtual void genPropertiesSupport();
+
+ /// Generate code to manage the encoding of properties to bytecode.
+ virtual void genPropertiesSupportForBytecode(
+ llvm::ArrayRef<ConstArgument> attrOrProperties);
+
+ /// Generate getters for the properties.
+ virtual void genPropGetters();
+
+ /// Generate setters for the properties.
+ virtual void genPropSetters();
+
+ /// Generate getters for the attributes.
+ virtual void genAttrGetters();
+
+ /// Generate setters for the attributes.
+ virtual void genAttrSetters();
+
+ /// Generate removers for optional attributes.
+ virtual void genOptionalAttrRemovers();
+
+ /// Generate getters for named operands.
+ virtual void genNamedOperandGetters();
+
+ /// Generate setters for named operands.
+ virtual void genNamedOperandSetters();
+
+ /// Generate getters for named results.
+ virtual void genNamedResultGetters();
+
+ /// Generate getters for named regions.
+ virtual void genNamedRegionGetters();
+
+ /// Generate getters for named successors.
+ virtual void genNamedSuccessorGetters();
+
+ /// Generate the method to populate default attributes.
+ virtual void genPopulateDefaultAttributes();
+
+ /// Generate builder methods for the operation.
+ virtual void genBuilder();
+
+ /// Generate the build() method that takes each operand/attribute as a
+ /// stand-alone parameter.
+ virtual void genSeparateArgParamBuilder();
+ virtual void
+ genInlineCreateBody(const SmallVector<MethodParameter> ¶mList);
+
+ /// Generate the build() method that uses the first operand's type as all
+ /// results' types, with stand-alone parameters.
+ virtual void genUseOperandAsResultTypeSeparateParamBuilder();
+
+ /// The kind of collective builder to generate.
+ enum class CollectiveBuilderKind {
+ /// Inherent attributes/properties are passed by const Properties&.
+ PropStruct,
+ /// Inherent attributes/properties are passed by attribute dictionary.
+ AttrDict,
+ };
+
+ /// Generate the build() method that uses the first operand's type as all
+ /// results' types, with collective parameters.
+ virtual void
+ genUseOperandAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
+
+ /// Generate the build() method that uses inferred types as result types.
+ /// Requires InferTypeOpInterface.
+ virtual void
+ genInferredTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
+
+ /// Generate the build() method that uses the first attribute's type as all
+ /// result types, with collective parameters.
+ virtual void
+ genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
+
+ /// Generate the build() method with collective result-type and
+ /// operand/attribute parameters.
+ virtual void genCollectiveParamBuilder(CollectiveBuilderKind kind);
+
+ /// The kind of parameter to generate for result types in builders.
+ enum class TypeParamKind {
+ /// No result type in the parameter list.
+ None,
+ /// A separate parameter for each result type.
+ Separate,
+ /// An ArrayRef<Type> for all result types.
+ Collective,
+ };
+
+ /// The kind of parameter to generate for attributes in builders.
+ enum class AttrParamKind {
+ /// A wrapped MLIR Attribute instance.
+ WrappedAttr,
+ /// A raw value without MLIR Attribute wrapper.
+ UnwrappedValue,
+ };
+
+ /// Build the parameter list for a build() method. Writes to paramList and
+ /// updates resultTypeNames. inferredAttributes is populated with attributes
+ /// elided from the build list. typeParamKind and attrParamKind control how
+ /// result types and attributes are placed in the parameter list.
+ virtual void
+ buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
+ llvm::StringSet<> &inferredAttributes,
+ SmallVectorImpl<std::string> &resultTypeNames,
+ TypeParamKind typeParamKind,
+ AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
+
+ /// Add op arguments and regions into the operation state for build() methods.
+ virtual void
+ genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
+ llvm::StringSet<> &inferredAttributes,
+ bool isRawValueAttr = false);
+
+ /// Generate canonicalizer declarations for the operation.
+ virtual void genCanonicalizerDecls();
+
+ /// Generate the folder declaration for the operation.
+ virtual void genFolderDecls();
+
+ /// Generate the parser for the operation.
+ virtual void genParser();
+
+ /// Generate the printer for the operation.
+ virtual void genPrinter();
+
+ /// Generate the verify method for the operation.
+ virtual void genVerifier();
+
+ /// Generate custom verify methods for the operation.
+ virtual void genCustomVerifier();
+
+ /// Generate verify statements for operands and results. The generated code
+ /// is attached to body.
+ virtual void genOperandResultVerifier(MethodBody &body,
+ Operator::const_value_range values,
+ llvm::StringRef valueKind);
+
+ /// Generate verify statements for regions. The generated code is attached to
+ /// body.
+ virtual void genRegionVerifier(MethodBody &body);
+
+ /// Generate verify statements for successors. The generated code is attached
+ /// to body.
+ virtual void genSuccessorVerifier(MethodBody &body);
+
+ /// Generate the traits used by the object.
+ virtual void genTraits();
+
+ /// Generate OpInterface methods for all interfaces.
+ virtual void genOpInterfaceMethods();
+
+ /// Generate OpInterface methods for the given interface.
+ virtual void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
+
+ /// Generate an op interface method for the given interface method. If
+ /// declaration is true, generates a declaration, else a definition.
+ virtual Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
+ bool declaration = true);
+
+ /// Generate a using declaration for an op interface method to include the
+ /// default implementation from the interface trait. This is needed when the
+ /// interface defines multiple methods with the same name but some have a
+ /// default implementation and some don't.
+ virtual UsingDeclaration *
+ genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+ const tblgen::InterfaceMethod &method);
+
+ /// Generate the side-effect interface methods.
+ virtual void genSideEffectInterfaceMethods();
+
+ /// Generate the type inference interface methods.
+ virtual void genTypeInterfaceMethods();
+
+ // The TableGen record for this op.
+ // TODO: OpEmitter should not have a Record directly,
+ // it should rather go through the Operator for better abstraction.
+ const llvm::Record &def;
+
+ // The wrapper operator class for querying information from this op.
+ const Operator &op;
+
+ // The C++ code builder for this op.
+ OpClass opClass;
+
+ // The format context for verification code generation.
+ FmtContext verifyCtx;
+
+ // The emitter containing all of the locally emitted verification functions.
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter;
+
+ // Helper for emitting op code.
+ OpOrAdaptorHelper emitHelper;
+
+ // Keep track of interface using declarations generated to avoid duplicates.
+ llvm::StringSet<> interfaceUsingNames;
+
+ // Whether to emit fatal errors or not.
+ bool fatalOnError;
+};
+
+//===----------------------------------------------------------------------===//
+// Top-level entry points
+//===----------------------------------------------------------------------===//
+
+/// Emit op declarations for all op records in defs. If fatalOnError is
+/// true, assembly format parse errors are fatal; otherwise they are ignored.
+bool emitOpDecls(const llvm::RecordKeeper &records,
+ llvm::ArrayRef<const llvm::Record *> defs, unsigned shardCount,
+ llvm::raw_ostream &os, bool fatalOnError = true);
+
+/// Generate the dialect op registration hook and op class definitions for a
+/// shard of ops.
+void emitOpDefShard(const llvm::RecordKeeper &records,
+ llvm::ArrayRef<const llvm::Record *> shardDefs,
+ const Dialect &dialect, unsigned shardIndex,
+ unsigned shardCount, llvm::raw_ostream &os,
+ bool fatalOnError = true);
+
+/// Emit op definitions for all op records in defs. If fatalOnError is
+/// true, assembly format parse errors are fatal; otherwise they are ignored.
+bool emitOpDefs(const llvm::RecordKeeper &records,
+ llvm::ArrayRef<const llvm::Record *> defs, unsigned shardCount,
+ llvm::raw_ostream &os, bool fatalOnError = true);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_OPDEFINITIONSGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/OpDocGen.h b/mlir/include/mlir/TableGen/Generators/OpDocGen.h
new file mode 100644
index 0000000000000..57c20ce0a31d0
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/OpDocGen.h
@@ -0,0 +1,102 @@
+//===- OpDocGen.h - Op/dialect documentation generator ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares functions for generating documentation for MLIR dialects,
+// operations, attributes, types, and enums from TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_OPDOCGEN_H
+#define MLIR_TABLEGEN_GENERATORS_OPDOCGEN_H
+
+#include "mlir/TableGen/AttrOrTypeDef.h"
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/EnumInfo.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/raw_ostream.h"
+#include <optional>
+#include <string>
+#include <vector>
+
+namespace llvm {
+class Record;
+class RecordKeeper;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// A group of operations that share a documentation section.
+struct OpDocGroup {
+ const Dialect &getDialect() const { return ops.front().getDialect(); }
+
+ /// Summary description of the section.
+ std::string summary = "";
+
+ /// Description of the section.
+ llvm::StringRef description = "";
+
+ /// Instances inside the section.
+ std::vector<Operator> ops;
+};
+
+/// Holds all records collected from a dialect relevant for documentation
+/// generation.
+struct DialectRecords {
+ DialectRecords(Dialect dialect, llvm::StringRef inputFilename)
+ : dialect(dialect), inputFilename(inputFilename) {}
+
+ Dialect dialect;
+ std::string inputFilename;
+ std::vector<Attribute> attributes;
+ std::vector<AttrDef> attrDefs;
+ std::vector<OpDocGroup> ops;
+ std::vector<Type> types;
+ std::vector<TypeDef> typeDefs;
+ std::vector<EnumInfo> enums;
+};
+
+/// Collect, filter, and organize all records relevant for dialect documentation
+/// generation. opDefs are the op definitions to include (e.g. filtered by
+/// the caller). dialect is the dialect to collect records for.
+/// keepOpSourceOrder disables alphabetical sorting of ops.
+std::optional<DialectRecords>
+collectRecords(const llvm::RecordKeeper &records,
+ llvm::ArrayRef<const llvm::Record *> opDefs,
+ const Dialect &dialect, bool keepOpSourceOrder);
+
+/// Emit documentation for a single operation. stripPrefix is stripped from
+/// the fully qualified class name. allowHugoSpecificFeatures enables
+/// Hugo-specific markup in attribute descriptions.
+void emitOpDoc(const Operator &op, llvm::StringRef stripPrefix,
+ bool allowHugoSpecificFeatures, llvm::raw_ostream &os);
+
+/// Emit operation documentation for all ops in records.
+bool emitOpDoc(const DialectRecords &records, llvm::StringRef stripPrefix,
+ bool allowHugoSpecificFeatures, llvm::raw_ostream &os);
+
+/// Emit attribute definition documentation for all attrDefs in records.
+bool emitAttrDefDoc(const DialectRecords &records, llvm::raw_ostream &os);
+
+/// Emit type definition documentation for all typeDefs in records.
+bool emitTypeDefDoc(const DialectRecords &records, llvm::raw_ostream &os);
+
+/// Emit enum documentation for all enums in records.
+bool emitEnumDoc(const DialectRecords &records, llvm::raw_ostream &os);
+
+/// Emit full dialect documentation including all ops, attrs, types, and enums.
+bool emitDialectDoc(const DialectRecords &records, llvm::StringRef stripPrefix,
+ bool allowHugoSpecificFeatures, llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_OPDOCGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/OpFormatGen.h b/mlir/include/mlir/TableGen/Generators/OpFormatGen.h
new file mode 100644
index 0000000000000..4863e75b98171
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/OpFormatGen.h
@@ -0,0 +1,180 @@
+//===- OpFormatGen.h - MLIR operation format generator ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the interface for generating parsers and printers from the
+// declarative format.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_OPFORMATGEN_H
+#define MLIR_TABLEGEN_GENERATORS_OPFORMATGEN_H
+
+#include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/Class.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/Generators/FormatGen.h"
+#include "mlir/TableGen/Generators/OpClass.h"
+#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Property.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+#include <optional>
+#include <vector>
+
+namespace mlir {
+namespace tblgen {
+
+//===----------------------------------------------------------------------===//
+// OperationFormat
+//===----------------------------------------------------------------------===//
+
+/// Holds the parsed assembly format for an operation and drives generation of
+/// the corresponding parser and printer methods.
+struct OperationFormat {
+ using ConstArgument =
+ llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
+
+ /// Represents a specific resolver for an operand or result type.
+ class TypeResolution {
+ public:
+ TypeResolution() = default;
+
+ /// Get the index into the buildable types for this type, or std::nullopt.
+ std::optional<int> getBuilderIdx() const { return builderIdx; }
+ void setBuilderIdx(int idx) { builderIdx = idx; }
+
+ /// Get the variable this type is resolved to, or nullptr.
+ const NamedTypeConstraint *getVariable() const {
+ return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
+ }
+ /// Get the attribute this type is resolved to, or nullptr.
+ const NamedAttribute *getAttribute() const {
+ return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
+ }
+ /// Get the transformer for the type of the variable, or std::nullopt.
+ std::optional<llvm::StringRef> getVarTransformer() const {
+ return variableTransformer;
+ }
+ void setResolver(ConstArgument arg,
+ std::optional<llvm::StringRef> transformer) {
+ resolver = arg;
+ variableTransformer = transformer;
+ assert(getVariable() || getAttribute());
+ }
+
+ private:
+ /// If the type is resolved with a buildable type, this is the index into
+ /// 'buildableTypes' in the parent format.
+ std::optional<int> builderIdx;
+ /// If the type is resolved based upon another operand or result, this is
+ /// the variable or the attribute that this type is resolved to.
+ ConstArgument resolver;
+ /// If the type is resolved based upon another operand or result, this is
+ /// a transformer to apply to the variable when resolving.
+ std::optional<llvm::StringRef> variableTransformer;
+ };
+
+ /// The context in which an element is generated.
+ enum class GenContext {
+ /// The element is generated at the top-level or with the same behaviour.
+ Normal,
+ /// The element is generated inside an optional group.
+ Optional
+ };
+
+ OperationFormat(const Operator &op, bool hasProperties);
+ virtual ~OperationFormat() = default;
+
+ /// Generate the operation parser from this format.
+ virtual void genParser(Operator &op, OpClass &opClass);
+ /// Generate the parser code for a specific format element.
+ void genElementParser(FormatElement *element, MethodBody &body,
+ FmtContext &attrTypeCtx,
+ GenContext genCtx = GenContext::Normal);
+ /// Generate the C++ to resolve the types of operands and results during
+ /// parsing.
+ virtual void genParserTypeResolution(Operator &op, MethodBody &body);
+ /// Generate the C++ to resolve the types of the operands during parsing.
+ virtual void genParserOperandTypeResolution(
+ Operator &op, MethodBody &body,
+ llvm::function_ref<void(TypeResolution &, llvm::StringRef)>
+ emitTypeResolver);
+ /// Generate the C++ to resolve regions during parsing.
+ virtual void genParserRegionResolution(Operator &op, MethodBody &body);
+ /// Generate the C++ to resolve successors during parsing.
+ virtual void genParserSuccessorResolution(Operator &op, MethodBody &body);
+ /// Generate the C++ to handle variadic segment size traits.
+ virtual void genParserVariadicSegmentResolution(Operator &op,
+ MethodBody &body);
+
+ /// Generate the operation printer from this format.
+ virtual void genPrinter(Operator &op, OpClass &opClass);
+ /// Generate the printer code for a specific format element.
+ virtual void genElementPrinter(FormatElement *element, MethodBody &body,
+ Operator &op, bool &shouldEmitSpace,
+ bool &lastWasPunctuation);
+
+ /// The various elements in this format.
+ std::vector<FormatElement *> elements;
+
+ /// A flag indicating if all operand/result types were seen. If the format
+ /// contains these, it cannot contain individual type resolvers.
+ bool allOperands = false, allOperandTypes = false, allResultTypes = false;
+
+ /// A flag indicating if this operation infers its result types.
+ bool infersResultTypes = false;
+
+ /// A flag indicating if this operation has the SingleBlockImplicitTerminator
+ /// trait.
+ bool hasImplicitTermTrait;
+
+ /// A flag indicating if this operation has the SingleBlock trait.
+ bool hasSingleBlockTrait;
+
+ /// Indicate whether we need to use properties for the current operator.
+ bool useProperties;
+
+ /// Indicate whether prop-dict is used in the format.
+ bool hasPropDict;
+
+ /// The Operation class name.
+ llvm::StringRef opCppClassName;
+
+ /// A map of buildable types to indices.
+ llvm::MapVector<llvm::StringRef, int, llvm::StringMap<int>> buildableTypes;
+
+ /// The index of the buildable type, if valid, for every operand and result.
+ std::vector<TypeResolution> operandTypes, resultTypes;
+
+ /// The set of attributes explicitly used within the format.
+ llvm::SmallSetVector<const NamedAttribute *, 8> usedAttributes;
+ llvm::StringSet<> inferredAttributes;
+
+ /// The set of properties explicitly used within the format.
+ llvm::SmallSetVector<const NamedProperty *, 8> usedProperties;
+};
+
+//===----------------------------------------------------------------------===//
+// Interface
+//===----------------------------------------------------------------------===//
+
+/// Generate the assembly format for the given operator. If fatalOnError is
+/// true, format parse errors cause the process to exit; otherwise they are
+/// silently ignored.
+void generateOpFormat(const Operator &constOp, OpClass &opClass,
+ bool hasProperties, bool fatalOnError = true);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_OPFORMATGEN_H
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/include/mlir/TableGen/Generators/OpGenHelpers.h
similarity index 53%
rename from mlir/tools/mlir-tblgen/OpGenHelpers.h
rename to mlir/include/mlir/TableGen/Generators/OpGenHelpers.h
index 3cd171d10c08a..b6362517857e6 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.h
+++ b/mlir/include/mlir/TableGen/Generators/OpGenHelpers.h
@@ -10,8 +10,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
-#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
+#ifndef MLIR_TABLEGEN_GENERATORS_OPGENHELPERS_H
+#define MLIR_TABLEGEN_GENERATORS_OPGENHELPERS_H
#include "mlir/Support/LLVM.h"
#include "llvm/TableGen/Record.h"
@@ -20,21 +20,27 @@
namespace mlir {
namespace tblgen {
-/// Returns all the op definitions filtered by the user. The filtering is via
-/// command-line option "op-include-regex" and "op-exclude-regex".
+/// Returns all op definitions from records whose operation name matches the
+/// optional include/exclude regex filters. Pass empty strings to skip
+/// filtering.
std::vector<const llvm::Record *>
-getRequestedOpDefinitions(const llvm::RecordKeeper &records);
+getRequestedOpDefinitions(const llvm::RecordKeeper &records,
+ llvm::StringRef includeRegex,
+ llvm::StringRef excludeRegex);
-/// Checks whether `str` is a Python keyword or would shadow builtin function.
-/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
+/// Checks whether str is a Python keyword or would shadow a builtin
+/// function. Regenerate using:
+/// python -c"print(set(sorted(__import__('keyword').kwlist)))"
bool isPythonReserved(llvm::StringRef str);
-/// Shard the op definitions into the number of shards set by "op-shard-count".
+/// Shard defs into shardCount approximately equal-sized shards and
+/// append them to shardedDefs.
void shardOpDefinitions(
ArrayRef<const llvm::Record *> defs,
- SmallVectorImpl<ArrayRef<const llvm::Record *>> &shardedDefs);
+ SmallVectorImpl<ArrayRef<const llvm::Record *>> &shardedDefs,
+ unsigned shardCount);
} // namespace tblgen
} // namespace mlir
-#endif // MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
+#endif // MLIR_TABLEGEN_GENERATORS_OPGENHELPERS_H
diff --git a/mlir/include/mlir/TableGen/Generators/OpInterfacesGen.h b/mlir/include/mlir/TableGen/Generators/OpInterfacesGen.h
new file mode 100644
index 0000000000000..36e27a8077614
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/OpInterfacesGen.h
@@ -0,0 +1,96 @@
+//===- OpInterfacesGen.h - Op/Attr/Type interface generator -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_OPINTERFACESGEN_H
+#define MLIR_TABLEGEN_GENERATORS_OPINTERFACESGEN_H
+
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/raw_ostream.h"
+#include <vector>
+
+namespace llvm {
+class Record;
+class RecordKeeper;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Get all interface definitions of the given kind, excluding those that
+/// subclass "Declare<kind>InterfaceMethods".
+std::vector<const llvm::Record *>
+getAllInterfaceDefinitions(const llvm::RecordKeeper &records,
+ llvm::StringRef name);
+
+//===----------------------------------------------------------------------===//
+// InterfaceGenerator
+//===----------------------------------------------------------------------===//
+
+/// Base generator for processing TableGen interface definitions.
+class InterfaceGenerator {
+public:
+ virtual ~InterfaceGenerator() = default;
+
+ virtual bool emitInterfaceDefs();
+ virtual bool emitInterfaceDecls();
+ virtual bool emitInterfaceDocs();
+
+protected:
+ InterfaceGenerator(std::vector<const llvm::Record *> &&defs,
+ llvm::raw_ostream &os)
+ : defs(std::move(defs)), os(os) {}
+
+ virtual void emitConceptDecl(const Interface &interface);
+ virtual void emitModelDecl(const Interface &interface);
+ virtual void emitModelMethodsDef(const Interface &interface);
+ virtual void forwardDeclareInterface(const Interface &interface);
+ virtual void emitInterfaceDecl(const Interface &interface);
+ virtual void emitInterfaceTraitDecl(const Interface &interface);
+
+ /// The set of interface records to emit.
+ std::vector<const llvm::Record *> defs;
+ /// The stream to emit to.
+ llvm::raw_ostream &os;
+ /// The C++ value type of the interface, e.g. Operation*.
+ llvm::StringRef valueType;
+ /// The C++ base interface type.
+ llvm::StringRef interfaceBaseType;
+ /// The name of the typename for the value template.
+ llvm::StringRef valueTemplate;
+ /// The name of the substitution variable for the value.
+ llvm::StringRef substVar;
+ /// The format contexts to use for methods.
+ FmtContext nonStaticMethodFmt;
+ FmtContext traitMethodFmt;
+ FmtContext extraDeclsFmt;
+};
+
+/// A specialized generator for attribute interfaces.
+struct AttrInterfaceGenerator : public InterfaceGenerator {
+ AttrInterfaceGenerator(const llvm::RecordKeeper &records,
+ llvm::raw_ostream &os);
+};
+
+/// A specialized generator for operation interfaces.
+struct OpInterfaceGenerator : public InterfaceGenerator {
+ OpInterfaceGenerator(const llvm::RecordKeeper &records,
+ llvm::raw_ostream &os);
+};
+
+/// A specialized generator for type interfaces.
+struct TypeInterfaceGenerator : public InterfaceGenerator {
+ TypeInterfaceGenerator(const llvm::RecordKeeper &records,
+ llvm::raw_ostream &os);
+};
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_OPINTERFACESGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/OpPythonBindingGen.h b/mlir/include/mlir/TableGen/Generators/OpPythonBindingGen.h
new file mode 100644
index 0000000000000..2ed94f8b803b9
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/OpPythonBindingGen.h
@@ -0,0 +1,33 @@
+//===- OpPythonBindingGen.h - Python op bindings generator ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_OPPYTHONBINDINGGEN_H
+#define MLIR_TABLEGEN_GENERATORS_OPPYTHONBINDINGGEN_H
+
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class RecordKeeper;
+class raw_ostream;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Emit Python bindings for all ops belonging to dialectName. If
+/// dialectExtensionName is non-empty, emit an extension binding instead of a
+/// dialect class declaration.
+bool emitPythonOpBindings(const llvm::RecordKeeper &records,
+ llvm::StringRef dialectName,
+ llvm::StringRef dialectExtensionName,
+ llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_OPPYTHONBINDINGGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/PassCAPIGen.h b/mlir/include/mlir/TableGen/Generators/PassCAPIGen.h
new file mode 100644
index 0000000000000..b2e5f44ca0f85
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/PassCAPIGen.h
@@ -0,0 +1,40 @@
+//===- PassCAPIGen.h - MLIR pass C API generation utilities ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares utilities for generating C API bindings for passes from
+// TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_PASSCAPIGEN_H
+#define MLIR_TABLEGEN_GENERATORS_PASSCAPIGEN_H
+
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class RecordKeeper;
+class raw_ostream;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Emit a C header declaring the C API for all passes in records.
+/// prefix is used to namespace the generated function names.
+void emitPassCAPIHeader(const llvm::RecordKeeper &records,
+ llvm::StringRef prefix, llvm::raw_ostream &os);
+
+/// Emit the C implementation for the C API of all passes in records.
+/// prefix is used to namespace the generated function names.
+void emitPassCAPIImpl(const llvm::RecordKeeper &records, llvm::StringRef prefix,
+ llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_PASSCAPIGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/PassDocGen.h b/mlir/include/mlir/TableGen/Generators/PassDocGen.h
new file mode 100644
index 0000000000000..535648a971718
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/PassDocGen.h
@@ -0,0 +1,31 @@
+//===- PassDocGen.h - MLIR pass documentation generation utilities --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares utilities for generating documentation for passes from
+// TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_PASSDOCGEN_H
+#define MLIR_TABLEGEN_GENERATORS_PASSDOCGEN_H
+
+namespace llvm {
+class RecordKeeper;
+class raw_ostream;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Emit documentation for all passes derived from PassBase in records.
+void emitPassDocs(const llvm::RecordKeeper &records, llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_PASSDOCGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/PassGen.h b/mlir/include/mlir/TableGen/Generators/PassGen.h
new file mode 100644
index 0000000000000..bc427ca430cd8
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/PassGen.h
@@ -0,0 +1,68 @@
+//===- PassGen.h - MLIR pass C++ generation utilities ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares utilities for generating C++ code for pass declarations
+// and definitions from TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_PASSGEN_H
+#define MLIR_TABLEGEN_GENERATORS_PASSGEN_H
+
+#include "mlir/TableGen/Pass.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include <vector>
+
+namespace llvm {
+class RecordKeeper;
+class raw_ostream;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Extract the list of passes from the TableGen records.
+std::vector<Pass> getPasses(const llvm::RecordKeeper &records);
+
+/// Emit the struct definition used to set pass options programmatically.
+/// Emits nothing if the pass has no options.
+void emitPassOptionsStruct(const Pass &pass, llvm::raw_ostream &os);
+
+/// Emit the public declarations for a single pass (guarded by
+/// GEN_PASS_DECL_<PASSNAME>).
+void emitPassDecls(const Pass &pass, llvm::raw_ostream &os);
+
+/// Emit the base class definition for a single pass (guarded by
+/// GEN_PASS_DEF_<PASSNAME>).
+void emitPassDefs(const Pass &pass, llvm::raw_ostream &os);
+
+/// Emit the option member declarations for a single pass.
+void emitPassOptionDecls(const Pass &pass, llvm::raw_ostream &os);
+
+/// Emit the statistic member declarations for a single pass.
+void emitPassStatisticDecls(const Pass &pass, llvm::raw_ostream &os);
+
+/// Emit registration code for all passes. groupName is the name of the pass
+/// group used in the generated `register<GroupName>Passes()` function.
+void emitRegistrations(llvm::ArrayRef<Pass> passes, llvm::StringRef groupName,
+ llvm::raw_ostream &os);
+
+/// Emit the complete header content (declarations + definitions) for a single
+/// pass.
+void emitPass(const Pass &pass, llvm::raw_ostream &os);
+
+/// Emit the complete header content for all passes in records, including
+/// registration code. groupName names the generated registration function.
+void emitPasses(const llvm::RecordKeeper &records, llvm::StringRef groupName,
+ llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_PASSGEN_H
diff --git a/mlir/include/mlir/TableGen/Generators/RewriterGen.h b/mlir/include/mlir/TableGen/Generators/RewriterGen.h
new file mode 100644
index 0000000000000..2680413e5a69f
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Generators/RewriterGen.h
@@ -0,0 +1,31 @@
+//===- RewriterGen.h - MLIR pattern rewriter generator --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares utilities for generating C++ pattern rewriters from
+// TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_GENERATORS_REWRITERGEN_H
+#define MLIR_TABLEGEN_GENERATORS_REWRITERGEN_H
+
+namespace llvm {
+class RecordKeeper;
+class raw_ostream;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Emit pattern rewriters for all Pattern definitions in records.
+void emitRewriters(const llvm::RecordKeeper &records, llvm::raw_ostream &os);
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_GENERATORS_REWRITERGEN_H
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index a90c55847718e..dde84c07db795 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -43,6 +43,8 @@ llvm_add_library(MLIRTableGen STATIC
)
set_target_properties(MLIRTableGen PROPERTIES FOLDER "MLIR/Tablegenning")
+add_subdirectory(Generators)
+
mlir_check_all_link_libraries(MLIRTableGen)
add_mlir_library_install(MLIRTableGen)
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/lib/TableGen/Generators/AttrOrTypeDefGen.cpp
similarity index 76%
rename from mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
rename to mlir/lib/TableGen/Generators/AttrOrTypeDefGen.cpp
index 64f35e7fef6d3..fd9ff525cd4bf 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/lib/TableGen/Generators/AttrOrTypeDefGen.cpp
@@ -6,16 +6,15 @@
//
//===----------------------------------------------------------------------===//
-#include "AttrOrTypeFormatGen.h"
-#include "CppGenUtilities.h"
+#include "mlir/TableGen/Generators/AttrOrTypeDefGen.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/Format.h"
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/AttrOrTypeFormatGen.h"
+#include "mlir/TableGen/Generators/CppGenUtilities.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringSet.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/TableGen/CodeGenHelpers.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
@@ -61,135 +60,14 @@ static void collectAllDefs(StringRef selectedDialect,
}
//===----------------------------------------------------------------------===//
-// DefGen
+// AttrOrTypeDefEmitter
//===----------------------------------------------------------------------===//
-namespace {
-class DefGen {
-public:
- /// Create the attribute or type class.
- DefGen(const AttrOrTypeDef &def);
-
- void emitDecl(raw_ostream &os) const {
- if (storageCls && def.genStorageClass()) {
- llvm::NamespaceEmitter ns(os, def.getStorageNamespace());
- os << "struct " << def.getStorageClassName() << ";\n";
- }
- defCls.writeDeclTo(os);
- }
- void emitDef(raw_ostream &os) const {
- if (storageCls && def.genStorageClass()) {
- llvm::NamespaceEmitter ns(os, def.getStorageNamespace());
- storageCls->writeDeclTo(os); // everything is inline
- }
- defCls.writeDefTo(os);
- }
-
-private:
- /// Add traits from the TableGen definition to the class.
- void createParentWithTraits();
- /// Emit top-level declarations: using declarations and any extra class
- /// declarations.
- void emitTopLevelDeclarations();
- /// Emit the function that returns the type or attribute name.
- void emitName();
- /// Emit the dialect name as a static member variable.
- void emitDialectName();
- /// Emit attribute or type builders.
- void emitBuilders();
- /// Emit a verifier declaration for custom verification (impl. provided by
- /// the users).
- void emitVerifierDecl();
- /// Emit a verifier that checks type constraints.
- void emitInvariantsVerifierImpl();
- /// Emit an entry poiunt for verification that calls the invariants and
- /// custom verifier.
- void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
- /// Emit parsers and printers.
- void emitParserPrinter();
- /// Emit parameter accessors, if required.
- void emitAccessors();
- /// Emit interface methods.
- void emitInterfaceMethods();
-
- //===--------------------------------------------------------------------===//
- // Builder Emission
-
- /// Emit the default builder `Attribute::get`
- void emitDefaultBuilder();
- /// Emit the checked builder `Attribute::getChecked`
- void emitCheckedBuilder();
- /// Emit a custom builder.
- void emitCustomBuilder(const AttrOrTypeBuilder &builder);
- /// Emit a checked custom builder.
- void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
-
- //===--------------------------------------------------------------------===//
- // Interface Method Emission
-
- /// Emit methods for a trait.
- void emitTraitMethods(const InterfaceTrait &trait);
- /// Emit a trait method.
- void emitTraitMethod(const InterfaceMethod &method);
- /// Generate a using declaration for a trait method.
- void genTraitMethodUsingDecl(const InterfaceTrait &trait,
- const InterfaceMethod &method);
-
- //===--------------------------------------------------------------------===//
- // OpAsm{Type,Attr}Interface Default Method Emission
-
- /// Emit 'getAlias' method using mnemonic as alias.
- void emitMnemonicAliasMethod();
-
- //===--------------------------------------------------------------------===//
- // Storage Class Emission
- void emitStorageClass();
- /// Generate the storage class constructor.
- void emitStorageConstructor();
- /// Emit the key type `KeyTy`.
- void emitKeyType();
- /// Emit the equality comparison operator.
- void emitEquals();
- /// Emit the key hash function.
- void emitHashKey();
- /// Emit the function to construct the storage class.
- void emitConstruct();
-
- //===--------------------------------------------------------------------===//
- // Utility Function Declarations
-
- /// Get the method parameters for a def builder, where the first several
- /// parameters may be different.
- SmallVector<MethodParameter>
- getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
-
- //===--------------------------------------------------------------------===//
- // Class fields
-
- /// The attribute or type definition.
- const AttrOrTypeDef &def;
- /// The list of attribute or type parameters.
- ArrayRef<AttrOrTypeParameter> params;
- /// The attribute or type class.
- Class defCls;
- /// An optional attribute or type storage class. The storage class will
- /// exist if and only if the def has more than zero parameters.
- std::optional<Class> storageCls;
-
- /// The C++ base value of the def, either "Attribute" or "Type".
- StringRef valueType;
- /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
- StringRef defType;
-
- /// The set of using declarations for trait methods.
- llvm::StringSet<> interfaceUsingNames;
-};
-} // namespace
-
-DefGen::DefGen(const AttrOrTypeDef &def)
+AttrOrTypeDefEmitter::AttrOrTypeDefEmitter(const AttrOrTypeDef &def,
+ bool fatalOnError)
: def(def), params(def.getParameters()), defCls(def.getCppClassName()),
valueType(isa<AttrDef>(def) ? "Attribute" : "Type"),
- defType(isa<AttrDef>(def) ? "Attr" : "Type") {
+ defType(isa<AttrDef>(def) ? "Attr" : "Type"), fatalOnError(fatalOnError) {
// Check that all parameters have names.
for (const AttrOrTypeParameter ¶m : def.getParameters())
if (param.isAnonymous())
@@ -238,7 +116,23 @@ DefGen::DefGen(const AttrOrTypeDef &def)
emitStorageClass();
}
-void DefGen::createParentWithTraits() {
+void AttrOrTypeDefEmitter::emitDecl(llvm::raw_ostream &os) const {
+ if (storageCls && def.genStorageClass()) {
+ llvm::NamespaceEmitter ns(os, def.getStorageNamespace());
+ os << "struct " << def.getStorageClassName() << ";\n";
+ }
+ defCls.writeDeclTo(os);
+}
+
+void AttrOrTypeDefEmitter::emitDef(llvm::raw_ostream &os) const {
+ if (storageCls && def.genStorageClass()) {
+ llvm::NamespaceEmitter ns(os, def.getStorageNamespace());
+ storageCls->writeDeclTo(os); // everything is inline
+ }
+ defCls.writeDefTo(os);
+}
+
+void AttrOrTypeDefEmitter::createParentWithTraits() {
ParentClass defParent(strfmt("::mlir::{0}::{1}Base", valueType, defType));
defParent.addTemplateParam(def.getCppClassName());
defParent.addTemplateParam(def.getCppBaseClassName());
@@ -308,7 +202,7 @@ static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
}
-void DefGen::emitTopLevelDeclarations() {
+void AttrOrTypeDefEmitter::emitTopLevelDeclarations() {
// Inherit constructors from the attribute or type class.
defCls.declare<VisibilityDeclaration>(Visibility::Public);
defCls.declare<UsingDeclaration>("Base::Base");
@@ -320,7 +214,7 @@ void DefGen::emitTopLevelDeclarations() {
std::move(extraDef));
}
-void DefGen::emitName() {
+void AttrOrTypeDefEmitter::emitName() {
StringRef name;
if (auto *attrDef = dyn_cast<AttrDef>(&def)) {
name = attrDef->getAttrName();
@@ -333,14 +227,14 @@ void DefGen::emitName() {
defCls.declare<ExtraClassDeclaration>(std::move(nameDecl));
}
-void DefGen::emitDialectName() {
+void AttrOrTypeDefEmitter::emitDialectName() {
std::string decl =
strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
def.getDialect().getName());
defCls.declare<ExtraClassDeclaration>(std::move(decl));
}
-void DefGen::emitBuilders() {
+void AttrOrTypeDefEmitter::emitBuilders() {
if (!def.skipDefaultBuilders()) {
emitDefaultBuilder();
if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
@@ -353,7 +247,7 @@ void DefGen::emitBuilders() {
}
}
-void DefGen::emitVerifierDecl() {
+void AttrOrTypeDefEmitter::emitVerifierDecl() {
defCls.declareStaticMethod(
"::llvm::LogicalResult", "verify",
getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
@@ -367,7 +261,7 @@ if (!({0})) {
}
)";
-void DefGen::emitInvariantsVerifierImpl() {
+void AttrOrTypeDefEmitter::emitInvariantsVerifierImpl() {
SmallVector<MethodParameter> builderParams = getBuilderParams(
{{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
Method *verifier =
@@ -413,7 +307,8 @@ void DefGen::emitInvariantsVerifierImpl() {
verifier->body() << "return ::mlir::success();";
}
-void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
+void AttrOrTypeDefEmitter::emitInvariantsVerifier(bool hasImpl,
+ bool hasCustomVerifier) {
if (!hasImpl && !hasCustomVerifier)
return;
defCls.declare<UsingDeclaration>("Base::getChecked");
@@ -445,7 +340,7 @@ void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
verifier->body() << "return ::mlir::success();";
}
-void DefGen::emitParserPrinter() {
+void AttrOrTypeDefEmitter::emitParserPrinter() {
auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
"::llvm::StringLiteral", "getMnemonic");
mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic());
@@ -471,10 +366,11 @@ void DefGen::emitParserPrinter() {
MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
// Emit the bodies if we are using the declarative format.
if (hasAssemblyFormat)
- return generateAttrOrTypeFormat(def, parser->body(), printer->body());
+ return generateAttrOrTypeFormat(def, parser->body(), printer->body(),
+ fatalOnError);
}
-void DefGen::emitAccessors() {
+void AttrOrTypeDefEmitter::emitAccessors() {
for (auto ¶m : params) {
Method *m = defCls.addMethod(
param.getCppAccessorType(), param.getAccessorName(),
@@ -487,7 +383,7 @@ void DefGen::emitAccessors() {
}
}
-void DefGen::emitInterfaceMethods() {
+void AttrOrTypeDefEmitter::emitInterfaceMethods() {
for (auto &traitDef : def.getTraits())
if (auto *trait = dyn_cast<InterfaceTrait>(&traitDef))
if (trait->shouldDeclareMethods())
@@ -498,8 +394,8 @@ void DefGen::emitInterfaceMethods() {
// Builder Emission
//===----------------------------------------------------------------------===//
-SmallVector<MethodParameter>
-DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
+SmallVector<MethodParameter> AttrOrTypeDefEmitter::getBuilderParams(
+ std::initializer_list<MethodParameter> prefix) const {
SmallVector<MethodParameter> builderParams;
builderParams.append(prefix.begin(), prefix.end());
for (auto ¶m : params)
@@ -507,7 +403,7 @@ DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
return builderParams;
}
-void DefGen::emitDefaultBuilder() {
+void AttrOrTypeDefEmitter::emitDefaultBuilder() {
Method *m = defCls.addStaticMethod(
def.getCppClassName(), "get",
getBuilderParams({{"::mlir::MLIRContext *", "context"}}));
@@ -517,7 +413,7 @@ void DefGen::emitDefaultBuilder() {
body << ", std::move(" << param.getName() << ")";
}
-void DefGen::emitCheckedBuilder() {
+void AttrOrTypeDefEmitter::emitCheckedBuilder() {
Method *m = defCls.addStaticMethod(
def.getCppClassName(), "getChecked",
getBuilderParams(
@@ -578,7 +474,7 @@ static void emitDuplicatedBuilderError(const Method ¤tMethod,
PrintFatalError(loc, "Failed to generate builder " + methodName);
}
-void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
+void AttrOrTypeDefEmitter::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
// Don't emit a body if there isn't one.
auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
StringRef returnType = def.getCppClassName();
@@ -615,7 +511,8 @@ static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
return str;
}
-void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
+void AttrOrTypeDefEmitter::emitCheckedCustomBuilder(
+ const AttrOrTypeBuilder &builder) {
// Don't emit a body if there isn't one.
auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
StringRef returnType = def.getCppClassName();
@@ -652,7 +549,7 @@ void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
// Interface Method Emission
//===----------------------------------------------------------------------===//
-void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
+void AttrOrTypeDefEmitter::emitTraitMethods(const InterfaceTrait &trait) {
// Get the set of methods that should always be declared.
auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods();
StringSet<> alwaysDeclared;
@@ -673,7 +570,7 @@ void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
}
}
-void DefGen::emitTraitMethod(const InterfaceMethod &method) {
+void AttrOrTypeDefEmitter::emitTraitMethod(const InterfaceMethod &method) {
// All interface methods are declaration-only.
auto props =
method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration;
@@ -684,8 +581,8 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
std::move(params));
}
-void DefGen::genTraitMethodUsingDecl(const InterfaceTrait &trait,
- const InterfaceMethod &method) {
+void AttrOrTypeDefEmitter::genTraitMethodUsingDecl(
+ const InterfaceTrait &trait, const InterfaceMethod &method) {
std::string name = (llvm::Twine(trait.getFullyQualifiedTraitName()) + "<" +
def.getCppClassName() + ">::" + method.getName())
.str();
@@ -696,7 +593,7 @@ void DefGen::genTraitMethodUsingDecl(const InterfaceTrait &trait,
//===----------------------------------------------------------------------===//
// OpAsm{Type,Attr}Interface Default Method Emission
-void DefGen::emitMnemonicAliasMethod() {
+void AttrOrTypeDefEmitter::emitMnemonicAliasMethod() {
// If the mnemonic is not set, there is nothing to do.
if (!def.getMnemonic())
return;
@@ -713,7 +610,7 @@ void DefGen::emitMnemonicAliasMethod() {
// Storage Class Emission
//===----------------------------------------------------------------------===//
-void DefGen::emitStorageConstructor() {
+void AttrOrTypeDefEmitter::emitStorageConstructor() {
Constructor *ctor =
storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
for (auto ¶m : params) {
@@ -722,7 +619,7 @@ void DefGen::emitStorageConstructor() {
}
}
-void DefGen::emitKeyType() {
+void AttrOrTypeDefEmitter::emitKeyType() {
std::string keyType("std::tuple<");
llvm::raw_string_ostream os(keyType);
llvm::interleaveComma(params, os,
@@ -738,7 +635,7 @@ void DefGen::emitKeyType() {
m->body() << ");";
}
-void DefGen::emitEquals() {
+void AttrOrTypeDefEmitter::emitEquals() {
Method *eq = storageCls->addConstMethod<Method::Inline>(
"bool", "operator==", MethodParameter("const KeyTy &", "tblgenKey"));
auto &body = eq->body().indent();
@@ -751,7 +648,7 @@ void DefGen::emitEquals() {
llvm::interleave(llvm::enumerate(params), body, eachFn, ") && (");
}
-void DefGen::emitHashKey() {
+void AttrOrTypeDefEmitter::emitHashKey() {
Method *hash = storageCls->addStaticInlineMethod(
"::llvm::hash_code", "hashKey",
MethodParameter("const KeyTy &", "tblgenKey"));
@@ -762,7 +659,7 @@ void DefGen::emitHashKey() {
});
}
-void DefGen::emitConstruct() {
+void AttrOrTypeDefEmitter::emitConstruct() {
Method *construct = storageCls->addMethod(
strfmt("{0} *", def.getStorageClassName()), "construct",
def.hasStorageCustomConstructor() ? Method::StaticDeclaration
@@ -794,7 +691,7 @@ void DefGen::emitConstruct() {
}
}
-void DefGen::emitStorageClass() {
+void AttrOrTypeDefEmitter::emitStorageClass() {
// Add the appropriate parent class.
storageCls->addParent(strfmt("::mlir::{0}Storage", valueType));
// Add the constructor.
@@ -825,59 +722,20 @@ void DefGen::emitStorageClass() {
}
//===----------------------------------------------------------------------===//
-// DefGenerator
+// AttrTypeDefGenerator
//===----------------------------------------------------------------------===//
-namespace {
-/// This struct is the base generator used when processing tablegen interfaces.
-class DefGenerator {
-public:
- bool emitDecls(StringRef selectedDialect);
- bool emitDefs(StringRef selectedDialect);
-
-protected:
- DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
- StringRef defType, StringRef valueType, bool isAttrGenerator)
- : defRecords(defs), os(os), defType(defType), valueType(valueType),
- isAttrGenerator(isAttrGenerator) {
- // Sort by occurrence in file.
- llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
- return lhs->getID() < rhs->getID();
- });
- }
+AttrDefGenerator::AttrDefGenerator(const RecordKeeper &records, raw_ostream &os,
+ bool fatalOnError)
+ : AttrTypeDefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"),
+ os, "Attr", "Attribute", /*isAttrGenerator=*/true,
+ fatalOnError) {}
- /// Emit the list of def type names.
- void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
- /// Emit the code to dispatch between different defs during parsing/printing.
- void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
-
- /// The set of def records to emit.
- std::vector<const Record *> defRecords;
- /// The attribute or type class to emit.
- /// The stream to emit to.
- raw_ostream &os;
- /// The prefix of the tablegen def name, e.g. Attr or Type.
- StringRef defType;
- /// The C++ base value type of the def, e.g. Attribute or Type.
- StringRef valueType;
- /// Flag indicating if this generator is for Attributes. False if the
- /// generator is for types.
- bool isAttrGenerator;
-};
-
-/// A specialized generator for AttrDefs.
-struct AttrDefGenerator : public DefGenerator {
- AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
- : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
- "Attr", "Attribute", /*isAttrGenerator=*/true) {}
-};
-/// A specialized generator for TypeDefs.
-struct TypeDefGenerator : public DefGenerator {
- TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
- : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
- "Type", "Type", /*isAttrGenerator=*/false) {}
-};
-} // namespace
+TypeDefGenerator::TypeDefGenerator(const RecordKeeper &records, raw_ostream &os,
+ bool fatalOnError)
+ : AttrTypeDefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"),
+ os, "Type", "Type", /*isAttrGenerator=*/false,
+ fatalOnError) {}
//===----------------------------------------------------------------------===//
// GEN: Declarations
@@ -892,7 +750,19 @@ class AsmPrinter;
} // namespace mlir
)";
-bool DefGenerator::emitDecls(StringRef selectedDialect) {
+AttrTypeDefGenerator::AttrTypeDefGenerator(
+ ArrayRef<const Record *> defs, llvm::raw_ostream &os, StringRef defType,
+ StringRef valueType, bool isAttrGenerator, bool fatalOnError)
+ : defRecords(defs.begin(), defs.end()), os(os), defType(defType),
+ valueType(valueType), isAttrGenerator(isAttrGenerator),
+ fatalOnError(fatalOnError) {
+ // Sort by occurrence in file.
+ llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
+ return lhs->getID() < rhs->getID();
+ });
+}
+
+bool AttrTypeDefGenerator::emitDecls(StringRef selectedDialect) {
emitSourceFileHeader((defType + "Def Declarations").str(), os);
llvm::IfDefEmitter scope(os, "GET_" + defType.upper() + "DEF_CLASSES");
@@ -915,7 +785,7 @@ bool DefGenerator::emitDecls(StringRef selectedDialect) {
// Emit the declarations.
for (const AttrOrTypeDef &def : defs)
- DefGen(def).emitDecl(os);
+ AttrOrTypeDefEmitter(def, fatalOnError).emitDecl(os);
}
// Emit the TypeID explicit specializations to have a single definition for
// each of these.
@@ -932,7 +802,7 @@ bool DefGenerator::emitDecls(StringRef selectedDialect) {
// GEN: Def List
//===----------------------------------------------------------------------===//
-void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
+void AttrTypeDefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
llvm::IfDefEmitter scope(os, "GET_" + defType.upper() + "DEF_LIST");
auto interleaveFn = [&](const AttrOrTypeDef &def) {
os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
@@ -1046,13 +916,13 @@ static const char *const dialectDynamicTypePrinterDispatch = R"(
/// directive that should attach directly to the mnemonic (e.g., `<`, `(`,
/// ``).
///
-/// This inspects the raw format string rather than parsing it into a DefFormat.
-/// Parsing would require access to the format element types (which are local to
-/// `AttrOrTypeFormatGen.cpp`) and would re-parse every format string just to
-/// check its first token -- an overkill for a simple spacing heuristic.
-/// The only case this cannot distinguish structurally is an optional group
-/// whose then-branch starts with punctuation, but parsing has the same
-/// limitation since the group's anchor is not known at codegen time.
+/// This inspects the raw format string rather than parsing it into a
+/// AttrTypeDefFormat. Parsing would require access to the format element types
+/// (which are local to `AttrOrTypeFormatGen.cpp`) and would re-parse every
+/// format string just to check its first token -- an overkill for a simple
+/// spacing heuristic. The only case this cannot distinguish structurally is an
+/// optional group whose then-branch starts with punctuation, but parsing has
+/// the same limitation since the group's anchor is not known at codegen time.
static bool needsLeadingSpace(const AttrOrTypeDef &def) {
StringRef fmtStr = def.getAssemblyFormat()->trim();
if (fmtStr.empty())
@@ -1073,7 +943,8 @@ static bool needsLeadingSpace(const AttrOrTypeDef &def) {
/// Emit the dialect printer/parser dispatcher. User's code should call these
/// functions from their dialect's print/parse methods.
-void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
+void AttrTypeDefGenerator::emitParsePrintDispatch(
+ ArrayRef<AttrOrTypeDef> defs) {
if (llvm::none_of(defs, [](const AttrOrTypeDef &def) {
return def.getMnemonic().has_value();
})) {
@@ -1163,7 +1034,7 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
printer.writeDeclTo(indentedOs);
}
-bool DefGenerator::emitDefs(StringRef selectedDialect) {
+bool AttrTypeDefGenerator::emitDefs(StringRef selectedDialect) {
emitSourceFileHeader((defType + "Def Definitions").str(), os);
SmallVector<AttrOrTypeDef, 16> defs;
@@ -1177,7 +1048,7 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
for (const AttrOrTypeDef &def : defs) {
{
DialectNamespaceEmitter ns(os, def.getDialect());
- DefGen gen(def);
+ AttrOrTypeDefEmitter gen(def, fatalOnError);
gen.emitDef(os);
}
// Emit the TypeID explicit specializations to have a single symbol def.
@@ -1264,14 +1135,14 @@ static void emitConstraintDecls(ArrayRef<Constraint> constraints,
parameterTypeName, parameterName);
}
-static void emitTypeConstraintDecls(const RecordKeeper &records,
- raw_ostream &os) {
+void mlir::tblgen::emitTypeConstraintDecls(const RecordKeeper &records,
+ raw_ostream &os) {
emitConstraintDecls(getAllCppTypeConstraints(records), os, "::mlir::Type",
"type");
}
-static void emitAttrConstraintDecls(const RecordKeeper &records,
- raw_ostream &os) {
+void mlir::tblgen::emitAttrConstraintDecls(const RecordKeeper &records,
+ raw_ostream &os) {
emitConstraintDecls(getAllCppAttrConstraints(records), os,
"::mlir::Attribute", "attr");
}
@@ -1299,94 +1170,14 @@ return ({3});
}
}
-static void emitTypeConstraintDefs(const RecordKeeper &records,
- raw_ostream &os) {
+void mlir::tblgen::emitTypeConstraintDefs(const RecordKeeper &records,
+ raw_ostream &os) {
emitConstraintDefs(getAllCppTypeConstraints(records), os, "::mlir::Type",
"type");
}
-static void emitAttrConstraintDefs(const RecordKeeper &records,
- raw_ostream &os) {
+void mlir::tblgen::emitAttrConstraintDefs(const RecordKeeper &records,
+ raw_ostream &os) {
emitConstraintDefs(getAllCppAttrConstraints(records), os, "::mlir::Attribute",
"attr");
}
-
-//===----------------------------------------------------------------------===//
-// GEN: Registration hooks
-//===----------------------------------------------------------------------===//
-
-//===----------------------------------------------------------------------===//
-// AttrDef
-//===----------------------------------------------------------------------===//
-
-static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
-static llvm::cl::opt<std::string>
- attrDialect("attrdefs-dialect",
- llvm::cl::desc("Generate attributes for this dialect"),
- llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
-
-static mlir::GenRegistration
- genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
- [](const RecordKeeper &records, raw_ostream &os) {
- AttrDefGenerator generator(records, os);
- return generator.emitDefs(attrDialect);
- });
-static mlir::GenRegistration
- genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
- [](const RecordKeeper &records, raw_ostream &os) {
- AttrDefGenerator generator(records, os);
- return generator.emitDecls(attrDialect);
- });
-
-static mlir::GenRegistration
- genAttrConstrDefs("gen-attr-constraint-defs",
- "Generate attribute constraint definitions",
- [](const RecordKeeper &records, raw_ostream &os) {
- emitAttrConstraintDefs(records, os);
- return false;
- });
-static mlir::GenRegistration
- genAttrConstrDecls("gen-attr-constraint-decls",
- "Generate attribute constraint declarations",
- [](const RecordKeeper &records, raw_ostream &os) {
- emitAttrConstraintDecls(records, os);
- return false;
- });
-
-//===----------------------------------------------------------------------===//
-// TypeDef
-//===----------------------------------------------------------------------===//
-
-static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
-static llvm::cl::opt<std::string>
- typeDialect("typedefs-dialect",
- llvm::cl::desc("Generate types for this dialect"),
- llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
-
-static mlir::GenRegistration
- genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
- [](const RecordKeeper &records, raw_ostream &os) {
- TypeDefGenerator generator(records, os);
- return generator.emitDefs(typeDialect);
- });
-static mlir::GenRegistration
- genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
- [](const RecordKeeper &records, raw_ostream &os) {
- TypeDefGenerator generator(records, os);
- return generator.emitDecls(typeDialect);
- });
-
-static mlir::GenRegistration
- genTypeConstrDefs("gen-type-constraint-defs",
- "Generate type constraint definitions",
- [](const RecordKeeper &records, raw_ostream &os) {
- emitTypeConstraintDefs(records, os);
- return false;
- });
-static mlir::GenRegistration
- genTypeConstrDecls("gen-type-constraint-decls",
- "Generate type constraint declarations",
- [](const RecordKeeper &records, raw_ostream &os) {
- emitTypeConstraintDecls(records, os);
- return false;
- });
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/lib/TableGen/Generators/AttrOrTypeFormatGen.cpp
similarity index 77%
rename from mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
rename to mlir/lib/TableGen/Generators/AttrOrTypeFormatGen.cpp
index a9bca471cf5bd..1774412c5e0f8 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/lib/TableGen/Generators/AttrOrTypeFormatGen.cpp
@@ -6,12 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#include "AttrOrTypeFormatGen.h"
-#include "FormatGen.h"
+#include "mlir/TableGen/Generators/AttrOrTypeFormatGen.h"
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Format.h"
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/FormatGen.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringExtras.h"
@@ -32,49 +31,12 @@ using llvm::formatv;
// Element
//===----------------------------------------------------------------------===//
-namespace {
-/// This class represents an instance of a variable element. A variable refers
-/// to an attribute or type parameter.
-class ParameterElement
- : public VariableElementBase<VariableElement::Parameter> {
-public:
- ParameterElement(AttrOrTypeParameter param) : param(param) {}
-
- /// Get the parameter in the element.
- const AttrOrTypeParameter &getParam() const { return param; }
-
- /// Indicate if this variable is printed "qualified" (that is it is
- /// prefixed with the `#dialect.mnemonic`).
- bool shouldBeQualified() { return shouldBeQualifiedFlag; }
- void setShouldBeQualified(bool qualified = true) {
- shouldBeQualifiedFlag = qualified;
- }
-
- /// Returns true if the element contains an optional parameter.
- bool isOptional() const { return param.isOptional(); }
-
- /// Returns the name of the parameter.
- StringRef getName() const { return param.getName(); }
-
- /// Return the code to check whether the parameter is present.
- auto genIsPresent(FmtContext &ctx, const Twine &self) const {
- assert(isOptional() && "cannot guard on a mandatory parameter");
- std::string valueStr = tgfmt(*param.getDefaultValue(), &ctx).str();
- ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
- return tgfmt(getParam().getComparator(), &ctx);
- }
-
- /// Generate the code to check whether the parameter should be printed.
- MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
- assert(isOptional() && "cannot guard on a mandatory parameter");
- std::string self = param.getAccessorName() + "()";
- return os << "!(" << genIsPresent(ctx, self) << ")";
- }
-
-private:
- bool shouldBeQualifiedFlag = false;
- AttrOrTypeParameter param;
-};
+MethodBody &ParameterElement::genPrintGuard(FmtContext &ctx,
+ MethodBody &os) const {
+ assert(isOptional() && "cannot guard on a mandatory parameter");
+ std::string self = param.getAccessorName() + "()";
+ return os << "!(" << genIsPresent(ctx, self) << ")";
+}
/// Utility to return the encapsulated parameter element for the provided format
/// element. This parameter can originate from either a `ParameterElement`,
@@ -120,43 +82,13 @@ static bool formatNotOptional(FormatElement *el) {
return !formatIsOptional(el);
}
-/// This class represents a `params` directive that refers to all parameters
-/// of an attribute or type. When used as a top-level directive, it generates
-/// a format of the form:
-///
-/// (param-value (`,` param-value)*)?
-///
-/// When used as an argument to another directive that accepts variables,
-/// `params` can be used in place of manually listing all parameters of an
-/// attribute or type.
-class ParamsDirective
- : public VectorDirectiveBase<DirectiveElement::Params, ParameterElement *> {
-public:
- using Base::Base;
-
- /// Returns true if there are optional parameters present.
- bool hasOptionalElements() const {
- return llvm::any_of(getElements(), paramIsOptional);
- }
-};
-
-/// This class represents a `struct` directive that generates a struct format
-/// of the form:
-///
-/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
-///
-class StructDirective
- : public VectorDirectiveBase<DirectiveElement::Struct, FormatElement *> {
-public:
- using Base::Base;
-
- /// Returns true if there are optional format elements present.
- bool hasOptionalElements() const {
- return llvm::any_of(getElements(), formatIsOptional);
- }
-};
+bool ParamsDirective::hasOptionalElements() const {
+ return llvm::any_of(getElements(), paramIsOptional);
+}
-} // namespace
+bool StructDirective::hasOptionalElements() const {
+ return llvm::any_of(getElements(), formatIsOptional);
+}
//===----------------------------------------------------------------------===//
// Format Strings
@@ -199,77 +131,9 @@ if (::mlir::failed(_result_{0})) {{
)";
//===----------------------------------------------------------------------===//
-// DefFormat
+// AttrTypeDefFormat
//===----------------------------------------------------------------------===//
-namespace {
-class DefFormat {
-public:
- DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements)
- : def(def), elements(std::move(elements)) {}
-
- /// Generate the attribute or type parser.
- void genParser(MethodBody &os);
- /// Generate the attribute or type printer.
- void genPrinter(MethodBody &os);
-
-private:
- /// Generate the parser code for a specific format element.
- void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
- /// Generate the parser code for a literal.
- void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os,
- bool isOptional = false);
- /// Generate the parser code for a variable.
- void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
- /// Generate the parser code for a `params` directive.
- void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
- /// Generate the parser code for a `struct` directive.
- void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
- /// Generate the parser code for a `custom` directive.
- void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os,
- bool isOptional = false);
- /// Generate the parser code for an optional group.
- void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
- MethodBody &os);
-
- /// Generate the printer code for a specific format element.
- void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
- /// Generate the printer code for a literal.
- void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
- /// Generate the printer code for a variable.
- void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
- bool skipGuard = false);
- /// Generate a printer for comma-separated format elements.
- void genCommaSeparatedPrinter(
- ArrayRef<FormatElement *> params, FmtContext &ctx, MethodBody &os,
- function_ref<void(FormatElement *)> extra,
- function_ref<void(FormatElement *)> extraPost = nullptr);
- /// Generate the printer code for a `params` directive.
- void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
- /// Generate the printer code for a `struct` directive.
- void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
- /// Generate the printer code for a `custom` directive.
- void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
- /// Generate the printer code for an optional group.
- void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
- MethodBody &os);
- /// Generate a printer (or space eraser) for a whitespace element.
- void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
- MethodBody &os);
-
- /// The ODS definition of the attribute or type whose format is being used to
- /// generate a parser and printer.
- const AttrOrTypeDef &def;
- /// The list of top-level format elements returned by the assembly format
- /// parser.
- std::vector<FormatElement *> elements;
-
- /// Flags for printing spaces.
- bool shouldEmitSpace = false;
- bool lastWasPunctuation = false;
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// ParserGen
//===----------------------------------------------------------------------===//
@@ -307,7 +171,7 @@ if ($_type) {
os << "\n";
}
-void DefFormat::genParser(MethodBody &os) {
+void AttrTypeDefFormat::genParser(MethodBody &os) {
FmtContext ctx;
ctx.addSubst("_parser", "odsParser");
ctx.addSubst("_ctxt", "odsParser.getContext()");
@@ -374,8 +238,8 @@ void DefFormat::genParser(MethodBody &os) {
os << ");";
}
-void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
+ MethodBody &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralParser(literal->getSpelling(), ctx, os);
if (auto *var = dyn_cast<ParameterElement>(el))
@@ -394,8 +258,8 @@ void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
llvm_unreachable("unknown format element");
}
-void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
- MethodBody &os, bool isOptional) {
+void AttrTypeDefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
+ MethodBody &os, bool isOptional) {
os << "// Parse literal '" << value << "'\n";
os << tgfmt("if ($_parser.parse", &ctx);
if (isOptional)
@@ -431,8 +295,8 @@ void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
os << ") return {};\n";
}
-void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
+ MethodBody &os) {
// Check for a custom parser. Use the default attribute parser otherwise.
const AttrOrTypeParameter ¶m = el->getParam();
auto customParser = param.getParser();
@@ -464,8 +328,8 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
dialectLoading);
}
-void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
+ MethodBody &os) {
os << "// Parse parameter list\n";
// If there are optional parameters, we need to switch to `parseOptionalComma`
@@ -522,8 +386,8 @@ void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
os.unindent() << "} while(false);\n";
}
-void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
+ MethodBody &os) {
// Loop declaration for struct parser with only required parameters.
//
// $0: Number of expected parameters.
@@ -671,8 +535,8 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
os.unindent() << "}\n";
}
-void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
- MethodBody &os, bool isOptional) {
+void AttrTypeDefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
+ MethodBody &os, bool isOptional) {
os << "{\n";
os.indent();
@@ -716,8 +580,9 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
os.unindent() << "}\n";
}
-void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genOptionalGroupParser(OptionalElement *el,
+ FmtContext &ctx,
+ MethodBody &os) {
ArrayRef<FormatElement *> thenElements =
el->getThenElements(/*parseable=*/true);
@@ -774,7 +639,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
// PrinterGen
//===----------------------------------------------------------------------===//
-void DefFormat::genPrinter(MethodBody &os) {
+void AttrTypeDefFormat::genPrinter(MethodBody &os) {
FmtContext ctx;
ctx.addSubst("_printer", "odsPrinter");
ctx.addSubst("_ctxt", "getContext()");
@@ -791,8 +656,8 @@ void DefFormat::genPrinter(MethodBody &os) {
genElementPrinter(el, ctx, os);
}
-void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
+ MethodBody &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralPrinter(literal->getSpelling(), ctx, os);
if (auto *params = dyn_cast<ParamsDirective>(el))
@@ -811,8 +676,8 @@ void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
llvm::PrintFatalError("unsupported format element");
}
-void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
+ MethodBody &os) {
// Don't insert a space before certain punctuation.
bool needSpace =
shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
@@ -825,8 +690,9 @@ void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
}
-void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
- MethodBody &os, bool skipGuard) {
+void AttrTypeDefFormat::genVariablePrinter(ParameterElement *el,
+ FmtContext &ctx, MethodBody &os,
+ bool skipGuard) {
const AttrOrTypeParameter ¶m = el->getParam();
ctx.withSelf(param.getAccessorName() + "()");
@@ -882,7 +748,7 @@ static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os,
inverted);
}
-void DefFormat::genCommaSeparatedPrinter(
+void AttrTypeDefFormat::genCommaSeparatedPrinter(
ArrayRef<FormatElement *> args, FmtContext &ctx, MethodBody &os,
function_ref<void(FormatElement *)> extra,
function_ref<void(FormatElement *)> extraPost) {
@@ -922,8 +788,8 @@ void DefFormat::genCommaSeparatedPrinter(
os.unindent() << "}\n";
}
-void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
+ MethodBody &os) {
SmallVector<FormatElement *> args = llvm::map_to_vector(
el->getElements(), [](ParameterElement *param) -> FormatElement * {
return static_cast<FormatElement *>(param);
@@ -931,8 +797,8 @@ void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
genCommaSeparatedPrinter(args, ctx, os, [&](FormatElement *param) {});
}
-void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
+ MethodBody &os) {
ArrayRef<FormatElement *> elems = el->getElements();
// An `ArrayRefParameter` without a custom printer in a non-last struct
// position must be wrapped in `[...]` to avoid ambiguity with the
@@ -958,8 +824,8 @@ void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
});
}
-void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
+ MethodBody &os) {
// Insert a space before the custom directive, if necessary.
if (shouldEmitSpace || !lastWasPunctuation)
os << tgfmt("$_printer << ' ';\n", &ctx);
@@ -982,8 +848,9 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
os.unindent() << ");\n";
}
-void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genOptionalGroupPrinter(OptionalElement *el,
+ FmtContext &ctx,
+ MethodBody &os) {
FormatElement *anchor = el->getAnchor();
if (auto *param = dyn_cast<ParameterElement>(anchor)) {
guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted());
@@ -1009,8 +876,8 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
os.unindent() << "}\n";
}
-void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
- MethodBody &os) {
+void AttrTypeDefFormat::genWhitespacePrinter(WhitespaceElement *el,
+ FmtContext &ctx, MethodBody &os) {
if (el->getValue() == "\\n") {
os << tgfmt("$_printer.printNewline();\n", &ctx);
} else if (!el->getValue().empty()) {
@@ -1022,59 +889,11 @@ void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
}
//===----------------------------------------------------------------------===//
-// DefFormatParser
+// AttrTypeDefFormatParser
//===----------------------------------------------------------------------===//
-namespace {
-class DefFormatParser : public FormatParser {
-public:
- DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
- : FormatParser(mgr, def.getLoc()[0]), def(def),
- seenParams(def.getNumParameters()) {}
-
- /// Parse the attribute or type format and create the format elements.
- FailureOr<DefFormat> parse();
-
-protected:
- /// Verify the parsed elements.
- LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
- /// Verify the elements of a custom directive.
- LogicalResult
- verifyCustomDirectiveArguments(SMLoc loc,
- ArrayRef<FormatElement *> arguments) override;
- /// Verify the elements of an optional group.
- LogicalResult verifyOptionalGroupElements(SMLoc loc,
- ArrayRef<FormatElement *> elements,
- FormatElement *anchor) override;
- /// Verify the arguments to a struct directive.
- LogicalResult verifyStructArguments(SMLoc loc,
- ArrayRef<FormatElement *> arguments);
-
- LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
-
- /// Parse an attribute or type variable.
- FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
- Context ctx) override;
- /// Parse an attribute or type format directive.
- FailureOr<FormatElement *>
- parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
-
-private:
- /// Parse a `params` directive.
- FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
- /// Parse a `struct` directive.
- FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
-
- /// Attribute or type tablegen def.
- const AttrOrTypeDef &def;
-
- /// Seen attribute or type parameters.
- BitVector seenParams;
-};
-} // namespace
-
-LogicalResult DefFormatParser::verify(SMLoc loc,
- ArrayRef<FormatElement *> elements) {
+LogicalResult
+AttrTypeDefFormatParser::verify(SMLoc loc, ArrayRef<FormatElement *> elements) {
// Check that all parameters are referenced in the format.
for (auto [index, param] : llvm::enumerate(def.getParameters())) {
if (param.isOptional())
@@ -1107,16 +926,15 @@ LogicalResult DefFormatParser::verify(SMLoc loc,
return success();
}
-LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
+LogicalResult AttrTypeDefFormatParser::verifyCustomDirectiveArguments(
SMLoc loc, ArrayRef<FormatElement *> arguments) {
// Arguments are fully verified by the parser context.
return success();
}
-LogicalResult
-DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
- ArrayRef<FormatElement *> elements,
- FormatElement *anchor) {
+LogicalResult AttrTypeDefFormatParser::verifyOptionalGroupElements(
+ llvm::SMLoc loc, ArrayRef<FormatElement *> elements,
+ FormatElement *anchor) {
// `params` and `struct` directives are allowed only if all the contained
// parameters are optional.
for (FormatElement *el : elements) {
@@ -1168,9 +986,8 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
return success();
}
-LogicalResult
-DefFormatParser::verifyStructArguments(SMLoc loc,
- ArrayRef<FormatElement *> arguments) {
+LogicalResult AttrTypeDefFormatParser::verifyStructArguments(
+ SMLoc loc, ArrayRef<FormatElement *> arguments) {
for (FormatElement *el : arguments) {
if (!isa<ParameterElement, CustomDirective, ParamsDirective>(el)) {
return emitError(loc, "expected a parameter, custom directive or params "
@@ -1190,23 +1007,24 @@ DefFormatParser::verifyStructArguments(SMLoc loc,
return success();
}
-LogicalResult DefFormatParser::markQualified(SMLoc loc,
- FormatElement *element) {
+LogicalResult AttrTypeDefFormatParser::markQualified(SMLoc loc,
+ FormatElement *element) {
if (!isa<ParameterElement>(element))
return emitError(loc, "`qualified` argument list expected a variable");
cast<ParameterElement>(element)->setShouldBeQualified();
return success();
}
-FailureOr<DefFormat> DefFormatParser::parse() {
+FailureOr<AttrTypeDefFormat> AttrTypeDefFormatParser::parse() {
FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
if (failed(elements))
return failure();
- return DefFormat(def, std::move(*elements));
+ return AttrTypeDefFormat(def, std::move(*elements));
}
FailureOr<FormatElement *>
-DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
+AttrTypeDefFormatParser::parseVariableImpl(SMLoc loc, StringRef name,
+ Context ctx) {
// Lookup the parameter.
ArrayRef<AttrOrTypeParameter> params = def.getParameters();
auto *it = llvm::find_if(
@@ -1235,8 +1053,8 @@ DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
}
FailureOr<FormatElement *>
-DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
- Context ctx) {
+AttrTypeDefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
+ Context ctx) {
switch (kind) {
case FormatToken::kw_qualified:
@@ -1250,8 +1068,8 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
}
}
-FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
- Context ctx) {
+FailureOr<FormatElement *>
+AttrTypeDefFormatParser::parseParamsDirective(SMLoc loc, Context ctx) {
// It doesn't make sense to allow references to all parameters in a custom
// directive because parameters are the only things that can be bound.
if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
@@ -1277,8 +1095,8 @@ FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
return create<ParamsDirective>(std::move(vars));
}
-FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
- Context ctx) {
+FailureOr<FormatElement *>
+AttrTypeDefFormatParser::parseStructDirective(SMLoc loc, Context ctx) {
if (ctx != TopLevelContext)
return emitError(loc, "`struct` can only be used at the top-level context");
@@ -1331,16 +1149,17 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
MethodBody &parser,
- MethodBody &printer) {
+ MethodBody &printer,
+ bool fatalOnError) {
llvm::SourceMgr mgr;
mgr.AddNewSourceBuffer(
llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc());
// Parse the custom assembly format>
- DefFormatParser fmtParser(mgr, def);
- FailureOr<DefFormat> format = fmtParser.parse();
+ AttrTypeDefFormatParser fmtParser(mgr, def);
+ FailureOr<AttrTypeDefFormat> format = fmtParser.parse();
if (failed(format)) {
- if (formatErrorIsFatal)
+ if (fatalOnError)
PrintFatalError(def.getLoc(), "failed to parse assembly format");
return;
}
diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/lib/TableGen/Generators/BytecodeDialectGen.cpp
similarity index 95%
rename from mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
rename to mlir/lib/TableGen/Generators/BytecodeDialectGen.cpp
index dd178b5e5d232..0c7a826d60760 100644
--- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
+++ b/mlir/lib/TableGen/Generators/BytecodeDialectGen.cpp
@@ -6,12 +6,11 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/TableGen/Generators/BytecodeDialectGen.h"
#include "mlir/Support/IndentedOstream.h"
-#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVectorExtras.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -19,11 +18,6 @@
using namespace llvm;
-static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
-static cl::opt<std::string>
- selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
- cl::cat(dialectGenCat), cl::CommaSeparated);
-
namespace {
/// Helper class to generate C++ bytecode parser helpers.
@@ -437,19 +431,21 @@ struct AttrOrType {
};
} // namespace
-static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
+namespace mlir {
+namespace tblgen {
+
+bool emitBytecodeDialect(const RecordKeeper &records, StringRef dialectName,
+ raw_ostream &os) {
MapVector<StringRef, AttrOrType> dialectAttrOrType;
for (const Record *it :
records.getAllDerivedDefinitions("DialectAttributes")) {
- if (!selectedBcDialect.empty() &&
- it->getValueAsString("dialect") != selectedBcDialect)
+ if (!dialectName.empty() && it->getValueAsString("dialect") != dialectName)
continue;
dialectAttrOrType[it->getValueAsString("dialect")].attr =
it->getValueAsListOfDefs("elems");
}
for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) {
- if (!selectedBcDialect.empty() &&
- it->getValueAsString("dialect") != selectedBcDialect)
+ if (!dialectName.empty() && it->getValueAsString("dialect") != dialectName)
continue;
dialectAttrOrType[it->getValueAsString("dialect")].type =
it->getValueAsListOfDefs("elems");
@@ -490,8 +486,5 @@ static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
return false;
}
-static mlir::GenRegistration
- genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
- [](const RecordKeeper &records, raw_ostream &os) {
- return emitBCRW(records, os);
- });
+} // namespace tblgen
+} // namespace mlir
diff --git a/mlir/lib/TableGen/Generators/CMakeLists.txt b/mlir/lib/TableGen/Generators/CMakeLists.txt
new file mode 100644
index 0000000000000..29b057ccb1d66
--- /dev/null
+++ b/mlir/lib/TableGen/Generators/CMakeLists.txt
@@ -0,0 +1,39 @@
+# This library has the same linking constraints as MLIRTableGen: it is used by
+# mlir-tblgen (built with DISABLE_LLVM_LINK_LLVM_DYLIB) and must therefore also
+# be built with that flag.
+llvm_add_library(MLIRTableGenGenerators STATIC
+ AttrOrTypeDefGen.cpp
+ AttrOrTypeFormatGen.cpp
+ BytecodeDialectGen.cpp
+ CppGenUtilities.cpp
+ DialectGen.cpp
+ DialectInterfacesGen.cpp
+ DocGenUtilities.cpp
+ EnumsGen.cpp
+ EnumPythonBindingGen.cpp
+ FormatGen.cpp
+ OpClass.cpp
+ OpDefinitionsGen.cpp
+ OpDocGen.cpp
+ OpFormatGen.cpp
+ OpGenHelpers.cpp
+ OpInterfacesGen.cpp
+ OpPythonBindingGen.cpp
+ PassCAPIGen.cpp
+ PassDocGen.cpp
+ PassGen.cpp
+ RewriterGen.cpp
+
+ DISABLE_LLVM_LINK_LLVM_DYLIB
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/TableGen/Generators
+
+ LINK_LIBS PUBLIC
+ MLIRTableGen
+)
+set_target_properties(MLIRTableGenGenerators PROPERTIES FOLDER "MLIR/Tablegenning")
+
+mlir_check_all_link_libraries(MLIRTableGenGenerators)
+
+add_mlir_library_install(MLIRTableGenGenerators)
diff --git a/mlir/tools/mlir-tblgen/CppGenUtilities.cpp b/mlir/lib/TableGen/Generators/CppGenUtilities.cpp
similarity index 74%
rename from mlir/tools/mlir-tblgen/CppGenUtilities.cpp
rename to mlir/lib/TableGen/Generators/CppGenUtilities.cpp
index 6c05d5336224d..335e2adb06f51 100644
--- a/mlir/tools/mlir-tblgen/CppGenUtilities.cpp
+++ b/mlir/lib/TableGen/Generators/CppGenUtilities.cpp
@@ -1,4 +1,4 @@
-//===- CppGenUtilities.cpp - MLIR cpp gen utilities --------------===//
+//===- CppGenUtilities.cpp - MLIR C++ gen utilities -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,17 +6,18 @@
//
//===----------------------------------------------------------------------===//
//
-// Defines common utilities for generating cpp files from tablegen
-// structures.
+// Defines common utilities for generating C++ files from TableGen structures.
//
//===----------------------------------------------------------------------===//
-#include "CppGenUtilities.h"
+#include "mlir/TableGen/Generators/CppGenUtilities.h"
#include "mlir/Support/IndentedOstream.h"
+using llvm::StringRef;
+
void mlir::tblgen::emitSummaryAndDescComments(llvm::raw_ostream &os,
- llvm::StringRef summary,
- llvm::StringRef description,
+ StringRef summary,
+ StringRef description,
bool terminateComment) {
StringRef trimmedSummary = summary.rtrim();
StringRef trimmedDesc = description.rtrim();
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/lib/TableGen/Generators/DialectGen.cpp
similarity index 73%
rename from mlir/tools/mlir-tblgen/DialectGen.cpp
rename to mlir/lib/TableGen/Generators/DialectGen.cpp
index 8eecad39f49f3..8d1c4f6d2b5fc 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/lib/TableGen/Generators/DialectGen.cpp
@@ -1,4 +1,4 @@
-//===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
+//===- DialectGen.cpp - MLIR dialect C++ definitions generator ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,96 +10,27 @@
//
//===----------------------------------------------------------------------===//
-#include "CppGenUtilities.h"
-#include "DialectGenUtilities.h"
+#include "mlir/TableGen/Generators/DialectGen.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/CppGenUtilities.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Signals.h"
+#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
-#define DEBUG_TYPE "mlir-tblgen-opdefgen"
-
using namespace mlir;
using namespace mlir::tblgen;
using llvm::Record;
using llvm::RecordKeeper;
-static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
-static llvm::cl::opt<std::string>
- selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
- llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
-
-/// Utility iterator used for filtering records for a specific dialect.
-namespace {
-using DialectFilterIterator =
- llvm::filter_iterator<ArrayRef<Record *>::iterator,
- std::function<bool(const Record *)>>;
-} // namespace
-
-static void populateDiscardableAttributes(
- Dialect &dialect, const llvm::DagInit *discardableAttrDag,
- SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
- for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
- const llvm::Init *arg = discardableAttrDag->getArg(i);
-
- StringRef givenName = discardableAttrDag->getArgNameStr(i);
- if (givenName.empty())
- PrintFatalError(dialect.getDef()->getLoc(),
- "discardable attributes must be named");
- discardableAttributes.push_back(
- {givenName.str(), arg->getAsUnquotedString()});
- }
-}
-
-/// Given a set of records for a T, filter the ones that correspond to
-/// the given dialect.
-template <typename T>
-static iterator_range<DialectFilterIterator>
-filterForDialect(ArrayRef<Record *> records, Dialect &dialect) {
- auto filterFn = [&](const Record *record) {
- return T(record).getDialect() == dialect;
- };
- return {DialectFilterIterator(records.begin(), records.end(), filterFn),
- DialectFilterIterator(records.end(), records.end(), filterFn)};
-}
-
-std::optional<Dialect>
-tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
- if (dialects.empty()) {
- llvm::errs() << "no dialect was found\n";
- return std::nullopt;
- }
-
- // Select the dialect to gen for.
- if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0)
- return dialects.front();
-
- if (selectedDialect.getNumOccurrences() == 0) {
- llvm::errs() << "when more than 1 dialect is present, one must be selected "
- "via '-dialect'\n";
- return std::nullopt;
- }
-
- const auto *dialectIt = llvm::find_if(dialects, [](const Dialect &dialect) {
- return dialect.getName() == selectedDialect;
- });
- if (dialectIt == dialects.end()) {
- llvm::errs() << "selected dialect with '-dialect' does not exist\n";
- return std::nullopt;
- }
- return *dialectIt;
-}
-
//===----------------------------------------------------------------------===//
// GEN: Dialect declarations
//===----------------------------------------------------------------------===//
@@ -124,7 +55,7 @@ class {0} : public ::mlir::{2} {
/// Registration for a single dependent dialect: to be inserted in the ctor
/// above for each dependent dialect.
-const char *const dialectRegistrationTemplate =
+static const char *const dialectRegistrationTemplate =
"getContext()->loadDialect<{0}>();";
/// The code block for the attribute parser/printer hooks.
@@ -236,20 +167,87 @@ static const char *const discardableAttrHelperDecl = R"(
public:
)";
-/// Generate the declaration for the given dialect class.
-static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
+/// The code block to generate a dialect constructor definition.
+///
+/// {0}: The name of the dialect class.
+/// {1}: Initialization code emitted in the ctor body before initialize().
+/// {2}: The dialect parent class.
+/// {3}: Extra members to initialize.
+static const char *const dialectConstructorStr = R"(
+{0}::{0}(::mlir::MLIRContext *context)
+ : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
+ {3}
+ {{
+ {1}
+ initialize();
+}
+)";
+
+/// The code block to generate a default destructor definition.
+///
+/// {0}: The name of the dialect class.
+static const char *const dialectDestructorStr = R"(
+{0}::~{0}() = default;
+
+)";
+
+void mlir::tblgen::populateDiscardableAttributes(
+ Dialect &dialect, const llvm::DagInit *discardableAttrDag,
+ llvm::SmallVectorImpl<std::pair<std::string, std::string>>
+ &discardableAttributes) {
+ for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
+ const llvm::Init *arg = discardableAttrDag->getArg(i);
+
+ llvm::StringRef givenName = discardableAttrDag->getArgNameStr(i);
+ if (givenName.empty())
+ llvm::PrintFatalError(dialect.getDef()->getLoc(),
+ "discardable attributes must be named");
+ discardableAttributes.push_back(
+ {givenName.str(), arg->getAsUnquotedString()});
+ }
+}
+
+std::optional<Dialect>
+mlir::tblgen::findDialectToGenerate(llvm::ArrayRef<Dialect> dialects,
+ llvm::StringRef selectedDialect) {
+ if (dialects.empty()) {
+ llvm::errs() << "no dialect was found\n";
+ return std::nullopt;
+ }
+
+ // If there is exactly one dialect and none was explicitly selected, use it.
+ if (dialects.size() == 1 && selectedDialect.empty())
+ return dialects.front();
+
+ if (selectedDialect.empty()) {
+ llvm::errs() << "when more than 1 dialect is present, one must be selected "
+ "via '-dialect'\n";
+ return std::nullopt;
+ }
+
+ const auto *dialectIt = llvm::find_if(dialects, [&](const Dialect &dialect) {
+ return dialect.getName() == selectedDialect;
+ });
+ if (dialectIt == dialects.end()) {
+ llvm::errs() << "selected dialect with '-dialect' does not exist\n";
+ return std::nullopt;
+ }
+ return *dialectIt;
+}
+
+void mlir::tblgen::emitDialectDecl(Dialect &dialect, llvm::raw_ostream &os) {
// Emit all nested namespaces.
{
DialectNamespaceEmitter nsEmitter(os, dialect);
// Emit the start of the decl.
std::string cppName = dialect.getCppClassName();
- StringRef superClassName =
+ llvm::StringRef superClassName =
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
tblgen::emitSummaryAndDescComments(os, dialect.getSummary(),
dialect.getDescription(),
- /*terminateCmment=*/false);
+ /*terminateComment=*/false);
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
superClassName);
@@ -258,7 +256,7 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
if (dialect.useDefaultAttributePrinterParser())
os << attrParserDecl;
// If the dialect requested the default type printer and parser, emit the
- // delcarations for the hooks.
+ // declarations for the hooks.
if (dialect.useDefaultTypePrinterParser())
os << typeParserDecl;
@@ -278,7 +276,8 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
const llvm::DagInit *discardableAttrDag =
dialect.getDiscardableAttributes();
- SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+ llvm::SmallVector<std::pair<std::string, std::string>>
+ discardableAttributes;
populateDiscardableAttributes(dialect, discardableAttrDag,
discardableAttributes);
@@ -292,7 +291,8 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
dialect.getName());
}
- if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
+ if (std::optional<llvm::StringRef> extraDecl =
+ dialect.getExtraClassDeclaration())
os << *extraDecl;
// End the dialect decl.
@@ -303,52 +303,26 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
<< "::" << dialect.getCppClassName() << ")\n";
}
-static bool emitDialectDecls(const RecordKeeper &records, raw_ostream &os) {
+bool mlir::tblgen::emitDialectDecls(const RecordKeeper &records,
+ llvm::StringRef selectedDialect,
+ llvm::raw_ostream &os) {
emitSourceFileHeader("Dialect Declarations", os, records);
auto dialectDefs = records.getAllDerivedDefinitions("Dialect");
if (dialectDefs.empty())
return false;
- SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
- std::optional<Dialect> dialect = findDialectToGenerate(dialects);
+ llvm::SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
+ std::optional<Dialect> dialect =
+ findDialectToGenerate(dialects, selectedDialect);
if (!dialect)
return true;
emitDialectDecl(*dialect, os);
return false;
}
-//===----------------------------------------------------------------------===//
-// GEN: Dialect definitions
-//===----------------------------------------------------------------------===//
-
-/// The code block to generate a dialect constructor definition.
-///
-/// {0}: The name of the dialect class.
-/// {1}: Initialization code that is emitted in the ctor body before calling
-/// initialize(), such as dependent dialect registration.
-/// {2}: The dialect parent class.
-/// {3}: Extra members to initialize
-static const char *const dialectConstructorStr = R"(
-{0}::{0}(::mlir::MLIRContext *context)
- : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
- {3}
- {{
- {1}
- initialize();
-}
-)";
-
-/// The code block to generate a default destructor definition.
-///
-/// {0}: The name of the dialect class.
-static const char *const dialectDestructorStr = R"(
-{0}::~{0}() = default;
-
-)";
-
-static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
- raw_ostream &os) {
+void mlir::tblgen::emitDialectDef(Dialect &dialect, const RecordKeeper &records,
+ llvm::raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();
// Emit the TypeID explicit specializations to have a single symbol def.
@@ -359,13 +333,13 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
// Emit all nested namespaces.
DialectNamespaceEmitter nsEmitter(os, dialect);
- /// Build the list of dependent dialects.
+ // Build the list of dependent dialects.
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
llvm::interleave(
dialect.getDependentDialects(), dialectsOs,
- [&](StringRef dependentDialect) {
+ [&](llvm::StringRef dependentDialect) {
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
},
@@ -373,19 +347,19 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
}
// Emit the constructor and destructor.
- StringRef superClassName =
+ llvm::StringRef superClassName =
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
- SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+ llvm::SmallVector<std::pair<std::string, std::string>> discardableAttributes;
populateDiscardableAttributes(dialect, discardableAttrDag,
discardableAttributes);
std::string discardableAttributesInit;
for (const auto &attrPair : discardableAttributes) {
std::string camelName = llvm::convertToCamelFromSnakeCase(
attrPair.first, /*capitalizeFirst=*/false);
- llvm::raw_string_ostream os(discardableAttributesInit);
- os << ", " << camelName << "AttrName(context)";
+ llvm::raw_string_ostream initOs(discardableAttributesInit);
+ initOs << ", " << camelName << "AttrName(context)";
}
os << llvm::formatv(dialectConstructorStr, cppClassName,
@@ -395,33 +369,20 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
os << llvm::formatv(dialectDestructorStr, cppClassName);
}
-static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) {
+bool mlir::tblgen::emitDialectDefs(const RecordKeeper &records,
+ llvm::StringRef selectedDialect,
+ llvm::raw_ostream &os) {
emitSourceFileHeader("Dialect Definitions", os, records);
auto dialectDefs = records.getAllDerivedDefinitions("Dialect");
if (dialectDefs.empty())
return false;
- SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
- std::optional<Dialect> dialect = findDialectToGenerate(dialects);
+ llvm::SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
+ std::optional<Dialect> dialect =
+ findDialectToGenerate(dialects, selectedDialect);
if (!dialect)
return true;
emitDialectDef(*dialect, records, os);
return false;
}
-
-//===----------------------------------------------------------------------===//
-// GEN: Dialect registration hooks
-//===----------------------------------------------------------------------===//
-
-static mlir::GenRegistration
- genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
- [](const RecordKeeper &records, raw_ostream &os) {
- return emitDialectDecls(records, os);
- });
-
-static mlir::GenRegistration
- genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
- [](const RecordKeeper &records, raw_ostream &os) {
- return emitDialectDefs(records, os);
- });
diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/lib/TableGen/Generators/DialectInterfacesGen.cpp
similarity index 72%
rename from mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
rename to mlir/lib/TableGen/Generators/DialectInterfacesGen.cpp
index e695b8c761895..6078ce09fa490 100644
--- a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
+++ b/mlir/lib/TableGen/Generators/DialectInterfacesGen.cpp
@@ -1,4 +1,4 @@
-//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===//
+//===- DialectInterfacesGen.cpp - MLIR dialect interface generator --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,10 +10,10 @@
//
//===----------------------------------------------------------------------===//
-#include "CppGenUtilities.h"
-#include "DocGenUtilities.h"
+#include "mlir/TableGen/Generators/DialectInterfacesGen.h"
#include "mlir/Support/IndentedOstream.h"
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/CppGenUtilities.h"
+#include "mlir/TableGen/Generators/DocGenUtilities.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
@@ -50,39 +50,21 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
os << ") const";
}
-/// Get an array of all Dialect Interface definitions
-static std::vector<const Record *>
-getAllInterfaceDefinitions(const RecordKeeper &records) {
+std::vector<const Record *>
+mlir::tblgen::getAllDialectInterfaceDefinitions(const RecordKeeper &records) {
std::vector<const Record *> defs =
records.getAllDerivedDefinitions("DialectInterface");
llvm::erase_if(defs, [&](const Record *def) {
- // Ignore interfaces defined outside of the top-level file.
return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
llvm::SrcMgr.getMainFileID();
});
return defs;
}
-namespace {
-/// This struct is the generator used when processing tablegen dialect
-/// interfaces.
-class DialectInterfaceGenerator {
-public:
- DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
- : defs(getAllInterfaceDefinitions(records)), os(os) {}
-
- bool emitInterfaceDecls();
-
-protected:
- void emitInterfaceDecl(const DialectInterface &interface);
-
- /// The set of interface records to emit.
- std::vector<const Record *> defs;
- // The stream to emit to.
- raw_ostream &os;
-};
-} // namespace
+mlir::tblgen::DialectInterfaceGenerator::DialectInterfaceGenerator(
+ const RecordKeeper &records, raw_ostream &os)
+ : defs(getAllDialectInterfaceDefinitions(records)), os(os) {}
//===----------------------------------------------------------------------===//
// GEN: Interface declarations
@@ -98,7 +80,6 @@ static void emitInterfaceMethodDoc(const InterfaceMethod &method,
static void emitInterfaceMethodsDef(const DialectInterface &interface,
raw_ostream &os) {
-
raw_indented_ostream ios(os);
ios.indent(2);
@@ -118,7 +99,6 @@ static void emitInterfaceMethodsDef(const DialectInterface &interface,
continue;
}
- // if it is not a method declaration, then it's a normal interface method.
ios << " {";
if (auto body = method.getBody()) {
@@ -133,11 +113,8 @@ static void emitInterfaceMethodsDef(const DialectInterface &interface,
static void emitConstructor(const DialectInterface &interface,
raw_ostream &os) {
-
raw_indented_ostream ios(os);
- // We consider a constructor protected if interface has at least one pure
- // virtual method
auto hasProtectedConstructor =
llvm::any_of(interface.getMethods(), [](const InterfaceMethod &method) {
return method.isPureVirtual();
@@ -152,14 +129,13 @@ static void emitConstructor(const DialectInterface &interface,
interface.getName());
}
-void DialectInterfaceGenerator::emitInterfaceDecl(
+void mlir::tblgen::DialectInterfaceGenerator::emitInterfaceDecl(
const DialectInterface &interface) {
llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
tblgen::emitSummaryAndDescComments(os, "",
interface.getDescription().value_or(""));
- // Emit the main interface class declaration.
os << llvm::formatv(
"class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n"
"public:\n",
@@ -167,7 +143,6 @@ void DialectInterfaceGenerator::emitInterfaceDecl(
emitInterfaceMethodsDef(interface, os);
- // Emit any extra declarations.
if (std::optional<StringRef> extraDecls =
interface.getExtraClassDeclaration()) {
raw_indented_ostream ios(os);
@@ -183,12 +158,9 @@ void DialectInterfaceGenerator::emitInterfaceDecl(
os << "};\n";
}
-bool DialectInterfaceGenerator::emitInterfaceDecls() {
-
+bool mlir::tblgen::DialectInterfaceGenerator::emitInterfaceDecls() {
llvm::emitSourceFileHeader("Dialect Interface Declarations", os);
- // Sort according to ID, so defs are emitted in the order in which they appear
- // in the Tablegen file.
std::vector<const Record *> sortedDefs(defs);
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
return lhs->getID() < rhs->getID();
@@ -199,13 +171,3 @@ bool DialectInterfaceGenerator::emitInterfaceDecls() {
return false;
}
-
-//===----------------------------------------------------------------------===//
-// GEN: Interface registration hooks
-//===----------------------------------------------------------------------===//
-
-static mlir::GenRegistration genDecls(
- "gen-dialect-interface-decls", "Generate dialect interface declarations.",
- [](const RecordKeeper &records, raw_ostream &os) {
- return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
- });
diff --git a/mlir/lib/TableGen/Generators/DocGenUtilities.cpp b/mlir/lib/TableGen/Generators/DocGenUtilities.cpp
new file mode 100644
index 0000000000000..ecc3436d653ea
--- /dev/null
+++ b/mlir/lib/TableGen/Generators/DocGenUtilities.cpp
@@ -0,0 +1,54 @@
+//===- DocGenUtilities.cpp - MLIR doc gen utilities -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines common utilities for generating documentation from TableGen
+// structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Generators/DocGenUtilities.h"
+#include "mlir/Support/IndentedOstream.h"
+#include "llvm/ADT/Twine.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+using llvm::StringRef;
+
+void mlir::tblgen::emitSummary(StringRef summary, llvm::raw_ostream &os) {
+ if (summary.empty())
+ return;
+ StringRef trimmed = summary.trim();
+ char first = std::toupper(trimmed.front());
+ StringRef rest = trimmed.drop_front();
+ os << "\n_" << first << rest << "_\n";
+}
+
+void mlir::tblgen::emitDescription(StringRef description,
+ llvm::raw_ostream &os) {
+ if (description.empty())
+ return;
+ os << "\n";
+ raw_indented_ostream ros(os);
+ StringRef trimmed = description.rtrim(" \t");
+ ros.printReindented(trimmed);
+ if (!trimmed.ends_with("\n"))
+ ros << "\n";
+}
+
+void mlir::tblgen::emitDescriptionComment(StringRef description,
+ llvm::raw_ostream &os,
+ StringRef prefix) {
+ if (description.empty())
+ return;
+ os << "\n";
+ raw_indented_ostream ros(os);
+ StringRef trimmed = description.rtrim(" \t");
+ ros.printReindented(trimmed, (llvm::Twine(prefix) + "/// ").str());
+ if (!trimmed.ends_with("\n"))
+ ros << "\n";
+}
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/lib/TableGen/Generators/EnumPythonBindingGen.cpp
similarity index 89%
rename from mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
rename to mlir/lib/TableGen/Generators/EnumPythonBindingGen.cpp
index 6cef09d9958c7..a74ce49e82ce1 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/lib/TableGen/Generators/EnumPythonBindingGen.cpp
@@ -10,14 +10,13 @@
// generate the corresponding Python binding classes.
//
//===----------------------------------------------------------------------===//
-#include "OpGenHelpers.h"
+#include "mlir/TableGen/Generators/EnumPythonBindingGen.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/EnumInfo.h"
-#include "mlir/TableGen/GenInfo.h"
-#include "llvm/Support/CommandLine.h"
+#include "mlir/TableGen/Generators/OpGenHelpers.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
@@ -27,10 +26,6 @@ using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
-// Declared in OpPythonBindingGen.cpp; the two generators share the same
-// -bind-dialect option to allow filtering enum registrations by dialect.
-extern std::string dialectNameStorage;
-
/// File header and includes.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
@@ -130,9 +125,11 @@ static bool emitDialectEnumAttributeBuilder(StringRef dialect,
return false;
}
-/// Emits Python bindings for all enums in the record keeper. Returns
-/// `false` on success, `true` on failure.
-static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
+namespace mlir {
+namespace tblgen {
+
+bool emitPythonEnums(const RecordKeeper &records, StringRef dialectName,
+ raw_ostream &os) {
os << fileHeader;
for (const Record *it :
records.getAllDerivedDefinitionsIfDefined("EnumInfo")) {
@@ -147,7 +144,7 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
// When -bind-dialect is specified, only emit builders for EnumAttr records
// belonging to that dialect. This prevents duplicate registrations when
// multiple dialects include the same .td files.
- if (!dialectNameStorage.empty() && dialect != dialectNameStorage)
+ if (!dialectName.empty() && dialect != dialectName)
continue;
if (!attr.getMnemonic()) {
llvm::errs() << "enum case " << attr
@@ -174,8 +171,5 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
return false;
}
-// Registers the enum utility generator to mlir-tblgen.
-static mlir::GenRegistration
- genPythonEnumBindings("gen-python-enum-bindings",
- "Generate Python bindings for enum attributes",
- &emitPythonEnums);
+} // namespace tblgen
+} // namespace mlir
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/lib/TableGen/Generators/EnumsGen.cpp
similarity index 96%
rename from mlir/tools/mlir-tblgen/EnumsGen.cpp
rename to mlir/lib/TableGen/Generators/EnumsGen.cpp
index abf3fd9505f25..59b05a230b125 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/lib/TableGen/Generators/EnumsGen.cpp
@@ -10,11 +10,11 @@
//
//===----------------------------------------------------------------------===//
-#include "FormatGen.h"
+#include "mlir/TableGen/Generators/EnumsGen.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/EnumInfo.h"
#include "mlir/TableGen/Format.h"
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/FormatGen.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
@@ -297,6 +297,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
os << " if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n"
" return p << '\"' << valueStr << '\"';\n";
}
+
os << " return p << valueStr;\n"
"}\n"
"} // namespace llvm\n";
@@ -350,7 +351,7 @@ static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
}
- // Emit the function to return the max enum value
+ // Emit the function to return the max enum value.
os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
os << formatv(" return {0};\n", maxEnumVal);
os << "}\n\n";
@@ -623,7 +624,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
const Record *baseAttrDef = enumInfo.getBaseAttrClass();
Attribute baseAttr(baseAttrDef);
- // Emit classof method
+ // Emit classof method.
os << formatv("bool {0}::classof(::mlir::Attribute attr) {{\n",
attrClassName);
@@ -639,14 +640,14 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
os << "}\n";
- // Emit get method
+ // Emit get method.
os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
attrClassName, enumName);
StringRef underlyingType = enumInfo.getUnderlyingType();
- // Assuming that it is IntegerAttr constraint
+ // Assuming that it is IntegerAttr constraint.
int64_t bitwidth = 64;
if (baseAttrDef->getValue("valueType")) {
auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
@@ -664,7 +665,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
os << "}\n";
- // Emit getValue method
+ // Emit getValue method.
os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
@@ -699,7 +700,7 @@ static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
os << "}\n";
}
-static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
+void mlir::tblgen::emitEnumDecl(const Record &enumDef, raw_ostream &os) {
EnumInfo enumInfo(enumDef);
StringRef enumName = enumInfo.getEnumClassName();
StringRef cppNamespace = enumInfo.getCppNamespace();
@@ -714,11 +715,11 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
{
llvm::NamespaceEmitter ns(os, cppNamespace);
- // Emit the enum class definition
+ // Emit the enum class definition.
emitEnumClass(enumDef, enumName, underlyingType, description, enumerants,
os);
- // Emit conversion function declarations
+ // Emit conversion function declarations.
if (llvm::all_of(enumerants, [](EnumCase enumerant) {
return enumerant.getValue() >= 0;
})) {
@@ -731,11 +732,10 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName,
strToSymFnName);
- if (enumInfo.isBitEnum()) {
+ if (enumInfo.isBitEnum())
emitOperators(enumDef, os);
- } else {
+ else
emitMaxValueFn(enumDef, os);
- }
// Generate a generic `stringifyEnum` function that forwards to the method
// specified by the user.
@@ -782,11 +782,11 @@ class {1} : public ::mlir::{2} {
std::string(formatv("{0}::{1}", cppNamespace, enumName));
emitParserPrinter(enumInfo, qualName, cppNamespace, os);
- // Emit DenseMapInfo for this enum class
+ // Emit DenseMapInfo for this enum class.
emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
}
-static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
+bool mlir::tblgen::emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Declarations", os, records);
for (const Record *def :
@@ -796,7 +796,7 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
return false;
}
-static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
+void mlir::tblgen::emitEnumDef(const Record &enumDef, raw_ostream &os) {
EnumInfo enumInfo(enumDef);
llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace());
@@ -815,7 +815,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
emitSpecializedAttrDef(enumDef, os);
}
-static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
+bool mlir::tblgen::emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Definitions", os, records);
for (const Record *def :
@@ -824,17 +824,3 @@ static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
return false;
}
-
-// Registers the enum utility generator to mlir-tblgen.
-static mlir::GenRegistration
- genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
- [](const RecordKeeper &records, raw_ostream &os) {
- return emitEnumDecls(records, os);
- });
-
-// Registers the enum utility generator to mlir-tblgen.
-static mlir::GenRegistration
- genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
- [](const RecordKeeper &records, raw_ostream &os) {
- return emitEnumDefs(records, os);
- });
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/lib/TableGen/Generators/FormatGen.cpp
similarity index 87%
rename from mlir/tools/mlir-tblgen/FormatGen.cpp
rename to mlir/lib/TableGen/Generators/FormatGen.cpp
index 04d3ed1f3b70d..0ca0a50f9f86c 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/lib/TableGen/Generators/FormatGen.cpp
@@ -1,4 +1,4 @@
-//===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===//
+//===- FormatGen.cpp - Utilities for custom assembly formats --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "FormatGen.h"
+#include "mlir/TableGen/Generators/FormatGen.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/TableGen/Error.h"
@@ -19,10 +19,23 @@ using llvm::SourceMgr;
// FormatToken
//===----------------------------------------------------------------------===//
+FormatToken::FormatToken(Kind kind, StringRef spelling)
+ : kind(kind), spelling(spelling) {}
+
+StringRef FormatToken::getSpelling() const { return spelling; }
+
+FormatToken::Kind FormatToken::getKind() const { return kind; }
+
SMLoc FormatToken::getLoc() const {
return SMLoc::getFromPointer(spelling.data());
}
+bool FormatToken::is(Kind k) { return getKind() == k; }
+
+bool FormatToken::isKeyword() const {
+ return getKind() > Kind::keyword_start && getKind() < Kind::keyword_end;
+}
+
//===----------------------------------------------------------------------===//
// FormatLexer
//===----------------------------------------------------------------------===//
@@ -32,6 +45,11 @@ FormatLexer::FormatLexer(SourceMgr &mgr, SMLoc loc)
curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
curPtr(curBuffer.begin()) {}
+FormatToken FormatLexer::formToken(FormatToken::Kind kind,
+ const char *tokStart) {
+ return FormatToken(kind, StringRef(tokStart, curPtr - tokStart));
+}
+
FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
@@ -196,13 +214,73 @@ FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
}
//===----------------------------------------------------------------------===//
-// FormatParser
+// FormatElement
//===----------------------------------------------------------------------===//
FormatElement::~FormatElement() = default;
+OptionalElement::OptionalElement(std::vector<FormatElement *> &&thenElements,
+ std::vector<FormatElement *> &&elseElements,
+ unsigned thenParseStart,
+ unsigned elseParseStart, FormatElement *anchor,
+ bool inverted)
+ : thenElements(std::move(thenElements)),
+ elseElements(std::move(elseElements)), thenParseStart(thenParseStart),
+ elseParseStart(elseParseStart), anchor(anchor), inverted(inverted) {}
+
+ArrayRef<FormatElement *>
+OptionalElement::getThenElements(bool parseable) const {
+ return llvm::ArrayRef(thenElements)
+ .drop_front(parseable ? thenParseStart : 0);
+}
+
+ArrayRef<FormatElement *>
+OptionalElement::getElseElements(bool parseable) const {
+ return llvm::ArrayRef(elseElements)
+ .drop_front(parseable ? elseParseStart : 0);
+}
+
+FormatElement *OptionalElement::getAnchor() const { return anchor; }
+
+bool OptionalElement::isInverted() const { return inverted; }
+
+//===----------------------------------------------------------------------===//
+// FormatParser
+//===----------------------------------------------------------------------===//
+
FormatParser::~FormatParser() = default;
+FormatParser::FormatParser(llvm::SourceMgr &mgr, llvm::SMLoc loc)
+ : lexer(mgr, loc), curToken(lexer.lexToken()) {}
+
+LogicalResult FormatParser::emitError(llvm::SMLoc loc, const Twine &msg) {
+ lexer.emitError(loc, msg);
+ return failure();
+}
+
+LogicalResult FormatParser::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
+ const Twine ¬e) {
+ lexer.emitErrorAndNote(loc, msg, note);
+ return failure();
+}
+
+FailureOr<FormatToken> FormatParser::parseToken(FormatToken::Kind kind,
+ const Twine &msg) {
+ if (!curToken.is(kind))
+ return emitError(curToken.getLoc(), msg);
+ FormatToken tok = curToken;
+ consumeToken();
+ return tok;
+}
+
+void FormatParser::consumeToken() {
+ assert(!curToken.is(FormatToken::eof) && !curToken.is(FormatToken::error) &&
+ "shouldn't advance past EOF or errors");
+ curToken = lexer.lexToken();
+}
+
+FormatToken FormatParser::peekToken() { return curToken; }
+
FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
SMLoc loc = curToken.getLoc();
@@ -535,12 +613,3 @@ bool mlir::tblgen::isValidLiteral(StringRef value,
// Otherwise, this must be an identifier.
return canFormatStringAsKeyword(value, emitError);
}
-
-//===----------------------------------------------------------------------===//
-// Commandline Options
-//===----------------------------------------------------------------------===//
-
-llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal(
- "asmformat-error-is-fatal",
- llvm::cl::desc("Emit a fatal error if format parsing fails"),
- llvm::cl::init(true));
diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/lib/TableGen/Generators/OpClass.cpp
similarity index 97%
rename from mlir/tools/mlir-tblgen/OpClass.cpp
rename to mlir/lib/TableGen/Generators/OpClass.cpp
index 60fa1833ce625..008016a1a6085 100644
--- a/mlir/tools/mlir-tblgen/OpClass.cpp
+++ b/mlir/lib/TableGen/Generators/OpClass.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "OpClass.h"
+#include "mlir/TableGen/Generators/OpClass.h"
using namespace mlir;
using namespace mlir::tblgen;
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/lib/TableGen/Generators/OpDefinitionsGen.cpp
similarity index 90%
rename from mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
rename to mlir/lib/TableGen/Generators/OpDefinitionsGen.cpp
index edb009938f005..f6536e590c683 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/lib/TableGen/Generators/OpDefinitionsGen.cpp
@@ -11,17 +11,17 @@
//
//===----------------------------------------------------------------------===//
-#include "CppGenUtilities.h"
-#include "OpClass.h"
-#include "OpFormatGen.h"
-#include "OpGenHelpers.h"
+#include "mlir/TableGen/Generators/OpDefinitionsGen.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/CppGenUtilities.h"
+#include "mlir/TableGen/Generators/OpClass.h"
+#include "mlir/TableGen/Generators/OpFormatGen.h"
+#include "mlir/TableGen/Generators/OpGenHelpers.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Property.h"
@@ -297,198 +297,6 @@ static std::string constBuildAttrFromParam(const tblgen::Attribute &attr,
return tgfmt(builderTemplate, &fctx, paramName).str();
}
-namespace {
-/// Metadata on a registered attribute. Given that attributes are stored in
-/// sorted order on operations, we can use information from ODS to deduce the
-/// number of required attributes less and and greater than each attribute,
-/// allowing us to search only a subrange of the attributes in ODS-generated
-/// getters.
-struct AttributeMetadata {
- /// The attribute name.
- StringRef attrName;
- /// Whether the attribute is required.
- bool isRequired;
- /// The ODS attribute constraint. Not present for implicit attributes.
- std::optional<Attribute> constraint;
- /// The number of required attributes less than this attribute.
- unsigned lowerBound = 0;
- /// The number of required attributes greater than this attribute.
- unsigned upperBound = 0;
-};
-
-/// Helper class to select between OpAdaptor and Op code templates.
-class OpOrAdaptorHelper {
-public:
- OpOrAdaptorHelper(const Operator &op, bool emitForOp)
- : op(op), emitForOp(emitForOp) {
- computeAttrMetadata();
- }
-
- /// Object that wraps a functor in a stream operator for interop with
- /// llvm::formatv.
- class Formatter {
- public:
- template <typename Functor>
- Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
-
- std::string str() const {
- std::string result;
- llvm::raw_string_ostream os(result);
- os << *this;
- return os.str();
- }
-
- private:
- std::function<raw_ostream &(raw_ostream &)> func;
-
- friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) {
- return fmt.func(os);
- }
- };
-
- // Generate code for getting an attribute.
- Formatter getAttr(StringRef attrName, bool isNamed = false) const {
- assert(attrMetadata.count(attrName) && "expected attribute metadata");
- return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
- const AttributeMetadata &attr = attrMetadata.find(attrName)->second;
- if (hasProperties()) {
- assert(!isNamed);
- return os << "getProperties()." << attrName;
- }
- return os << formatv(subrangeGetAttr, getAttrName(attrName),
- attr.lowerBound, attr.upperBound, getAttrRange(),
- isNamed ? "Named" : "");
- };
- }
-
- // Generate code for getting the name of an attribute.
- Formatter getAttrName(StringRef attrName) const {
- return [this, attrName](raw_ostream &os) -> raw_ostream & {
- if (emitForOp)
- return os << op.getGetterName(attrName) << "AttrName()";
- return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(),
- op.getGetterName(attrName));
- };
- }
-
- // Get the code snippet for getting the named attribute range.
- StringRef getAttrRange() const {
- return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
- }
-
- // Get the prefix code for emitting an error.
- Formatter emitErrorPrefix() const {
- return [this](raw_ostream &os) -> raw_ostream & {
- if (emitForOp)
- return os << "emitOpError(\"";
- return os << formatv("emitError(loc, \"'{0}' op ", op.getOperationName());
- };
- }
-
- // Get the call to get an operand or segment of operands.
- Formatter getOperand(unsigned index) const {
- return [this, index](raw_ostream &os) -> raw_ostream & {
- return os << formatv(op.getOperand(index).isVariadic()
- ? "this->getODSOperands({0})"
- : "(*this->getODSOperands({0}).begin())",
- index);
- };
- }
-
- // Get the call to get a result of segment of results.
- Formatter getResult(unsigned index) const {
- return [this, index](raw_ostream &os) -> raw_ostream & {
- if (!emitForOp)
- return os << "<no results should be generated>";
- return os << formatv(op.getResult(index).isVariadic()
- ? "this->getODSResults({0})"
- : "(*this->getODSResults({0}).begin())",
- index);
- };
- }
-
- // Return whether an op instance is available.
- bool isEmittingForOp() const { return emitForOp; }
-
- // Return the ODS operation wrapper.
- const Operator &getOp() const { return op; }
-
- // Get the attribute metadata sorted by name.
- const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const {
- return attrMetadata;
- }
-
- /// Returns whether to emit a `Properties` struct for this operation or not.
- bool hasProperties() const {
- if (!op.getProperties().empty())
- return true;
- return true;
- }
-
- /// Returns whether the operation will have a non-empty `Properties` struct.
- bool hasNonEmptyPropertiesStruct() const {
- if (!op.getProperties().empty())
- return true;
- if (!hasProperties())
- return false;
- if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") ||
- op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
- return true;
- return llvm::any_of(getAttrMetadata(),
- [](const std::pair<StringRef, AttributeMetadata> &it) {
- return !it.second.constraint ||
- !it.second.constraint->isDerivedAttr();
- });
- }
-
- std::optional<NamedProperty> &getOperandSegmentsSize() {
- return operandSegmentsSize;
- }
-
- std::optional<NamedProperty> &getResultSegmentsSize() {
- return resultSegmentsSize;
- }
-
- uint32_t getOperandSegmentSizesLegacyIndex() {
- return operandSegmentSizesLegacyIndex;
- }
-
- uint32_t getResultSegmentSizesLegacyIndex() {
- return resultSegmentSizesLegacyIndex;
- }
-
-private:
- // Compute the attribute metadata.
- void computeAttrMetadata();
-
- // The operation ODS wrapper.
- const Operator &op;
- // True if code is being generate for an op. False for an adaptor.
- const bool emitForOp;
-
- // The attribute metadata, mapped by name.
- llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
-
- // Property
- std::optional<NamedProperty> operandSegmentsSize;
- std::string operandSegmentsSizeStorage;
- std::string operandSegmentsSizeParser;
- std::optional<NamedProperty> resultSegmentsSize;
- std::string resultSegmentsSizeStorage;
- std::string resultSegmentsSizeParser;
-
- // Indices to store the position in the emission order of the operand/result
- // segment sizes attribute if emitted as part of the properties for legacy
- // bytecode encodings, i.e. versions less than 6.
- uint32_t operandSegmentSizesLegacyIndex = 0;
- uint32_t resultSegmentSizesLegacyIndex = 0;
-
- // The number of required attributes.
- unsigned numRequired;
-};
-
-} // namespace
-
void OpOrAdaptorHelper::computeAttrMetadata() {
// Enumerate the attribute names of this op, ensuring the attribute names are
// unique in case implicit attributes are explicitly registered.
@@ -581,240 +389,25 @@ void OpOrAdaptorHelper::computeAttrMetadata() {
attrMetadata.insert({attr.attrName, attr});
}
+OpOrAdaptorHelper::Formatter OpOrAdaptorHelper::getAttr(StringRef attrName,
+ bool isNamed) const {
+ assert(attrMetadata.count(attrName) && "expected attribute metadata");
+ return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
+ const AttributeMetadata &attr = attrMetadata.find(attrName)->second;
+ if (hasProperties()) {
+ assert(!isNamed);
+ return os << "getProperties()." << attrName;
+ }
+ return os << formatv(subrangeGetAttr, getAttrName(attrName),
+ attr.lowerBound, attr.upperBound, getAttrRange(),
+ isNamed ? "Named" : "");
+ };
+}
+
//===----------------------------------------------------------------------===//
// Op emitter
//===----------------------------------------------------------------------===//
-namespace {
-// Helper class to emit a record into the given output stream.
-class OpEmitter {
- using ConstArgument =
- llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
-
-public:
- static void
- emitDecl(const Operator &op, raw_ostream &os,
- const StaticVerifierFunctionEmitter &staticVerifierEmitter);
- static void
- emitDef(const Operator &op, raw_ostream &os,
- const StaticVerifierFunctionEmitter &staticVerifierEmitter);
-
-private:
- OpEmitter(const Operator &op,
- const StaticVerifierFunctionEmitter &staticVerifierEmitter);
-
- void emitDecl(raw_ostream &os);
- void emitDef(raw_ostream &os);
-
- // Generate methods for accessing the attribute names of this operation.
- void genAttrNameGetters();
-
- // Generates the OpAsmOpInterface for this operation if possible.
- void genOpAsmInterface();
-
- // Generates the `getOperationName` method for this op.
- void genOpNameGetter();
-
- // Generates code to manage the properties, if any!
- void genPropertiesSupport();
-
- // Generates code to manage the encoding of properties to bytecode.
- void
- genPropertiesSupportForBytecode(ArrayRef<ConstArgument> attrOrProperties);
-
- // Generates getters for the properties.
- void genPropGetters();
-
- // Generates seters for the properties.
- void genPropSetters();
-
- // Generates getters for the attributes.
- void genAttrGetters();
-
- // Generates setter for the attributes.
- void genAttrSetters();
-
- // Generates removers for optional attributes.
- void genOptionalAttrRemovers();
-
- // Generates getters for named operands.
- void genNamedOperandGetters();
-
- // Generates setters for named operands.
- void genNamedOperandSetters();
-
- // Generates getters for named results.
- void genNamedResultGetters();
-
- // Generates getters for named regions.
- void genNamedRegionGetters();
-
- // Generates getters for named successors.
- void genNamedSuccessorGetters();
-
- // Generates the method to populate default attributes.
- void genPopulateDefaultAttributes();
-
- // Generates builder methods for the operation.
- void genBuilder();
-
- // Generates the build() method that takes each operand/attribute
- // as a stand-alone parameter.
- void genSeparateArgParamBuilder();
- void genInlineCreateBody(const SmallVector<MethodParameter> ¶mList);
-
- // Generates the build() method that takes each operand/attribute as a
- // stand-alone parameter. The generated build() method uses first operand's
- // type as all results' types.
- void genUseOperandAsResultTypeSeparateParamBuilder();
-
- // The kind of collective builder to generate
- enum class CollectiveBuilderKind {
- PropStruct, // Inherent attributes/properties are passed by `const
- // Properties&`
- AttrDict, // Inherent attributes/properties are passed by attribute
- // dictionary
- };
-
- // Generates the build() method that takes all operands/attributes
- // collectively as one parameter. The generated build() method uses first
- // operand's type as all results' types.
- void
- genUseOperandAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
-
- // Generates the build() method that takes aggregate operands/attributes
- // parameters. This build() method uses inferred types as result types.
- // Requires: The type needs to be inferable via InferTypeOpInterface.
- void genInferredTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
-
- // Generates the build() method that takesaggregate operands/attributes as
- // parameters. The generated build() method uses first attribute's
- // type as all result's types.
- void genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
-
- // Generates the build() method that takes all result types collectively as
- // one parameter. Similarly for operands and attributes.
- void genCollectiveParamBuilder(CollectiveBuilderKind kind);
-
- // The kind of parameter to generate for result types in builders.
- enum class TypeParamKind {
- None, // No result type in parameter list.
- Separate, // A separate parameter for each result type.
- Collective, // An ArrayRef<Type> for all result types.
- };
-
- // The kind of parameter to generate for attributes in builders.
- enum class AttrParamKind {
- WrappedAttr, // A wrapped MLIR Attribute instance.
- UnwrappedValue, // A raw value without MLIR Attribute wrapper.
- };
-
- // Builds the parameter list for build() method of this op. This method writes
- // to `paramList` the comma-separated parameter list and updates
- // `resultTypeNames` with the names for parameters for specifying result
- // types. `inferredAttributes` is populated with any attributes that are
- // elided from the build list. The given `typeParamKind` and `attrParamKind`
- // controls how result types and attributes are placed in the parameter list.
- void buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
- llvm::StringSet<> &inferredAttributes,
- SmallVectorImpl<std::string> &resultTypeNames,
- TypeParamKind typeParamKind,
- AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
-
- // Adds op arguments and regions into operation state for build() methods.
- void
- genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
- llvm::StringSet<> &inferredAttributes,
- bool isRawValueAttr = false);
-
- // Generates canonicalizer declaration for the operation.
- void genCanonicalizerDecls();
-
- // Generates the folder declaration for the operation.
- void genFolderDecls();
-
- // Generates the parser for the operation.
- void genParser();
-
- // Generates the printer for the operation.
- void genPrinter();
-
- // Generates verify method for the operation.
- void genVerifier();
-
- // Generates custom verify methods for the operation.
- void genCustomVerifier();
-
- // Generates verify statements for operands and results in the operation.
- // The generated code will be attached to `body`.
- void genOperandResultVerifier(MethodBody &body,
- Operator::const_value_range values,
- StringRef valueKind);
-
- // Generates verify statements for regions in the operation.
- // The generated code will be attached to `body`.
- void genRegionVerifier(MethodBody &body);
-
- // Generates verify statements for successors in the operation.
- // The generated code will be attached to `body`.
- void genSuccessorVerifier(MethodBody &body);
-
- // Generates the traits used by the object.
- void genTraits();
-
- // Generate the OpInterface methods for all interfaces.
- void genOpInterfaceMethods();
-
- // Generate op interface methods for the given interface.
- void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
-
- // Generate op interface method for the given interface method. If
- // 'declaration' is true, generates a declaration, else a definition.
- Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
- bool declaration = true);
-
- // Generate a `using` declaration for the op interface method to include
- // the default implementation from the interface trait.
- // This is needed when the interface defines multiple methods with the same
- // name, but some have a default implementation and some don't.
- UsingDeclaration *
- genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
- const tblgen::InterfaceMethod &method);
-
- // Generate the side effect interface methods.
- void genSideEffectInterfaceMethods();
-
- // Generate the type inference interface methods.
- void genTypeInterfaceMethods();
-
-private:
- // The TableGen record for this op.
- // TODO: OpEmitter should not have a Record directly,
- // it should rather go through the Operator for better abstraction.
- const Record &def;
-
- // The wrapper operator class for querying information from this op.
- const Operator &op;
-
- // The C++ code builder for this op
- OpClass opClass;
-
- // The format context for verification code generation.
- FmtContext verifyCtx;
-
- // The emitter containing all of the locally emitted verification functions.
- const StaticVerifierFunctionEmitter &staticVerifierEmitter;
-
- // Helper for emitting op code.
- OpOrAdaptorHelper emitHelper;
-
- // Keep track of the interface using declarations that have been generated to
- // avoid duplicates.
- llvm::StringSet<> interfaceUsingNames;
-};
-
-} // namespace
-
// Populate the format context `ctx` with substitutions of attributes, operands
// and results.
static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
@@ -1131,12 +724,13 @@ static std::string formatExtraDefinitions(const Operator &op) {
}
OpEmitter::OpEmitter(const Operator &op,
- const StaticVerifierFunctionEmitter &staticVerifierEmitter)
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool fatalOnError)
: def(op.getDef()), op(op),
opClass(op.getCppClassName(), formatExtraDeclarations(op),
formatExtraDefinitions(op)),
staticVerifierEmitter(staticVerifierEmitter),
- emitHelper(op, /*emitForOp=*/true) {
+ emitHelper(op, /*emitForOp=*/true), fatalOnError(fatalOnError) {
verifyCtx.addSubst("_op", "(*this->getOperation())");
verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
@@ -1168,19 +762,21 @@ OpEmitter::OpEmitter(const Operator &op,
genFolderDecls();
genTypeInterfaceMethods();
genOpInterfaceMethods();
- generateOpFormat(op, opClass, emitHelper.hasProperties());
+ generateOpFormat(op, opClass, emitHelper.hasProperties(), fatalOnError);
genSideEffectInterfaceMethods();
}
void OpEmitter::emitDecl(
const Operator &op, raw_ostream &os,
- const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
- OpEmitter(op, staticVerifierEmitter).emitDecl(os);
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool fatalOnError) {
+ OpEmitter(op, staticVerifierEmitter, fatalOnError).emitDecl(os);
}
void OpEmitter::emitDef(
const Operator &op, raw_ostream &os,
- const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
- OpEmitter(op, staticVerifierEmitter).emitDef(os);
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool fatalOnError) {
+ OpEmitter(op, staticVerifierEmitter, fatalOnError).emitDef(os);
}
void OpEmitter::emitDecl(raw_ostream &os) {
@@ -4694,9 +4290,11 @@ void OpOperandAdaptorEmitter::emitDef(
}
/// Emit the class declarations or definitions for the given op defs.
-static void emitOpClasses(
- const RecordKeeper &records, ArrayRef<const Record *> defs, raw_ostream &os,
- const StaticVerifierFunctionEmitter &staticVerifierEmitter, bool emitDecl) {
+static void
+emitOpClasses(const RecordKeeper &records, ArrayRef<const Record *> defs,
+ raw_ostream &os,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool emitDecl, bool fatalOnError = true) {
if (defs.empty())
return;
@@ -4709,7 +4307,7 @@ static void emitOpClasses(
os << formatv(opCommentHeader, op.getQualCppClassName(),
"declarations");
OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
- OpEmitter::emitDecl(op, os, staticVerifierEmitter);
+ OpEmitter::emitDecl(op, os, staticVerifierEmitter, fatalOnError);
}
// Emit the TypeID explicit specialization to have a single definition.
if (!op.getCppNamespace().empty()) {
@@ -4726,7 +4324,7 @@ static void emitOpClasses(
NamespaceEmitter emitter(os, op.getCppNamespace());
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
- OpEmitter::emitDef(op, os, staticVerifierEmitter);
+ OpEmitter::emitDef(op, os, staticVerifierEmitter, fatalOnError);
}
// Emit the TypeID explicit specialization to have a single definition.
if (!op.getCppNamespace().empty()) {
@@ -4744,7 +4342,8 @@ static void emitOpClasses(
/// Emit the declarations for the provided op classes.
static void emitOpClassDecls(const RecordKeeper &records,
- ArrayRef<const Record *> defs, raw_ostream &os) {
+ ArrayRef<const Record *> defs, raw_ostream &os,
+ bool fatalOnError = true) {
// First emit forward declaration for each class, this allows them to refer
// to each others in traits for example.
for (const Record *def : defs) {
@@ -4762,13 +4361,14 @@ static void emitOpClassDecls(const RecordKeeper &records,
StaticVerifierFunctionEmitter staticVerifierEmitter(os, records);
staticVerifierEmitter.collectOpConstraints(defs);
emitOpClasses(records, defs, os, staticVerifierEmitter,
- /*emitDecl=*/true);
+ /*emitDecl=*/true, fatalOnError);
}
/// Emit the definitions for the provided op classes.
static void emitOpClassDefs(const RecordKeeper &records,
ArrayRef<const Record *> defs, raw_ostream &os,
- StringRef constraintPrefix = "") {
+ StringRef constraintPrefix = "",
+ bool fatalOnError = true) {
if (defs.empty())
return;
@@ -4781,21 +4381,21 @@ static void emitOpClassDefs(const RecordKeeper &records,
// Emit the classes.
emitOpClasses(records, defs, os, staticVerifierEmitter,
- /*emitDecl=*/false);
+ /*emitDecl=*/false, fatalOnError);
}
-/// Emit op declarations for all op records.
-static bool emitOpDecls(const RecordKeeper &records, raw_ostream &os) {
+/// Emit op declarations for all op records in \p defs.
+bool mlir::tblgen::emitOpDecls(const RecordKeeper &records,
+ ArrayRef<const Record *> defs,
+ unsigned shardCount, raw_ostream &os,
+ bool fatalOnError) {
emitSourceFileHeader("Op Declarations", os, records);
- std::vector<const Record *> defs = getRequestedOpDefinitions(records);
- emitOpClassDecls(records, defs, os);
+ emitOpClassDecls(records, defs, os, fatalOnError);
// If we are generating sharded op definitions, emit the sharded op
// registration hooks.
- SmallVector<ArrayRef<const Record *>, 4> shardedDefs;
- shardOpDefinitions(defs, shardedDefs);
- if (defs.empty() || shardedDefs.size() <= 1)
+ if (defs.empty() || shardCount <= 1)
return false;
Dialect dialect = Operator(defs.front()).getDialect();
@@ -4805,7 +4405,7 @@ static bool emitOpDecls(const RecordKeeper &records, raw_ostream &os) {
"void register{0}Operations{1}({2}::{0} *dialect);\n";
os << formatv(opRegistrationHook, dialect.getCppClassName(), "",
dialect.getCppNamespace());
- for (unsigned i = 0; i < shardedDefs.size(); ++i) {
+ for (unsigned i = 0; i < shardCount; ++i) {
os << formatv(opRegistrationHook, dialect.getCppClassName(), i,
dialect.getCppNamespace());
}
@@ -4815,10 +4415,11 @@ static bool emitOpDecls(const RecordKeeper &records, raw_ostream &os) {
/// Generate the dialect op registration hook and the op class definitions for a
/// shard of ops.
-static void emitOpDefShard(const RecordKeeper &records,
- ArrayRef<const Record *> defs,
- const Dialect &dialect, unsigned shardIndex,
- unsigned shardCount, raw_ostream &os) {
+void mlir::tblgen::emitOpDefShard(const RecordKeeper &records,
+ ArrayRef<const Record *> defs,
+ const Dialect &dialect, unsigned shardIndex,
+ unsigned shardCount, raw_ostream &os,
+ bool fatalOnError) {
std::string shardGuard = "GET_OP_DEFS_";
std::string indexStr = std::to_string(shardIndex);
shardGuard += indexStr;
@@ -4849,16 +4450,18 @@ static void emitOpDefShard(const RecordKeeper &records,
os << "}\n";
// Generate the per-shard op definitions.
- emitOpClassDefs(records, defs, os, indexStr);
+ emitOpClassDefs(records, defs, os, indexStr, fatalOnError);
}
-/// Emit op definitions for all op records.
-static bool emitOpDefs(const RecordKeeper &records, raw_ostream &os) {
+/// Emit op definitions for all op records in \p defs.
+bool mlir::tblgen::emitOpDefs(const RecordKeeper &records,
+ ArrayRef<const Record *> defs,
+ unsigned shardCount, raw_ostream &os,
+ bool fatalOnError) {
emitSourceFileHeader("Op Definitions", os, records);
- std::vector<const Record *> defs = getRequestedOpDefinitions(records);
SmallVector<ArrayRef<const Record *>, 4> shardedDefs;
- shardOpDefinitions(defs, shardedDefs);
+ shardOpDefinitions(defs, shardedDefs, shardCount);
// If no shard was requested, emit the regular op list and class definitions.
if (shardedDefs.size() == 1) {
@@ -4871,7 +4474,7 @@ static bool emitOpDefs(const RecordKeeper &records, raw_ostream &os) {
}
{
IfDefEmitter scope(os, "GET_OP_CLASSES");
- emitOpClassDefs(records, defs, os);
+ emitOpClassDefs(records, defs, os, "", fatalOnError);
}
return false;
}
@@ -4880,19 +4483,8 @@ static bool emitOpDefs(const RecordKeeper &records, raw_ostream &os) {
return false;
Dialect dialect = Operator(defs.front()).getDialect();
for (auto [idx, value] : llvm::enumerate(shardedDefs)) {
- emitOpDefShard(records, value, dialect, idx, shardedDefs.size(), os);
+ emitOpDefShard(records, value, dialect, idx, shardedDefs.size(), os,
+ fatalOnError);
}
return false;
}
-
-static mlir::GenRegistration
- genOpDecls("gen-op-decls", "Generate op declarations",
- [](const RecordKeeper &records, raw_ostream &os) {
- return emitOpDecls(records, os);
- });
-
-static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
- [](const RecordKeeper &records,
- raw_ostream &os) {
- return emitOpDefs(records, os);
- });
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/lib/TableGen/Generators/OpDocGen.cpp
similarity index 69%
rename from mlir/tools/mlir-tblgen/OpDocGen.cpp
rename to mlir/lib/TableGen/Generators/OpDocGen.cpp
index 5e3cf302ed3ea..1cae820cc0c04 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/lib/TableGen/Generators/OpDocGen.cpp
@@ -11,26 +11,21 @@
//
//===----------------------------------------------------------------------===//
-#include "DialectGenUtilities.h"
-#include "DocGenUtilities.h"
-#include "OpGenHelpers.h"
-#include "mlir/Support/IndentedOstream.h"
+#include "mlir/TableGen/Generators/OpDocGen.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/EnumInfo.h"
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/DocGenUtilities.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Regex.h"
-#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
-#include "llvm/TableGen/TableGenBackend.h"
#include <set>
#include <string>
@@ -40,107 +35,6 @@ using namespace mlir;
using namespace mlir::tblgen;
using mlir::tblgen::Operator;
-//===----------------------------------------------------------------------===//
-// Commandline Options
-//===----------------------------------------------------------------------===//
-static cl::OptionCategory
- docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
-static cl::opt<std::string>
- stripPrefix("strip-prefix",
- cl::desc("Strip prefix of the fully qualified names"),
- cl::init("::mlir::"), cl::cat(docCat));
-static cl::opt<bool> allowHugoSpecificFeatures(
- "allow-hugo-specific-features",
- cl::desc("Allows using features specific to Hugo"), cl::init(false),
- cl::cat(docCat));
-static cl::opt<bool>
- keepOpSourceOrder("keep-op-source-order",
- cl::desc("Do not sort ops alphabetically"),
- cl::init(false), cl::cat(docCat));
-
-void mlir::tblgen::emitSummary(StringRef summary, raw_ostream &os) {
- if (summary.empty())
- return;
- StringRef trimmed = summary.trim();
- char first = std::toupper(trimmed.front());
- StringRef rest = trimmed.drop_front();
- os << "\n_" << first << rest << "_\n";
-}
-
-// Emit the description by aligning the text to the left per line (e.g.,
-// removing the minimum indentation across the block).
-//
-// This expects that the description in the tablegen file is already formatted
-// in a way the user wanted but has some additional indenting due to being
-// nested in the op definition.
-void mlir::tblgen::emitDescription(StringRef description, raw_ostream &os) {
- if (description.empty())
- return;
- os << "\n";
- raw_indented_ostream ros(os);
- StringRef trimmed = description.rtrim(" \t");
- ros.printReindented(trimmed);
- if (!trimmed.ends_with("\n"))
- ros << "\n";
-}
-
-void mlir::tblgen::emitDescriptionComment(StringRef description,
- raw_ostream &os, StringRef prefix) {
- if (description.empty())
- return;
- os << "\n";
- raw_indented_ostream ros(os);
- StringRef trimmed = description.rtrim(" \t");
- ros.printReindented(trimmed, (Twine(prefix) + "/// ").str());
- if (!trimmed.ends_with("\n"))
- ros << "\n";
-}
-
-/// Emit the given named constraint.
-template <typename T>
-static void emitNamedConstraint(const T &it, raw_ostream &os) {
- if (!it.name.empty())
- os << "| `" << it.name << "`";
- else
- os << "| «unnamed»";
- os << " | " << it.constraint.getSummary() << " |\n";
-}
-
-//===----------------------------------------------------------------------===//
-// Records
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct OpDocGroup {
- const Dialect &getDialect() const { return ops.front().getDialect(); }
-
- /// Summary description of the section.
- std::string summary = "";
-
- /// Description of the section.
- StringRef description = "";
-
- /// Instances inside the section.
- std::vector<Operator> ops;
-};
-
-/// Holds all records collected from a dialect relevant for documentation
-/// generation.
-struct DialectRecords {
- DialectRecords(Dialect dialect, StringRef inputFilename)
- : dialect(dialect), inputFilename(inputFilename) {}
-
- Dialect dialect;
- StringRef inputFilename;
- std::vector<Attribute> attributes;
- std::vector<AttrDef> attrDefs;
- std::vector<OpDocGroup> ops;
- std::vector<Type> types;
- std::vector<TypeDef> typeDefs;
- std::vector<EnumInfo> enums;
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// Operation Documentation
//===----------------------------------------------------------------------===//
@@ -233,7 +127,18 @@ static StringRef resolveAttrDescription(const Attribute &attr) {
return description;
}
-static void emitOpDoc(const Operator &op, raw_ostream &os) {
+/// Emit the given named constraint.
+template <typename T>
+static void emitNamedConstraint(const T &it, raw_ostream &os) {
+ if (!it.name.empty())
+ os << "| `" << it.name << "`";
+ else
+ os << "| «unnamed»";
+ os << " | " << it.constraint.getSummary() << " |\n";
+}
+
+void mlir::tblgen::emitOpDoc(const Operator &op, StringRef stripPrefix,
+ bool allowHugoSpecificFeatures, raw_ostream &os) {
std::string classNameStr = op.getQualCppClassName();
StringRef className = classNameStr;
(void)className.consume_front(stripPrefix);
@@ -335,7 +240,8 @@ static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn,
}
}
-static void emitOpDocGroup(const OpDocGroup &grouping, raw_ostream &os) {
+static void emitOpDocGroup(const OpDocGroup &grouping, StringRef stripPrefix,
+ bool allowHugoSpecificFeatures, raw_ostream &os) {
bool nested = !grouping.summary.empty();
maybeNest(
nested,
@@ -345,18 +251,20 @@ static void emitOpDocGroup(const OpDocGroup &grouping, raw_ostream &os) {
emitDescription(grouping.description, os);
os << "\n";
}
- for (const Operator &op : grouping.ops) {
- emitOpDoc(op, os);
- }
+ for (const Operator &op : grouping.ops)
+ mlir::tblgen::emitOpDoc(op, stripPrefix, allowHugoSpecificFeatures,
+ os);
},
os);
}
-static bool emitOpDoc(const DialectRecords &records, raw_ostream &os) {
+bool mlir::tblgen::emitOpDoc(const DialectRecords &records,
+ StringRef stripPrefix,
+ bool allowHugoSpecificFeatures, raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
emitSourceLink(records.inputFilename, os);
for (const OpDocGroup &grouping : records.ops)
- emitOpDocGroup(grouping, os);
+ emitOpDocGroup(grouping, stripPrefix, allowHugoSpecificFeatures, os);
return false;
}
@@ -381,7 +289,7 @@ static void emitTypeDoc(const Type &type, raw_ostream &os) {
}
//===----------------------------------------------------------------------===//
-// TypeDef Documentation
+// TypeDef/AttrDef Documentation
//===----------------------------------------------------------------------===//
static void emitAttrOrTypeDefAssemblyFormat(const AttrOrTypeDef &def,
@@ -419,9 +327,8 @@ static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
emitAttrOrTypeDefAssemblyFormat(def, os);
// Emit the description if present.
- if (def.hasDescription()) {
+ if (def.hasDescription())
mlir::tblgen::emitDescription(def.getDescription(), os);
- }
// Emit parameter documentation.
ArrayRef<AttrOrTypeParameter> parameters = def.getParameters();
@@ -439,14 +346,16 @@ static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
os << "\n";
}
-static bool emitAttrDefDoc(const DialectRecords &records, raw_ostream &os) {
+bool mlir::tblgen::emitAttrDefDoc(const DialectRecords &records,
+ raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
for (const AttrDef &def : records.attrDefs)
emitAttrOrTypeDefDoc(def, os);
return false;
}
-static bool emitTypeDefDoc(const DialectRecords &records, raw_ostream &os) {
+bool mlir::tblgen::emitTypeDefDoc(const DialectRecords &records,
+ raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
for (const TypeDef &def : records.typeDefs)
emitAttrOrTypeDefDoc(def, os);
@@ -457,7 +366,7 @@ static bool emitTypeDefDoc(const DialectRecords &records, raw_ostream &os) {
// Enum Documentation
//===----------------------------------------------------------------------===//
-static void emitEnumDoc(const EnumInfo &def, raw_ostream &os) {
+static void emitSingleEnumDoc(const EnumInfo &def, raw_ostream &os) {
os << formatv("\n### {0}\n", def.getEnumClassName());
// Emit the summary if present.
@@ -476,10 +385,10 @@ static void emitEnumDoc(const EnumInfo &def, raw_ostream &os) {
os << "\n";
}
-static bool emitEnumDoc(const DialectRecords &records, raw_ostream &os) {
+bool mlir::tblgen::emitEnumDoc(const DialectRecords &records, raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
for (const EnumInfo &def : records.enums)
- emitEnumDoc(def, os);
+ emitSingleEnumDoc(def, os);
return false;
}
@@ -487,12 +396,13 @@ static bool emitEnumDoc(const DialectRecords &records, raw_ostream &os) {
// Dialect Documentation
//===----------------------------------------------------------------------===//
-static void emitBlock(const DialectRecords &records, raw_ostream &os) {
+static void emitBlock(const DialectRecords &records, StringRef stripPrefix,
+ bool allowHugoSpecificFeatures, raw_ostream &os) {
if (!records.ops.empty()) {
os << "\n## Operations\n";
emitSourceLink(records.inputFilename, os);
for (const OpDocGroup &grouping : records.ops)
- emitOpDocGroup(grouping, os);
+ emitOpDocGroup(grouping, stripPrefix, allowHugoSpecificFeatures, os);
}
if (!records.attributes.empty()) {
@@ -523,11 +433,14 @@ static void emitBlock(const DialectRecords &records, raw_ostream &os) {
if (!records.enums.empty()) {
os << "\n## Enums\n";
for (const EnumInfo &def : records.enums)
- emitEnumDoc(def, os);
+ emitSingleEnumDoc(def, os);
}
}
-static bool emitDialectDoc(const DialectRecords &records, raw_ostream &os) {
+bool mlir::tblgen::emitDialectDoc(const DialectRecords &records,
+ StringRef stripPrefix,
+ bool allowHugoSpecificFeatures,
+ raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
os << "\n# '" << records.dialect.getName() << "' Dialect\n";
emitSummary(records.dialect.getSummary(), os);
@@ -538,7 +451,7 @@ static bool emitDialectDoc(const DialectRecords &records, raw_ostream &os) {
if (!r.match(records.dialect.getDescription()))
os << "\n[TOC]\n";
- emitBlock(records, os);
+ emitBlock(records, stripPrefix, allowHugoSpecificFeatures, os);
return false;
}
@@ -546,25 +459,17 @@ static bool emitDialectDoc(const DialectRecords &records, raw_ostream &os) {
// Record Collection
//===----------------------------------------------------------------------===//
-/// Collect, filter, and organize all records relevant for dialect documentation
-/// generation. Returns none if no single dialect could be determined. See
-/// `mlir::tblgen::findDialectToGenerate`.
-static std::optional<DialectRecords>
-collectRecords(const RecordKeeper &records) {
- auto dialectDefs = records.getAllDerivedDefinitionsIfDefined("Dialect");
- SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
- std::optional<Dialect> dialect = findDialectToGenerate(dialects);
- if (!dialect)
- return std::nullopt;
-
- std::vector<const Record *> opDefs = getRequestedOpDefinitions(records);
+std::optional<DialectRecords>
+mlir::tblgen::collectRecords(const RecordKeeper &records,
+ ArrayRef<const Record *> opDefs,
+ const Dialect &dialect, bool keepOpSourceOrder) {
auto attrDefs = records.getAllDerivedDefinitionsIfDefined("DialectAttr");
auto typeDefs = records.getAllDerivedDefinitionsIfDefined("DialectType");
auto typeDefDefs = records.getAllDerivedDefinitionsIfDefined("TypeDef");
auto attrDefDefs = records.getAllDerivedDefinitionsIfDefined("AttrDef");
auto enumDefs = records.getAllDerivedDefinitionsIfDefined("EnumInfo");
- DialectRecords result(*dialect, records.getInputFilename());
+ DialectRecords result(dialect, records.getInputFilename());
SmallDenseSet<const Record *> seen;
auto addIfNotSeen = [&](const Record *record, const auto &def, auto &vec) {
if (seen.insert(record).second) {
@@ -574,7 +479,7 @@ collectRecords(const RecordKeeper &records) {
return false;
};
auto addIfInDialect = [&](const Record *record, const auto &def, auto &vec) {
- return def.getDialect() == *dialect && addIfNotSeen(record, def, vec);
+ return def.getDialect() == dialect && addIfNotSeen(record, def, vec);
};
SmallDenseMap<const Record *, OpDocGroup> opDocGroup;
@@ -609,8 +514,7 @@ collectRecords(const RecordKeeper &records) {
for (const Record *def : enumDefs)
addIfNotSeen(def, EnumInfo(def), result.enums);
- // Sort alphabetically ignorning dialect for ops and section name for
- // sections.
+ // Sort alphabetically ignoring dialect for ops and section name for sections.
// TODO: The sorting order could be revised, currently attempting to sort of
// keep in alphabetical order.
if (keepOpSourceOrder)
@@ -626,48 +530,3 @@ collectRecords(const RecordKeeper &records) {
return result;
}
-
-//===----------------------------------------------------------------------===//
-// Gen Registration
-//===----------------------------------------------------------------------===//
-
-static mlir::GenRegistration
- genAttrRegister("gen-attrdef-doc",
- "Generate dialect attribute documentation",
- [](const RecordKeeper &records, raw_ostream &os) {
- if (auto filtered = collectRecords(records))
- return emitAttrDefDoc(*filtered, os);
- return true;
- });
-
-static mlir::GenRegistration
- genOpRegister("gen-op-doc", "Generate dialect documentation",
- [](const RecordKeeper &records, raw_ostream &os) {
- if (auto filtered = collectRecords(records))
- return emitOpDoc(*filtered, os);
- return true;
- });
-
-static mlir::GenRegistration
- genTypeRegister("gen-typedef-doc", "Generate dialect type documentation",
- [](const RecordKeeper &records, raw_ostream &os) {
- if (auto filtered = collectRecords(records))
- return emitTypeDefDoc(*filtered, os);
- return true;
- });
-
-static mlir::GenRegistration
- genEnumRegister("gen-enum-doc", "Generate dialect enum documentation",
- [](const RecordKeeper &records, raw_ostream &os) {
- if (auto filtered = collectRecords(records))
- return emitEnumDoc(*filtered, os);
- return true;
- });
-
-static mlir::GenRegistration
- genRegister("gen-dialect-doc", "Generate dialect documentation",
- [](const RecordKeeper &records, raw_ostream &os) {
- if (auto filtered = collectRecords(records))
- return emitDialectDoc(*filtered, os);
- return true;
- });
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/lib/TableGen/Generators/OpFormatGen.cpp
similarity index 96%
rename from mlir/tools/mlir-tblgen/OpFormatGen.cpp
rename to mlir/lib/TableGen/Generators/OpFormatGen.cpp
index c79f0f377644e..d698c438c24e8 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/lib/TableGen/Generators/OpFormatGen.cpp
@@ -6,19 +6,20 @@
//
//===----------------------------------------------------------------------===//
-#include "OpFormatGen.h"
-#include "FormatGen.h"
-#include "OpClass.h"
+#include "mlir/TableGen/Generators/OpFormatGen.h"
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/EnumInfo.h"
#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/Generators/FormatGen.h"
+#include "mlir/TableGen/Generators/OpClass.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Signals.h"
@@ -291,138 +292,17 @@ class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> {
// OperationFormat
//===----------------------------------------------------------------------===//
-namespace {
-
-using ConstArgument =
- llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
-
-struct OperationFormat {
- /// This class represents a specific resolver for an operand or result type.
- class TypeResolution {
- public:
- TypeResolution() = default;
-
- /// Get the index into the buildable types for this type, or std::nullopt.
- std::optional<int> getBuilderIdx() const { return builderIdx; }
- void setBuilderIdx(int idx) { builderIdx = idx; }
-
- /// Get the variable this type is resolved to, or nullptr.
- const NamedTypeConstraint *getVariable() const {
- return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
- }
- /// Get the attribute this type is resolved to, or nullptr.
- const NamedAttribute *getAttribute() const {
- return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
- }
- /// Get the transformer for the type of the variable, or std::nullopt.
- std::optional<StringRef> getVarTransformer() const {
- return variableTransformer;
- }
- void setResolver(ConstArgument arg, std::optional<StringRef> transformer) {
- resolver = arg;
- variableTransformer = transformer;
- assert(getVariable() || getAttribute());
- }
-
- private:
- /// If the type is resolved with a buildable type, this is the index into
- /// 'buildableTypes' in the parent format.
- std::optional<int> builderIdx;
- /// If the type is resolved based upon another operand or result, this is
- /// the variable or the attribute that this type is resolved to.
- ConstArgument resolver;
- /// If the type is resolved based upon another operand or result, this is
- /// a transformer to apply to the variable when resolving.
- std::optional<StringRef> variableTransformer;
- };
-
- /// The context in which an element is generated.
- enum class GenContext {
- /// The element is generated at the top-level or with the same behaviour.
- Normal,
- /// The element is generated inside an optional group.
- Optional
- };
-
- OperationFormat(const Operator &op, bool hasProperties)
- : useProperties(hasProperties), opCppClassName(op.getCppClassName()) {
- operandTypes.resize(op.getNumOperands(), TypeResolution());
- resultTypes.resize(op.getNumResults(), TypeResolution());
-
- hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
- return trait.getDef().isSubClassOf("SingleBlockImplicitTerminatorImpl");
- });
-
- hasSingleBlockTrait = op.getTrait("::mlir::OpTrait::SingleBlock");
- }
+OperationFormat::OperationFormat(const Operator &op, bool hasProperties)
+ : useProperties(hasProperties), opCppClassName(op.getCppClassName()) {
+ operandTypes.resize(op.getNumOperands(), TypeResolution());
+ resultTypes.resize(op.getNumResults(), TypeResolution());
- /// Generate the operation parser from this format.
- void genParser(Operator &op, OpClass &opClass);
- /// Generate the parser code for a specific format element.
- void genElementParser(FormatElement *element, MethodBody &body,
- FmtContext &attrTypeCtx,
- GenContext genCtx = GenContext::Normal);
- /// Generate the C++ to resolve the types of operands and results during
- /// parsing.
- void genParserTypeResolution(Operator &op, MethodBody &body);
- /// Generate the C++ to resolve the types of the operands during parsing.
- void genParserOperandTypeResolution(
- Operator &op, MethodBody &body,
- function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
- /// Generate the C++ to resolve regions during parsing.
- void genParserRegionResolution(Operator &op, MethodBody &body);
- /// Generate the C++ to resolve successors during parsing.
- void genParserSuccessorResolution(Operator &op, MethodBody &body);
- /// Generate the C++ to handling variadic segment size traits.
- void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
-
- /// Generate the operation printer from this format.
- void genPrinter(Operator &op, OpClass &opClass);
-
- /// Generate the printer code for a specific format element.
- void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op,
- bool &shouldEmitSpace, bool &lastWasPunctuation);
-
- /// The various elements in this format.
- std::vector<FormatElement *> elements;
-
- /// A flag indicating if all operand/result types were seen. If the format
- /// contains these, it can not contain individual type resolvers.
- bool allOperands = false, allOperandTypes = false, allResultTypes = false;
-
- /// A flag indicating if this operation infers its result types
- bool infersResultTypes = false;
-
- /// A flag indicating if this operation has the SingleBlockImplicitTerminator
- /// trait.
- bool hasImplicitTermTrait;
-
- /// A flag indicating if this operation has the SingleBlock trait.
- bool hasSingleBlockTrait;
-
- /// Indicate whether we need to use properties for the current operator.
- bool useProperties;
-
- /// Indicate whether prop-dict is used in the format
- bool hasPropDict;
-
- /// The Operation class name
- StringRef opCppClassName;
-
- /// A map of buildable types to indices.
- llvm::MapVector<StringRef, int, StringMap<int>> buildableTypes;
-
- /// The index of the buildable type, if valid, for every operand and result.
- std::vector<TypeResolution> operandTypes, resultTypes;
-
- /// The set of attributes explicitly used within the format.
- llvm::SmallSetVector<const NamedAttribute *, 8> usedAttributes;
- llvm::StringSet<> inferredAttributes;
+ hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
+ return trait.getDef().isSubClassOf("SingleBlockImplicitTerminatorImpl");
+ });
- /// The set of properties explicitly used within the format.
- llvm::SmallSetVector<const NamedProperty *, 8> usedProperties;
-};
-} // namespace
+ hasSingleBlockTrait = op.getTrait("::mlir::OpTrait::SingleBlock");
+}
//===----------------------------------------------------------------------===//
// Parser Gen
@@ -2679,6 +2559,8 @@ static auto findArg(RangeT &&range, StringRef name) {
}
namespace {
+using ConstArgument = OperationFormat::ConstArgument;
+
/// This class implements a parser for an instance of an operation assembly
/// format.
class OpFormatParser : public FormatParser {
@@ -3846,7 +3728,7 @@ LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
//===----------------------------------------------------------------------===//
void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass,
- bool hasProperties) {
+ bool hasProperties, bool fatalOnError) {
// TODO: Operator doesn't expose all necessary functionality via
// the const interface.
Operator &op = const_cast<Operator &>(constOp);
@@ -3868,7 +3750,7 @@ void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass,
FailureOr<std::vector<FormatElement *>> elements = parser.parse();
if (failed(elements)) {
// Exit the process if format errors are treated as fatal.
- if (formatErrorIsFatal) {
+ if (fatalOnError) {
// Invoke the interrupt handlers to run the file cleanup handlers.
llvm::sys::RunInterruptHandlers();
std::exit(1);
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/lib/TableGen/Generators/OpGenHelpers.cpp
similarity index 61%
rename from mlir/tools/mlir-tblgen/OpGenHelpers.cpp
rename to mlir/lib/TableGen/Generators/OpGenHelpers.cpp
index 44dbacf19fffd..e3b72ed19601b 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
+++ b/mlir/lib/TableGen/Generators/OpGenHelpers.cpp
@@ -10,9 +10,8 @@
//
//===----------------------------------------------------------------------===//
-#include "OpGenHelpers.h"
+#include "mlir/TableGen/Generators/OpGenHelpers.h"
#include "llvm/ADT/StringSet.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Regex.h"
#include "llvm/TableGen/Error.h"
@@ -21,21 +20,6 @@ using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
-cl::OptionCategory opDefGenCat("Options for op definition generators");
-
-static cl::opt<std::string> opIncFilter(
- "op-include-regex",
- cl::desc("Regex of name of op's to include (no filter if empty)"),
- cl::cat(opDefGenCat));
-static cl::opt<std::string> opExcFilter(
- "op-exclude-regex",
- cl::desc("Regex of name of op's to exclude (no filter if empty)"),
- cl::cat(opDefGenCat));
-static cl::opt<unsigned> opShardCount(
- "op-shard-count",
- cl::desc("The number of shards into which the op classes will be divided"),
- cl::cat(opDefGenCat), cl::init(1));
-
static std::string getOperationName(const Record &def) {
auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
auto opName = def.getValueAsString("opName");
@@ -45,23 +29,21 @@ static std::string getOperationName(const Record &def) {
}
std::vector<const Record *>
-mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &records) {
+mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &records,
+ StringRef includeRegex,
+ StringRef excludeRegex) {
const Record *classDef = records.getClass("Op");
if (!classDef)
PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
- Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
+ Regex incRegex(includeRegex), excRegex(excludeRegex);
std::vector<const Record *> defs;
for (const auto &def : records.getDefs()) {
if (!def.second->isSubClassOf(classDef))
continue;
- // Include if no include filter or include filter matches.
- if (!opIncFilter.empty() &&
- !includeRegex.match(getOperationName(*def.second)))
+ if (!includeRegex.empty() && !incRegex.match(getOperationName(*def.second)))
continue;
- // Unless there is an exclude filter and it matches.
- if (!opExcFilter.empty() &&
- excludeRegex.match(getOperationName(*def.second)))
+ if (!excludeRegex.empty() && excRegex.match(getOperationName(*def.second)))
continue;
defs.push_back(def.second.get());
}
@@ -77,8 +59,6 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
"import", "in", "is", "lambda", "nonlocal", "not", "or",
"pass", "raise", "return", "try", "while", "with", "yield",
});
- // These aren't Python keywords but builtin functions that shouldn't/can't be
- // shadowed.
reserved.insert("callable");
reserved.insert("issubclass");
reserved.insert("type");
@@ -87,17 +67,18 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
void mlir::tblgen::shardOpDefinitions(
ArrayRef<const Record *> defs,
- SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) {
- assert(opShardCount > 0 && "expected a positive shard count");
- if (opShardCount == 1) {
+ SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs,
+ unsigned shardCount) {
+ assert(shardCount > 0 && "expected a positive shard count");
+ if (shardCount == 1) {
shardedDefs.push_back(defs);
return;
}
- unsigned minShardSize = defs.size() / opShardCount;
- unsigned numMissing = defs.size() - minShardSize * opShardCount;
- shardedDefs.reserve(opShardCount);
- for (unsigned i = 0, start = 0; i < opShardCount; ++i) {
+ unsigned minShardSize = defs.size() / shardCount;
+ unsigned numMissing = defs.size() - minShardSize * shardCount;
+ shardedDefs.reserve(shardCount);
+ for (unsigned i = 0, start = 0; i < shardCount; ++i) {
unsigned size = minShardSize + (i < numMissing);
shardedDefs.push_back(defs.slice(start, size));
start += size;
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/lib/TableGen/Generators/OpInterfacesGen.cpp
similarity index 70%
rename from mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
rename to mlir/lib/TableGen/Generators/OpInterfacesGen.cpp
index ab8d534a99f19..64a49374a010a 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/lib/TableGen/Generators/OpInterfacesGen.cpp
@@ -10,10 +10,10 @@
//
//===----------------------------------------------------------------------===//
-#include "CppGenUtilities.h"
-#include "DocGenUtilities.h"
+#include "mlir/TableGen/Generators/OpInterfacesGen.h"
#include "mlir/TableGen/Format.h"
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/CppGenUtilities.h"
+#include "mlir/TableGen/Generators/DocGenUtilities.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
@@ -63,109 +63,64 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
os << " const";
}
-/// Get an array of all OpInterface definitions but exclude those subclassing
-/// "DeclareOpInterfaceMethods".
-static std::vector<const Record *>
-getAllInterfaceDefinitions(const RecordKeeper &records, StringRef name) {
+std::vector<const Record *>
+mlir::tblgen::getAllInterfaceDefinitions(const RecordKeeper &records,
+ StringRef name) {
std::vector<const Record *> defs =
records.getAllDerivedDefinitions((name + "Interface").str());
std::string declareName = ("Declare" + name + "InterfaceMethods").str();
llvm::erase_if(defs, [&](const Record *def) {
- // Ignore any "declare methods" interfaces.
if (def->isSubClassOf(declareName))
return true;
- // Ignore interfaces defined outside of the top-level file.
return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
llvm::SrcMgr.getMainFileID();
});
return defs;
}
-namespace {
-/// This struct is the base generator used when processing tablegen interfaces.
-class InterfaceGenerator {
-public:
- bool emitInterfaceDefs();
- bool emitInterfaceDecls();
- bool emitInterfaceDocs();
-
-protected:
- InterfaceGenerator(std::vector<const Record *> &&defs, raw_ostream &os)
- : defs(std::move(defs)), os(os) {}
-
- void emitConceptDecl(const Interface &interface);
- void emitModelDecl(const Interface &interface);
- void emitModelMethodsDef(const Interface &interface);
- void forwardDeclareInterface(const Interface &interface);
- void emitInterfaceDecl(const Interface &interface);
- void emitInterfaceTraitDecl(const Interface &interface);
-
- /// The set of interface records to emit.
- std::vector<const Record *> defs;
- // The stream to emit to.
- raw_ostream &os;
- /// The C++ value type of the interface, e.g. Operation*.
- StringRef valueType;
- /// The C++ base interface type.
- StringRef interfaceBaseType;
- /// The name of the typename for the value template.
- StringRef valueTemplate;
- /// The name of the substituion variable for the value.
- StringRef substVar;
- /// The format context to use for methods.
- tblgen::FmtContext nonStaticMethodFmt;
- tblgen::FmtContext traitMethodFmt;
- tblgen::FmtContext extraDeclsFmt;
-};
-
-/// A specialized generator for attribute interfaces.
-struct AttrInterfaceGenerator : public InterfaceGenerator {
- AttrInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
- : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
- valueType = "::mlir::Attribute";
- interfaceBaseType = "AttributeInterface";
- valueTemplate = "ConcreteAttr";
- substVar = "_attr";
- StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))";
- nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
- traitMethodFmt.addSubst(substVar,
- "(*static_cast<const ConcreteAttr *>(this))");
- extraDeclsFmt.addSubst(substVar, "(*this)");
- }
-};
-/// A specialized generator for operation interfaces.
-struct OpInterfaceGenerator : public InterfaceGenerator {
- OpInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
- : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
- valueType = "::mlir::Operation *";
- interfaceBaseType = "OpInterface";
- valueTemplate = "ConcreteOp";
- substVar = "_op";
- StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
- nonStaticMethodFmt.addSubst("_this", "impl")
- .addSubst(substVar, castCode)
- .withSelf(castCode);
- traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
- extraDeclsFmt.addSubst(substVar, "(*this)");
- }
-};
-/// A specialized generator for type interfaces.
-struct TypeInterfaceGenerator : public InterfaceGenerator {
- TypeInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
- : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
- valueType = "::mlir::Type";
- interfaceBaseType = "TypeInterface";
- valueTemplate = "ConcreteType";
- substVar = "_type";
- StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))";
- nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
- traitMethodFmt.addSubst(substVar,
- "(*static_cast<const ConcreteType *>(this))");
- extraDeclsFmt.addSubst(substVar, "(*this)");
- }
-};
-} // namespace
+mlir::tblgen::AttrInterfaceGenerator::AttrInterfaceGenerator(
+ const RecordKeeper &records, raw_ostream &os)
+ : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
+ valueType = "::mlir::Attribute";
+ interfaceBaseType = "AttributeInterface";
+ valueTemplate = "ConcreteAttr";
+ substVar = "_attr";
+ StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))";
+ nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
+ traitMethodFmt.addSubst(substVar,
+ "(*static_cast<const ConcreteAttr *>(this))");
+ extraDeclsFmt.addSubst(substVar, "(*this)");
+}
+
+mlir::tblgen::OpInterfaceGenerator::OpInterfaceGenerator(
+ const RecordKeeper &records, raw_ostream &os)
+ : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
+ valueType = "::mlir::Operation *";
+ interfaceBaseType = "OpInterface";
+ valueTemplate = "ConcreteOp";
+ substVar = "_op";
+ StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
+ nonStaticMethodFmt.addSubst("_this", "impl")
+ .addSubst(substVar, castCode)
+ .withSelf(castCode);
+ traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
+ extraDeclsFmt.addSubst(substVar, "(*this)");
+}
+
+mlir::tblgen::TypeInterfaceGenerator::TypeInterfaceGenerator(
+ const RecordKeeper &records, raw_ostream &os)
+ : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
+ valueType = "::mlir::Type";
+ interfaceBaseType = "TypeInterface";
+ valueTemplate = "ConcreteType";
+ substVar = "_type";
+ StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))";
+ nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
+ traitMethodFmt.addSubst(substVar,
+ "(*static_cast<const ConcreteType *>(this))");
+ extraDeclsFmt.addSubst(substVar, "(*this)");
+}
//===----------------------------------------------------------------------===//
// GEN: Interface definitions
@@ -209,12 +164,10 @@ static void emitInterfaceDef(const Interface &interface, StringRef valueType,
StringRef interfaceQualName = interfaceQualNameStr;
interfaceQualName.consume_front("::");
- // Insert the method definitions.
bool isOpInterface = isa<OpInterface>(interface);
emitInterfaceDefMethods(interfaceQualName, interface, valueType, "getImpl()",
os, isOpInterface);
- // Insert the method definitions for base classes.
for (auto &base : interface.getBaseInterfaces()) {
emitInterfaceDefMethods(interfaceQualName, base, valueType,
"getImpl()->impl" + base.getName(), os,
@@ -222,7 +175,7 @@ static void emitInterfaceDef(const Interface &interface, StringRef valueType,
}
}
-bool InterfaceGenerator::emitInterfaceDefs() {
+bool mlir::tblgen::InterfaceGenerator::emitInterfaceDefs() {
llvm::emitSourceFileHeader("Interface Definitions", os);
for (const auto *def : defs)
@@ -234,10 +187,10 @@ bool InterfaceGenerator::emitInterfaceDefs() {
// GEN: Interface declarations
//===----------------------------------------------------------------------===//
-void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
+void mlir::tblgen::InterfaceGenerator::emitConceptDecl(
+ const Interface &interface) {
os << " struct Concept {\n";
- // Insert each of the pure virtual concept methods.
os << " /// The methods defined by the interface.\n";
for (auto &method : interface.getMethods()) {
os << " ";
@@ -253,7 +206,6 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
os << ");\n";
}
- // Insert a field containing a concept for each of the base interfaces.
auto baseInterfaces = interface.getBaseInterfaces();
if (!baseInterfaces.empty()) {
os << " /// The base classes of this interface.\n";
@@ -262,8 +214,6 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
<< base.getName() << " = nullptr;\n";
}
- // Define an "initialize" method that allows for the initialization of the
- // base class concepts.
os << "\n void initializeInterfaceConcept(::mlir::detail::InterfaceMap "
"&interfaceMap) {\n";
std::string interfaceQualName = interface.getFullyQualifiedName();
@@ -282,8 +232,8 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
os << " };\n";
}
-void InterfaceGenerator::emitModelDecl(const Interface &interface) {
- // Emit the basic model and the fallback model.
+void mlir::tblgen::InterfaceGenerator::emitModelDecl(
+ const Interface &interface) {
for (const char *modelClass : {"Model", "FallbackModel"}) {
os << " template<typename " << valueTemplate << ">\n";
os << " class " << modelClass << " : public Concept {\n public:\n";
@@ -295,7 +245,6 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
[&](const InterfaceMethod &method) { os << method.getUniqueName(); });
os << "} {}\n\n";
- // Insert each of the virtual method overrides.
for (auto &method : interface.getMethods()) {
emitCPPType(method.getReturnType(), os << " static inline ");
emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType,
@@ -306,15 +255,12 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
os << " };\n";
}
- // Emit the template for the external model.
os << " template<typename ConcreteModel, typename " << valueTemplate
<< ">\n";
os << " class ExternalModel : public FallbackModel<ConcreteModel> {\n";
os << " public:\n";
os << " using ConcreteEntity = " << valueTemplate << ";\n";
- // Emit declarations for methods that have default implementations. Other
- // methods are expected to be implemented by the concrete derived model.
for (auto &method : interface.getMethods()) {
if (!method.getDefaultImplementation())
continue;
@@ -342,7 +288,8 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
os << " };\n";
}
-void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
+void mlir::tblgen::InterfaceGenerator::emitModelMethodsDef(
+ const Interface &interface) {
llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
for (auto &method : interface.getMethods()) {
os << "template<typename " << valueTemplate << ">\n";
@@ -354,7 +301,6 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
/*addConst=*/false);
os << " {\n ";
- // Check for a provided body to the function.
if (std::optional<StringRef> body = method.getBody()) {
if (method.isStatic())
os << body->trim();
@@ -364,13 +310,11 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
continue;
}
- // Forward to the method on the concrete operation type.
if (method.isStatic())
os << "return " << valueTemplate << "::";
else
os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt);
- // Add the arguments to the call.
os << method.getName() << '(';
llvm::interleaveComma(
method.getArguments(), os,
@@ -388,13 +332,11 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
/*addConst=*/false);
os << " {\n ";
- // Forward to the method on the concrete Model implementation.
if (method.isStatic())
os << "return " << valueTemplate << "::";
else
os << "return static_cast<const " << valueTemplate << " *>(impl)->";
- // Add the arguments to the call.
os << method.getUniqueName() << '(';
if (!method.isStatic())
os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
@@ -404,7 +346,6 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
os << ");\n}\n";
}
- // Emit default implementations for the external model.
for (auto &method : interface.getMethods()) {
if (!method.getDefaultImplementation())
continue;
@@ -433,7 +374,6 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
os << " {\n";
- // Use the empty context for static methods.
tblgen::FmtContext ctx;
os << tblgen::tgfmt(method.getDefaultImplementation()->trim(),
method.isStatic() ? &ctx : &nonStaticMethodFmt);
@@ -441,7 +381,8 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
}
}
-void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
+void mlir::tblgen::InterfaceGenerator::emitInterfaceTraitDecl(
+ const Interface &interface) {
auto cppNamespace = (interface.getCppNamespace() + "::detail").str();
llvm::NamespaceEmitter ns(os, cppNamespace);
@@ -453,10 +394,8 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
interfaceName, interfaceTraitsName, interfaceBaseType,
valueTemplate);
- // Insert the default implementation for any methods.
bool isOpInterface = isa<OpInterface>(interface);
for (auto &method : interface.getMethods()) {
- // Flag interface methods named verifyTrait.
if (method.getName() == "verifyTrait")
PrintFatalError(
formatv("'verifyTrait' method cannot be specified as interface "
@@ -509,7 +448,6 @@ static void emitInterfaceDeclMethods(const Interface &interface,
os << ";\n";
}
- // Emit any extra declarations.
if (std::optional<StringRef> extraDecls =
interface.getExtraClassDeclaration())
os << extraDecls->rtrim() << "\n";
@@ -518,11 +456,10 @@ static void emitInterfaceDeclMethods(const Interface &interface,
os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
}
-void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) {
+void mlir::tblgen::InterfaceGenerator::forwardDeclareInterface(
+ const Interface &interface) {
llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
- // Emit a forward declaration of the interface class so that it becomes usable
- // in the signature of its methods.
tblgen::emitSummaryAndDescComments(os, "",
interface.getDescription().value_or(""));
@@ -530,48 +467,41 @@ void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) {
os << "class " << interfaceName << ";\n";
}
-void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
+void mlir::tblgen::InterfaceGenerator::emitInterfaceDecl(
+ const Interface &interface) {
llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
- // Emit a forward declaration of the interface class so that it becomes usable
- // in the signature of its methods.
tblgen::emitSummaryAndDescComments(os, "",
interface.getDescription().value_or(""));
- // Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
<< "struct " << interfaceTraitsName << " {\n";
emitConceptDecl(interface);
emitModelDecl(interface);
os << "};\n";
- // Emit the derived trait for the interface.
os << "template <typename " << valueTemplate << ">\n";
os << "struct " << interface.getName() << "Trait;\n";
os << "\n} // namespace detail\n";
- // Emit the main interface class declaration.
os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
"public:\n"
" using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
interfaceName, interfaceName, interfaceTraitsName,
interfaceBaseType);
- // Emit a utility wrapper trait class.
os << llvm::formatv(" template <typename {1}>\n"
" struct Trait : public detail::{0}Trait<{1}> {{};\n",
interfaceName, valueTemplate);
- // Insert the method declarations.
bool isOpInterface = isa<OpInterface>(interface);
emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
extraDeclsFmt);
- // Insert the method declarations for base classes.
for (auto &base : interface.getBaseInterfaces()) {
std::string baseQualName = base.getFullyQualifiedName();
os << " //"
@@ -582,18 +512,15 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
"===---------------------------------------------------------------"
"-===//\n\n";
- // Allow implicit conversion to the base interface.
os << " operator " << baseQualName << " () const {\n"
<< " if (!*this) return nullptr;\n"
<< " return " << baseQualName << "(*this, getImpl()->impl"
<< base.getName() << ");\n"
<< " }\n\n";
- // Inherit the base interface's methods.
emitInterfaceDeclMethods(base, os, valueType, isOpInterface, extraDeclsFmt);
}
- // Emit classof code if necessary.
if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
auto extraClassOfFmt = tblgen::FmtContext();
extraClassOfFmt.addSubst(substVar, "odsInterfaceInstance");
@@ -610,10 +537,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
os << "};\n";
}
-bool InterfaceGenerator::emitInterfaceDecls() {
+bool mlir::tblgen::InterfaceGenerator::emitInterfaceDecls() {
llvm::emitSourceFileHeader("Interface Declarations", os);
- // Sort according to ID, so defs are emitted in the order in which they appear
- // in the Tablegen file.
std::vector<const Record *> sortedDefs(defs);
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
return lhs->getID() < rhs->getID();
@@ -637,19 +562,15 @@ bool InterfaceGenerator::emitInterfaceDecls() {
static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
Interface interface(&interfaceDef);
- // Emit the interface name followed by the description.
os << "\n## " << interface.getName() << " (`" << interfaceDef.getName()
<< "`)\n";
if (auto description = interface.getDescription())
mlir::tblgen::emitDescription(*description, os);
- // Emit the methods required by the interface.
os << "\n### Methods:\n";
for (const auto &method : interface.getMethods()) {
- // Emit the method name.
os << "\n#### `" << method.getName() << "`\n\n```c++\n";
- // Emit the method signature.
if (method.isStatic())
os << "static ";
emitCPPType(method.getReturnType(), os) << method.getName() << '(';
@@ -659,11 +580,9 @@ static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
});
os << ");\n```\n";
- // Emit the description.
if (auto description = method.getDescription())
mlir::tblgen::emitDescription(*description, os);
- // If the body is not provided, this method must be provided by the user.
if (!method.getBody())
os << "\nNOTE: This method *must* be implemented by the user.";
@@ -671,7 +590,7 @@ static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
}
}
-bool InterfaceGenerator::emitInterfaceDocs() {
+bool mlir::tblgen::InterfaceGenerator::emitInterfaceDocs() {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
os << "\n# " << interfaceBaseType << " definitions\n";
@@ -679,41 +598,3 @@ bool InterfaceGenerator::emitInterfaceDocs() {
emitInterfaceDoc(*def, os);
return false;
}
-
-//===----------------------------------------------------------------------===//
-// GEN: Interface registration hooks
-//===----------------------------------------------------------------------===//
-
-namespace {
-template <typename GeneratorT>
-struct InterfaceGenRegistration {
- InterfaceGenRegistration(StringRef genArg, StringRef genDesc)
- : genDeclArg(("gen-" + genArg + "-interface-decls").str()),
- genDefArg(("gen-" + genArg + "-interface-defs").str()),
- genDocArg(("gen-" + genArg + "-interface-docs").str()),
- genDeclDesc(("Generate " + genDesc + " interface declarations").str()),
- genDefDesc(("Generate " + genDesc + " interface definitions").str()),
- genDocDesc(("Generate " + genDesc + " interface documentation").str()),
- genDecls(genDeclArg, genDeclDesc,
- [](const RecordKeeper &records, raw_ostream &os) {
- return GeneratorT(records, os).emitInterfaceDecls();
- }),
- genDefs(genDefArg, genDefDesc,
- [](const RecordKeeper &records, raw_ostream &os) {
- return GeneratorT(records, os).emitInterfaceDefs();
- }),
- genDocs(genDocArg, genDocDesc,
- [](const RecordKeeper &records, raw_ostream &os) {
- return GeneratorT(records, os).emitInterfaceDocs();
- }) {}
-
- std::string genDeclArg, genDefArg, genDocArg;
- std::string genDeclDesc, genDefDesc, genDocDesc;
- mlir::GenRegistration genDecls, genDefs, genDocs;
-};
-} // namespace
-
-static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr",
- "attribute");
-static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op");
-static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type");
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/lib/TableGen/Generators/OpPythonBindingGen.cpp
similarity index 97%
rename from mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
rename to mlir/lib/TableGen/Generators/OpPythonBindingGen.cpp
index 81c598ebbef0a..9ec3de46741e5 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/lib/TableGen/Generators/OpPythonBindingGen.cpp
@@ -11,15 +11,14 @@
//
//===----------------------------------------------------------------------===//
-#include "OpGenHelpers.h"
+#include "mlir/TableGen/Generators/OpPythonBindingGen.h"
+#include "mlir/TableGen/Generators/OpGenHelpers.h"
#include "mlir/Support/IndentedOstream.h"
-#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -97,11 +96,17 @@ static constexpr std::pair<StringRef, StringRef> builtinAttrTypeMappings[] = {
{"F64ElementsAttr", "_Union[_Sequence[float], _Buffer]"},
};
+// These maps are populated once per emitPythonOpBindings call and are read by
+// static helper functions throughout the file. Suppress the global-constructors
+// warning since the empty-default construction is trivially safe.
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wglobal-constructors"
/// Maps from C++ type names to Python type annotations.
static llvm::StringMap<std::string> pythonTypeMap;
/// Maps from TableGen attribute def names to Python types.
static llvm::StringMap<std::string> pythonAttrTypeMap;
+#pragma clang diagnostic pop
/// File header and includes.
/// {0} is the dialect namespace.
@@ -419,21 +424,6 @@ def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]:
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
)Py";
-static llvm::cl::OptionCategory
- clOpPythonBindingCat("Options for -gen-python-op-bindings");
-
-std::string dialectNameStorage;
-
-llvm::cl::opt<std::string, /*ExternalStorage=*/true>
- clDialectName("bind-dialect",
- llvm::cl::desc("The dialect to run the generator for"),
- llvm::cl::location(dialectNameStorage),
- llvm::cl::cat(clOpPythonBindingCat));
-
-static llvm::cl::opt<std::string> clDialectExtensionName(
- "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
- llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
-
using AttributeClasses = DenseMap<StringRef, StringRef>;
/// Checks whether `str` would shadow a generated variable or attribute
@@ -1369,7 +1359,7 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
}
}
-/// Emits builder that extracts results from op
+/// Emits builder that extracts results from op.
static void emitValueBuilder(const Operator &op,
SmallVector<std::string> functionArgs,
raw_ostream &os) {
@@ -1501,11 +1491,15 @@ static void populateTypeMap(llvm::StringMap<std::string> &map,
}
}
-/// Emits bindings for the dialect specified in the command line, including file
+namespace mlir {
+namespace tblgen {
+
+/// Emits bindings for the dialect specified via \p dialectName, including file
/// headers and utilities. Returns `false` on success to comply with Tablegen
/// registration requirements.
-static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
- if (dialectNameStorage.empty())
+bool emitPythonOpBindings(const RecordKeeper &records, StringRef dialectName,
+ StringRef dialectExtensionName, raw_ostream &os) {
+ if (dialectName.empty())
llvm::PrintFatalError("dialect name not provided");
populateTypeMap(pythonTypeMap, builtinTypeMappings, records, "PythonTypeName",
@@ -1514,19 +1508,18 @@ static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
"PythonAttrType", "defName", "pyType");
os << fileHeader;
- if (!clDialectExtensionName.empty())
- os << formatv(dialectExtensionTemplate, dialectNameStorage);
+ if (!dialectExtensionName.empty())
+ os << formatv(dialectExtensionTemplate, dialectName);
else
- os << formatv(dialectClassTemplate, dialectNameStorage);
+ os << formatv(dialectClassTemplate, dialectName);
for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
- if (op.getDialectName() == dialectNameStorage)
+ if (op.getDialectName() == dialectName)
emitOpBindings(op, os);
}
return false;
}
-static GenRegistration
- genPythonBindings("gen-python-op-bindings",
- "Generate Python bindings for MLIR Ops", &emitAllOps);
+} // namespace tblgen
+} // namespace mlir
diff --git a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp b/mlir/lib/TableGen/Generators/PassCAPIGen.cpp
similarity index 50%
rename from mlir/tools/mlir-tblgen/PassCAPIGen.cpp
rename to mlir/lib/TableGen/Generators/PassCAPIGen.cpp
index 8c13c9b031335..bcc3e6bf8b5a1 100644
--- a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
+++ b/mlir/lib/TableGen/Generators/PassCAPIGen.cpp
@@ -1,4 +1,4 @@
-//===- Pass.cpp - MLIR pass registration generator ------------------------===//
+//===- PassCAPIGen.cpp - MLIR pass C API generator ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,16 +6,14 @@
//
//===----------------------------------------------------------------------===//
//
-// PassCAPIGen uses the description of passes to generate C API for the passes.
+// PassCAPIGen uses the description of passes to generate C API bindings.
//
//===----------------------------------------------------------------------===//
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/PassCAPIGen.h"
#include "mlir/TableGen/Pass.h"
#include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
@@ -23,23 +21,14 @@ using namespace mlir::tblgen;
using llvm::formatv;
using llvm::RecordKeeper;
-static llvm::cl::OptionCategory
- passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl");
-static llvm::cl::opt<std::string>
- groupName("prefix",
- llvm::cl::desc("The prefix to use for this group of passes. The "
- "form will be mlirCreate<prefix><passname>, the "
- "prefix can avoid conflicts across libraries."),
- llvm::cl::cat(passGenCat));
-
-const char *const passDecl = R"(
+static const char *const passDecl = R"(
/* Create {0} Pass. */
MLIR_CAPI_EXPORTED MlirPass mlirCreate{0}{1}(void);
MLIR_CAPI_EXPORTED void mlirRegister{0}{1}(void);
)";
-const char *const fileHeader = R"(
+static const char *const fileHeader = R"(
/* Autogenerated by mlir-tblgen; don't manually edit. */
#include "mlir-c/Pass.h"
@@ -50,29 +39,14 @@ extern "C" {
)";
-const char *const fileFooter = R"(
+static const char *const fileFooter = R"(
#ifdef __cplusplus
}
#endif
)";
-/// Emit TODO
-static bool emitCAPIHeader(const RecordKeeper &records, raw_ostream &os) {
- os << fileHeader;
- os << "// Registration for the entire group\n";
- os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName
- << "Passes(void);\n\n";
- for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
- Pass pass(def);
- StringRef defName = pass.getDef()->getName();
- os << formatv(passDecl, groupName, defName);
- }
- os << fileFooter;
- return false;
-}
-
-const char *const passCreateDef = R"(
+static const char *const passCreateDef = R"(
MlirPass mlirCreate{0}{1}(void) {
return wrap({2}.release());
}
@@ -83,7 +57,7 @@ void mlirRegister{0}{1}(void) {
)";
/// {0}: The name of the pass group.
-const char *const passGroupRegistrationCode = R"(
+static const char *const passGroupRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Group Registration
//===----------------------------------------------------------------------===//
@@ -93,29 +67,37 @@ void mlirRegister{0}Passes(void) {{
}
)";
-static bool emitCAPIImpl(const RecordKeeper &records, raw_ostream &os) {
+void mlir::tblgen::emitPassCAPIHeader(const RecordKeeper &records,
+ llvm::StringRef prefix,
+ llvm::raw_ostream &os) {
+ os << fileHeader;
+ os << "// Registration for the entire group\n";
+ os << "MLIR_CAPI_EXPORTED void mlirRegister" << prefix << "Passes(void);\n\n";
+ for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
+ Pass pass(def);
+ llvm::StringRef defName = pass.getDef()->getName();
+ os << formatv(passDecl, prefix, defName);
+ }
+ os << fileFooter;
+}
+
+void mlir::tblgen::emitPassCAPIImpl(const RecordKeeper &records,
+ llvm::StringRef prefix,
+ llvm::raw_ostream &os) {
os << "/* Autogenerated by mlir-tblgen; don't manually edit. */";
- os << formatv(passGroupRegistrationCode, groupName);
+ os << formatv(passGroupRegistrationCode, prefix);
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
Pass pass(def);
- StringRef defName = pass.getDef()->getName();
+ llvm::StringRef defName = pass.getDef()->getName();
std::string constructorCall;
- if (StringRef constructor = pass.getConstructor(); !constructor.empty())
+ if (llvm::StringRef constructor = pass.getConstructor();
+ !constructor.empty())
constructorCall = constructor.str();
else
constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
- os << formatv(passCreateDef, groupName, defName, constructorCall);
+ os << formatv(passCreateDef, prefix, defName, constructorCall);
}
- return false;
}
-
-static mlir::GenRegistration genCAPIHeader("gen-pass-capi-header",
- "Generate pass C API header",
- &emitCAPIHeader);
-
-static mlir::GenRegistration genCAPIImpl("gen-pass-capi-impl",
- "Generate pass C API implementation",
- &emitCAPIImpl);
diff --git a/mlir/tools/mlir-tblgen/PassDocGen.cpp b/mlir/lib/TableGen/Generators/PassDocGen.cpp
similarity index 75%
rename from mlir/tools/mlir-tblgen/PassDocGen.cpp
rename to mlir/lib/TableGen/Generators/PassDocGen.cpp
index 456f9ceffeb9b..421b1aabac24a 100644
--- a/mlir/tools/mlir-tblgen/PassDocGen.cpp
+++ b/mlir/lib/TableGen/Generators/PassDocGen.cpp
@@ -10,24 +10,27 @@
//
//===----------------------------------------------------------------------===//
-#include "DocGenUtilities.h"
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/PassDocGen.h"
+#include "mlir/TableGen/Generators/DocGenUtilities.h"
#include "mlir/TableGen/Pass.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
using llvm::RecordKeeper;
-/// Emit the documentation for the given pass.
-static void emitDoc(const Pass &pass, raw_ostream &os) {
+/// Emit the documentation for a single pass.
+static void emitPassDoc(const Pass &pass, llvm::raw_ostream &os) {
os << llvm::formatv("\n### `-{0}`\n", pass.getArgument());
emitSummary(pass.getSummary(), os);
emitDescription(pass.getDescription(), os);
// Handle the options of the pass.
- ArrayRef<PassOption> options = pass.getOptions();
+ llvm::ArrayRef<PassOption> options = pass.getOptions();
if (!options.empty()) {
os << "\n#### Options\n\n```\n";
size_t longestOption = 0;
@@ -42,7 +45,7 @@ static void emitDoc(const Pass &pass, raw_ostream &os) {
}
// Handle the statistics of the pass.
- ArrayRef<PassStatistic> stats = pass.getStatistics();
+ llvm::ArrayRef<PassStatistic> stats = pass.getStatistics();
if (!stats.empty()) {
os << "\n#### Statistics\n\n```\n";
size_t longestStat = 0;
@@ -57,25 +60,19 @@ static void emitDoc(const Pass &pass, raw_ostream &os) {
}
}
-static void emitDocs(const RecordKeeper &records, raw_ostream &os) {
+void mlir::tblgen::emitPassDocs(const RecordKeeper &records,
+ llvm::raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
auto passDefs = records.getAllDerivedDefinitions("PassBase");
// Collect the registered passes, sorted by argument name.
- SmallVector<Pass, 16> passes(passDefs.begin(), passDefs.end());
- SmallVector<Pass *, 16> sortedPasses(llvm::make_pointer_range(passes));
+ llvm::SmallVector<Pass, 16> passes(passDefs.begin(), passDefs.end());
+ llvm::SmallVector<Pass *, 16> sortedPasses(llvm::make_pointer_range(passes));
llvm::array_pod_sort(sortedPasses.begin(), sortedPasses.end(),
[](Pass *const *lhs, Pass *const *rhs) {
return (*lhs)->getArgument().compare(
(*rhs)->getArgument());
});
for (Pass *pass : sortedPasses)
- emitDoc(*pass, os);
+ emitPassDoc(*pass, os);
}
-
-static mlir::GenRegistration
- genRegister("gen-pass-doc", "Generate pass documentation",
- [](const RecordKeeper &records, raw_ostream &os) {
- emitDocs(records, os);
- return false;
- });
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/lib/TableGen/Generators/PassGen.cpp
similarity index 79%
rename from mlir/tools/mlir-tblgen/PassGen.cpp
rename to mlir/lib/TableGen/Generators/PassGen.cpp
index e4ae78f022405..d4905c2b58cee 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/lib/TableGen/Generators/PassGen.cpp
@@ -1,4 +1,4 @@
-//===- Pass.cpp - MLIR pass registration generator ------------------------===//
+//===- PassGen.cpp - MLIR pass C++ generation utilities -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -11,10 +11,9 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/PassGen.h"
#include "mlir/TableGen/Pass.h"
#include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -24,36 +23,17 @@ using namespace mlir::tblgen;
using llvm::formatv;
using llvm::RecordKeeper;
-static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
-static llvm::cl::opt<std::string>
- groupName("name", llvm::cl::desc("The name of this group of passes"),
- llvm::cl::cat(passGenCat));
-
-/// Extract the list of passes from the TableGen records.
-static std::vector<Pass> getPasses(const RecordKeeper &records) {
- std::vector<Pass> passes;
-
- for (const auto *def : records.getAllDerivedDefinitions("PassBase"))
- passes.emplace_back(def);
-
- return passes;
-}
-
-const char *const passHeader = R"(
+static const char *const passHeader = R"(
//===----------------------------------------------------------------------===//
// {0}
//===----------------------------------------------------------------------===//
)";
-//===----------------------------------------------------------------------===//
-// GEN: Pass registration generation
-//===----------------------------------------------------------------------===//
-
/// The code snippet used to generate a pass registration.
///
/// {0}: The def name of the pass record.
/// {1}: The pass constructor call.
-const char *const passRegistrationCode = R"(
+static const char *const passRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
@@ -80,7 +60,7 @@ inline void register{0}Pass() {{
/// group.
///
/// {0}: The name of the pass group.
-const char *const passGroupRegistrationCode = R"(
+static const char *const passGroupRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
@@ -88,110 +68,14 @@ const char *const passGroupRegistrationCode = R"(
inline void register{0}Passes() {{
)";
-/// Emits the definition of the struct to be used to control the pass options.
-static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
- StringRef passName = pass.getDef()->getName();
- ArrayRef<PassOption> options = pass.getOptions();
-
- // Emit the struct only if the pass has at least one option.
- if (options.empty())
- return;
-
- os << formatv("struct {0}Options {{\n", passName);
-
- for (const PassOption &opt : options) {
- std::string type = opt.getType().str();
-
- if (opt.isListOption())
- type = "::llvm::SmallVector<" + type + ">";
-
- os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName());
-
- if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
- os << " = " << defaultVal;
-
- os << ";\n";
- }
-
- os << "};\n";
-}
-
-static std::string getPassDeclVarName(const Pass &pass) {
- return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
-}
-
-static std::string getPassRegistrationVarName(const Pass &pass) {
- return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper();
-}
-
-/// Emit the code to be included in the public header of the pass.
-static void emitPassDecls(const Pass &pass, raw_ostream &os) {
- StringRef passName = pass.getDef()->getName();
- std::string enableVarName = getPassDeclVarName(pass);
-
- os << "#ifdef " << enableVarName << "\n";
- emitPassOptionsStruct(pass, os);
-
- if (StringRef constructor = pass.getConstructor(); constructor.empty()) {
- // Default constructor declaration.
- os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n";
-
- // Declaration of the constructor with options.
- if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
- os << formatv("std::unique_ptr<::mlir::Pass> create{0}("
- "{0}Options options);\n",
- passName);
- }
-
- os << "#undef " << enableVarName << "\n";
- os << "#endif // " << enableVarName << "\n";
-}
-
-/// Emit the code for registering each of the given passes with the global
-/// PassRegistry.
-static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
- os << "#ifdef GEN_PASS_REGISTRATION\n";
- os << "// Generate registrations for all passes.\n";
- for (const Pass &pass : passes)
- os << "#define " << getPassRegistrationVarName(pass) << "\n";
- os << "#endif // GEN_PASS_REGISTRATION\n";
-
- for (const Pass &pass : passes) {
- std::string passName = pass.getDef()->getName().str();
- std::string passEnableVarName = getPassRegistrationVarName(pass);
-
- std::string constructorCall;
- if (StringRef constructor = pass.getConstructor(); !constructor.empty())
- constructorCall = constructor.str();
- else
- constructorCall = formatv("create{0}()", passName).str();
- os << formatv(passRegistrationCode, passName, passEnableVarName,
- constructorCall);
- }
-
- os << "#ifdef GEN_PASS_REGISTRATION\n";
- os << formatv(passGroupRegistrationCode, groupName);
-
- for (const Pass &pass : passes)
- os << " register" << pass.getDef()->getName() << "();\n";
-
- os << "}\n";
- os << "#undef GEN_PASS_REGISTRATION\n";
- os << "#endif // GEN_PASS_REGISTRATION\n";
-}
-
-//===----------------------------------------------------------------------===//
-// GEN: Pass base class generation
-//===----------------------------------------------------------------------===//
-
/// The code snippet used to generate the start of a pass base class.
///
/// {0}: The def name of the pass record.
/// {1}: The base class for the pass.
-/// {2): The command line argument for the pass.
+/// {2}: The command line argument for the pass.
/// {3}: The summary for the pass.
/// {4}: The dependent dialects registration.
-const char *const baseClassBegin = R"(
+static const char *const baseClassBegin = R"(
template <typename DerivedT>
class {0}Base : public {1} {
public:
@@ -240,48 +124,112 @@ class {0}Base : public {1} {
)";
-/// Registration for a single dependent dialect, to be inserted for each
-/// dependent dialect in the `getDependentDialects` above.
-const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";
+static const char *const dialectRegistrationTemplate =
+ "registry.insert<{0}>();";
-const char *const friendDefaultConstructorDeclTemplate = R"(
+static const char *const friendDefaultConstructorDeclTemplate = R"(
namespace impl {{
std::unique_ptr<::mlir::Pass> create{0}();
} // namespace impl
)";
-const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"(
+static const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"(
namespace impl {{
std::unique_ptr<::mlir::Pass> create{0}({0}Options options);
} // namespace impl
)";
-const char *const friendDefaultConstructorDefTemplate = R"(
+static const char *const friendDefaultConstructorDefTemplate = R"(
friend std::unique_ptr<::mlir::Pass> create{0}() {{
return std::make_unique<DerivedT>();
}
)";
-const char *const friendDefaultConstructorWithOptionsDefTemplate = R"(
+static const char *const friendDefaultConstructorWithOptionsDefTemplate = R"(
friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
return std::make_unique<DerivedT>(std::move(options));
}
)";
-const char *const defaultConstructorDefTemplate = R"(
+static const char *const defaultConstructorDefTemplate = R"(
std::unique_ptr<::mlir::Pass> create{0}() {{
return impl::create{0}();
}
)";
-const char *const defaultConstructorWithOptionsDefTemplate = R"(
+static const char *const defaultConstructorWithOptionsDefTemplate = R"(
std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
return impl::create{0}(std::move(options));
}
)";
-/// Emit the declarations for each of the pass options.
-static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
+static std::string getPassDeclVarName(const Pass &pass) {
+ return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
+}
+
+static std::string getPassRegistrationVarName(const Pass &pass) {
+ return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper();
+}
+
+std::vector<Pass> mlir::tblgen::getPasses(const RecordKeeper &records) {
+ std::vector<Pass> passes;
+ for (const auto *def : records.getAllDerivedDefinitions("PassBase"))
+ passes.emplace_back(def);
+ return passes;
+}
+
+void mlir::tblgen::emitPassOptionsStruct(const Pass &pass,
+ llvm::raw_ostream &os) {
+ StringRef passName = pass.getDef()->getName();
+ ArrayRef<PassOption> options = pass.getOptions();
+
+ // Emit the struct only if the pass has at least one option.
+ if (options.empty())
+ return;
+
+ os << formatv("struct {0}Options {{\n", passName);
+
+ for (const PassOption &opt : options) {
+ std::string type = opt.getType().str();
+
+ if (opt.isListOption())
+ type = "::llvm::SmallVector<" + type + ">";
+
+ os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName());
+
+ if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
+ os << " = " << defaultVal;
+
+ os << ";\n";
+ }
+
+ os << "};\n";
+}
+
+void mlir::tblgen::emitPassDecls(const Pass &pass, llvm::raw_ostream &os) {
+ StringRef passName = pass.getDef()->getName();
+ std::string enableVarName = getPassDeclVarName(pass);
+
+ os << "#ifdef " << enableVarName << "\n";
+ emitPassOptionsStruct(pass, os);
+
+ if (StringRef constructor = pass.getConstructor(); constructor.empty()) {
+ // Default constructor declaration.
+ os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n";
+
+ // Declaration of the constructor with options.
+ if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
+ os << formatv("std::unique_ptr<::mlir::Pass> create{0}("
+ "{0}Options options);\n",
+ passName);
+ }
+
+ os << "#undef " << enableVarName << "\n";
+ os << "#endif // " << enableVarName << "\n";
+}
+
+void mlir::tblgen::emitPassOptionDecls(const Pass &pass,
+ llvm::raw_ostream &os) {
for (const PassOption &opt : pass.getOptions()) {
os.indent(2) << "::mlir::Pass::"
<< (opt.isListOption() ? "ListOption" : "Option");
@@ -297,8 +245,8 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
}
}
-/// Emit the declarations for each of the pass statistics.
-static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
+void mlir::tblgen::emitPassStatisticDecls(const Pass &pass,
+ llvm::raw_ostream &os) {
for (const PassStatistic &stat : pass.getStatistics()) {
os << formatv(
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", R\"PS({2})PS\"};\n",
@@ -307,8 +255,7 @@ static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
}
}
-/// Emit the code to be used in the implementation of the pass.
-static void emitPassDefs(const Pass &pass, raw_ostream &os) {
+void mlir::tblgen::emitPassDefs(const Pass &pass, llvm::raw_ostream &os) {
StringRef passName = pass.getDef()->getName();
std::string enableVarName = "GEN_PASS_DEF_" + passName.upper();
bool emitDefaultConstructors = pass.getConstructor().empty();
@@ -350,12 +297,12 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
os.indent(2) << "}\n";
}
- // Protected content
+ // Protected content.
os << "protected:\n";
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
- // Private content
+ // Private content.
os << "private:\n";
if (emitDefaultConstructors) {
@@ -379,7 +326,7 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
os << "#endif // " << enableVarName << "\n";
}
-static void emitPass(const Pass &pass, raw_ostream &os) {
+void mlir::tblgen::emitPass(const Pass &pass, llvm::raw_ostream &os) {
StringRef passName = pass.getDef()->getName();
os << formatv(passHeader, passName);
@@ -387,7 +334,42 @@ static void emitPass(const Pass &pass, raw_ostream &os) {
emitPassDefs(pass, os);
}
-static void emitPasses(const RecordKeeper &records, raw_ostream &os) {
+void mlir::tblgen::emitRegistrations(llvm::ArrayRef<Pass> passes,
+ llvm::StringRef groupName,
+ llvm::raw_ostream &os) {
+ os << "#ifdef GEN_PASS_REGISTRATION\n";
+ os << "// Generate registrations for all passes.\n";
+ for (const Pass &pass : passes)
+ os << "#define " << getPassRegistrationVarName(pass) << "\n";
+ os << "#endif // GEN_PASS_REGISTRATION\n";
+
+ for (const Pass &pass : passes) {
+ std::string passName = pass.getDef()->getName().str();
+ std::string passEnableVarName = getPassRegistrationVarName(pass);
+
+ std::string constructorCall;
+ if (StringRef constructor = pass.getConstructor(); !constructor.empty())
+ constructorCall = constructor.str();
+ else
+ constructorCall = formatv("create{0}()", passName).str();
+ os << formatv(passRegistrationCode, passName, passEnableVarName,
+ constructorCall);
+ }
+
+ os << "#ifdef GEN_PASS_REGISTRATION\n";
+ os << formatv(passGroupRegistrationCode, groupName);
+
+ for (const Pass &pass : passes)
+ os << " register" << pass.getDef()->getName() << "();\n";
+
+ os << "}\n";
+ os << "#undef GEN_PASS_REGISTRATION\n";
+ os << "#endif // GEN_PASS_REGISTRATION\n";
+}
+
+void mlir::tblgen::emitPasses(const RecordKeeper &records,
+ llvm::StringRef groupName,
+ llvm::raw_ostream &os) {
std::vector<Pass> passes = getPasses(records);
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
@@ -402,7 +384,7 @@ static void emitPasses(const RecordKeeper &records, raw_ostream &os) {
for (const Pass &pass : passes)
emitPass(pass, os);
- emitRegistrations(passes, os);
+ emitRegistrations(passes, groupName, os);
// TODO: Remove warning, kept in to make error understandable.
// Emit the old code until all the passes have switched to the new design.
@@ -411,10 +393,3 @@ static void emitPasses(const RecordKeeper &records, raw_ostream &os) {
os << "#undef GEN_PASS_CLASSES\n";
os << "#endif // GEN_PASS_CLASSES\n";
}
-
-static mlir::GenRegistration
- genPassDecls("gen-pass-decls", "Generate pass declarations",
- [](const RecordKeeper &records, raw_ostream &os) {
- emitPasses(records, os);
- return false;
- });
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/lib/TableGen/Generators/RewriterGen.cpp
similarity index 86%
rename from mlir/tools/mlir-tblgen/RewriterGen.cpp
rename to mlir/lib/TableGen/Generators/RewriterGen.cpp
index e3043708a46d1..8b1bb1c506fb0 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/lib/TableGen/Generators/RewriterGen.cpp
@@ -10,12 +10,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/TableGen/Generators/RewriterGen.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
-#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Pattern.h"
#include "mlir/TableGen/Predicate.h"
@@ -25,7 +25,6 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatAdapters.h"
#include "llvm/Support/PrettyStackTrace.h"
@@ -64,18 +63,18 @@ class StaticMatcherHelper;
class PatternEmitter {
public:
- PatternEmitter(const Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
- StaticMatcherHelper &helper);
+ PatternEmitter(const Record *pat, tblgen::RecordOperatorMap *mapper,
+ raw_ostream &os, StaticMatcherHelper &helper);
// Emits the mlir::RewritePattern struct named `rewriteName`.
void emit(StringRef rewriteName);
// Emits the static function of DAG matcher.
- void emitStaticMatcher(DagNode tree, std::string funcName);
+ void emitStaticMatcher(tblgen::DagNode tree, std::string funcName);
private:
// Emits the code for matching ops.
- void emitMatchLogic(DagNode tree, StringRef opName);
+ void emitMatchLogic(tblgen::DagNode tree, StringRef opName);
// Emits the code for rewriting ops.
void emitRewriteLogic();
@@ -85,50 +84,51 @@ class PatternEmitter {
//===--------------------------------------------------------------------===//
// Emits C++ statements for matching the DAG structure.
- void emitMatch(DagNode tree, StringRef name, int depth);
+ void emitMatch(tblgen::DagNode tree, StringRef name, int depth);
// Emit C++ function call to static DAG matcher.
- void emitStaticMatchCall(DagNode tree, StringRef name);
+ void emitStaticMatchCall(tblgen::DagNode tree, StringRef name);
// Emit C++ function call to static type/attribute constraint function.
void emitStaticVerifierCall(StringRef funcName, StringRef opName,
StringRef arg, StringRef failureStr);
// Emits C++ statements for matching using a native code call.
- void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
+ void emitNativeCodeMatch(tblgen::DagNode tree, StringRef name, int depth);
// Emits C++ statements for matching the op constrained by the given DAG
// `tree` returning the op's variable name.
- void emitOpMatch(DagNode tree, StringRef opName, int depth);
+ void emitOpMatch(tblgen::DagNode tree, StringRef opName, int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
// bound name and the constraint of the operand respectively.
- void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
- int operandIndex, DagLeaf operandMatcher,
- StringRef argName, int argIndex,
- std::optional<int> variadicSubIndex);
+ void emitOperandMatch(tblgen::DagNode tree, StringRef opName,
+ StringRef operandName, int operandIndex,
+ tblgen::DagLeaf operandMatcher, StringRef argName,
+ int argIndex, std::optional<int> variadicSubIndex);
// Emits C++ statements for matching the operands which can be matched in
// either order.
- void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
- StringRef opName, int argIndex, int &operandIndex,
- int depth);
+ void emitEitherOperandMatch(tblgen::DagNode tree,
+ tblgen::DagNode eitherArgTree, StringRef opName,
+ int argIndex, int &operandIndex, int depth);
// Emits C++ statements for matching a variadic operand.
- void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree,
+ void emitVariadicOperandMatch(tblgen::DagNode tree,
+ tblgen::DagNode variadicArgTree,
StringRef opName, int argIndex,
int &operandIndex, int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
- void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
- int depth);
+ void emitAttributeMatch(tblgen::DagNode tree, StringRef castedName,
+ int argIndex, int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as a property.
- void emitPropertyMatch(DagNode tree, StringRef castedName, int argIndex,
- int depth);
+ void emitPropertyMatch(tblgen::DagNode tree, StringRef castedName,
+ int argIndex, int depth);
// Emits C++ for checking a match with a corresponding match failure
// diagnostic.
@@ -154,93 +154,98 @@ class PatternEmitter {
// of the matched root op this pattern is intended to replace, which can be
// used to deduce the result type of the op generated from this result
// pattern.
- std::string handleResultPattern(DagNode resultTree, int resultIndex,
+ std::string handleResultPattern(tblgen::DagNode resultTree, int resultIndex,
int depth);
// Emits the C++ statement to replace the matched DAG with a value built via
// calling native C++ code.
- std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
+ std::string handleReplaceWithNativeCodeCall(tblgen::DagNode resultTree,
+ int depth);
// Returns the symbol of the old value serving as the replacement.
- StringRef handleReplaceWithValue(DagNode tree);
+ StringRef handleReplaceWithValue(tblgen::DagNode tree);
// Emits the C++ statement to replace the matched DAG with an array of
// matched values.
- std::string handleVariadic(DagNode tree, int depth);
+ std::string handleVariadic(tblgen::DagNode tree, int depth);
// Trailing directives are used at the end of DAG node argument lists to
// specify additional behaviour for op matchers and creators, etc.
struct TrailingDirectives {
// DAG node containing the `location` directive. Null if there is none.
- DagNode location;
+ tblgen::DagNode location;
// DAG node containing the `returnType` directive. Null if there is none.
- DagNode returnType;
+ tblgen::DagNode returnType;
// Number of found trailing directives.
int numDirectives;
};
// Collect any trailing directives.
- TrailingDirectives getTrailingDirectives(DagNode tree);
+ TrailingDirectives getTrailingDirectives(tblgen::DagNode tree);
// Returns the location value to use.
std::string getLocation(TrailingDirectives &tail);
// Returns the location value to use.
- std::string handleLocationDirective(DagNode tree);
+ std::string handleLocationDirective(tblgen::DagNode tree);
// Emit return type argument.
- std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
+ std::string handleReturnTypeArg(tblgen::DagNode returnType, int i, int depth);
// Emits the C++ statement to build a new op out of the given DAG `tree` and
// returns the variable name that this op is assigned to. If the root op in
// DAG `tree` has a specified name, the created op will be assigned to a
// variable of the given name. Otherwise, a unique name will be used as the
// result value name.
- std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
+ std::string handleOpCreation(tblgen::DagNode tree, int resultIndex,
+ int depth);
using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
// Emits a local variable for each value and attribute to be used for creating
// an op.
- void createSeparateLocalVarsForOpArgs(DagNode node,
+ void createSeparateLocalVarsForOpArgs(tblgen::DagNode node,
ChildNodeIndexNameMap &childNodeNames);
// Emits the concrete arguments used to call an op's builder.
- void supplyValuesForOpArgs(DagNode node,
+ void supplyValuesForOpArgs(tblgen::DagNode node,
const ChildNodeIndexNameMap &childNodeNames,
int depth);
// Emits the local variables for holding all values as a whole and all named
// attributes as a whole to be used for creating an op.
- void createAggregateLocalVarsForOpArgs(
- DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
+ void
+ createAggregateLocalVarsForOpArgs(tblgen::DagNode node,
+ const ChildNodeIndexNameMap &childNodeNames,
+ int depth);
// Returns the C++ expression to construct a constant attribute of the given
// `value` for the given attribute kind `attr`.
- std::string handleConstantAttr(Attribute attr, const Twine &value);
+ std::string handleConstantAttr(tblgen::Attribute attr, const Twine &value);
// Returns the C++ expression to build an argument from the given DAG `leaf`.
// `patArgName` is used to bound the argument to the source pattern.
- std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
+ std::string handleOpArgument(tblgen::DagLeaf leaf, StringRef patArgName);
//===--------------------------------------------------------------------===//
// General utilities
//===--------------------------------------------------------------------===//
// Collects all of the operations within the given dag tree.
- void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
+ void collectOps(tblgen::DagNode tree,
+ llvm::SmallPtrSetImpl<const tblgen::Operator *> &ops);
// Returns a unique symbol for a local variable of the given `op`.
- std::string getUniqueSymbol(const Operator *op);
+ std::string getUniqueSymbol(const tblgen::Operator *op);
//===--------------------------------------------------------------------===//
// Symbol utilities
//===--------------------------------------------------------------------===//
// Returns how many static values the given DAG `node` correspond to.
- int getNodeValueCount(DagNode node);
+ int getNodeValueCount(tblgen::DagNode node);
private:
// Pattern instantiation location followed by the location of multiclass
@@ -249,13 +254,13 @@ class PatternEmitter {
ArrayRef<SMLoc> loc;
// Op's TableGen Record to wrapper object.
- RecordOperatorMap *opMap;
+ tblgen::RecordOperatorMap *opMap;
// Handy wrapper for pattern being emitted.
- Pattern pattern;
+ tblgen::Pattern pattern;
// Map for all bound symbols' info.
- SymbolInfoMap symbolInfoMap;
+ tblgen::SymbolInfoMap symbolInfoMap;
StaticMatcherHelper &staticMatcherHelper;
@@ -265,7 +270,7 @@ class PatternEmitter {
raw_indented_ostream os;
// Format contexts containing placeholder substitutions.
- FmtContext fmtCtx;
+ tblgen::FmtContext fmtCtx;
};
// Tracks DagNode's reference multiple times across patterns. Enables generating
@@ -274,11 +279,11 @@ class PatternEmitter {
class StaticMatcherHelper {
public:
StaticMatcherHelper(raw_ostream &os, const RecordKeeper &records,
- RecordOperatorMap &mapper);
+ tblgen::RecordOperatorMap &mapper);
// Determine if we should inline the match logic or delegate to a static
// function.
- bool useStaticMatcher(DagNode node) {
+ bool useStaticMatcher(tblgen::DagNode node) {
// either/variadic node must be associated to the parentOp, thus we can't
// emit a static matcher rooted at them.
if (node.isEither() || node.isVariadic())
@@ -288,13 +293,13 @@ class StaticMatcherHelper {
}
// Get the name of the static DAG matcher function corresponding to the node.
- std::string getMatcherName(DagNode node) {
+ std::string getMatcherName(tblgen::DagNode node) {
assert(useStaticMatcher(node));
return matcherNames[node];
}
// Get the name of static type/attribute verification function.
- StringRef getVerifierName(DagLeaf leaf);
+ StringRef getVerifierName(tblgen::DagLeaf leaf);
// Collect the `Record`s, i.e., the DRR, so that we can get the information of
// the duplicated DAGs.
@@ -327,39 +332,40 @@ class StaticMatcherHelper {
// inlining.
//
// The topological order of all the DagNodes among all patterns.
- SmallVector<std::pair<DagNode, const Record *>> topologicalOrder;
+ SmallVector<std::pair<tblgen::DagNode, const Record *>> topologicalOrder;
- RecordOperatorMap &opMap;
+ tblgen::RecordOperatorMap &opMap;
- // Records of the static function name of each DagNode
- DenseMap<DagNode, std::string> matcherNames;
+ // Records of the static function name of each DagNode.
+ DenseMap<tblgen::DagNode, std::string> matcherNames;
// After collecting all the DagNode in each pattern, `refStats` records the
// number of users for each DagNode. We will generate the static matcher for a
// DagNode while the number of users exceeds a certain threshold.
- DenseMap<DagNode, unsigned> refStats;
+ DenseMap<tblgen::DagNode, unsigned> refStats;
// Number of static matcher generated. This is used to generate a unique name
// for each DagNode.
int staticMatcherCounter = 0;
// The DagLeaf which contains type, attr, or prop constraint.
- SetVector<DagLeaf> constraints;
+ SetVector<tblgen::DagLeaf> constraints;
// Static type/attribute verification function emitter.
- StaticVerifierFunctionEmitter staticVerifierEmitter;
+ tblgen::StaticVerifierFunctionEmitter staticVerifierEmitter;
};
} // namespace
-PatternEmitter::PatternEmitter(const Record *pat, RecordOperatorMap *mapper,
+PatternEmitter::PatternEmitter(const Record *pat,
+ tblgen::RecordOperatorMap *mapper,
raw_ostream &os, StaticMatcherHelper &helper)
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
fmtCtx.withBuilder("rewriter");
}
-std::string PatternEmitter::handleConstantAttr(Attribute attr,
+std::string PatternEmitter::handleConstantAttr(tblgen::Attribute attr,
const Twine &value) {
if (!attr.isConstBuildable())
PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
@@ -369,7 +375,8 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
}
-void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
+void PatternEmitter::emitStaticMatcher(tblgen::DagNode tree,
+ std::string funcName) {
os << formatv(
"static ::llvm::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
"::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
@@ -397,7 +404,8 @@ void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
}
// Helper function to match patterns.
-void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
+void PatternEmitter::emitMatch(tblgen::DagNode tree, StringRef name,
+ int depth) {
if (tree.isNativeCodeCall()) {
emitNativeCodeMatch(tree, name, depth);
return;
@@ -411,7 +419,8 @@ void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
}
-void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
+void PatternEmitter::emitStaticMatchCall(tblgen::DagNode tree,
+ StringRef opName) {
std::string funcName = staticMatcherHelper.getMatcherName(tree);
os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName,
opName);
@@ -427,7 +436,7 @@ void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
// global symbolInfoMap.
// Collect all the bound symbols in the Dag
- SymbolInfoMap localSymbolMap(loc);
+ tblgen::SymbolInfoMap localSymbolMap(loc);
pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
for (const auto &info : localSymbolMap) {
@@ -452,7 +461,7 @@ void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
}
// Helper function to match patterns.
-void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
+void PatternEmitter::emitNativeCodeMatch(tblgen::DagNode tree, StringRef opName,
int depth) {
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
LLVM_DEBUG(tree.print(llvm::dbgs()));
@@ -481,7 +490,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
std::string argName = formatv("arg{0}_{1}", depth, i);
- if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+ if (tblgen::DagNode argTree = tree.getArgAsNestedDag(i)) {
if (argTree.isEither())
PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
if (argTree.isVariadic())
@@ -499,7 +508,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
"with unspecified interface type");
os << interfaceType << " " << argName;
if (leaf.isPropDefinition()) {
- Property propDef = leaf.getAsProperty();
+ tblgen::Property propDef = leaf.getAsProperty();
// Ensure properties that aren't zero-arg-constructable still work.
if (propDef.hasDefaultValue())
os << " = " << propDef.getDefaultValue();
@@ -548,7 +557,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
argName, i));
}
- DagLeaf leaf = tree.getArgAsLeaf(i);
+ tblgen::DagLeaf leaf = tree.getArgAsLeaf(i);
// The parameter for native function doesn't bind any constraints.
if (leaf.isUnspecified())
@@ -576,8 +585,9 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
}
// Helper function to match patterns.
-void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
- Operator &op = tree.getDialectOp(opMap);
+void PatternEmitter::emitOpMatch(tblgen::DagNode tree, StringRef opName,
+ int depth) {
+ tblgen::Operator &op = tree.getDialectOp(opMap);
LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
<< op.getOperationName() << "' at depth " << depth
<< '\n');
@@ -625,7 +635,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
std::string argName = formatv("op{0}", depth + 1);
// Handle nested DAG construct first
- if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+ if (tblgen::DagNode argTree = tree.getArgAsNestedDag(i)) {
if (argTree.isEither()) {
emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand,
depth);
@@ -633,7 +643,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
continue;
}
if (auto *operand =
- llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
+ llvm::dyn_cast_if_present<tblgen::NamedTypeConstraint *>(opArg)) {
if (argTree.isVariadic()) {
if (!operand->isVariadic()) {
auto error = formatv("variadic DAG construct can't match op {0}'s "
@@ -674,7 +684,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
}
// Next handle DAG leaf: operand or attribute
- if (isa<NamedTypeConstraint *>(opArg)) {
+ if (isa<tblgen::NamedTypeConstraint *>(opArg)) {
auto operandName =
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
@@ -682,9 +692,9 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
/*argName=*/tree.getArgName(i), opArgIdx,
/*variadicSubIndex=*/std::nullopt);
++nextOperand;
- } else if (isa<NamedAttribute *>(opArg)) {
+ } else if (isa<tblgen::NamedAttribute *>(opArg)) {
emitAttributeMatch(tree, castedName, opArgIdx, depth);
- } else if (isa<NamedProperty *>(opArg)) {
+ } else if (isa<tblgen::NamedProperty *>(opArg)) {
emitPropertyMatch(tree, castedName, opArgIdx, depth);
} else {
PrintFatalError(loc, "unhandled case when matching op");
@@ -695,13 +705,13 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
<< '\n');
}
-void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
+void PatternEmitter::emitOperandMatch(tblgen::DagNode tree, StringRef opName,
StringRef operandName, int operandIndex,
- DagLeaf operandMatcher, StringRef argName,
- int argIndex,
+ tblgen::DagLeaf operandMatcher,
+ StringRef argName, int argIndex,
std::optional<int> variadicSubIndex) {
- Operator &op = tree.getDialectOp(opMap);
- NamedTypeConstraint operand = op.getOperand(operandIndex);
+ tblgen::Operator &op = tree.getDialectOp(opMap);
+ tblgen::NamedTypeConstraint operand = op.getOperand(operandIndex);
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
@@ -713,7 +723,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Only need to verify if the matcher's type is different from the one
// of op definition.
- Constraint constraint = operandMatcher.getAsConstraint();
+ tblgen::Constraint constraint = operandMatcher.getAsConstraint();
if (operand.constraint != constraint) {
if (operand.isVariableLength()) {
auto error = formatv(
@@ -745,14 +755,15 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
}
}
-void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
+void PatternEmitter::emitEitherOperandMatch(tblgen::DagNode tree,
+ tblgen::DagNode eitherArgTree,
StringRef opName, int argIndex,
int &operandIndex, int depth) {
constexpr int numEitherArgs = 2;
if (eitherArgTree.getNumArgs() != numEitherArgs)
PrintFatalError(loc, "`either` only supports grouping two operands");
- Operator &op = tree.getDialectOp(opMap);
+ tblgen::Operator &op = tree.getDialectOp(opMap);
std::string codeBuffer;
llvm::raw_string_ostream tblgenOps(codeBuffer);
@@ -765,7 +776,7 @@ void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
os.indent();
for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
- if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
+ if (tblgen::DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
if (argTree.isEither())
PrintFatalError(loc, "either cannot be nested");
@@ -791,7 +802,7 @@ void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
// need to queue the operation only if the matching success. Thus we emit
// the code at the end.
tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
- } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
+ } else if (isa<tblgen::NamedTypeConstraint *>(op.getArg(argIndex))) {
emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
operandIndex,
/*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
@@ -824,11 +835,11 @@ void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
os.unindent().unindent() << "}\n";
}
-void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
- DagNode variadicArgTree,
+void PatternEmitter::emitVariadicOperandMatch(tblgen::DagNode tree,
+ tblgen::DagNode variadicArgTree,
StringRef opName, int argIndex,
int &operandIndex, int depth) {
- Operator &op = tree.getDialectOp(opMap);
+ tblgen::Operator &op = tree.getDialectOp(opMap);
os << "{\n";
os.indent();
@@ -852,7 +863,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
}
for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) {
- if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) {
+ if (tblgen::DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) {
if (!argTree.isOperation())
PrintFatalError(loc, "variadic only accepts operation sub-dags");
@@ -872,7 +883,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
os << formatv("tblgen_ops.push_back({0});\n", argName);
os.unindent() << "}\n";
- } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
+ } else if (isa<tblgen::NamedTypeConstraint *>(op.getArg(argIndex))) {
auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i);
emitOperandMatch(tree, opName, operandName.str(), operandIndex,
/*operandMatcher=*/variadicArgTree.getArgAsLeaf(i),
@@ -885,10 +896,11 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
os.unindent() << "}\n";
}
-void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
- int argIndex, int depth) {
- Operator &op = tree.getDialectOp(opMap);
- auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
+void PatternEmitter::emitAttributeMatch(tblgen::DagNode tree,
+ StringRef castedName, int argIndex,
+ int depth) {
+ tblgen::Operator &op = tree.getDialectOp(opMap);
+ auto *namedAttr = cast<tblgen::NamedAttribute *>(op.getArg(argIndex));
const auto &attr = namedAttr->attr;
os << "{\n";
@@ -957,10 +969,11 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
os.unindent() << "}\n";
}
-void PatternEmitter::emitPropertyMatch(DagNode tree, StringRef castedName,
- int argIndex, int depth) {
- Operator &op = tree.getDialectOp(opMap);
- auto *namedProp = cast<NamedProperty *>(op.getArg(argIndex));
+void PatternEmitter::emitPropertyMatch(tblgen::DagNode tree,
+ StringRef castedName, int argIndex,
+ int depth) {
+ tblgen::Operator &op = tree.getDialectOp(opMap);
+ auto *namedProp = cast<tblgen::NamedProperty *>(op.getArg(argIndex));
os << "{\n";
os.indent() << formatv(
@@ -1013,7 +1026,7 @@ void PatternEmitter::emitMatchCheck(StringRef opName,
<< failureStr << ";\n});";
}
-void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
+void PatternEmitter::emitMatchLogic(tblgen::DagNode tree, StringRef opName) {
LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
int depth = 0;
emitMatch(tree, opName, depth);
@@ -1049,7 +1062,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
auto &entities = appliedConstraint.entities;
auto condition = constraint.getConditionTemplate();
- if (isa<TypeConstraint>(constraint)) {
+ if (isa<tblgen::TypeConstraint>(constraint)) {
if (entities.size() != 1)
PrintFatalError(loc, "type constraint requires exactly one argument");
@@ -1060,7 +1073,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
entities.front(), escapeString(constraint.getSummary())));
- } else if (isa<AttrConstraint>(constraint)) {
+ } else if (isa<tblgen::AttrConstraint>(constraint)) {
PrintFatalError(
loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
} else {
@@ -1091,11 +1104,12 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
}
-void PatternEmitter::collectOps(DagNode tree,
- llvm::SmallPtrSetImpl<const Operator *> &ops) {
+void PatternEmitter::collectOps(
+ tblgen::DagNode tree,
+ llvm::SmallPtrSetImpl<const tblgen::Operator *> &ops) {
// Check if this tree is an operation.
if (tree.isOperation()) {
- const Operator &op = tree.getDialectOp(opMap);
+ const tblgen::Operator &op = tree.getDialectOp(opMap);
LLVM_DEBUG(llvm::dbgs()
<< "found operation " << op.getOperationName() << '\n');
ops.insert(&op);
@@ -1109,13 +1123,13 @@ void PatternEmitter::collectOps(DagNode tree,
void PatternEmitter::emit(StringRef rewriteName) {
// Get the DAG tree for the source pattern.
- DagNode sourceTree = pattern.getSourcePattern();
+ tblgen::DagNode sourceTree = pattern.getSourcePattern();
- const Operator &rootOp = pattern.getSourceRootOp();
+ const tblgen::Operator &rootOp = pattern.getSourceRootOp();
auto rootName = rootOp.getOperationName();
// Collect the set of result operations.
- llvm::SmallPtrSet<const Operator *, 4> resultOps;
+ llvm::SmallPtrSet<const tblgen::Operator *, 4> resultOps;
LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
collectOps(pattern.getResultPattern(i), resultOps);
@@ -1131,12 +1145,13 @@ void PatternEmitter::emit(StringRef rewriteName) {
: ::mlir::RewritePattern("{1}", {2}, context, {{)",
rewriteName, rootName, pattern.getBenefit());
// Sort result operators by name.
- llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
- resultOps.end());
- llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
- return lhs->getOperationName() < rhs->getOperationName();
- });
- llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
+ llvm::SmallVector<const tblgen::Operator *, 4> sortedResultOps(
+ resultOps.begin(), resultOps.end());
+ llvm::sort(sortedResultOps,
+ [&](const tblgen::Operator *lhs, const tblgen::Operator *rhs) {
+ return lhs->getOperationName() < rhs->getOperationName();
+ });
+ llvm::interleaveComma(sortedResultOps, os, [&](const tblgen::Operator *op) {
os << '"' << op->getOperationName() << '"';
});
os << "}) {}\n";
@@ -1188,7 +1203,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
void PatternEmitter::emitRewriteLogic() {
LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
- const Operator &rootOp = pattern.getSourceRootOp();
+ const tblgen::Operator &rootOp = pattern.getSourceRootOp();
int numExpectedResults = rootOp.getNumResults();
int numResultPatterns = pattern.getNumResultPatterns();
@@ -1232,7 +1247,7 @@ void PatternEmitter::emitRewriteLogic() {
// Process auxiliary result patterns.
for (int i = 0; i < replStartIndex; ++i) {
- DagNode resultTree = pattern.getResultPattern(i);
+ tblgen::DagNode resultTree = pattern.getResultPattern(i);
auto val = handleResultPattern(resultTree, offsets[i], 0);
// Normal op creation will be streamed to `os` by the above call; but
// NativeCodeCall will only be materialized to `os` if it is used. Here
@@ -1247,7 +1262,7 @@ void PatternEmitter::emitRewriteLogic() {
int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
for (int i = 0, offset = -numSupplementalPatterns;
i < numSupplementalPatterns; ++i) {
- DagNode resultTree = pattern.getSupplementalPattern(i);
+ tblgen::DagNode resultTree = pattern.getSupplementalPattern(i);
auto val = handleResultPattern(resultTree, offset++, 0);
if (resultTree.isNativeCodeCall() &&
resultTree.getNumReturnsOfNativeCode() == 0)
@@ -1265,7 +1280,7 @@ void PatternEmitter::emitRewriteLogic() {
// Process replacement result patterns.
os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
for (int i = replStartIndex; i < numResultPatterns; ++i) {
- DagNode resultTree = pattern.getResultPattern(i);
+ tblgen::DagNode resultTree = pattern.getResultPattern(i);
auto val = handleResultPattern(resultTree, offsets[i], 0);
os << "\n";
// Resolve each symbol for all range use so that we can loop over them.
@@ -1287,12 +1302,12 @@ void PatternEmitter::emitRewriteLogic() {
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
}
-std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
+std::string PatternEmitter::getUniqueSymbol(const tblgen::Operator *op) {
return std::string(
formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
}
-std::string PatternEmitter::handleResultPattern(DagNode resultTree,
+std::string PatternEmitter::handleResultPattern(tblgen::DagNode resultTree,
int resultIndex, int depth) {
LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
LLVM_DEBUG(resultTree.print(llvm::dbgs()));
@@ -1322,7 +1337,7 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
return symbol;
}
-std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
+std::string PatternEmitter::handleVariadic(tblgen::DagNode tree, int depth) {
assert(tree.isVariadic());
std::string output;
@@ -1345,7 +1360,7 @@ std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
return name;
}
-StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
+StringRef PatternEmitter::handleReplaceWithValue(tblgen::DagNode tree) {
assert(tree.isReplaceWithValue());
if (tree.getNumArgs() != 1) {
@@ -1360,7 +1375,7 @@ StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
return tree.getArgName(0);
}
-std::string PatternEmitter::handleLocationDirective(DagNode tree) {
+std::string PatternEmitter::handleLocationDirective(tblgen::DagNode tree) {
assert(tree.isLocationDirective());
auto lookUpArgLoc = [this, &tree](int idx) {
const auto *const lookupFmt = "{0}.getLoc()";
@@ -1375,7 +1390,7 @@ std::string PatternEmitter::handleLocationDirective(DagNode tree) {
PrintFatalError(loc, "cannot bind symbol to location");
if (tree.getNumArgs() == 1) {
- DagLeaf leaf = tree.getArgAsLeaf(0);
+ tblgen::DagLeaf leaf = tree.getArgAsLeaf(0);
if (leaf.isStringAttr())
return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
leaf.getStringAttr())
@@ -1389,7 +1404,7 @@ std::string PatternEmitter::handleLocationDirective(DagNode tree) {
os << "rewriter.getFusedLoc({";
bool first = true;
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
- DagLeaf leaf = tree.getArgAsLeaf(i);
+ tblgen::DagLeaf leaf = tree.getArgAsLeaf(i);
// Handle the optional string value.
if (leaf.isStringAttr()) {
if (!strAttr.empty())
@@ -1408,8 +1423,8 @@ std::string PatternEmitter::handleLocationDirective(DagNode tree) {
return os.str();
}
-std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
- int depth) {
+std::string PatternEmitter::handleReturnTypeArg(tblgen::DagNode returnType,
+ int i, int depth) {
// Nested NativeCodeCall.
if (auto dagNode = returnType.getArgAsNestedDag(i)) {
if (!dagNode.isNativeCodeCall())
@@ -1426,7 +1441,7 @@ std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
}
-std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
+std::string PatternEmitter::handleOpArgument(tblgen::DagLeaf leaf,
StringRef patArgName) {
if (leaf.isStringAttr())
PrintFatalError(loc, "raw string not supported as argument");
@@ -1440,7 +1455,7 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
// This is an enum case backed by an IntegerAttr. We need to get its value
// to build the constant.
std::string val = std::to_string(enumCase.getValue());
- return handleConstantAttr(Attribute(&enumCase.getDef()), val);
+ return handleConstantAttr(tblgen::Attribute(&enumCase.getDef()), val);
}
if (leaf.isConstantProp()) {
auto constantProp = leaf.getAsConstantProp();
@@ -1463,8 +1478,9 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
PrintFatalError(loc, "unhandled case when rewriting op");
}
-std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
- int depth) {
+std::string
+PatternEmitter::handleReplaceWithNativeCodeCall(tblgen::DagNode tree,
+ int depth) {
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
LLVM_DEBUG(tree.print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
@@ -1511,7 +1527,7 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
if (tree.getNumReturnsOfNativeCode() != 0) {
// Determine the local variable name for return value.
std::string varName =
- SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
+ tblgen::SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
if (varName.empty()) {
varName = formatv("nativeVar_{0}", nextValueId++);
// Register the local variable for later uses.
@@ -1530,7 +1546,7 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
return symbol;
}
-int PatternEmitter::getNodeValueCount(DagNode node) {
+int PatternEmitter::getNodeValueCount(tblgen::DagNode node) {
if (node.isOperation()) {
// If the op is bound to a symbol in the rewrite rule, query its result
// count from the symbol info map.
@@ -1549,8 +1565,9 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
}
PatternEmitter::TrailingDirectives
-PatternEmitter::getTrailingDirectives(DagNode tree) {
- TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
+PatternEmitter::getTrailingDirectives(tblgen::DagNode tree) {
+ TrailingDirectives tail = {tblgen::DagNode(nullptr), tblgen::DagNode(nullptr),
+ 0};
// Look backwards through the arguments.
auto numPatArgs = tree.getNumArgs();
@@ -1593,13 +1610,13 @@ PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
return "odsLoc";
}
-std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
- int depth) {
+std::string PatternEmitter::handleOpCreation(tblgen::DagNode tree,
+ int resultIndex, int depth) {
LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
LLVM_DEBUG(tree.print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
- Operator &resultOp = tree.getDialectOp(opMap);
+ tblgen::Operator &resultOp = tree.getDialectOp(opMap);
auto numOpArgs = resultOp.getNumArgs();
auto numPatArgs = tree.getNumArgs();
@@ -1624,7 +1641,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// FIXME: This does not yet check for variable/leaf case.
// FIXME: Change so that native code call can be handled.
const auto *operand =
- llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(i));
+ llvm::dyn_cast_if_present<tblgen::NamedTypeConstraint *>(
+ resultOp.getArg(i));
if (!operand || !operand->isVariadic())
return;
@@ -1667,7 +1685,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
resultValue = std::string(tree.getSymbol());
// Strip the index to get the name for the value pack and use it to name the
// local variable for the op.
- valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
+ valuePackName =
+ std::string(tblgen::SymbolInfoMap::getValuePackName(resultValue));
}
// Create the local variable for this op.
@@ -1758,8 +1777,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
}
void PatternEmitter::createSeparateLocalVarsForOpArgs(
- DagNode node, ChildNodeIndexNameMap &childNodeNames) {
- Operator &resultOp = node.getDialectOp(opMap);
+ tblgen::DagNode node, ChildNodeIndexNameMap &childNodeNames) {
+ tblgen::Operator &resultOp = node.getDialectOp(opMap);
// Now prepare operands used for building this op:
// * If the operand is non-variadic, we create a `Value` local variable.
@@ -1768,8 +1787,9 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
int valueIndex = 0; // An index for uniquing local variable names.
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
- const auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(
- resultOp.getArg(argIndex));
+ const auto *operand =
+ llvm::dyn_cast_if_present<tblgen::NamedTypeConstraint *>(
+ resultOp.getArg(argIndex));
// We do not need special handling for attributes or properties.
if (!operand)
continue;
@@ -1796,7 +1816,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
} else {
- DagLeaf leaf = node.getArgAsLeaf(argIndex);
+ tblgen::DagLeaf leaf = node.getArgAsLeaf(argIndex);
auto symbol =
symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
if (leaf.isNativeCodeCall()) {
@@ -1815,17 +1835,18 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
}
void PatternEmitter::supplyValuesForOpArgs(
- DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
- Operator &resultOp = node.getDialectOp(opMap);
+ tblgen::DagNode node, const ChildNodeIndexNameMap &childNodeNames,
+ int depth) {
+ tblgen::Operator &resultOp = node.getDialectOp(opMap);
for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
argIndex != numOpArgs; ++argIndex) {
// Start each argument on its own line.
os << ",\n ";
- Argument opArg = resultOp.getArg(argIndex);
+ tblgen::Argument opArg = resultOp.getArg(argIndex);
// Handle the case of operand first.
if (auto *operand =
- llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
+ llvm::dyn_cast_if_present<tblgen::NamedTypeConstraint *>(opArg)) {
if (!operand->name.empty())
os << "/*" << operand->name << "=*/";
os << childNodeNames.lookup(argIndex);
@@ -1845,12 +1866,12 @@ void PatternEmitter::supplyValuesForOpArgs(
auto patArgName = node.getArgName(argIndex);
if (leaf.isConstantAttr() || leaf.isEnumCase()) {
// TODO: Refactor out into map to avoid recomputing these.
- if (!isa<NamedAttribute *>(opArg))
+ if (!isa<tblgen::NamedAttribute *>(opArg))
PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";
} else if (leaf.isConstantProp()) {
- if (!isa<NamedProperty *>(opArg))
+ if (!isa<tblgen::NamedProperty *>(opArg))
PrintFatalError(loc, Twine("expected property ") + Twine(argIndex));
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";
@@ -1863,8 +1884,9 @@ void PatternEmitter::supplyValuesForOpArgs(
}
void PatternEmitter::createAggregateLocalVarsForOpArgs(
- DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
- Operator &resultOp = node.getDialectOp(opMap);
+ tblgen::DagNode node, const ChildNodeIndexNameMap &childNodeNames,
+ int depth) {
+ tblgen::Operator &resultOp = node.getDialectOp(opMap);
auto scope = os.scope();
os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
@@ -1881,7 +1903,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
bool hasOperandSegmentSizes = false;
std::vector<std::string> sizes;
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
- if (isa<NamedAttribute *>(resultOp.getArg(argIndex))) {
+ if (isa<tblgen::NamedAttribute *>(resultOp.getArg(argIndex))) {
// The argument in the op definition.
auto opArgName = resultOp.getArgName(argIndex);
hasOperandSegmentSizes =
@@ -1901,7 +1923,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
continue;
}
- if (isa<NamedProperty *>(resultOp.getArg(argIndex))) {
+ if (isa<tblgen::NamedProperty *>(resultOp.getArg(argIndex))) {
// The argument in the op definition.
auto opArgName = resultOp.getArgName(argIndex);
auto setterName = resultOp.getSetterName(opArgName);
@@ -1924,7 +1946,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
}
const auto *operand =
- cast<NamedTypeConstraint *>(resultOp.getArg(argIndex));
+ cast<tblgen::NamedTypeConstraint *>(resultOp.getArg(argIndex));
if (operand->isVariadic()) {
++numVariadic;
std::string range;
@@ -1946,7 +1968,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
os << symbolInfoMap.getValueAndRangeUse(
childNodeNames.lookup(argIndex));
} else {
- DagLeaf leaf = node.getArgAsLeaf(argIndex);
+ tblgen::DagLeaf leaf = node.getArgAsLeaf(argIndex);
if (leaf.isConstantAttr())
// TODO: Use better location
PrintFatalError(
@@ -1983,7 +2005,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
const RecordKeeper &records,
- RecordOperatorMap &mapper)
+ tblgen::RecordOperatorMap &mapper)
: opMap(mapper), staticVerifierEmitter(os, records) {}
void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
@@ -1991,7 +2013,7 @@ void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
// ensure that all the dependent static matchers are generated before emitting
// the matching logic of the DagNode, we use topological order to achieve it.
for (auto &dagInfo : topologicalOrder) {
- DagNode node = dagInfo.first;
+ tblgen::DagNode node = dagInfo.first;
if (!useStaticMatcher(node))
continue;
@@ -2009,23 +2031,23 @@ void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
}
void StaticMatcherHelper::addPattern(const Record *record) {
- Pattern pat(record, &opMap);
+ tblgen::Pattern pat(record, &opMap);
// While generating the function body of the DAG matcher, it may depends on
// other DAG matchers. To ensure the dependent matchers are ready, we compute
// the topological order for all the DAGs and emit the DAG matchers in this
// order.
- llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
+ llvm::unique_function<void(tblgen::DagNode)> dfs = [&](tblgen::DagNode node) {
++refStats[node];
if (refStats[node] != 1)
return;
for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
- if (DagNode sibling = node.getArgAsNestedDag(i))
+ if (tblgen::DagNode sibling = node.getArgAsNestedDag(i))
dfs(sibling);
else {
- DagLeaf leaf = node.getArgAsLeaf(i);
+ tblgen::DagLeaf leaf = node.getArgAsLeaf(i);
if (!leaf.isUnspecified())
constraints.insert(leaf);
}
@@ -2036,7 +2058,7 @@ void StaticMatcherHelper::addPattern(const Record *record) {
dfs(pat.getSourcePattern());
}
-StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
+StringRef StaticMatcherHelper::getVerifierName(tblgen::DagLeaf leaf) {
if (leaf.isAttrMatcher()) {
std::optional<StringRef> constraint =
staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint());
@@ -2053,7 +2075,10 @@ StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
}
-static void emitRewriters(const RecordKeeper &records, raw_ostream &os) {
+namespace mlir {
+namespace tblgen {
+
+void emitRewriters(const llvm::RecordKeeper &records, llvm::raw_ostream &os) {
emitSourceFileHeader("Rewriters", os, records);
auto patterns = records.getAllDerivedDefinitions("Pattern");
@@ -2101,9 +2126,5 @@ static void emitRewriters(const RecordKeeper &records, raw_ostream &os) {
os << "}\n";
}
-static mlir::GenRegistration
- genRewriters("gen-rewriters", "Generate pattern rewriters",
- [](const RecordKeeper &records, raw_ostream &os) {
- emitRewriters(records, os);
- return false;
- });
+} // namespace tblgen
+} // namespace mlir
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
deleted file mode 100644
index d4711532a79bb..0000000000000
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
+++ /dev/null
@@ -1,26 +0,0 @@
-//===- AttrOrTypeFormatGen.h - MLIR attribute and type format generator ---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
-#define MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
-
-#include "mlir/TableGen/Class.h"
-
-namespace mlir {
-namespace tblgen {
-class AttrOrTypeDef;
-
-/// Generate a parser and printer based on a custom assembly format for an
-/// attribute or type.
-void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser,
- MethodBody &printer);
-
-} // namespace tblgen
-} // namespace mlir
-
-#endif // MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index d7087cba3c874..8abc27a7ff4c9 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -8,37 +8,19 @@ set(LLVM_LINK_COMPONENTS
add_tablegen(mlir-tblgen MLIR
DESTINATION "${MLIR_TOOLS_INSTALL_DIR}"
EXPORT MLIR
- AttrOrTypeDefGen.cpp
- AttrOrTypeFormatGen.cpp
- BytecodeDialectGen.cpp
- DialectGen.cpp
- DialectInterfacesGen.cpp
DirectiveCommonGen.cpp
- EnumsGen.cpp
- EnumPythonBindingGen.cpp
- FormatGen.cpp
+ Generators.cpp
LLVMIRConversionGen.cpp
LLVMIRIntrinsicGen.cpp
mlir-tblgen.cpp
OmpOpGen.cpp
- OpClass.cpp
- OpDefinitionsGen.cpp
- OpDocGen.cpp
- OpFormatGen.cpp
- OpGenHelpers.cpp
- OpInterfacesGen.cpp
- OpPythonBindingGen.cpp
- PassCAPIGen.cpp
- PassDocGen.cpp
- PassGen.cpp
- RewriterGen.cpp
SPIRVUtilsGen.cpp
TosaUtilsGen.cpp
- CppGenUtilities.cpp
)
target_link_libraries(mlir-tblgen
PRIVATE
+ MLIRTableGenGenerators
MLIRTblgenLib)
mlir_check_all_link_libraries(mlir-tblgen)
diff --git a/mlir/tools/mlir-tblgen/DialectGenUtilities.h b/mlir/tools/mlir-tblgen/DialectGenUtilities.h
deleted file mode 100644
index 979a9d67b4047..0000000000000
--- a/mlir/tools/mlir-tblgen/DialectGenUtilities.h
+++ /dev/null
@@ -1,24 +0,0 @@
-//===- DialectGenUtilities.h - Utilities for dialect generation -----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TOOLS_MLIRTBLGEN_DIALECTGENUTILITIES_H_
-#define MLIR_TOOLS_MLIRTBLGEN_DIALECTGENUTILITIES_H_
-
-#include "mlir/Support/LLVM.h"
-
-namespace mlir {
-namespace tblgen {
-class Dialect;
-
-/// Find the dialect selected by the user to generate for. Returns std::nullopt
-/// if no dialect was found, or if more than one potential dialect was found.
-std::optional<Dialect> findDialectToGenerate(ArrayRef<Dialect> dialects);
-} // namespace tblgen
-} // namespace mlir
-
-#endif // MLIR_TOOLS_MLIRTBLGEN_DIALECTGENUTILITIES_H_
diff --git a/mlir/tools/mlir-tblgen/Generators.cpp b/mlir/tools/mlir-tblgen/Generators.cpp
new file mode 100644
index 0000000000000..343be1121fb1a
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/Generators.cpp
@@ -0,0 +1,429 @@
+//===- Generators.cpp - Generator registrations for mlir-tblgen -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file registers all generators for mlir-tblgen by calling into the
+// MLIRTableGenCppGen library. CLI options are read here and threaded as
+// explicit parameters to the library functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Generators/AttrOrTypeDefGen.h"
+#include "mlir/TableGen/Generators/BytecodeDialectGen.h"
+#include "mlir/TableGen/Generators/DialectGen.h"
+#include "mlir/TableGen/Generators/DialectInterfacesGen.h"
+#include "mlir/TableGen/Generators/EnumPythonBindingGen.h"
+#include "mlir/TableGen/Generators/EnumsGen.h"
+#include "mlir/TableGen/Generators/FormatGen.h"
+#include "mlir/TableGen/Generators/OpDefinitionsGen.h"
+#include "mlir/TableGen/Generators/OpDocGen.h"
+#include "mlir/TableGen/Generators/OpGenHelpers.h"
+#include "mlir/TableGen/Generators/OpInterfacesGen.h"
+#include "mlir/TableGen/Generators/OpPythonBindingGen.h"
+#include "mlir/TableGen/Generators/PassCAPIGen.h"
+#include "mlir/TableGen/Generators/PassDocGen.h"
+#include "mlir/TableGen/Generators/PassGen.h"
+#include "mlir/TableGen/Generators/RewriterGen.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::tblgen;
+
+//===----------------------------------------------------------------------===//
+// Assembly format options (shared by AttrOrTypeDef and Op generators)
+//===----------------------------------------------------------------------===//
+
+static cl::opt<bool>
+ formatErrorIsFatal("asmformat-error-is-fatal",
+ cl::desc("Emit a fatal error if format parsing fails"),
+ cl::init(true));
+
+//===----------------------------------------------------------------------===//
+// Op definition generator options (shared by op-def and op-doc generators)
+//===----------------------------------------------------------------------===//
+
+static cl::OptionCategory opDefGenCat("Options for op definition generators");
+static cl::opt<std::string> opIncFilter(
+ "op-include-regex",
+ cl::desc("Regex of name of op's to include (no filter if empty)"),
+ cl::cat(opDefGenCat));
+static cl::opt<std::string> opExcFilter(
+ "op-exclude-regex",
+ cl::desc("Regex of name of op's to exclude (no filter if empty)"),
+ cl::cat(opDefGenCat));
+static cl::opt<unsigned> opShardCount(
+ "op-shard-count",
+ cl::desc("The number of shards into which the op classes will be divided"),
+ cl::cat(opDefGenCat), cl::init(1));
+
+static std::vector<const Record *>
+getRequestedOpDefs(const RecordKeeper &records) {
+ return getRequestedOpDefinitions(records, opIncFilter, opExcFilter);
+}
+
+static void shardOps(ArrayRef<const Record *> defs,
+ SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) {
+ shardOpDefinitions(defs, shardedDefs, opShardCount);
+}
+
+//===----------------------------------------------------------------------===//
+// AttrOrTypeDef generators
+//===----------------------------------------------------------------------===//
+
+static cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
+static cl::opt<std::string>
+ attrDialect("attrdefs-dialect",
+ cl::desc("Generate attributes for this dialect"),
+ cl::cat(attrdefGenCat), cl::CommaSeparated);
+
+static GenRegistration
+ genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ AttrDefGenerator generator(records, os, formatErrorIsFatal);
+ return generator.emitDefs(attrDialect);
+ });
+static GenRegistration
+ genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ AttrDefGenerator generator(records, os, formatErrorIsFatal);
+ return generator.emitDecls(attrDialect);
+ });
+static GenRegistration
+ genAttrConstrDefs("gen-attr-constraint-defs",
+ "Generate attribute constraint definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ emitAttrConstraintDefs(records, os);
+ return false;
+ });
+static GenRegistration
+ genAttrConstrDecls("gen-attr-constraint-decls",
+ "Generate attribute constraint declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ emitAttrConstraintDecls(records, os);
+ return false;
+ });
+
+static cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
+static cl::opt<std::string>
+ typeDialect("typedefs-dialect", cl::desc("Generate types for this dialect"),
+ cl::cat(typedefGenCat), cl::CommaSeparated);
+
+static GenRegistration
+ genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ TypeDefGenerator generator(records, os, formatErrorIsFatal);
+ return generator.emitDefs(typeDialect);
+ });
+static GenRegistration
+ genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ TypeDefGenerator generator(records, os, formatErrorIsFatal);
+ return generator.emitDecls(typeDialect);
+ });
+static GenRegistration genTypeConstrDefs("gen-type-constraint-defs",
+ "Generate type constraint definitions",
+ [](const RecordKeeper &records,
+ raw_ostream &os) {
+ emitTypeConstraintDefs(records, os);
+ return false;
+ });
+static GenRegistration
+ genTypeConstrDecls("gen-type-constraint-decls",
+ "Generate type constraint declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ emitTypeConstraintDecls(records, os);
+ return false;
+ });
+
+//===----------------------------------------------------------------------===//
+// Bytecode dialect generator
+//===----------------------------------------------------------------------===//
+
+static cl::OptionCategory bytecodeGenCat("Options for -gen-bytecode");
+static cl::opt<std::string>
+ selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
+ cl::cat(bytecodeGenCat), cl::CommaSeparated);
+
+static GenRegistration
+ genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitBytecodeDialect(records, selectedBcDialect, os);
+ });
+
+//===----------------------------------------------------------------------===//
+// Dialect generators
+//===----------------------------------------------------------------------===//
+
+static cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
+static cl::opt<std::string> selectedDialect("dialect",
+ cl::desc("The dialect to gen for"),
+ cl::cat(dialectGenCat),
+ cl::CommaSeparated);
+
+static GenRegistration
+ genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitDialectDecls(records, selectedDialect, os);
+ });
+static GenRegistration
+ genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitDialectDefs(records, selectedDialect, os);
+ });
+
+//===----------------------------------------------------------------------===//
+// Dialect interface generator
+//===----------------------------------------------------------------------===//
+
+static GenRegistration genDialectInterfaceDecls(
+ "gen-dialect-interface-decls", "Generate dialect interface declarations.",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
+ });
+
+//===----------------------------------------------------------------------===//
+// Enum generators
+//===----------------------------------------------------------------------===//
+
+static GenRegistration
+ genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitEnumDecls(records, os);
+ });
+static GenRegistration
+ genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitEnumDefs(records, os);
+ });
+
+//===----------------------------------------------------------------------===//
+// Op definition generators
+//===----------------------------------------------------------------------===//
+
+static GenRegistration
+ genOpDecls("gen-op-decls", "Generate op declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ std::vector<const Record *> defs = getRequestedOpDefs(records);
+ SmallVector<ArrayRef<const Record *>> shardedDefs;
+ shardOps(defs, shardedDefs);
+ return emitOpDecls(records, defs, shardedDefs.size(), os,
+ formatErrorIsFatal);
+ });
+static GenRegistration
+ genOpDefs("gen-op-defs", "Generate op definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ std::vector<const Record *> defs = getRequestedOpDefs(records);
+ SmallVector<ArrayRef<const Record *>> shardedDefs;
+ shardOps(defs, shardedDefs);
+ return emitOpDefs(records, defs, shardedDefs.size(), os,
+ formatErrorIsFatal);
+ });
+
+//===----------------------------------------------------------------------===//
+// Op documentation generators
+//===----------------------------------------------------------------------===//
+
+static cl::OptionCategory
+ docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
+static cl::opt<std::string>
+ stripPrefix("strip-prefix",
+ cl::desc("Strip prefix of the fully qualified names"),
+ cl::init("::mlir::"), cl::cat(docCat));
+static cl::opt<bool> allowHugoSpecificFeatures(
+ "allow-hugo-specific-features",
+ cl::desc("Allows using features specific to Hugo"), cl::init(false),
+ cl::cat(docCat));
+static cl::opt<bool>
+ keepOpSourceOrder("keep-op-source-order",
+ cl::desc("Do not sort ops alphabetically"),
+ cl::init(false), cl::cat(docCat));
+
+static bool
+withDialectRecords(const RecordKeeper &records,
+ llvm::function_ref<bool(const DialectRecords &)> fn) {
+ auto dialectDefs = records.getAllDerivedDefinitionsIfDefined("Dialect");
+ SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
+ std::optional<Dialect> dialect =
+ findDialectToGenerate(dialects, selectedDialect.getNumOccurrences() > 0
+ ? selectedDialect.getValue()
+ : "");
+ if (!dialect)
+ return true;
+ std::optional<DialectRecords> filtered = collectRecords(
+ records, getRequestedOpDefs(records), *dialect, keepOpSourceOrder);
+ if (!filtered)
+ return true;
+ return fn(*filtered);
+}
+
+static GenRegistration genAttrDocRegister(
+ "gen-attrdef-doc", "Generate dialect attribute documentation",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return withDialectRecords(records, [&](const DialectRecords &r) {
+ return emitAttrDefDoc(r, os);
+ });
+ });
+static GenRegistration genOpDocRegister(
+ "gen-op-doc", "Generate dialect documentation",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return withDialectRecords(records, [&](const DialectRecords &r) {
+ return emitOpDoc(r, stripPrefix, allowHugoSpecificFeatures, os);
+ });
+ });
+static GenRegistration genTypeDocRegister(
+ "gen-typedef-doc", "Generate dialect type documentation",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return withDialectRecords(records, [&](const DialectRecords &r) {
+ return emitTypeDefDoc(r, os);
+ });
+ });
+static GenRegistration genEnumDocRegister(
+ "gen-enum-doc", "Generate dialect enum documentation",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return withDialectRecords(
+ records, [&](const DialectRecords &r) { return emitEnumDoc(r, os); });
+ });
+static GenRegistration genDialectDocRegister(
+ "gen-dialect-doc", "Generate dialect documentation",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return withDialectRecords(records, [&](const DialectRecords &r) {
+ return emitDialectDoc(r, stripPrefix, allowHugoSpecificFeatures, os);
+ });
+ });
+
+//===----------------------------------------------------------------------===//
+// Op interface generators
+//===----------------------------------------------------------------------===//
+
+namespace {
+template <typename GeneratorT>
+struct InterfaceGenRegistration {
+ InterfaceGenRegistration(StringRef genArg, StringRef genDesc)
+ : genDeclArg(("gen-" + genArg + "-interface-decls").str()),
+ genDefArg(("gen-" + genArg + "-interface-defs").str()),
+ genDocArg(("gen-" + genArg + "-interface-docs").str()),
+ genDeclDesc(("Generate " + genDesc + " interface declarations").str()),
+ genDefDesc(("Generate " + genDesc + " interface definitions").str()),
+ genDocDesc(("Generate " + genDesc + " interface documentation").str()),
+ genDecls(genDeclArg, genDeclDesc,
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return GeneratorT(records, os).emitInterfaceDecls();
+ }),
+ genDefs(genDefArg, genDefDesc,
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return GeneratorT(records, os).emitInterfaceDefs();
+ }),
+ genDocs(genDocArg, genDocDesc,
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return GeneratorT(records, os).emitInterfaceDocs();
+ }) {}
+
+ std::string genDeclArg, genDefArg, genDocArg;
+ std::string genDeclDesc, genDefDesc, genDocDesc;
+ GenRegistration genDecls, genDefs, genDocs;
+};
+} // namespace
+
+static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr",
+ "attribute");
+static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op");
+static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type");
+
+//===----------------------------------------------------------------------===//
+// Python binding generators
+//===----------------------------------------------------------------------===//
+
+static cl::OptionCategory
+ clOpPythonBindingCat("Options for -gen-python-op-bindings");
+
+// dialectNameStorage is shared between gen-python-op-bindings and
+// gen-python-enum-bindings via the -bind-dialect option.
+static std::string dialectNameStorage;
+
+static cl::opt<std::string, /*ExternalStorage=*/true> clDialectName(
+ "bind-dialect", cl::desc("The dialect to run the generator for"),
+ cl::location(dialectNameStorage), cl::cat(clOpPythonBindingCat));
+
+static cl::opt<std::string>
+ clDialectExtensionName("dialect-extension",
+ cl::desc("The prefix of the dialect extension"),
+ cl::init(""), cl::cat(clOpPythonBindingCat));
+
+static GenRegistration genPythonEnumBindings(
+ "gen-python-enum-bindings", "Generate Python bindings for enum attributes",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitPythonEnums(records, dialectNameStorage, os);
+ });
+static GenRegistration
+ genPythonBindings("gen-python-op-bindings",
+ "Generate Python bindings for MLIR Ops",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitPythonOpBindings(records, dialectNameStorage,
+ clDialectExtensionName, os);
+ });
+
+//===----------------------------------------------------------------------===//
+// Pass generators
+//===----------------------------------------------------------------------===//
+
+static cl::OptionCategory
+ passCAPIGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl");
+static cl::opt<std::string> passCAPIGroupName(
+ "prefix",
+ cl::desc("The prefix to use for this group of passes. The form will be "
+ "mlirCreate<prefix><passname>, the prefix can avoid conflicts "
+ "across libraries."),
+ cl::cat(passCAPIGenCat));
+
+static GenRegistration
+ genCAPIHeader("gen-pass-capi-header", "Generate pass C API header",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ emitPassCAPIHeader(records, passCAPIGroupName, os);
+ return false;
+ });
+static GenRegistration
+ genCAPIImpl("gen-pass-capi-impl", "Generate pass C API implementation",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ emitPassCAPIImpl(records, passCAPIGroupName, os);
+ return false;
+ });
+static GenRegistration genPassDoc("gen-pass-doc", "Generate pass documentation",
+ [](const RecordKeeper &records,
+ raw_ostream &os) {
+ emitPassDocs(records, os);
+ return false;
+ });
+
+static cl::OptionCategory passGenCat("Options for -gen-pass-decls");
+static cl::opt<std::string>
+ passGroupName("name", cl::desc("The name of this group of passes"),
+ cl::cat(passGenCat));
+
+static GenRegistration
+ genPassDecls("gen-pass-decls", "Generate pass declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ emitPasses(records, passGroupName, os);
+ return false;
+ });
+
+//===----------------------------------------------------------------------===//
+// Rewriter generator
+//===----------------------------------------------------------------------===//
+
+static GenRegistration
+ genRewriters("gen-rewriters", "Generate pattern rewriters",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ emitRewriters(records, os);
+ return false;
+ });
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.h b/mlir/tools/mlir-tblgen/OpFormatGen.h
deleted file mode 100644
index 5e43f38498664..0000000000000
--- a/mlir/tools/mlir-tblgen/OpFormatGen.h
+++ /dev/null
@@ -1,29 +0,0 @@
-//===- OpFormatGen.h - MLIR operation format generator ----------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the interface for generating parsers and printers from the
-// declarative format.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_
-#define MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_
-
-namespace mlir {
-namespace tblgen {
-class OpClass;
-class Operator;
-
-// Generate the assembly format for the given operator.
-void generateOpFormat(const Operator &constOp, OpClass &opClass,
- bool hasProperties);
-
-} // namespace tblgen
-} // namespace mlir
-
-#endif // MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_
More information about the Mlir-commits
mailing list