[Mlir-commits] [mlir] [MLIR][TableGen] Migrate MLIR backend to use const RecordKeeper (PR #107505)
Rahul Joshi
llvmlistbot at llvm.org
Thu Sep 5 19:42:17 PDT 2024
https://github.com/jurahul created https://github.com/llvm/llvm-project/pull/107505
- Migrate MLIR backend to use a const RecordKeeper reference.
>From b40830f98257a25c7c73304116bf0dd71797e5fa Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Wed, 4 Sep 2024 09:38:19 -0700
Subject: [PATCH] [MLIR][TableGen] Migrate MLIR backend to use const
RecordKeeper
- Migrate MLIR backend to use a const RecordKeeper reference.
---
mlir/include/mlir/TableGen/CodeGenHelpers.h | 4 +-
mlir/include/mlir/TableGen/GenInfo.h | 6 +-
mlir/lib/TableGen/CodeGenHelpers.cpp | 6 +-
mlir/lib/Tools/PDLL/Parser/Parser.cpp | 15 ++--
mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp | 2 +-
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 37 +++++-----
mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp | 5 +-
mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp | 7 +-
.../mlir-tblgen/EnumPythonBindingGen.cpp | 5 +-
mlir/tools/mlir-tblgen/EnumsGen.cpp | 8 +-
mlir/tools/mlir-tblgen/OmpOpGen.cpp | 17 +++--
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 24 +++---
mlir/tools/mlir-tblgen/OpDocGen.cpp | 74 +++++++++----------
mlir/tools/mlir-tblgen/OpGenHelpers.cpp | 10 +--
mlir/tools/mlir-tblgen/OpGenHelpers.h | 7 +-
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 27 +++----
mlir/tools/mlir-tblgen/RewriterGen.cpp | 24 +++---
.../tools/tblgen-to-irdl/OpDefinitionsGen.cpp | 20 ++---
18 files changed, 146 insertions(+), 152 deletions(-)
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index c263c69c53d1e3..465240907a3dee 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -106,7 +106,7 @@ class StaticVerifierFunctionEmitter {
StringRef tag = "");
/// Collect and unique all the constraints used by operations.
- void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
+ void collectOpConstraints(ArrayRef<const llvm::Record *> opDefs);
/// Collect and unique all compatible type, attribute, successor, and region
/// constraints from the operations in the file and emit them at the top of
@@ -114,7 +114,7 @@ class StaticVerifierFunctionEmitter {
///
/// Constraints that do not meet the restriction that they can only reference
/// `$_self` and `$_op` are not uniqued.
- void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);
+ void emitOpConstraints(ArrayRef<const llvm::Record *> opDefs);
/// Unique all compatible type and attribute constraints from a pattern file
/// and emit them at the top of the generated file.
diff --git a/mlir/include/mlir/TableGen/GenInfo.h b/mlir/include/mlir/TableGen/GenInfo.h
index d59d64223827bd..ef2e12f07df16d 100644
--- a/mlir/include/mlir/TableGen/GenInfo.h
+++ b/mlir/include/mlir/TableGen/GenInfo.h
@@ -21,8 +21,8 @@ class RecordKeeper;
namespace mlir {
/// Generator function to invoke.
-using GenFunction =
- std::function<bool(llvm::RecordKeeper &recordKeeper, raw_ostream &os)>;
+using GenFunction = std::function<bool(const llvm::RecordKeeper &recordKeeper,
+ raw_ostream &os)>;
/// Structure to group information about a generator (argument to invoke via
/// mlir-tblgen, description, and generator function).
@@ -34,7 +34,7 @@ class GenInfo {
: arg(arg), description(description), generator(std::move(generator)) {}
/// Invokes the generator and returns whether the generator failed.
- bool invoke(llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
+ bool invoke(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
assert(generator && "Cannot call generator with null generator");
return generator(recordKeeper, os);
}
diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index 718a8136944bdf..314b20491460aa 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -49,7 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
void StaticVerifierFunctionEmitter::emitOpConstraints(
- ArrayRef<llvm::Record *> opDefs) {
+ ArrayRef<const llvm::Record *> opDefs) {
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
emitTypeConstraints();
emitAttrConstraints();
@@ -264,14 +264,14 @@ void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map,
}
void StaticVerifierFunctionEmitter::collectOpConstraints(
- ArrayRef<Record *> opDefs) {
+ ArrayRef<const Record *> opDefs) {
const auto collectTypeConstraints = [&](Operator::const_value_range values) {
for (const NamedTypeConstraint &value : values)
if (value.hasPredicate())
collectConstraint(typeConstraints, "type", value.constraint);
};
- for (Record *def : opDefs) {
+ for (const Record *def : opDefs) {
Operator op(*def);
/// Collect type constraints.
collectTypeConstraints(op.getOperands());
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 01c78e280080ee..2f842df48826d5 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -164,7 +164,7 @@ class Parser {
SmallVectorImpl<ast::Decl *> &decls);
/// Process the records of a parsed tablegen include file.
- void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
+ void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
SmallVectorImpl<ast::Decl *> &decls);
/// Create a user defined native constraint for a constraint imported from
@@ -863,7 +863,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
return success();
}
-void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
+void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
SmallVectorImpl<ast::Decl *> &decls) {
// Return the length kind of the given value.
auto getLengthKind = [](const auto &value) {
@@ -887,7 +887,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
// Process the parsed tablegen records to build ODS information.
/// Operations.
- for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
+ for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
tblgen::Operator op(def);
// Check to see if this operation is known to support type inferrence.
@@ -920,13 +920,13 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
}
}
- auto shouldBeSkipped = [this](llvm::Record *def) {
+ auto shouldBeSkipped = [this](const llvm::Record *def) {
return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
def->isSubClassOf("DeclareInterfaceMethods");
};
/// Attr constraints.
- for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
+ for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
if (shouldBeSkipped(def))
continue;
@@ -936,7 +936,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
constraint.getStorageType()));
}
/// Type constraints.
- for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
+ for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
if (shouldBeSkipped(def))
continue;
@@ -947,7 +947,8 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
}
/// OpInterfaces.
ast::Type opTy = ast::OperationType::get(ctx);
- for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
+ for (const llvm::Record *def :
+ tdRecords.getAllDerivedDefinitions("OpInterface")) {
if (shouldBeSkipped(def))
continue;
diff --git a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
index 564161fe4c1a24..1911b6e3aa3927 100644
--- a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
+++ b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
@@ -90,7 +90,7 @@ static bool findUse(Record &record, Init *deprecatedInit,
});
}
-static void warnOfDeprecatedUses(RecordKeeper &records) {
+static void warnOfDeprecatedUses(const RecordKeeper &records) {
// This performs a direct check for any def marked as deprecated and then
// finds all uses of deprecated def. Deprecated defs are not expected to be
// either numerous or long lived.
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index feca04bff643d5..0a10d54e479211 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -30,7 +30,7 @@ using namespace mlir::tblgen;
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
/// specified and can only find one dialect's defs, use that.
static void collectAllDefs(StringRef selectedDialect,
- std::vector<llvm::Record *> records,
+ const std::vector<const llvm::Record *> &records,
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
// Nothing to do if no defs were found.
if (records.empty())
@@ -690,14 +690,15 @@ class DefGenerator {
bool emitDefs(StringRef selectedDialect);
protected:
- DefGenerator(const std::vector<llvm::Record *> &defs, raw_ostream &os,
+ DefGenerator(ArrayRef<const llvm::Record *> defs, raw_ostream &os,
StringRef defType, StringRef valueType, bool isAttrGenerator)
: defRecords(defs), os(os), defType(defType), valueType(valueType),
isAttrGenerator(isAttrGenerator) {
// Sort by occurrence in file.
- llvm::sort(defRecords, [](llvm::Record *lhs, llvm::Record *rhs) {
- return lhs->getID() < rhs->getID();
- });
+ llvm::sort(defRecords,
+ [](const llvm::Record *lhs, const llvm::Record *rhs) {
+ return lhs->getID() < rhs->getID();
+ });
}
/// Emit the list of def type names.
@@ -706,7 +707,7 @@ class DefGenerator {
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
/// The set of def records to emit.
- std::vector<llvm::Record *> defRecords;
+ std::vector<const llvm::Record *> defRecords;
/// The attribute or type class to emit.
/// The stream to emit to.
raw_ostream &os;
@@ -721,13 +722,13 @@ class DefGenerator {
/// A specialized generator for AttrDefs.
struct AttrDefGenerator : public DefGenerator {
- AttrDefGenerator(llvm::RecordKeeper &records, raw_ostream &os)
+ AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
};
/// A specialized generator for TypeDefs.
struct TypeDefGenerator : public DefGenerator {
- TypeDefGenerator(llvm::RecordKeeper &records, raw_ostream &os)
+ TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
"Type", "Type", /*isAttrGenerator=*/false) {}
};
@@ -1029,9 +1030,9 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
/// Find all type constraints for which a C++ function should be generated.
static std::vector<Constraint>
-getAllTypeConstraints(llvm::RecordKeeper &records) {
+getAllTypeConstraints(const llvm::RecordKeeper &records) {
std::vector<Constraint> result;
- for (llvm::Record *def :
+ for (const llvm::Record *def :
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
// Ignore constraints defined outside of the top-level file.
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
@@ -1046,7 +1047,7 @@ getAllTypeConstraints(llvm::RecordKeeper &records) {
return result;
}
-static void emitTypeConstraintDecls(llvm::RecordKeeper &records,
+static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDecl = R"(
bool {0}(::mlir::Type type);
@@ -1056,7 +1057,7 @@ bool {0}(::mlir::Type type);
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
}
-static void emitTypeConstraintDefs(llvm::RecordKeeper &records,
+static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDef = R"(
bool {0}(::mlir::Type type) {
@@ -1087,13 +1088,13 @@ static llvm::cl::opt<std::string>
static mlir::GenRegistration
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
- [](llvm::RecordKeeper &records, raw_ostream &os) {
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDefs(attrDialect);
});
static mlir::GenRegistration
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
- [](llvm::RecordKeeper &records, raw_ostream &os) {
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDecls(attrDialect);
});
@@ -1109,13 +1110,13 @@ static llvm::cl::opt<std::string>
static mlir::GenRegistration
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
- [](llvm::RecordKeeper &records, raw_ostream &os) {
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDefs(typeDialect);
});
static mlir::GenRegistration
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
- [](llvm::RecordKeeper &records, raw_ostream &os) {
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});
@@ -1123,14 +1124,14 @@ static mlir::GenRegistration
static mlir::GenRegistration
genTypeConstrDefs("gen-type-constraint-defs",
"Generate type constraint definitions",
- [](llvm::RecordKeeper &records, raw_ostream &os) {
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDefs(records, os);
return false;
});
static mlir::GenRegistration
genTypeConstrDecls("gen-type-constraint-decls",
"Generate type constraint declarations",
- [](llvm::RecordKeeper &records, raw_ostream &os) {
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDecls(records, os);
return false;
});
diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
index 66a3750d7c8266..964b33a9fa41f8 100644
--- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
@@ -429,14 +429,15 @@ struct AttrOrType {
static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
MapVector<StringRef, AttrOrType> dialectAttrOrType;
- for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) {
+ for (const Record *it :
+ records.getAllDerivedDefinitions("DialectAttributes")) {
if (!selectedBcDialect.empty() &&
it->getValueAsString("dialect") != selectedBcDialect)
continue;
dialectAttrOrType[it->getValueAsString("dialect")].attr =
it->getValueAsListOfDefs("elems");
}
- for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) {
+ for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) {
if (!selectedBcDialect.empty() &&
it->getValueAsString("dialect") != selectedBcDialect)
continue;
diff --git a/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp b/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp
index 337b6a5e5d5bd1..de3e6d8ee8cbc8 100644
--- a/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp
+++ b/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp
@@ -23,6 +23,7 @@
using llvm::Clause;
using llvm::ClauseVal;
using llvm::raw_ostream;
+using llvm::Record;
using llvm::RecordKeeper;
// LLVM has multiple places (Clang, Flang, MLIR) where information about
@@ -49,13 +50,11 @@ static bool emitDecls(const RecordKeeper &recordKeeper, llvm::StringRef dialect,
"'--directives-dialect'");
}
- const auto &directiveLanguages =
+ const auto directiveLanguages =
recordKeeper.getAllDerivedDefinitions("DirectiveLanguage");
assert(!directiveLanguages.empty() && "DirectiveLanguage missing.");
- const auto &clauses = recordKeeper.getAllDerivedDefinitions("Clause");
-
- for (const auto &r : clauses) {
+ for (const Record *r : recordKeeper.getAllDerivedDefinitions("Clause")) {
Clause c{r};
const auto &clauseVals = c.getClauseVals();
if (clauseVals.empty())
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index f4ced0803772ed..79249944e484f7 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -136,13 +136,14 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
os << fileHeader;
- for (auto &it :
+ for (const llvm::Record *it :
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
EnumAttr enumAttr(*it);
emitEnumClass(enumAttr, os);
emitAttributeBuilder(enumAttr, os);
}
- for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+ for (const llvm::Record *it :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
AttrOrTypeDef attr(&*it);
if (!attr.getMnemonic()) {
llvm::errs() << "enum case " << attr
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index f1d7a233b66a9a..95767a29b9c3cf 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -645,8 +645,8 @@ class {1} : public ::mlir::{2} {
static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Declarations", os, recordKeeper);
- auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
- for (const auto *def : defs)
+ for (const Record *def :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
emitEnumDecl(*def, os);
return false;
@@ -683,8 +683,8 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Definitions", os, recordKeeper);
- auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
- for (const auto *def : defs)
+ for (const Record *def :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
emitEnumDef(*def, os);
return false;
diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
index b7f6ca975a9a34..15458212637888 100644
--- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp
+++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
@@ -24,8 +24,8 @@ using namespace llvm;
/// `OpenMP_Clause` class the record is based on is found, the optional
/// "OpenMP_" prefix and "Skip" and "Clause" suffixes are removed to return only
/// the clause name, i.e. "OpenMP_CollapseClauseSkip" is returned as "Collapse".
-static StringRef extractOmpClauseName(Record *clause) {
- Record *ompClause = clause->getRecords().getClass("OpenMP_Clause");
+static StringRef extractOmpClauseName(const Record *clause) {
+ const Record *ompClause = clause->getRecords().getClass("OpenMP_Clause");
assert(ompClause && "base OpenMP records expected to be defined");
StringRef clauseClassName;
@@ -33,7 +33,7 @@ static StringRef extractOmpClauseName(Record *clause) {
clause->getDirectSuperClasses(clauseSuperClasses);
// Check if OpenMP_Clause is a direct superclass.
- for (Record *superClass : clauseSuperClasses) {
+ for (const Record *superClass : clauseSuperClasses) {
if (superClass == ompClause) {
clauseClassName = clause->getName();
break;
@@ -83,7 +83,8 @@ static bool verifyArgument(DagInit *arguments, StringRef argName,
/// Check that the given string record value, identified by its name \c value,
/// is either undefined or empty in both the given operation and clause record
/// or its contents for the clause record are contained in the operation record.
-static bool verifyStringValue(StringRef value, Record *op, Record *clause) {
+static bool verifyStringValue(StringRef value, const Record *op,
+ const Record *clause) {
auto opValue = op->getValueAsOptionalString(value);
auto clauseValue = clause->getValueAsOptionalString(value);
@@ -100,7 +101,7 @@ static bool verifyStringValue(StringRef value, Record *op, Record *clause) {
/// present in the corresponding operation field.
///
/// Print warnings or errors where this is not the case.
-static void verifyClause(Record *op, Record *clause) {
+static void verifyClause(const Record *op, const Record *clause) {
StringRef clauseClassName = extractOmpClauseName(clause);
if (!clause->getValueAsBit("ignoreArgs")) {
@@ -149,9 +150,9 @@ static void verifyClause(Record *op, Record *clause) {
/// Verify that all properties of `OpenMP_Clause`s of records deriving from
/// `OpenMP_Op`s have been inherited by the latter.
-static bool verifyDecls(RecordKeeper &recordKeeper, raw_ostream &) {
- for (Record *op : recordKeeper.getAllDerivedDefinitions("OpenMP_Op")) {
- for (Record *clause : op->getValueAsListOfDefs("clauseList"))
+static bool verifyDecls(const RecordKeeper &recordKeeper, raw_ostream &) {
+ for (const Record *op : recordKeeper.getAllDerivedDefinitions("OpenMP_Op")) {
+ for (const Record *clause : op->getValueAsListOfDefs("clauseList"))
verifyClause(op, clause);
}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 572c1545b43fcb..c3edd00ed14e7d 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -4503,7 +4503,7 @@ void OpOperandAdaptorEmitter::emitDef(
/// Emit the class declarations or definitions for the given op defs.
static void
emitOpClasses(const RecordKeeper &recordKeeper,
- const std::vector<Record *> &defs, raw_ostream &os,
+ const std::vector<const Record *> &defs, raw_ostream &os,
const StaticVerifierFunctionEmitter &staticVerifierEmitter,
bool emitDecl) {
if (defs.empty())
@@ -4540,7 +4540,7 @@ emitOpClasses(const RecordKeeper &recordKeeper,
/// Emit the declarations for the provided op classes.
static void emitOpClassDecls(const RecordKeeper &recordKeeper,
- const std::vector<Record *> &defs,
+ const std::vector<const Record *> &defs,
raw_ostream &os) {
// First emit forward declaration for each class, this allows them to refer
// to each others in traits for example.
@@ -4562,7 +4562,7 @@ static void emitOpClassDecls(const RecordKeeper &recordKeeper,
/// Emit the definitions for the provided op classes.
static void emitOpClassDefs(const RecordKeeper &recordKeeper,
- ArrayRef<Record *> defs, raw_ostream &os,
+ ArrayRef<const Record *> defs, raw_ostream &os,
StringRef constraintPrefix = "") {
if (defs.empty())
return;
@@ -4583,12 +4583,12 @@ static void emitOpClassDefs(const RecordKeeper &recordKeeper,
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Declarations", os, recordKeeper);
- std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
+ std::vector<const Record *> defs = getRequestedOpDefinitions(recordKeeper);
emitOpClassDecls(recordKeeper, defs, os);
// If we are generating sharded op definitions, emit the sharded op
// registration hooks.
- SmallVector<ArrayRef<Record *>, 4> shardedDefs;
+ SmallVector<ArrayRef<const Record *>, 4> shardedDefs;
shardOpDefinitions(defs, shardedDefs);
if (defs.empty() || shardedDefs.size() <= 1)
return false;
@@ -4611,9 +4611,9 @@ static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
/// Generate the dialect op registration hook and the op class definitions for a
/// shard of ops.
static void emitOpDefShard(const RecordKeeper &recordKeeper,
- ArrayRef<Record *> defs, const Dialect &dialect,
- unsigned shardIndex, unsigned shardCount,
- raw_ostream &os) {
+ ArrayRef<const Record *> defs,
+ const Dialect &dialect, unsigned shardIndex,
+ unsigned shardCount, raw_ostream &os) {
std::string shardGuard = "GET_OP_DEFS_";
std::string indexStr = std::to_string(shardIndex);
shardGuard += indexStr;
@@ -4637,7 +4637,7 @@ static void emitOpDefShard(const RecordKeeper &recordKeeper,
"Op Registration Hook")
<< formatv(opRegistrationHook, dialect.getCppNamespace(),
dialect.getCppClassName(), shardIndex);
- for (Record *def : defs) {
+ for (const Record *def : defs) {
os << formatv(" ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n",
Operator(def).getQualCppClassName());
}
@@ -4651,8 +4651,8 @@ static void emitOpDefShard(const RecordKeeper &recordKeeper,
static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Definitions", os, recordKeeper);
- std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
- SmallVector<ArrayRef<Record *>, 4> shardedDefs;
+ std::vector<const Record *> defs = getRequestedOpDefinitions(recordKeeper);
+ SmallVector<ArrayRef<const Record *>, 4> shardedDefs;
shardOpDefinitions(defs, shardedDefs);
// If no shard was requested, emit the regular op list and class definitions.
@@ -4661,7 +4661,7 @@ static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
IfDefScope scope("GET_OP_LIST", os);
interleave(
defs, os,
- [&](Record *def) { os << Operator(def).getQualCppClassName(); },
+ [&](const Record *def) { os << Operator(def).getQualCppClassName(); },
",\n");
}
{
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index 066e5b24f5a3c1..d60eda0a8b958b 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -282,7 +282,7 @@ static void emitSourceLink(StringRef inputFilename, raw_ostream &os) {
<< inputFromMlirInclude << ")\n\n";
}
-static void emitOpDoc(RecordKeeper &recordKeeper, raw_ostream &os) {
+static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
auto opDefs = getRequestedOpDefinitions(recordKeeper);
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
@@ -371,10 +371,9 @@ static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
os << "\n";
}
-static void emitAttrOrTypeDefDoc(RecordKeeper &recordKeeper, raw_ostream &os,
- StringRef recordTypeName) {
- std::vector<llvm::Record *> defs =
- recordKeeper.getAllDerivedDefinitions(recordTypeName);
+static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
+ raw_ostream &os, StringRef recordTypeName) {
+ auto defs = recordKeeper.getAllDerivedDefinitions(recordTypeName);
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
for (const llvm::Record *def : defs)
@@ -405,12 +404,10 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
os << "\n";
}
-static void emitEnumDoc(RecordKeeper &recordKeeper, raw_ostream &os) {
- std::vector<llvm::Record *> defs =
- recordKeeper.getAllDerivedDefinitions("EnumAttr");
-
+static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
- for (const llvm::Record *def : defs)
+ for (const llvm::Record *def :
+ recordKeeper.getAllDerivedDefinitions("EnumAttr"))
emitEnumDoc(EnumAttr(def), os);
}
@@ -518,24 +515,19 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
os);
}
-static bool emitDialectDoc(RecordKeeper &recordKeeper, raw_ostream &os) {
- std::vector<Record *> dialectDefs =
- recordKeeper.getAllDerivedDefinitionsIfDefined("Dialect");
+static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ auto dialectDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("Dialect");
SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
std::optional<Dialect> dialect = findDialectToGenerate(dialects);
if (!dialect)
return true;
- std::vector<Record *> opDefs = getRequestedOpDefinitions(recordKeeper);
- std::vector<Record *> attrDefs =
- recordKeeper.getAllDerivedDefinitionsIfDefined("DialectAttr");
- std::vector<Record *> typeDefs =
- recordKeeper.getAllDerivedDefinitionsIfDefined("DialectType");
- std::vector<Record *> typeDefDefs =
- recordKeeper.getAllDerivedDefinitionsIfDefined("TypeDef");
- std::vector<Record *> attrDefDefs =
- recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef");
- std::vector<Record *> enumDefs =
+ std::vector<const Record *> opDefs = getRequestedOpDefinitions(recordKeeper);
+ auto attrDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("DialectAttr");
+ auto typeDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("DialectType");
+ auto typeDefDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("TypeDef");
+ auto attrDefDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef");
+ auto enumDefs =
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
std::vector<Attribute> dialectAttrs;
@@ -545,26 +537,28 @@ static bool emitDialectDoc(RecordKeeper &recordKeeper, raw_ostream &os) {
std::vector<TypeDef> dialectTypeDefs;
std::vector<EnumAttr> dialectEnums;
- llvm::SmallDenseSet<Record *> seen;
- auto addIfNotSeen = [&](llvm::Record *record, const auto &def, auto &vec) {
+ llvm::SmallDenseSet<const Record *> seen;
+ auto addIfNotSeen = [&](const llvm::Record *record, const auto &def,
+ auto &vec) {
if (seen.insert(record).second) {
vec.push_back(def);
return true;
}
return false;
};
- auto addIfInDialect = [&](llvm::Record *record, const auto &def, auto &vec) {
+ auto addIfInDialect = [&](const llvm::Record *record, const auto &def,
+ auto &vec) {
return def.getDialect() == *dialect && addIfNotSeen(record, def, vec);
};
- SmallDenseMap<Record *, OpDocGroup> opDocGroup;
+ SmallDenseMap<const Record *, OpDocGroup> opDocGroup;
- for (Record *def : attrDefDefs)
+ for (const Record *def : attrDefDefs)
addIfInDialect(def, AttrDef(def), dialectAttrDefs);
- for (Record *def : attrDefs)
+ for (const Record *def : attrDefs)
addIfInDialect(def, Attribute(def), dialectAttrs);
- for (Record *def : opDefs) {
- if (Record *group = def->getValueAsOptionalDef("opDocGroup")) {
+ for (const Record *def : opDefs) {
+ if (const Record *group = def->getValueAsOptionalDef("opDocGroup")) {
OpDocGroup &op = opDocGroup[group];
addIfInDialect(def, Operator(def), op.ops);
} else {
@@ -573,7 +567,7 @@ static bool emitDialectDoc(RecordKeeper &recordKeeper, raw_ostream &os) {
addIfInDialect(def, op, dialectOps);
}
}
- for (Record *rec :
+ for (const Record *rec :
recordKeeper.getAllDerivedDefinitionsIfDefined("OpDocGroup")) {
if (opDocGroup[rec].ops.empty())
continue;
@@ -581,12 +575,12 @@ static bool emitDialectDoc(RecordKeeper &recordKeeper, raw_ostream &os) {
opDocGroup[rec].description = rec->getValueAsString("description");
dialectOps.push_back(opDocGroup[rec]);
}
- for (Record *def : typeDefDefs)
+ for (const Record *def : typeDefDefs)
addIfInDialect(def, TypeDef(def), dialectTypeDefs);
- for (Record *def : typeDefs)
+ for (const Record *def : typeDefs)
addIfInDialect(def, Type(def), dialectTypes);
dialectEnums.reserve(enumDefs.size());
- for (Record *def : enumDefs)
+ for (const Record *def : enumDefs)
addIfNotSeen(def, EnumAttr(def), dialectEnums);
// Sort alphabetically ignorning dialect for ops and section name for
@@ -617,34 +611,34 @@ static bool emitDialectDoc(RecordKeeper &recordKeeper, raw_ostream &os) {
static mlir::GenRegistration
genAttrRegister("gen-attrdef-doc",
"Generate dialect attribute documentation",
- [](RecordKeeper &records, raw_ostream &os) {
+ [](const RecordKeeper &records, raw_ostream &os) {
emitAttrOrTypeDefDoc(records, os, "AttrDef");
return false;
});
static mlir::GenRegistration
genOpRegister("gen-op-doc", "Generate dialect documentation",
- [](RecordKeeper &records, raw_ostream &os) {
+ [](const RecordKeeper &records, raw_ostream &os) {
emitOpDoc(records, os);
return false;
});
static mlir::GenRegistration
genTypeRegister("gen-typedef-doc", "Generate dialect type documentation",
- [](RecordKeeper &records, raw_ostream &os) {
+ [](const RecordKeeper &records, raw_ostream &os) {
emitAttrOrTypeDefDoc(records, os, "TypeDef");
return false;
});
static mlir::GenRegistration
genEnumRegister("gen-enum-doc", "Generate dialect enum documentation",
- [](RecordKeeper &records, raw_ostream &os) {
+ [](const RecordKeeper &records, raw_ostream &os) {
emitEnumDoc(records, os);
return false;
});
static mlir::GenRegistration
genRegister("gen-dialect-doc", "Generate dialect documentation",
- [](RecordKeeper &records, raw_ostream &os) {
+ [](const RecordKeeper &records, raw_ostream &os) {
return emitDialectDoc(records, os);
});
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
index c2a2423a240269..702ea664324554 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
@@ -44,14 +44,14 @@ static std::string getOperationName(const Record &def) {
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
}
-std::vector<Record *>
+std::vector<const Record *>
mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
- Record *classDef = recordKeeper.getClass("Op");
+ const Record *classDef = recordKeeper.getClass("Op");
if (!classDef)
PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
- std::vector<Record *> defs;
+ std::vector<const Record *> defs;
for (const auto &def : recordKeeper.getDefs()) {
if (!def.second->isSubClassOf(classDef))
continue;
@@ -86,8 +86,8 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
}
void mlir::tblgen::shardOpDefinitions(
- ArrayRef<llvm::Record *> defs,
- SmallVectorImpl<ArrayRef<llvm::Record *>> &shardedDefs) {
+ ArrayRef<const llvm::Record *> defs,
+ SmallVectorImpl<ArrayRef<const llvm::Record *>> &shardedDefs) {
assert(opShardCount > 0 && "expected a positive shard count");
if (opShardCount == 1) {
shardedDefs.push_back(defs);
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h
index 1b43d5d3ce3a7d..b7c9fe3a7b799c 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.h
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h
@@ -22,7 +22,7 @@ namespace tblgen {
/// Returns all the op definitions filtered by the user. The filtering is via
/// command-line option "op-include-regex" and "op-exclude-regex".
-std::vector<llvm::Record *>
+std::vector<const llvm::Record *>
getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
/// Checks whether `str` is a Python keyword or would shadow builtin function.
@@ -30,8 +30,9 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
bool isPythonReserved(llvm::StringRef str);
/// Shard the op defintions into the number of shards set by "op-shard-count".
-void shardOpDefinitions(ArrayRef<llvm::Record *> defs,
- SmallVectorImpl<ArrayRef<llvm::Record *>> &shardedDefs);
+void shardOpDefinitions(
+ ArrayRef<const llvm::Record *> defs,
+ SmallVectorImpl<ArrayRef<const llvm::Record *>> &shardedDefs);
} // namespace tblgen
} // namespace mlir
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 00f21a1cefbdd8..7c32c2549d788f 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -61,9 +61,10 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method,
/// Get an array of all OpInterface definitions but exclude those subclassing
/// "DeclareOpInterfaceMethods".
-static std::vector<llvm::Record *>
-getAllInterfaceDefinitions(llvm::RecordKeeper &recordKeeper, StringRef name) {
- std::vector<llvm::Record *> defs =
+static std::vector<const llvm::Record *>
+getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper,
+ StringRef name) {
+ std::vector<const llvm::Record *> defs =
recordKeeper.getAllDerivedDefinitions((name + "Interface").str());
std::string declareName = ("Declare" + name + "InterfaceMethods").str();
@@ -87,7 +88,7 @@ class InterfaceGenerator {
bool emitInterfaceDocs();
protected:
- InterfaceGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os)
+ InterfaceGenerator(std::vector<const llvm::Record *> &&defs, raw_ostream &os)
: defs(std::move(defs)), os(os) {}
void emitConceptDecl(const Interface &interface);
@@ -98,7 +99,7 @@ class InterfaceGenerator {
void emitInterfaceDecl(const Interface &interface);
/// The set of interface records to emit.
- std::vector<llvm::Record *> defs;
+ std::vector<const llvm::Record *> defs;
// The stream to emit to.
raw_ostream &os;
/// The C++ value type of the interface, e.g. Operation*.
@@ -117,7 +118,7 @@ class InterfaceGenerator {
/// A specialized generator for attribute interfaces.
struct AttrInterfaceGenerator : public InterfaceGenerator {
- AttrInterfaceGenerator(llvm::RecordKeeper &records, raw_ostream &os)
+ AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
valueType = "::mlir::Attribute";
interfaceBaseType = "AttributeInterface";
@@ -132,7 +133,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
};
/// A specialized generator for operation interfaces.
struct OpInterfaceGenerator : public InterfaceGenerator {
- OpInterfaceGenerator(llvm::RecordKeeper &records, raw_ostream &os)
+ OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
valueType = "::mlir::Operation *";
interfaceBaseType = "OpInterface";
@@ -148,7 +149,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
};
/// A specialized generator for type interfaces.
struct TypeInterfaceGenerator : public InterfaceGenerator {
- TypeInterfaceGenerator(llvm::RecordKeeper &records, raw_ostream &os)
+ TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
valueType = "::mlir::Type";
interfaceBaseType = "TypeInterface";
@@ -606,8 +607,8 @@ bool InterfaceGenerator::emitInterfaceDecls() {
llvm::emitSourceFileHeader("Interface Declarations", os);
// Sort according to ID, so defs are emitted in the order in which they appear
// in the Tablegen file.
- std::vector<llvm::Record *> sortedDefs(defs);
- llvm::sort(sortedDefs, [](llvm::Record *lhs, llvm::Record *rhs) {
+ std::vector<const llvm::Record *> sortedDefs(defs);
+ llvm::sort(sortedDefs, [](const llvm::Record *lhs, const llvm::Record *rhs) {
return lhs->getID() < rhs->getID();
});
for (const llvm::Record *def : sortedDefs)
@@ -683,15 +684,15 @@ struct InterfaceGenRegistration {
genDefDesc(("Generate " + genDesc + " interface definitions").str()),
genDocDesc(("Generate " + genDesc + " interface documentation").str()),
genDecls(genDeclArg, genDeclDesc,
- [](llvm::RecordKeeper &records, raw_ostream &os) {
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDecls();
}),
genDefs(genDefArg, genDefDesc,
- [](llvm::RecordKeeper &records, raw_ostream &os) {
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDefs();
}),
genDocs(genDocArg, genDocDesc,
- [](llvm::RecordKeeper &records, raw_ostream &os) {
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDocs();
}) {}
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 401f02246ed235..598eb8ea12fcc1 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -64,7 +64,7 @@ class StaticMatcherHelper;
class PatternEmitter {
public:
- PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
+ PatternEmitter(const Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
StaticMatcherHelper &helper);
// Emits the mlir::RewritePattern struct named `rewriteName`.
@@ -268,7 +268,7 @@ class PatternEmitter {
// inlining them.
class StaticMatcherHelper {
public:
- StaticMatcherHelper(raw_ostream &os, RecordKeeper &recordKeeper,
+ StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
RecordOperatorMap &mapper);
// Determine if we should inline the match logic or delegate to a static
@@ -293,7 +293,7 @@ class StaticMatcherHelper {
// Collect the `Record`s, i.e., the DRR, so that we can get the information of
// the duplicated DAGs.
- void addPattern(Record *record);
+ void addPattern(const Record *record);
// Emit all static functions of DAG Matcher.
void populateStaticMatchers(raw_ostream &os);
@@ -322,7 +322,7 @@ class StaticMatcherHelper {
// inlining.
//
// The topological order of all the DagNodes among all patterns.
- SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
+ SmallVector<std::pair<DagNode, const Record *>> topologicalOrder;
RecordOperatorMap &opMap;
@@ -347,7 +347,7 @@ class StaticMatcherHelper {
} // namespace
-PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
+PatternEmitter::PatternEmitter(const Record *pat, RecordOperatorMap *mapper,
raw_ostream &os, StaticMatcherHelper &helper)
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
@@ -1886,7 +1886,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
}
StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
- RecordKeeper &recordKeeper,
+ const RecordKeeper &recordKeeper,
RecordOperatorMap &mapper)
: opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}
@@ -1912,7 +1912,7 @@ void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef());
}
-void StaticMatcherHelper::addPattern(Record *record) {
+void StaticMatcherHelper::addPattern(const Record *record) {
Pattern pat(record, &opMap);
// While generating the function body of the DAG matcher, it may depends on
@@ -1951,10 +1951,10 @@ StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
}
-static void emitRewriters(RecordKeeper &recordKeeper, raw_ostream &os) {
+static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Rewriters", os, recordKeeper);
- const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
+ auto patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
// We put the map here because it can be shared among multiple patterns.
RecordOperatorMap recordOpMap;
@@ -1962,7 +1962,7 @@ static void emitRewriters(RecordKeeper &recordKeeper, raw_ostream &os) {
// Exam all the patterns and generate static matcher for the duplicated
// DagNode.
StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap);
- for (Record *p : patterns)
+ for (const Record *p : patterns)
staticMatcher.addPattern(p);
staticMatcher.populateStaticConstraintFunctions(os);
staticMatcher.populateStaticMatchers(os);
@@ -1973,7 +1973,7 @@ static void emitRewriters(RecordKeeper &recordKeeper, raw_ostream &os) {
std::string baseRewriterName = "GeneratedConvert";
int rewriterIndex = 0;
- for (Record *p : patterns) {
+ for (const Record *p : patterns) {
std::string name;
if (p->isAnonymous()) {
// If no name is provided, ensure unique rewriter names simply by
@@ -2001,7 +2001,7 @@ static void emitRewriters(RecordKeeper &recordKeeper, raw_ostream &os) {
static mlir::GenRegistration
genRewriters("gen-rewriters", "Generate pattern rewriters",
- [](RecordKeeper &records, raw_ostream &os) {
+ [](const RecordKeeper &records, raw_ostream &os) {
emitRewriters(records, os);
return false;
});
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index 0957a5d55db959..4a13a00335f655 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -60,7 +60,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
if (predRec.isSubClassOf("AnyTypeOf")) {
std::vector<Value> constraints;
- for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
+ for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createConstraint(builder, tblgen::Constraint(child)));
}
@@ -70,7 +70,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
if (predRec.isSubClassOf("AllOfType")) {
std::vector<Value> constraints;
- for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
+ for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createConstraint(builder, tblgen::Constraint(child)));
}
@@ -145,14 +145,8 @@ static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
StringAttr::get(ctx, selectedDialect));
}
-static std::vector<llvm::Record *>
-getOpDefinitions(RecordKeeper &recordKeeper) {
- if (!recordKeeper.getClass("Op"))
- return {};
- return recordKeeper.getAllDerivedDefinitions("Op");
-}
-
-static bool emitDialectIRDLDefs(RecordKeeper &recordKeeper, raw_ostream &os) {
+static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
+ raw_ostream &os) {
// Initialize.
MLIRContext ctx;
ctx.getOrLoadDialect<irdl::IRDLDialect>();
@@ -167,8 +161,8 @@ static bool emitDialectIRDLDefs(RecordKeeper &recordKeeper, raw_ostream &os) {
// Set insertion point to start of DialectOp.
builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock());
- std::vector<Record *> defs = getOpDefinitions(recordKeeper);
- for (auto *def : defs) {
+ for (const Record *def :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
tblgen::Operator tblgenOp(def);
if (tblgenOp.getDialectName() != selectedDialect)
continue;
@@ -184,6 +178,6 @@ static bool emitDialectIRDLDefs(RecordKeeper &recordKeeper, raw_ostream &os) {
static mlir::GenRegistration
genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions",
- [](RecordKeeper &records, raw_ostream &os) {
+ [](const RecordKeeper &records, raw_ostream &os) {
return emitDialectIRDLDefs(records, os);
});
More information about the Mlir-commits
mailing list