[Mlir-commits] [mlir] [MLIR][TableGen] Change MLIR TableGen to use cons Record * (PR #110687)
Rahul Joshi
llvmlistbot at llvm.org
Tue Oct 1 08:27:38 PDT 2024
https://github.com/jurahul created https://github.com/llvm/llvm-project/pull/110687
None
>From 100f91697ee6f34ce1010d7df0d9cbc1371eb0c9 Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Tue, 1 Oct 2024 08:25:43 -0700
Subject: [PATCH] [MLIR][TableGen] Change MLIR TableGen to use cons Record *
---
mlir/include/mlir/TableGen/Predicate.h | 2 +-
mlir/lib/TableGen/Predicate.cpp | 4 +--
mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp | 35 ++++++++++---------
mlir/tools/mlir-tblgen/OmpOpGen.cpp | 2 +-
mlir/tools/mlir-tblgen/OpDocGen.cpp | 2 +-
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 6 ++--
.../tools/tblgen-to-irdl/OpDefinitionsGen.cpp | 8 +++--
7 files changed, 33 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h
index fd5a5a5dc99ae1..2eeb253d56b266 100644
--- a/mlir/include/mlir/TableGen/Predicate.h
+++ b/mlir/include/mlir/TableGen/Predicate.h
@@ -104,7 +104,7 @@ class CombinedPred : public Pred {
const llvm::Record *getCombinerDef() const;
// Get the predicates that are combined by this predicate.
- std::vector<llvm::Record *> getChildren() const;
+ std::vector<const llvm::Record *> getChildren() const;
};
// A combined predicate that requires all child predicates of 'CPred' type to
diff --git a/mlir/lib/TableGen/Predicate.cpp b/mlir/lib/TableGen/Predicate.cpp
index 3c3c475f4d3659..a2122f15af1f6a 100644
--- a/mlir/lib/TableGen/Predicate.cpp
+++ b/mlir/lib/TableGen/Predicate.cpp
@@ -79,10 +79,10 @@ const llvm::Record *CombinedPred::getCombinerDef() const {
return def->getValueAsDef("kind");
}
-std::vector<llvm::Record *> CombinedPred::getChildren() const {
+std::vector<const llvm::Record *> CombinedPred::getChildren() const {
assert(def->getValue("children") &&
"CombinedPred must have a value 'children'");
- return def->getValueAsListOfDefs("children");
+ return def->getValueAsListOfConstDefs("children");
}
namespace {
diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
index 964b33a9fa41f8..01669b701622c3 100644
--- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
@@ -32,14 +32,14 @@ class Generator {
Generator(raw_ostream &output) : output(output) {}
/// Returns whether successfully emitted attribute/type parsers.
- void emitParse(StringRef kind, Record &x);
+ void emitParse(StringRef kind, const Record &x);
/// Returns whether successfully emitted attribute/type printers.
void emitPrint(StringRef kind, StringRef type,
- ArrayRef<std::pair<int64_t, Record *>> vec);
+ ArrayRef<std::pair<int64_t, const Record *>> vec);
/// Emits parse dispatch table.
- void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
+ void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec);
/// Emits print dispatch table.
void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
@@ -51,8 +51,9 @@ class Generator {
StringRef failure, mlir::raw_indented_ostream &ios);
/// Emits print instructions.
- void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
- StringRef name, mlir::raw_indented_ostream &ios);
+ void emitPrintHelper(const Record *memberRec, StringRef kind,
+ StringRef parent, StringRef name,
+ mlir::raw_indented_ostream &ios);
raw_ostream &output;
};
@@ -75,7 +76,7 @@ static std::string capitalize(StringRef str) {
}
/// Return the C++ type for the given record.
-static std::string getCType(Record *def) {
+static std::string getCType(const Record *def) {
std::string format = "{0}";
if (def->isSubClassOf("Array")) {
def = def->getValueAsDef("elemT");
@@ -92,7 +93,8 @@ static std::string getCType(Record *def) {
return formatv(format.c_str(), cType.str());
}
-void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
+void Generator::emitParseDispatch(StringRef kind,
+ ArrayRef<const Record *> vec) {
mlir::raw_indented_ostream os(output);
char const *head =
R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
@@ -126,7 +128,7 @@ void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
os << "return " << capitalize(kind) << "();\n";
}
-void Generator::emitParse(StringRef kind, Record &x) {
+void Generator::emitParse(StringRef kind, const Record &x) {
if (x.getNameInitAsString() == "ReservedOrDead")
return;
@@ -293,7 +295,7 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
}
void Generator::emitPrint(StringRef kind, StringRef type,
- ArrayRef<std::pair<int64_t, Record *>> vec) {
+ ArrayRef<std::pair<int64_t, const Record *>> vec) {
if (type == "ReservedOrDead")
return;
@@ -304,7 +306,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
auto funScope = os.scope("{\n", "}\n\n");
// Check that predicates specified if multiple bytecode instances.
- for (llvm::Record *rec : make_second_range(vec)) {
+ for (const llvm::Record *rec : make_second_range(vec)) {
StringRef pred = rec->getValueAsString("printerPredicate");
if (vec.size() > 1 && pred.empty()) {
for (auto [index, rec] : vec) {
@@ -344,7 +346,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
}
}
-void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
+void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
StringRef parent, StringRef name,
mlir::raw_indented_ostream &ios) {
std::string getter;
@@ -423,7 +425,7 @@ void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
namespace {
/// Container of Attribute or Type for Dialect.
struct AttrOrType {
- std::vector<Record *> attr, type;
+ std::vector<const Record *> attr, type;
};
} // namespace
@@ -435,14 +437,14 @@ static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
it->getValueAsString("dialect") != selectedBcDialect)
continue;
dialectAttrOrType[it->getValueAsString("dialect")].attr =
- it->getValueAsListOfDefs("elems");
+ it->getValueAsListOfConstDefs("elems");
}
for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) {
if (!selectedBcDialect.empty() &&
it->getValueAsString("dialect") != selectedBcDialect)
continue;
dialectAttrOrType[it->getValueAsString("dialect")].type =
- it->getValueAsListOfDefs("elems");
+ it->getValueAsListOfConstDefs("elems");
}
if (dialectAttrOrType.size() != 1)
@@ -452,7 +454,7 @@ static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
auto it = dialectAttrOrType.front();
Generator gen(os);
- SmallVector<std::vector<Record *> *, 2> vecs;
+ SmallVector<std::vector<const Record *> *, 2> vecs;
SmallVector<std::string, 2> kinds;
vecs.push_back(&it.second.attr);
kinds.push_back("attribute");
@@ -460,7 +462,8 @@ static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
kinds.push_back("type");
for (auto [vec, kind] : zip(vecs, kinds)) {
// Handle Attribute/Type emission.
- std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
+ std::map<std::string, std::vector<std::pair<int64_t, const Record *>>>
+ perType;
for (auto kt : llvm::enumerate(*vec))
perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
for (const auto &jt : perType) {
diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
index c1e0e9fab6682d..c529c190382c6f 100644
--- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp
+++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
@@ -318,7 +318,7 @@ static void genOperandsDef(const Record *op, raw_ostream &os) {
return;
SmallVector<std::string> clauseNames;
- for (Record *clause : op->getValueAsListOfDefs("clauseList"))
+ for (const Record *clause : op->getValueAsListOfDefs("clauseList"))
clauseNames.push_back((extractOmpClauseName(clause) + "ClauseOps").str());
StringRef opName = stripPrefixAndSuffix(
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index bf759572d25013..ed9d90a25625fc 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -155,7 +155,7 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
llvm::raw_string_ostream os(effectStr);
os << effectName << "{";
auto list = trait.getDef().getValueAsListOfDefs("effects");
- llvm::interleaveComma(list, os, [&](Record *rec) {
+ llvm::interleaveComma(list, os, [&](const Record *rec) {
StringRef effect = rec->getValueAsString("effect");
effect.consume_front("::");
effect.consume_front("mlir::");
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 8f830cdf513fbd..fa4925cbeed2fd 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -166,7 +166,8 @@ std::vector<Availability> getAvailabilities(const Record &def) {
std::vector<Availability> availabilities;
if (def.getValue("availability")) {
- std::vector<Record *> availDefs = def.getValueAsListOfDefs("availability");
+ std::vector<const Record *> availDefs =
+ def.getValueAsListOfConstDefs("availability");
availabilities.reserve(availDefs.size());
for (const Record *avail : availDefs)
availabilities.emplace_back(avail);
@@ -1449,7 +1450,8 @@ static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
if (!def.getValue("implies"))
continue;
- std::vector<Record *> impliedCapsDefs = def.getValueAsListOfDefs("implies");
+ std::vector<const Record *> impliedCapsDefs =
+ def.getValueAsListOfConstDefs("implies");
os << " case spirv::Capability::" << enumerant.getSymbol()
<< ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
<< "] = {";
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index d0a3552fb123da..f0fd5bba1a65a1 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -249,7 +249,7 @@ Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
std::vector<Value> constraints;
constraints.push_back(createTypeConstraint(
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
- for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
+ for (const Record *child : predRec.getValueAsListOfDefs("predicateList")) {
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
}
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
@@ -273,7 +273,8 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
std::vector<Value> constraints;
constraints.push_back(createAttrConstraint(
builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
- for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
+ for (const Record *child :
+ predRec.getValueAsListOfDefs("attrConstraints")) {
constraints.push_back(createPredicate(
builder, tblgen::Pred(child->getValueAsDef("predicate"))));
}
@@ -283,7 +284,8 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
if (predRec.isSubClassOf("AnyAttrOf")) {
std::vector<Value> constraints;
- for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
+ for (const Record *child :
+ predRec.getValueAsListOfDefs("allowedAttributes")) {
constraints.push_back(
createAttrConstraint(builder, tblgen::Constraint(child)));
}
More information about the Mlir-commits
mailing list