[llvm] [mlir] [LLVM][TableGen] Change `RecordKeeper::getClass` to return const pointer (PR #112261)

Rahul Joshi via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 22 14:38:19 PDT 2024


https://github.com/jurahul updated https://github.com/llvm/llvm-project/pull/112261

>From 0b116f671a03846b2ababd8ef19b927ee2dfb8be Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Mon, 14 Oct 2024 13:49:12 -0700
Subject: [PATCH] [LLVM][TableGen] Change `RecordKeeper::getClass` to return
 const pointer

Change `RecordKeeper::getClass` to return const record pointer.
---
 llvm/include/llvm/TableGen/Record.h           |  8 +++---
 llvm/lib/TableGen/Record.cpp                  |  6 ++---
 llvm/lib/TableGen/TGParser.cpp                | 27 ++++++++++---------
 llvm/lib/TableGen/TGParser.h                  | 10 +++----
 .../TableGen/Common/CodeGenRegisters.cpp      |  2 +-
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |  2 +-
 6 files changed, 28 insertions(+), 27 deletions(-)

diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 78b44cfc649a5c..e64b78c3c1e3b9 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -1355,11 +1355,11 @@ class VarDefInit final
       public FoldingSetNode,
       public TrailingObjects<VarDefInit, const ArgumentInit *> {
   SMLoc Loc;
-  Record *Class;
+  const Record *Class;
   const DefInit *Def = nullptr; // after instantiation
   unsigned NumArgs;
 
-  explicit VarDefInit(SMLoc Loc, Record *Class, unsigned N);
+  explicit VarDefInit(SMLoc Loc, const Record *Class, unsigned N);
 
   const DefInit *instantiate();
 
@@ -1373,7 +1373,7 @@ class VarDefInit final
   static bool classof(const Init *I) {
     return I->getKind() == IK_VarDefInit;
   }
-  static const VarDefInit *get(SMLoc Loc, Record *Class,
+  static const VarDefInit *get(SMLoc Loc, const Record *Class,
                                ArrayRef<const ArgumentInit *> Args);
 
   void Profile(FoldingSetNodeID &ID) const;
@@ -2000,7 +2000,7 @@ class RecordKeeper {
   const GlobalMap &getGlobals() const { return ExtraGlobals; }
 
   /// Get the class with the specified name.
-  Record *getClass(StringRef Name) const {
+  const Record *getClass(StringRef Name) const {
     auto I = Classes.find(Name);
     return I == Classes.end() ? nullptr : I->second.get();
   }
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index 9241fb3d8e72d9..1d71482b020b22 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -2294,7 +2294,7 @@ const RecTy *DefInit::getFieldType(const StringInit *FieldName) const {
 
 std::string DefInit::getAsString() const { return std::string(Def->getName()); }
 
-static void ProfileVarDefInit(FoldingSetNodeID &ID, Record *Class,
+static void ProfileVarDefInit(FoldingSetNodeID &ID, const Record *Class,
                               ArrayRef<const ArgumentInit *> Args) {
   ID.AddInteger(Args.size());
   ID.AddPointer(Class);
@@ -2303,11 +2303,11 @@ static void ProfileVarDefInit(FoldingSetNodeID &ID, Record *Class,
     ID.AddPointer(I);
 }
 
-VarDefInit::VarDefInit(SMLoc Loc, Record *Class, unsigned N)
+VarDefInit::VarDefInit(SMLoc Loc, const Record *Class, unsigned N)
     : TypedInit(IK_VarDefInit, RecordRecTy::get(Class)), Loc(Loc), Class(Class),
       NumArgs(N) {}
 
-const VarDefInit *VarDefInit::get(SMLoc Loc, Record *Class,
+const VarDefInit *VarDefInit::get(SMLoc Loc, const Record *Class,
                                   ArrayRef<const ArgumentInit *> Args) {
   FoldingSetNodeID ID;
   ProfileVarDefInit(ID, Class, Args);
diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index f315557f38aadd..e01342ffcd3c8f 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -34,7 +34,7 @@ namespace llvm {
 
 struct SubClassReference {
   SMRange RefRange;
-  Record *Rec = nullptr;
+  const Record *Rec = nullptr;
   SmallVector<const ArgumentInit *, 4> TemplateArgs;
 
   SubClassReference() = default;
@@ -110,7 +110,7 @@ static void checkConcrete(Record &R) {
 
 /// Return an Init with a qualifier prefix referring
 /// to CurRec's name.
-static const Init *QualifyName(Record &CurRec, const Init *Name) {
+static const Init *QualifyName(const Record &CurRec, const Init *Name) {
   RecordKeeper &RK = CurRec.getRecords();
   const Init *NewName = BinOpInit::getStrConcat(
       CurRec.getNameInit(),
@@ -127,7 +127,7 @@ static const Init *QualifyName(MultiClass *MC, const Init *Name) {
 }
 
 /// Return the qualified version of the implicit 'NAME' template argument.
-static const Init *QualifiedNameOfImplicitName(Record &Rec) {
+static const Init *QualifiedNameOfImplicitName(const Record &Rec) {
   return QualifyName(Rec, StringInit::get(Rec.getRecords(), "NAME"));
 }
 
@@ -298,7 +298,7 @@ bool TGParser::SetValue(Record *CurRec, SMLoc Loc, const Init *ValName,
 /// AddSubClass - Add SubClass as a subclass to CurRec, resolving its template
 /// args as SubClass's template arguments.
 bool TGParser::AddSubClass(Record *CurRec, SubClassReference &SubClass) {
-  Record *SC = SubClass.Rec;
+  const Record *SC = SubClass.Rec;
   MapResolver R(CurRec);
 
   // Loop over all the subclass record's fields. Add regular fields to the new
@@ -588,7 +588,7 @@ bool TGParser::addDefOne(std::unique_ptr<Record> Rec) {
   return false;
 }
 
-bool TGParser::resolveArguments(Record *Rec,
+bool TGParser::resolveArguments(const Record *Rec,
                                 ArrayRef<const ArgumentInit *> ArgValues,
                                 SMLoc Loc, ArgValueHandler ArgValueHandler) {
   ArrayRef<const Init *> ArgNames = Rec->getTemplateArgs();
@@ -632,7 +632,7 @@ bool TGParser::resolveArguments(Record *Rec,
 
 /// Resolve the arguments of class and set them to MapResolver.
 /// Returns true if failed.
-bool TGParser::resolveArgumentsOfClass(MapResolver &R, Record *Rec,
+bool TGParser::resolveArgumentsOfClass(MapResolver &R, const Record *Rec,
                                        ArrayRef<const ArgumentInit *> ArgValues,
                                        SMLoc Loc) {
   return resolveArguments(
@@ -710,13 +710,13 @@ const Init *TGParser::ParseObjectName(MultiClass *CurMultiClass) {
 ///
 ///    ClassID ::= ID
 ///
-Record *TGParser::ParseClassID() {
+const Record *TGParser::ParseClassID() {
   if (Lex.getCode() != tgtok::Id) {
     TokError("expected name for ClassID");
     return nullptr;
   }
 
-  Record *Result = Records.getClass(Lex.getCurStrVal());
+  const Record *Result = Records.getClass(Lex.getCurStrVal());
   if (!Result) {
     std::string Msg("Couldn't find class '" + Lex.getCurStrVal() + "'");
     if (MultiClasses[Lex.getCurStrVal()].get())
@@ -2708,7 +2708,7 @@ const Init *TGParser::ParseSimpleValue(Record *CurRec, const RecTy *ItemType,
     // Value ::= CLASSID '<' ArgValueList '>' (CLASSID has been consumed)
     // This is supposed to synthesize a new anonymous definition, deriving
     // from the class with the template arguments, but no body.
-    Record *Class = Records.getClass(Name->getValue());
+    const Record *Class = Records.getClass(Name->getValue());
     if (!Class) {
       Error(NameLoc.Start,
             "Expected a class name, got '" + Name->getValue() + "'");
@@ -3196,7 +3196,7 @@ void TGParser::ParseValueList(SmallVectorImpl<const Init *> &Result,
 //   NamedArgValueList ::= [NameValue '=' Value {',' NameValue '=' Value}*]
 bool TGParser::ParseTemplateArgValueList(
     SmallVectorImpl<const ArgumentInit *> &Result, Record *CurRec,
-    Record *ArgsRec) {
+    const Record *ArgsRec) {
   assert(Result.empty() && "Result vector is not empty");
   ArrayRef<const Init *> TArgs = ArgsRec->getTemplateArgs();
 
@@ -3990,7 +3990,7 @@ bool TGParser::ParseClass() {
     return TokError("expected class name after 'class' keyword");
 
   const std::string &Name = Lex.getCurStrVal();
-  Record *CurRec = Records.getClass(Name);
+  Record *CurRec = const_cast<Record *>(Records.getClass(Name));
   if (CurRec) {
     // If the body was previously defined, this is an error.
     if (!CurRec->getValues().empty() ||
@@ -4411,7 +4411,8 @@ bool TGParser::ParseFile() {
 // If necessary, replace an argument with a cast to the required type.
 // The argument count has already been checked.
 bool TGParser::CheckTemplateArgValues(
-    SmallVectorImpl<const ArgumentInit *> &Values, SMLoc Loc, Record *ArgsRec) {
+    SmallVectorImpl<const ArgumentInit *> &Values, SMLoc Loc,
+    const Record *ArgsRec) {
   ArrayRef<const Init *> TArgs = ArgsRec->getTemplateArgs();
 
   for (const ArgumentInit *&Value : Values) {
@@ -4421,7 +4422,7 @@ bool TGParser::CheckTemplateArgValues(
     if (Value->isNamed())
       ArgName = Value->getName();
 
-    RecordVal *Arg = ArgsRec->getValue(ArgName);
+    const RecordVal *Arg = ArgsRec->getValue(ArgName);
     const RecTy *ArgType = Arg->getType();
 
     if (const auto *ArgValue = dyn_cast<TypedInit>(Value->getValue())) {
diff --git a/llvm/lib/TableGen/TGParser.h b/llvm/lib/TableGen/TGParser.h
index a1f1db6622aceb..cac1ba827f1138 100644
--- a/llvm/lib/TableGen/TGParser.h
+++ b/llvm/lib/TableGen/TGParser.h
@@ -248,9 +248,9 @@ class TGParser {
 
   using ArgValueHandler = std::function<void(const Init *, const Init *)>;
   bool resolveArguments(
-      Record *Rec, ArrayRef<const ArgumentInit *> ArgValues, SMLoc Loc,
+      const Record *Rec, ArrayRef<const ArgumentInit *> ArgValues, SMLoc Loc,
       ArgValueHandler ArgValueHandler = [](const Init *, const Init *) {});
-  bool resolveArgumentsOfClass(MapResolver &R, Record *Rec,
+  bool resolveArgumentsOfClass(MapResolver &R, const Record *Rec,
                                ArrayRef<const ArgumentInit *> ArgValues,
                                SMLoc Loc);
   bool resolveArgumentsOfMultiClass(SubstStack &Substs, MultiClass *MC,
@@ -296,7 +296,7 @@ class TGParser {
   void ParseValueList(SmallVectorImpl<const Init *> &Result, Record *CurRec,
                       const RecTy *ItemType = nullptr);
   bool ParseTemplateArgValueList(SmallVectorImpl<const ArgumentInit *> &Result,
-                                 Record *CurRec, Record *ArgsRec);
+                                 Record *CurRec, const Record *ArgsRec);
   void ParseDagArgList(
       SmallVectorImpl<std::pair<const Init *, const StringInit *>> &Result,
       Record *CurRec);
@@ -316,12 +316,12 @@ class TGParser {
   const Init *ParseOperationCond(Record *CurRec, const RecTy *ItemType);
   const RecTy *ParseOperatorType();
   const Init *ParseObjectName(MultiClass *CurMultiClass);
-  Record *ParseClassID();
+  const Record *ParseClassID();
   MultiClass *ParseMultiClassID();
   bool ApplyLetStack(Record *CurRec);
   bool ApplyLetStack(RecordsEntry &Entry);
   bool CheckTemplateArgValues(SmallVectorImpl<const ArgumentInit *> &Values,
-                              SMLoc Loc, Record *ArgsRec);
+                              SMLoc Loc, const Record *ArgsRec);
 };
 
 } // end namespace llvm
diff --git a/llvm/utils/TableGen/Common/CodeGenRegisters.cpp b/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
index 9e1ebf32c46444..78f6dcdf305ffe 100644
--- a/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
+++ b/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
@@ -649,7 +649,7 @@ struct TupleExpander : SetTheory::Expander {
       return;
 
     // Precompute some types.
-    Record *RegisterCl = Def->getRecords().getClass("Register");
+    const Record *RegisterCl = Def->getRecords().getClass("Register");
     const RecTy *RegisterRecTy = RecordRecTy::get(RegisterCl);
     std::vector<StringRef> RegNames =
         Def->getValueAsListOfStrings("RegAsmNames");
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 3480c81ff7d4bb..75286231f5902f 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -1184,7 +1184,7 @@ static bool emitSerializationFns(const RecordKeeper &records, raw_ostream &os) {
       utilsString;
   raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
       serFn(serFnString), deserFn(deserFnString);
-  Record *attrClass = records.getClass("Attr");
+  const Record *attrClass = records.getClass("Attr");
 
   // Emit the serialization and deserialization functions simultaneously.
   StringRef opVar("op");



More information about the llvm-commits mailing list