[Mlir-commits] [mlir] [MLIR][TableGen] Change MLIR TableGen to use cons Record * (PR #110687)

Rahul Joshi llvmlistbot at llvm.org
Tue Oct 1 08:27:38 PDT 2024


https://github.com/jurahul created https://github.com/llvm/llvm-project/pull/110687

None

>From 100f91697ee6f34ce1010d7df0d9cbc1371eb0c9 Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Tue, 1 Oct 2024 08:25:43 -0700
Subject: [PATCH] [MLIR][TableGen] Change MLIR TableGen to use cons Record *

---
 mlir/include/mlir/TableGen/Predicate.h        |  2 +-
 mlir/lib/TableGen/Predicate.cpp               |  4 +--
 mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp | 35 ++++++++++---------
 mlir/tools/mlir-tblgen/OmpOpGen.cpp           |  2 +-
 mlir/tools/mlir-tblgen/OpDocGen.cpp           |  2 +-
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |  6 ++--
 .../tools/tblgen-to-irdl/OpDefinitionsGen.cpp |  8 +++--
 7 files changed, 33 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h
index fd5a5a5dc99ae1..2eeb253d56b266 100644
--- a/mlir/include/mlir/TableGen/Predicate.h
+++ b/mlir/include/mlir/TableGen/Predicate.h
@@ -104,7 +104,7 @@ class CombinedPred : public Pred {
   const llvm::Record *getCombinerDef() const;
 
   // Get the predicates that are combined by this predicate.
-  std::vector<llvm::Record *> getChildren() const;
+  std::vector<const llvm::Record *> getChildren() const;
 };
 
 // A combined predicate that requires all child predicates of 'CPred' type to
diff --git a/mlir/lib/TableGen/Predicate.cpp b/mlir/lib/TableGen/Predicate.cpp
index 3c3c475f4d3659..a2122f15af1f6a 100644
--- a/mlir/lib/TableGen/Predicate.cpp
+++ b/mlir/lib/TableGen/Predicate.cpp
@@ -79,10 +79,10 @@ const llvm::Record *CombinedPred::getCombinerDef() const {
   return def->getValueAsDef("kind");
 }
 
-std::vector<llvm::Record *> CombinedPred::getChildren() const {
+std::vector<const llvm::Record *> CombinedPred::getChildren() const {
   assert(def->getValue("children") &&
          "CombinedPred must have a value 'children'");
-  return def->getValueAsListOfDefs("children");
+  return def->getValueAsListOfConstDefs("children");
 }
 
 namespace {
diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
index 964b33a9fa41f8..01669b701622c3 100644
--- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
@@ -32,14 +32,14 @@ class Generator {
   Generator(raw_ostream &output) : output(output) {}
 
   /// Returns whether successfully emitted attribute/type parsers.
-  void emitParse(StringRef kind, Record &x);
+  void emitParse(StringRef kind, const Record &x);
 
   /// Returns whether successfully emitted attribute/type printers.
   void emitPrint(StringRef kind, StringRef type,
-                 ArrayRef<std::pair<int64_t, Record *>> vec);
+                 ArrayRef<std::pair<int64_t, const Record *>> vec);
 
   /// Emits parse dispatch table.
-  void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
+  void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec);
 
   /// Emits print dispatch table.
   void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
@@ -51,8 +51,9 @@ class Generator {
                        StringRef failure, mlir::raw_indented_ostream &ios);
 
   /// Emits print instructions.
-  void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
-                       StringRef name, mlir::raw_indented_ostream &ios);
+  void emitPrintHelper(const Record *memberRec, StringRef kind,
+                       StringRef parent, StringRef name,
+                       mlir::raw_indented_ostream &ios);
 
   raw_ostream &output;
 };
@@ -75,7 +76,7 @@ static std::string capitalize(StringRef str) {
 }
 
 /// Return the C++ type for the given record.
-static std::string getCType(Record *def) {
+static std::string getCType(const Record *def) {
   std::string format = "{0}";
   if (def->isSubClassOf("Array")) {
     def = def->getValueAsDef("elemT");
@@ -92,7 +93,8 @@ static std::string getCType(Record *def) {
   return formatv(format.c_str(), cType.str());
 }
 
-void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
+void Generator::emitParseDispatch(StringRef kind,
+                                  ArrayRef<const Record *> vec) {
   mlir::raw_indented_ostream os(output);
   char const *head =
       R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
@@ -126,7 +128,7 @@ void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
   os << "return " << capitalize(kind) << "();\n";
 }
 
-void Generator::emitParse(StringRef kind, Record &x) {
+void Generator::emitParse(StringRef kind, const Record &x) {
   if (x.getNameInitAsString() == "ReservedOrDead")
     return;
 
@@ -293,7 +295,7 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
 }
 
 void Generator::emitPrint(StringRef kind, StringRef type,
-                          ArrayRef<std::pair<int64_t, Record *>> vec) {
+                          ArrayRef<std::pair<int64_t, const Record *>> vec) {
   if (type == "ReservedOrDead")
     return;
 
@@ -304,7 +306,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
   auto funScope = os.scope("{\n", "}\n\n");
 
   // Check that predicates specified if multiple bytecode instances.
-  for (llvm::Record *rec : make_second_range(vec)) {
+  for (const llvm::Record *rec : make_second_range(vec)) {
     StringRef pred = rec->getValueAsString("printerPredicate");
     if (vec.size() > 1 && pred.empty()) {
       for (auto [index, rec] : vec) {
@@ -344,7 +346,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
   }
 }
 
-void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
+void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
                                 StringRef parent, StringRef name,
                                 mlir::raw_indented_ostream &ios) {
   std::string getter;
@@ -423,7 +425,7 @@ void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
 namespace {
 /// Container of Attribute or Type for Dialect.
 struct AttrOrType {
-  std::vector<Record *> attr, type;
+  std::vector<const Record *> attr, type;
 };
 } // namespace
 
@@ -435,14 +437,14 @@ static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
         it->getValueAsString("dialect") != selectedBcDialect)
       continue;
     dialectAttrOrType[it->getValueAsString("dialect")].attr =
-        it->getValueAsListOfDefs("elems");
+        it->getValueAsListOfConstDefs("elems");
   }
   for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) {
     if (!selectedBcDialect.empty() &&
         it->getValueAsString("dialect") != selectedBcDialect)
       continue;
     dialectAttrOrType[it->getValueAsString("dialect")].type =
-        it->getValueAsListOfDefs("elems");
+        it->getValueAsListOfConstDefs("elems");
   }
 
   if (dialectAttrOrType.size() != 1)
@@ -452,7 +454,7 @@ static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
   auto it = dialectAttrOrType.front();
   Generator gen(os);
 
-  SmallVector<std::vector<Record *> *, 2> vecs;
+  SmallVector<std::vector<const Record *> *, 2> vecs;
   SmallVector<std::string, 2> kinds;
   vecs.push_back(&it.second.attr);
   kinds.push_back("attribute");
@@ -460,7 +462,8 @@ static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
   kinds.push_back("type");
   for (auto [vec, kind] : zip(vecs, kinds)) {
     // Handle Attribute/Type emission.
-    std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
+    std::map<std::string, std::vector<std::pair<int64_t, const Record *>>>
+        perType;
     for (auto kt : llvm::enumerate(*vec))
       perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
     for (const auto &jt : perType) {
diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
index c1e0e9fab6682d..c529c190382c6f 100644
--- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp
+++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
@@ -318,7 +318,7 @@ static void genOperandsDef(const Record *op, raw_ostream &os) {
     return;
 
   SmallVector<std::string> clauseNames;
-  for (Record *clause : op->getValueAsListOfDefs("clauseList"))
+  for (const Record *clause : op->getValueAsListOfDefs("clauseList"))
     clauseNames.push_back((extractOmpClauseName(clause) + "ClauseOps").str());
 
   StringRef opName = stripPrefixAndSuffix(
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index bf759572d25013..ed9d90a25625fc 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -155,7 +155,7 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
         llvm::raw_string_ostream os(effectStr);
         os << effectName << "{";
         auto list = trait.getDef().getValueAsListOfDefs("effects");
-        llvm::interleaveComma(list, os, [&](Record *rec) {
+        llvm::interleaveComma(list, os, [&](const Record *rec) {
           StringRef effect = rec->getValueAsString("effect");
           effect.consume_front("::");
           effect.consume_front("mlir::");
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 8f830cdf513fbd..fa4925cbeed2fd 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -166,7 +166,8 @@ std::vector<Availability> getAvailabilities(const Record &def) {
   std::vector<Availability> availabilities;
 
   if (def.getValue("availability")) {
-    std::vector<Record *> availDefs = def.getValueAsListOfDefs("availability");
+    std::vector<const Record *> availDefs =
+        def.getValueAsListOfConstDefs("availability");
     availabilities.reserve(availDefs.size());
     for (const Record *avail : availDefs)
       availabilities.emplace_back(avail);
@@ -1449,7 +1450,8 @@ static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
     if (!def.getValue("implies"))
       continue;
 
-    std::vector<Record *> impliedCapsDefs = def.getValueAsListOfDefs("implies");
+    std::vector<const Record *> impliedCapsDefs =
+        def.getValueAsListOfConstDefs("implies");
     os << "  case spirv::Capability::" << enumerant.getSymbol()
        << ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
        << "] = {";
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index d0a3552fb123da..f0fd5bba1a65a1 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -249,7 +249,7 @@ Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     std::vector<Value> constraints;
     constraints.push_back(createTypeConstraint(
         builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
-    for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
+    for (const Record *child : predRec.getValueAsListOfDefs("predicateList")) {
       constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
     }
     auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
@@ -273,7 +273,8 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     std::vector<Value> constraints;
     constraints.push_back(createAttrConstraint(
         builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
-    for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
+    for (const Record *child :
+         predRec.getValueAsListOfDefs("attrConstraints")) {
       constraints.push_back(createPredicate(
           builder, tblgen::Pred(child->getValueAsDef("predicate"))));
     }
@@ -283,7 +284,8 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
 
   if (predRec.isSubClassOf("AnyAttrOf")) {
     std::vector<Value> constraints;
-    for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
+    for (const Record *child :
+         predRec.getValueAsListOfDefs("allowedAttributes")) {
       constraints.push_back(
           createAttrConstraint(builder, tblgen::Constraint(child)));
     }



More information about the Mlir-commits mailing list