[clang] 639a7ac - [Clang][AST] Store injected template arguments in TemplateParameterList (#113579)

via cfe-commits cfe-commits at lists.llvm.org
Tue Oct 29 10:36:58 PDT 2024


Author: Krystian Stasiowski
Date: 2024-10-29T13:36:55-04:00
New Revision: 639a7ac648f1e50ccd2556e17d401c04f9cce625

URL: https://github.com/llvm/llvm-project/commit/639a7ac648f1e50ccd2556e17d401c04f9cce625
DIFF: https://github.com/llvm/llvm-project/commit/639a7ac648f1e50ccd2556e17d401c04f9cce625.diff

LOG: [Clang][AST] Store injected template arguments in TemplateParameterList (#113579)

Currently, we store injected template arguments in
`RedeclarableTemplateDecl::CommonBase`. This approach has a couple
problems:
1. We can only access the injected template arguments of
`RedeclarableTemplateDecl` derived types, but other `Decl` kinds still
make use of the injected arguments (e.g.
`ClassTemplatePartialSpecializationDecl`,
`VarTemplatePartialSpecializationDecl`, and `TemplateTemplateParmDecl`).
2. Accessing the injected template arguments requires the common data
structure to be allocated. This may occur before we determine whether a
previous declaration exists (e.g. when comparing constraints), so if the
template _is_ a redeclaration, we end up discarding the common data
structure.

This patch moves the storage and access of injected template arguments
from `RedeclarableTemplateDecl` to `TemplateParameterList`.

Added: 
    

Modified: 
    clang/include/clang/AST/ASTContext.h
    clang/include/clang/AST/DeclTemplate.h
    clang/lib/AST/ASTContext.cpp
    clang/lib/AST/DeclTemplate.cpp
    clang/lib/Sema/SemaTemplateDeduction.cpp
    clang/lib/Sema/SemaTemplateInstantiate.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index a4d36f2eacd5d1..07b4e36f3ef05e 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -239,7 +239,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
   mutable llvm::ContextualFoldingSet<DependentTemplateSpecializationType,
                                      ASTContext&>
     DependentTemplateSpecializationTypes;
-  llvm::FoldingSet<PackExpansionType> PackExpansionTypes;
+  mutable llvm::FoldingSet<PackExpansionType> PackExpansionTypes;
   mutable llvm::FoldingSet<ObjCObjectTypeImpl> ObjCObjectTypes;
   mutable llvm::FoldingSet<ObjCObjectPointerType> ObjCObjectPointerTypes;
   mutable llvm::FoldingSet<DependentUnaryTransformType>
@@ -1778,13 +1778,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
       ElaboratedTypeKeyword Keyword, NestedNameSpecifier *NNS,
       const IdentifierInfo *Name, ArrayRef<TemplateArgument> Args) const;
 
-  TemplateArgument getInjectedTemplateArg(NamedDecl *ParamDecl);
-
-  /// Get a template argument list with one argument per template parameter
-  /// in a template parameter list, such as for the injected class name of
-  /// a class template.
-  void getInjectedTemplateArgs(const TemplateParameterList *Params,
-                               SmallVectorImpl<TemplateArgument> &Args);
+  TemplateArgument getInjectedTemplateArg(NamedDecl *ParamDecl) const;
 
   /// Form a pack expansion type with the given pattern.
   /// \param NumExpansions The number of expansions for the pack, if known.
@@ -1795,7 +1789,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
   ///        if this is the canonical type of another pack expansion type.
   QualType getPackExpansionType(QualType Pattern,
                                 std::optional<unsigned> NumExpansions,
-                                bool ExpectPackInType = true);
+                                bool ExpectPackInType = true) const;
 
   QualType getObjCInterfaceType(const ObjCInterfaceDecl *Decl,
                                 ObjCInterfaceDecl *PrevDecl = nullptr) const;

diff  --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h
index 0f0c0bf6e4ef4f..a572e3380f1655 100644
--- a/clang/include/clang/AST/DeclTemplate.h
+++ b/clang/include/clang/AST/DeclTemplate.h
@@ -71,6 +71,9 @@ NamedDecl *getAsNamedDecl(TemplateParameter P);
 class TemplateParameterList final
     : private llvm::TrailingObjects<TemplateParameterList, NamedDecl *,
                                     Expr *> {
+  /// The template argument list of the template parameter list.
+  TemplateArgument *InjectedArgs = nullptr;
+
   /// The location of the 'template' keyword.
   SourceLocation TemplateLoc;
 
@@ -196,6 +199,9 @@ class TemplateParameterList final
 
   bool hasAssociatedConstraints() const;
 
+  /// Get the template argument list of the template parameter list.
+  ArrayRef<TemplateArgument> getInjectedTemplateArgs(const ASTContext &Context);
+
   SourceLocation getTemplateLoc() const { return TemplateLoc; }
   SourceLocation getLAngleLoc() const { return LAngleLoc; }
   SourceLocation getRAngleLoc() const { return RAngleLoc; }
@@ -793,15 +799,6 @@ class RedeclarableTemplateDecl : public TemplateDecl,
     /// The first value in the array is the number of specializations/partial
     /// specializations that follow.
     GlobalDeclID *LazySpecializations = nullptr;
-
-    /// The set of "injected" template arguments used within this
-    /// template.
-    ///
-    /// This pointer refers to the template arguments (there are as
-    /// many template arguments as template parameters) for the
-    /// template, and is allocated lazily, since most templates do not
-    /// require the use of this information.
-    TemplateArgument *InjectedArgs = nullptr;
   };
 
   /// Pointer to the common data shared by all declarations of this
@@ -927,7 +924,10 @@ class RedeclarableTemplateDecl : public TemplateDecl,
   /// Although the C++ standard has no notion of the "injected" template
   /// arguments for a template, the notion is convenient when
   /// we need to perform substitutions inside the definition of a template.
-  ArrayRef<TemplateArgument> getInjectedTemplateArgs();
+  ArrayRef<TemplateArgument>
+  getInjectedTemplateArgs(const ASTContext &Context) const {
+    return getTemplateParameters()->getInjectedTemplateArgs(Context);
+  }
 
   using redecl_range = redeclarable_base::redecl_range;
   using redecl_iterator = redeclarable_base::redecl_iterator;
@@ -2087,10 +2087,6 @@ class ClassTemplatePartialSpecializationDecl
   /// The list of template parameters
   TemplateParameterList *TemplateParams = nullptr;
 
-  /// The set of "injected" template arguments used within this
-  /// partial specialization.
-  TemplateArgument *InjectedArgs = nullptr;
-
   /// The class template partial specialization from which this
   /// class template partial specialization was instantiated.
   ///
@@ -2136,9 +2132,11 @@ class ClassTemplatePartialSpecializationDecl
     return TemplateParams;
   }
 
-  /// Retrieve the template arguments list of the template parameter list
-  /// of this template.
-  ArrayRef<TemplateArgument> getInjectedTemplateArgs();
+  /// Get the template argument list of the template parameter list.
+  ArrayRef<TemplateArgument>
+  getInjectedTemplateArgs(const ASTContext &Context) const {
+    return getTemplateParameters()->getInjectedTemplateArgs(Context);
+  }
 
   /// \brief All associated constraints of this partial specialization,
   /// including the requires clause and any constraints derived from
@@ -2864,10 +2862,6 @@ class VarTemplatePartialSpecializationDecl
   /// The list of template parameters
   TemplateParameterList *TemplateParams = nullptr;
 
-  /// The set of "injected" template arguments used within this
-  /// partial specialization.
-  TemplateArgument *InjectedArgs = nullptr;
-
   /// The variable template partial specialization from which this
   /// variable template partial specialization was instantiated.
   ///
@@ -2914,9 +2908,11 @@ class VarTemplatePartialSpecializationDecl
     return TemplateParams;
   }
 
-  /// Retrieve the template arguments list of the template parameter list
-  /// of this template.
-  ArrayRef<TemplateArgument> getInjectedTemplateArgs();
+  /// Get the template argument list of the template parameter list.
+  ArrayRef<TemplateArgument>
+  getInjectedTemplateArgs(const ASTContext &Context) const {
+    return getTemplateParameters()->getInjectedTemplateArgs(Context);
+  }
 
   /// \brief All associated constraints of this partial specialization,
   /// including the requires clause and any constraints derived from

diff  --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 69892bda42b256..1c3f771f417ccf 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -5634,7 +5634,7 @@ ASTContext::getDependentTemplateSpecializationType(
   return QualType(T, 0);
 }
 
-TemplateArgument ASTContext::getInjectedTemplateArg(NamedDecl *Param) {
+TemplateArgument ASTContext::getInjectedTemplateArg(NamedDecl *Param) const {
   TemplateArgument Arg;
   if (const auto *TTP = dyn_cast<TemplateTypeParmDecl>(Param)) {
     QualType ArgType = getTypeDeclType(TTP);
@@ -5678,23 +5678,15 @@ TemplateArgument ASTContext::getInjectedTemplateArg(NamedDecl *Param) {
   }
 
   if (Param->isTemplateParameterPack())
-    Arg = TemplateArgument::CreatePackCopy(*this, Arg);
+    Arg =
+        TemplateArgument::CreatePackCopy(const_cast<ASTContext &>(*this), Arg);
 
   return Arg;
 }
 
-void
-ASTContext::getInjectedTemplateArgs(const TemplateParameterList *Params,
-                                    SmallVectorImpl<TemplateArgument> &Args) {
-  Args.reserve(Args.size() + Params->size());
-
-  for (NamedDecl *Param : *Params)
-    Args.push_back(getInjectedTemplateArg(Param));
-}
-
 QualType ASTContext::getPackExpansionType(QualType Pattern,
                                           std::optional<unsigned> NumExpansions,
-                                          bool ExpectPackInType) {
+                                          bool ExpectPackInType) const {
   assert((!ExpectPackInType || Pattern->containsUnexpandedParameterPack()) &&
          "Pack expansions must expand one or more parameter packs");
 

diff  --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index 4a506b7be45642..755ec72f00bf77 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -51,7 +51,7 @@ DefaultTemplateArgumentContainsUnexpandedPack(const TemplateParam &P) {
          P.getDefaultArgument().getArgument().containsUnexpandedParameterPack();
 }
 
-TemplateParameterList::TemplateParameterList(const ASTContext& C,
+TemplateParameterList::TemplateParameterList(const ASTContext &C,
                                              SourceLocation TemplateLoc,
                                              SourceLocation LAngleLoc,
                                              ArrayRef<NamedDecl *> Params,
@@ -244,6 +244,17 @@ bool TemplateParameterList::hasAssociatedConstraints() const {
   return HasRequiresClause || HasConstrainedParameters;
 }
 
+ArrayRef<TemplateArgument>
+TemplateParameterList::getInjectedTemplateArgs(const ASTContext &Context) {
+  if (!InjectedArgs) {
+    InjectedArgs = new (Context) TemplateArgument[size()];
+    llvm::transform(*this, InjectedArgs, [&](NamedDecl *ND) {
+      return Context.getInjectedTemplateArg(ND);
+    });
+  }
+  return {InjectedArgs, NumParams};
+}
+
 bool TemplateParameterList::shouldIncludeTypeForArgument(
     const PrintingPolicy &Policy, const TemplateParameterList *TPL,
     unsigned Idx) {
@@ -396,22 +407,6 @@ void RedeclarableTemplateDecl::addSpecializationImpl(
                                       SETraits::getDecl(Entry));
 }
 
-ArrayRef<TemplateArgument> RedeclarableTemplateDecl::getInjectedTemplateArgs() {
-  TemplateParameterList *Params = getTemplateParameters();
-  auto *CommonPtr = getCommonPtr();
-  if (!CommonPtr->InjectedArgs) {
-    auto &Context = getASTContext();
-    SmallVector<TemplateArgument, 16> TemplateArgs;
-    Context.getInjectedTemplateArgs(Params, TemplateArgs);
-    CommonPtr->InjectedArgs =
-        new (Context) TemplateArgument[TemplateArgs.size()];
-    std::copy(TemplateArgs.begin(), TemplateArgs.end(),
-              CommonPtr->InjectedArgs);
-  }
-
-  return llvm::ArrayRef(CommonPtr->InjectedArgs, Params->size());
-}
-
 //===----------------------------------------------------------------------===//
 // FunctionTemplateDecl Implementation
 //===----------------------------------------------------------------------===//
@@ -631,13 +626,10 @@ ClassTemplateDecl::getInjectedClassNameSpecialization() {
   //  expansion (14.5.3) whose pattern is the name of the template parameter
   //  pack.
   ASTContext &Context = getASTContext();
-  TemplateParameterList *Params = getTemplateParameters();
-  SmallVector<TemplateArgument, 16> TemplateArgs;
-  Context.getInjectedTemplateArgs(Params, TemplateArgs);
   TemplateName Name = Context.getQualifiedTemplateName(
       /*NNS=*/nullptr, /*TemplateKeyword=*/false, TemplateName(this));
-  CommonPtr->InjectedClassNameType =
-      Context.getTemplateSpecializationType(Name, TemplateArgs);
+  CommonPtr->InjectedClassNameType = Context.getTemplateSpecializationType(
+      Name, getTemplateParameters()->getInjectedTemplateArgs(Context));
   return CommonPtr->InjectedClassNameType;
 }
 
@@ -1184,20 +1176,6 @@ SourceRange ClassTemplatePartialSpecializationDecl::getSourceRange() const {
   return Range;
 }
 
-ArrayRef<TemplateArgument>
-ClassTemplatePartialSpecializationDecl::getInjectedTemplateArgs() {
-  TemplateParameterList *Params = getTemplateParameters();
-  auto *First = cast<ClassTemplatePartialSpecializationDecl>(getFirstDecl());
-  if (!First->InjectedArgs) {
-    auto &Context = getASTContext();
-    SmallVector<TemplateArgument, 16> TemplateArgs;
-    Context.getInjectedTemplateArgs(Params, TemplateArgs);
-    First->InjectedArgs = new (Context) TemplateArgument[TemplateArgs.size()];
-    std::copy(TemplateArgs.begin(), TemplateArgs.end(), First->InjectedArgs);
-  }
-  return llvm::ArrayRef(First->InjectedArgs, Params->size());
-}
-
 //===----------------------------------------------------------------------===//
 // FriendTemplateDecl Implementation
 //===----------------------------------------------------------------------===//
@@ -1548,20 +1526,6 @@ SourceRange VarTemplatePartialSpecializationDecl::getSourceRange() const {
   return Range;
 }
 
-ArrayRef<TemplateArgument>
-VarTemplatePartialSpecializationDecl::getInjectedTemplateArgs() {
-  TemplateParameterList *Params = getTemplateParameters();
-  auto *First = cast<VarTemplatePartialSpecializationDecl>(getFirstDecl());
-  if (!First->InjectedArgs) {
-    auto &Context = getASTContext();
-    SmallVector<TemplateArgument, 16> TemplateArgs;
-    Context.getInjectedTemplateArgs(Params, TemplateArgs);
-    First->InjectedArgs = new (Context) TemplateArgument[TemplateArgs.size()];
-    std::copy(TemplateArgs.begin(), TemplateArgs.end(), First->InjectedArgs);
-  }
-  return llvm::ArrayRef(First->InjectedArgs, Params->size());
-}
-
 static TemplateParameterList *
 createMakeIntegerSeqParameterList(const ASTContext &C, DeclContext *DC) {
   // typename T

diff  --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp
index db1d7fa237131a..b45f30fed49a64 100644
--- a/clang/lib/Sema/SemaTemplateDeduction.cpp
+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -6163,7 +6163,7 @@ struct TemplateArgumentListAreEqual {
             std::enable_if_t<!std::is_same_v<T1, T2>, bool> = true>
   bool operator()(T1 *Spec, T2 *Primary) {
     ArrayRef<TemplateArgument> Args1 = Spec->getTemplateArgs().asArray(),
-                               Args2 = Primary->getInjectedTemplateArgs();
+                               Args2 = Primary->getInjectedTemplateArgs(Ctx);
 
     for (unsigned I = 0, E = Args1.size(); I < E; ++I) {
       // We use profile, instead of structural comparison of the arguments,
@@ -6342,7 +6342,7 @@ bool Sema::isMoreSpecializedThanPrimary(
   VarTemplateDecl *Primary = Spec->getSpecializedTemplate();
   TemplateName Name(Primary);
   QualType PrimaryT = Context.getTemplateSpecializationType(
-      Name, Primary->getInjectedTemplateArgs());
+      Name, Primary->getInjectedTemplateArgs(Context));
   QualType PartialT = Context.getTemplateSpecializationType(
       Name, Spec->getTemplateArgs().asArray());
 
@@ -6372,18 +6372,14 @@ bool Sema::isTemplateTemplateParameterAtLeastAsSpecializedAs(
   //    - Each function template has a single function parameter whose type is
   //      a specialization of X with template arguments corresponding to the
   //      template parameters from the respective function template
-  SmallVector<TemplateArgument, 8> AArgs;
-  Context.getInjectedTemplateArgs(A, AArgs);
+  SmallVector<TemplateArgument, 8> AArgs(A->getInjectedTemplateArgs(Context));
 
   // Check P's arguments against A's parameter list. This will fill in default
   // template arguments as needed. AArgs are already correct by construction.
   // We can't just use CheckTemplateIdType because that will expand alias
   // templates.
-  SmallVector<TemplateArgument, 4> PArgs;
+  SmallVector<TemplateArgument, 4> PArgs(P->getInjectedTemplateArgs(Context));
   {
-    SFINAETrap Trap(*this);
-
-    Context.getInjectedTemplateArgs(P, PArgs);
     TemplateArgumentListInfo PArgList(P->getLAngleLoc(),
                                       P->getRAngleLoc());
     for (unsigned I = 0, N = P->size(); I != N; ++I) {
@@ -6399,6 +6395,7 @@ bool Sema::isTemplateTemplateParameterAtLeastAsSpecializedAs(
     }
     PArgs.clear();
 
+    SFINAETrap Trap(*this);
     // C++1z [temp.arg.template]p3:
     //   If the rewrite produces an invalid type, then P is not at least as
     //   specialized as A.

diff  --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index 6a55861fe5af3b..dea97bfce532c9 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -200,7 +200,7 @@ struct TemplateInstantiationArgumentCollecter
     if (Innermost)
       AddInnermostTemplateArguments(FTD);
     else if (ForConstraintInstantiation)
-      AddOuterTemplateArguments(FTD, FTD->getInjectedTemplateArgs(),
+      AddOuterTemplateArguments(FTD, FTD->getInjectedTemplateArgs(S.Context),
                                 /*Final=*/false);
 
     if (FTD->isMemberSpecialization())
@@ -219,7 +219,7 @@ struct TemplateInstantiationArgumentCollecter
     if (Innermost)
       AddInnermostTemplateArguments(VTD);
     else if (ForConstraintInstantiation)
-      AddOuterTemplateArguments(VTD, VTD->getInjectedTemplateArgs(),
+      AddOuterTemplateArguments(VTD, VTD->getInjectedTemplateArgs(S.Context),
                                 /*Final=*/false);
 
     if (VTD->isMemberSpecialization())
@@ -237,7 +237,8 @@ struct TemplateInstantiationArgumentCollecter
     if (Innermost)
       AddInnermostTemplateArguments(VTPSD);
     else if (ForConstraintInstantiation)
-      AddOuterTemplateArguments(VTPSD, VTPSD->getInjectedTemplateArgs(),
+      AddOuterTemplateArguments(VTPSD,
+                                VTPSD->getInjectedTemplateArgs(S.Context),
                                 /*Final=*/false);
 
     if (VTPSD->isMemberSpecialization())
@@ -254,7 +255,7 @@ struct TemplateInstantiationArgumentCollecter
     if (Innermost)
       AddInnermostTemplateArguments(CTD);
     else if (ForConstraintInstantiation)
-      AddOuterTemplateArguments(CTD, CTD->getInjectedTemplateArgs(),
+      AddOuterTemplateArguments(CTD, CTD->getInjectedTemplateArgs(S.Context),
                                 /*Final=*/false);
 
     if (CTD->isMemberSpecialization())
@@ -274,7 +275,8 @@ struct TemplateInstantiationArgumentCollecter
     if (Innermost)
       AddInnermostTemplateArguments(CTPSD);
     else if (ForConstraintInstantiation)
-      AddOuterTemplateArguments(CTPSD, CTPSD->getInjectedTemplateArgs(),
+      AddOuterTemplateArguments(CTPSD,
+                                CTPSD->getInjectedTemplateArgs(S.Context),
                                 /*Final=*/false);
 
     if (CTPSD->isMemberSpecialization())
@@ -290,7 +292,7 @@ struct TemplateInstantiationArgumentCollecter
     if (Innermost)
       AddInnermostTemplateArguments(TATD);
     else if (ForConstraintInstantiation)
-      AddOuterTemplateArguments(TATD, TATD->getInjectedTemplateArgs(),
+      AddOuterTemplateArguments(TATD, TATD->getInjectedTemplateArgs(S.Context),
                                 /*Final=*/false);
 
     return UseNextDecl(TATD);


        


More information about the cfe-commits mailing list