[llvm] 6251adc - [TableGen] Refactor the implementation of arguments to introduce ArgumentInit [nfc]

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 11 11:42:39 PDT 2023


Author: wangpc
Date: 2023-07-11T11:42:07-07:00
New Revision: 6251adc64d3d5fbc12c71c818a37a74b8999f989

URL: https://github.com/llvm/llvm-project/commit/6251adc64d3d5fbc12c71c818a37a74b8999f989
DIFF: https://github.com/llvm/llvm-project/commit/6251adc64d3d5fbc12c71c818a37a74b8999f989.diff

LOG: [TableGen] Refactor the implementation of arguments to introduce ArgumentInit [nfc]

A new Init type ArgumentInit is added to represent arguments.  We currently only support positional arguments; an upcoming change will add named argument support.

The index of argument in error message is removed.

Differential Revision: https://reviews.llvm.org/D154066

Added: 
    

Modified: 
    llvm/include/llvm/TableGen/Record.h
    llvm/lib/TableGen/Record.cpp
    llvm/lib/TableGen/TGParser.cpp
    llvm/lib/TableGen/TGParser.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 321b0f4566be84..b77336a896fc4d 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -317,7 +317,8 @@ class Init {
     IK_VarBitInit,
     IK_VarDefInit,
     IK_LastTypedInit,
-    IK_UnsetInit
+    IK_UnsetInit,
+    IK_ArgumentInit,
   };
 
 private:
@@ -480,6 +481,39 @@ class UnsetInit : public Init {
   std::string getAsString() const override { return "?"; }
 };
 
+// Represent an argument.
+class ArgumentInit : public Init, public FoldingSetNode {
+  Init *Value;
+
+protected:
+  explicit ArgumentInit(Init *Value) : Init(IK_ArgumentInit), Value(Value) {}
+
+public:
+  ArgumentInit(const ArgumentInit &) = delete;
+  ArgumentInit &operator=(const ArgumentInit &) = delete;
+
+  static bool classof(const Init *I) { return I->getKind() == IK_ArgumentInit; }
+
+  RecordKeeper &getRecordKeeper() const { return Value->getRecordKeeper(); }
+
+  static ArgumentInit *get(Init *Value);
+
+  Init *getValue() const { return Value; }
+
+  void Profile(FoldingSetNodeID &ID) const;
+
+  Init *resolveReferences(Resolver &R) const override;
+  std::string getAsString() const override { return Value->getAsString(); }
+
+  bool isComplete() const override { return false; }
+  bool isConcrete() const override { return false; }
+  Init *getBit(unsigned Bit) const override { return Value->getBit(Bit); }
+  Init *getCastTo(RecTy *Ty) const override { return Value->getCastTo(Ty); }
+  Init *convertInitializerTo(RecTy *Ty) const override {
+    return Value->convertInitializerTo(Ty);
+  }
+};
+
 /// 'true'/'false' - Represent a concrete initializer for a bit.
 class BitInit final : public TypedInit {
   friend detail::RecordKeeperImpl;
@@ -1278,8 +1312,9 @@ class DefInit : public TypedInit {
 
 /// classname<targs...> - Represent an uninstantiated anonymous class
 /// instantiation.
-class VarDefInit final : public TypedInit, public FoldingSetNode,
-                         public TrailingObjects<VarDefInit, Init *> {
+class VarDefInit final : public TypedInit,
+                         public FoldingSetNode,
+                         public TrailingObjects<VarDefInit, ArgumentInit *> {
   Record *Class;
   DefInit *Def = nullptr; // after instantiation
   unsigned NumArgs;
@@ -1298,7 +1333,7 @@ class VarDefInit final : public TypedInit, public FoldingSetNode,
   static bool classof(const Init *I) {
     return I->getKind() == IK_VarDefInit;
   }
-  static VarDefInit *get(Record *Class, ArrayRef<Init *> Args);
+  static VarDefInit *get(Record *Class, ArrayRef<ArgumentInit *> Args);
 
   void Profile(FoldingSetNodeID &ID) const;
 
@@ -1307,20 +1342,24 @@ class VarDefInit final : public TypedInit, public FoldingSetNode,
 
   std::string getAsString() const override;
 
-  Init *getArg(unsigned i) const {
+  ArgumentInit *getArg(unsigned i) const {
     assert(i < NumArgs && "Argument index out of range!");
-    return getTrailingObjects<Init *>()[i];
+    return getTrailingObjects<ArgumentInit *>()[i];
   }
 
-  using const_iterator = Init *const *;
+  using const_iterator = ArgumentInit *const *;
 
-  const_iterator args_begin() const { return getTrailingObjects<Init *>(); }
+  const_iterator args_begin() const {
+    return getTrailingObjects<ArgumentInit *>();
+  }
   const_iterator args_end  () const { return args_begin() + NumArgs; }
 
   size_t         args_size () const { return NumArgs; }
   bool           args_empty() const { return NumArgs == 0; }
 
-  ArrayRef<Init *> args() const { return ArrayRef(args_begin(), NumArgs); }
+  ArrayRef<ArgumentInit *> args() const {
+    return ArrayRef(args_begin(), NumArgs);
+  }
 
   Init *getBit(unsigned Bit) const override {
     llvm_unreachable("Illegal bit reference off anonymous def");

diff  --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index 039f5bcd93a249..c38e824143ad40 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -70,6 +70,7 @@ struct RecordKeeperImpl {
   BitInit TrueBitInit;
   BitInit FalseBitInit;
 
+  FoldingSet<ArgumentInit> TheArgumentInitPool;
   FoldingSet<BitsInit> TheBitsInitPool;
   std::map<int64_t, IntInit *> TheIntInitPool;
   StringMap<StringInit *, BumpPtrAllocator &> StringInitStringPool;
@@ -349,6 +350,8 @@ LLVM_DUMP_METHOD void Init::dump() const { return print(errs()); }
 RecordKeeper &Init::getRecordKeeper() const {
   if (auto *TyInit = dyn_cast<TypedInit>(this))
     return TyInit->getType()->getRecordKeeper();
+  if (auto *ArgInit = dyn_cast<ArgumentInit>(this))
+    return ArgInit->getRecordKeeper();
   return cast<UnsetInit>(this)->getRecordKeeper();
 }
 
@@ -364,6 +367,37 @@ Init *UnsetInit::convertInitializerTo(RecTy *Ty) const {
   return const_cast<UnsetInit *>(this);
 }
 
+static void ProfileArgumentInit(FoldingSetNodeID &ID, Init *Value) {
+  ID.AddPointer(Value);
+}
+
+void ArgumentInit::Profile(FoldingSetNodeID &ID) const {
+  ProfileArgumentInit(ID, Value);
+}
+
+ArgumentInit *ArgumentInit::get(Init *Value) {
+  FoldingSetNodeID ID;
+  ProfileArgumentInit(ID, Value);
+
+  RecordKeeper &RK = Value->getRecordKeeper();
+  detail::RecordKeeperImpl &RKImpl = RK.getImpl();
+  void *IP = nullptr;
+  if (ArgumentInit *I = RKImpl.TheArgumentInitPool.FindNodeOrInsertPos(ID, IP))
+    return I;
+
+  ArgumentInit *I = new (RKImpl.Allocator) ArgumentInit(Value);
+  RKImpl.TheArgumentInitPool.InsertNode(I, IP);
+  return I;
+}
+
+Init *ArgumentInit::resolveReferences(Resolver &R) const {
+  Init *NewValue = Value->resolveReferences(R);
+  if (NewValue != Value)
+    return ArgumentInit::get(NewValue);
+
+  return const_cast<ArgumentInit *>(this);
+}
+
 BitInit *BitInit::get(RecordKeeper &RK, bool V) {
   return V ? &RK.getImpl().TrueBitInit : &RK.getImpl().FalseBitInit;
 }
@@ -2131,9 +2165,8 @@ RecTy *DefInit::getFieldType(StringInit *FieldName) const {
 
 std::string DefInit::getAsString() const { return std::string(Def->getName()); }
 
-static void ProfileVarDefInit(FoldingSetNodeID &ID,
-                              Record *Class,
-                              ArrayRef<Init *> Args) {
+static void ProfileVarDefInit(FoldingSetNodeID &ID, Record *Class,
+                              ArrayRef<ArgumentInit *> Args) {
   ID.AddInteger(Args.size());
   ID.AddPointer(Class);
 
@@ -2145,7 +2178,7 @@ VarDefInit::VarDefInit(Record *Class, unsigned N)
     : TypedInit(IK_VarDefInit, RecordRecTy::get(Class)), Class(Class),
       NumArgs(N) {}
 
-VarDefInit *VarDefInit::get(Record *Class, ArrayRef<Init *> Args) {
+VarDefInit *VarDefInit::get(Record *Class, ArrayRef<ArgumentInit *> Args) {
   FoldingSetNodeID ID;
   ProfileVarDefInit(ID, Class, Args);
 
@@ -2154,11 +2187,11 @@ VarDefInit *VarDefInit::get(Record *Class, ArrayRef<Init *> Args) {
   if (VarDefInit *I = RK.TheVarDefInitPool.FindNodeOrInsertPos(ID, IP))
     return I;
 
-  void *Mem = RK.Allocator.Allocate(totalSizeToAlloc<Init *>(Args.size()),
-                                    alignof(VarDefInit));
+  void *Mem = RK.Allocator.Allocate(
+      totalSizeToAlloc<ArgumentInit *>(Args.size()), alignof(VarDefInit));
   VarDefInit *I = new (Mem) VarDefInit(Class, Args.size());
   std::uninitialized_copy(Args.begin(), Args.end(),
-                          I->getTrailingObjects<Init *>());
+                          I->getTrailingObjects<ArgumentInit *>());
   RK.TheVarDefInitPool.InsertNode(I, IP);
   return I;
 }
@@ -2188,7 +2221,7 @@ DefInit *VarDefInit::instantiate() {
 
     for (unsigned i = 0, e = TArgs.size(); i != e; ++i) {
       if (i < args_size())
-        R.set(TArgs[i], getArg(i));
+        R.set(TArgs[i], getArg(i)->getValue());
       else
         R.set(TArgs[i], NewRec->getValue(TArgs[i])->getValue());
 
@@ -2222,11 +2255,11 @@ DefInit *VarDefInit::instantiate() {
 Init *VarDefInit::resolveReferences(Resolver &R) const {
   TrackUnresolvedResolver UR(&R);
   bool Changed = false;
-  SmallVector<Init *, 8> NewArgs;
+  SmallVector<ArgumentInit *, 8> NewArgs;
   NewArgs.reserve(args_size());
 
-  for (Init *Arg : args()) {
-    Init *NewArg = Arg->resolveReferences(UR);
+  for (ArgumentInit *Arg : args()) {
+    auto *NewArg = cast<ArgumentInit>(Arg->resolveReferences(UR));
     NewArgs.push_back(NewArg);
     Changed |= NewArg != Arg;
   }

diff  --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index f373f2c4d7de57..6935f6e6d438ad 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -36,7 +36,7 @@ namespace llvm {
 struct SubClassReference {
   SMRange RefRange;
   Record *Rec;
-  SmallVector<Init*, 4> TemplateArgs;
+  SmallVector<ArgumentInit *, 4> TemplateArgs;
 
   SubClassReference() : Rec(nullptr) {}
 
@@ -46,7 +46,7 @@ struct SubClassReference {
 struct SubMultiClassReference {
   SMRange RefRange;
   MultiClass *MC;
-  SmallVector<Init*, 4> TemplateArgs;
+  SmallVector<ArgumentInit *, 4> TemplateArgs;
 
   SubMultiClassReference() : MC(nullptr) {}
 
@@ -569,7 +569,7 @@ bool TGParser::addDefOne(std::unique_ptr<Record> Rec) {
   return false;
 }
 
-bool TGParser::resolveArguments(Record *Rec, ArrayRef<Init *> ArgValues,
+bool TGParser::resolveArguments(Record *Rec, ArrayRef<ArgumentInit *> ArgValues,
                                 SMLoc Loc, ArgValueHandler ArgValueHandler) {
   ArrayRef<Init *> ArgNames = Rec->getTemplateArgs();
   assert(ArgValues.size() <= ArgNames.size() &&
@@ -579,7 +579,7 @@ bool TGParser::resolveArguments(Record *Rec, ArrayRef<Init *> ArgValues,
   // handle the (name, value) pair. If not and there was no default, complain.
   for (unsigned I = 0, E = ArgNames.size(); I != E; ++I) {
     if (I < ArgValues.size())
-      ArgValueHandler(ArgNames[I], ArgValues[I]);
+      ArgValueHandler(ArgNames[I], ArgValues[I]->getValue());
     else {
       Init *Default = Rec->getValue(ArgNames[I])->getValue();
       if (!Default->isComplete())
@@ -597,7 +597,8 @@ bool TGParser::resolveArguments(Record *Rec, ArrayRef<Init *> ArgValues,
 /// Resolve the arguments of class and set them to MapResolver.
 /// Returns true if failed.
 bool TGParser::resolveArgumentsOfClass(MapResolver &R, Record *Rec,
-                                       ArrayRef<Init *> ArgValues, SMLoc Loc) {
+                                       ArrayRef<ArgumentInit *> ArgValues,
+                                       SMLoc Loc) {
   return resolveArguments(Rec, ArgValues, Loc,
                           [&](Init *Name, Init *Value) { R.set(Name, Value); });
 }
@@ -605,7 +606,7 @@ bool TGParser::resolveArgumentsOfClass(MapResolver &R, Record *Rec,
 /// Resolve the arguments of multiclass and store them into SubstStack.
 /// Returns true if failed.
 bool TGParser::resolveArgumentsOfMultiClass(SubstStack &Substs, MultiClass *MC,
-                                            ArrayRef<Init *> ArgValues,
+                                            ArrayRef<ArgumentInit *> ArgValues,
                                             Init *DefmName, SMLoc Loc) {
   // Add an implicit argument NAME.
   Substs.emplace_back(QualifiedNameOfImplicitName(MC), DefmName);
@@ -2596,7 +2597,7 @@ Init *TGParser::ParseSimpleValue(Record *CurRec, RecTy *ItemType,
       return nullptr;
     }
 
-    SmallVector<Init *, 8> Args;
+    SmallVector<ArgumentInit *, 8> Args;
     Lex.Lex(); // consume the <
     if (ParseTemplateArgValueList(Args, CurRec, Class))
       return nullptr; // Error parsing value list.
@@ -3121,8 +3122,8 @@ void TGParser::ParseValueList(SmallVectorImpl<Init *> &Result, Record *CurRec,
 // error was detected.
 //
 //   TemplateArgList ::= '<' [Value {',' Value}*] '>'
-bool TGParser::ParseTemplateArgValueList(SmallVectorImpl<Init *> &Result,
-                                         Record *CurRec, Record *ArgsRec) {
+bool TGParser::ParseTemplateArgValueList(
+    SmallVectorImpl<ArgumentInit *> &Result, Record *CurRec, Record *ArgsRec) {
 
   assert(Result.empty() && "Result vector is not empty");
   ArrayRef<Init *> TArgs = ArgsRec->getTemplateArgs();
@@ -3144,7 +3145,7 @@ bool TGParser::ParseTemplateArgValueList(SmallVectorImpl<Init *> &Result,
     Init *Value = ParseValue(CurRec, ItemType);
     if (!Value)
       return true;
-    Result.push_back(Value);
+    Result.push_back(ArgumentInit::get(Value));
 
     if (consume(tgtok::greater)) // end of argument list?
       return false;
@@ -4247,9 +4248,8 @@ bool TGParser::ParseFile() {
 // inheritance, multiclass invocation, or anonymous class invocation.
 // If necessary, replace an argument with a cast to the required type.
 // The argument count has already been checked.
-bool TGParser::CheckTemplateArgValues(SmallVectorImpl<llvm::Init *> &Values,
-                                      SMLoc Loc, Record *ArgsRec) {
-
+bool TGParser::CheckTemplateArgValues(
+    SmallVectorImpl<llvm::ArgumentInit *> &Values, SMLoc Loc, Record *ArgsRec) {
   ArrayRef<Init *> TArgs = ArgsRec->getTemplateArgs();
 
   for (unsigned I = 0, E = Values.size(); I < E; ++I) {
@@ -4257,13 +4257,13 @@ bool TGParser::CheckTemplateArgValues(SmallVectorImpl<llvm::Init *> &Values,
     RecTy *ArgType = Arg->getType();
     auto *Value = Values[I];
 
-    if (TypedInit *ArgValue = dyn_cast<TypedInit>(Value)) { 
+    if (TypedInit *ArgValue = dyn_cast<TypedInit>(Value->getValue())) {
       auto *CastValue = ArgValue->getCastTo(ArgType);
       if (CastValue) {
         assert((!isa<TypedInit>(CastValue) ||
                 cast<TypedInit>(CastValue)->getType()->typeIsA(ArgType)) &&
                "result of template arg value cast has wrong type");
-        Values[I] = CastValue;
+        Values[I] = ArgumentInit::get(CastValue);
       } else {
         PrintFatalError(Loc,
                         "Value specified for template argument '" +

diff  --git a/llvm/lib/TableGen/TGParser.h b/llvm/lib/TableGen/TGParser.h
index 538f8ae4c2d3d2..5f27a86be8eff0 100644
--- a/llvm/lib/TableGen/TGParser.h
+++ b/llvm/lib/TableGen/TGParser.h
@@ -244,13 +244,13 @@ class TGParser {
 
   using ArgValueHandler = std::function<void(Init *, Init *)>;
   bool resolveArguments(
-      Record *Rec, ArrayRef<Init *> ArgValues, SMLoc Loc,
+      Record *Rec, ArrayRef<ArgumentInit *> ArgValues, SMLoc Loc,
       ArgValueHandler ArgValueHandler = [](Init *, Init *) {});
   bool resolveArgumentsOfClass(MapResolver &R, Record *Rec,
-                               ArrayRef<Init *> ArgValues, SMLoc Loc);
+                               ArrayRef<ArgumentInit *> ArgValues, SMLoc Loc);
   bool resolveArgumentsOfMultiClass(SubstStack &Substs, MultiClass *MC,
-                                    ArrayRef<Init *> ArgValues, Init *DefmName,
-                                    SMLoc Loc);
+                                    ArrayRef<ArgumentInit *> ArgValues,
+                                    Init *DefmName, SMLoc Loc);
 
 private:  // Parser methods.
   bool consume(tgtok::TokKind K);
@@ -288,7 +288,7 @@ class TGParser {
                    IDParseMode Mode = ParseValueMode);
   void ParseValueList(SmallVectorImpl<llvm::Init*> &Result,
                       Record *CurRec, RecTy *ItemType = nullptr);
-  bool ParseTemplateArgValueList(SmallVectorImpl<llvm::Init *> &Result,
+  bool ParseTemplateArgValueList(SmallVectorImpl<llvm::ArgumentInit *> &Result,
                                  Record *CurRec, Record *ArgsRec);
   void ParseDagArgList(
       SmallVectorImpl<std::pair<llvm::Init*, StringInit*>> &Result,
@@ -312,7 +312,7 @@ class TGParser {
   MultiClass *ParseMultiClassID();
   bool ApplyLetStack(Record *CurRec);
   bool ApplyLetStack(RecordsEntry &Entry);
-  bool CheckTemplateArgValues(SmallVectorImpl<llvm::Init *> &Values,
+  bool CheckTemplateArgValues(SmallVectorImpl<llvm::ArgumentInit *> &Values,
                               SMLoc Loc, Record *ArgsRec);
 };
 


        


More information about the llvm-commits mailing list