[clang] 9381c6f - [Clang][Sema] Use the correct injected template arguments for partial specializations when collecting multi-level template argument lists (#112381)

via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 16 07:40:08 PDT 2024


Author: Krystian Stasiowski
Date: 2024-10-16T10:40:03-04:00
New Revision: 9381c6fd04cc16a7606633f57c96c11e58181ddb

URL: https://github.com/llvm/llvm-project/commit/9381c6fd04cc16a7606633f57c96c11e58181ddb
DIFF: https://github.com/llvm/llvm-project/commit/9381c6fd04cc16a7606633f57c96c11e58181ddb.diff

LOG: [Clang][Sema] Use the correct injected template arguments for partial specializations when collecting multi-level template argument lists (#112381)

After #111852 refactored multi-level template argument list collection,
the following results in a crash:
```
template<typename T, bool B>
struct A;

template<bool B>
struct A<int, B>
{
    void f() requires B;
};

template<bool B>
void A<int, B>::f() requires B { } // crash here
```

This happens because when collecting template arguments for constraint
normalization from a partial specialization, we incorrectly use the
template argument list of the partial specialization. We should be using
the template argument list of the _template-head_ (as defined in
[temp.arg.general] p2). Fixes #112222.

Added: 
    

Modified: 
    clang/include/clang/AST/DeclTemplate.h
    clang/lib/AST/DeclTemplate.cpp
    clang/lib/Sema/SemaTemplateInstantiate.cpp
    clang/test/CXX/temp/temp.constr/temp.constr.decl/p4.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h
index 141f58c4600af0..0f0c0bf6e4ef4f 100644
--- a/clang/include/clang/AST/DeclTemplate.h
+++ b/clang/include/clang/AST/DeclTemplate.h
@@ -2085,7 +2085,11 @@ class ClassTemplateSpecializationDecl : public CXXRecordDecl,
 class ClassTemplatePartialSpecializationDecl
   : public ClassTemplateSpecializationDecl {
   /// The list of template parameters
-  TemplateParameterList* TemplateParams = nullptr;
+  TemplateParameterList *TemplateParams = nullptr;
+
+  /// The set of "injected" template arguments used within this
+  /// partial specialization.
+  TemplateArgument *InjectedArgs = nullptr;
 
   /// The class template partial specialization from which this
   /// class template partial specialization was instantiated.
@@ -2132,6 +2136,10 @@ class ClassTemplatePartialSpecializationDecl
     return TemplateParams;
   }
 
+  /// Retrieve the template arguments list of the template parameter list
+  /// of this template.
+  ArrayRef<TemplateArgument> getInjectedTemplateArgs();
+
   /// \brief All associated constraints of this partial specialization,
   /// including the requires clause and any constraints derived from
   /// constrained-parameters.
@@ -2856,6 +2864,10 @@ class VarTemplatePartialSpecializationDecl
   /// The list of template parameters
   TemplateParameterList *TemplateParams = nullptr;
 
+  /// The set of "injected" template arguments used within this
+  /// partial specialization.
+  TemplateArgument *InjectedArgs = nullptr;
+
   /// The variable template partial specialization from which this
   /// variable template partial specialization was instantiated.
   ///
@@ -2902,6 +2914,10 @@ class VarTemplatePartialSpecializationDecl
     return TemplateParams;
   }
 
+  /// Retrieve the template arguments list of the template parameter list
+  /// of this template.
+  ArrayRef<TemplateArgument> getInjectedTemplateArgs();
+
   /// \brief All associated constraints of this partial specialization,
   /// including the requires clause and any constraints derived from
   /// constrained-parameters.

diff  --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index d9b67b7bedf5a5..d2d8907b884ec8 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -1185,6 +1185,20 @@ SourceRange ClassTemplatePartialSpecializationDecl::getSourceRange() const {
   return Range;
 }
 
+ArrayRef<TemplateArgument>
+ClassTemplatePartialSpecializationDecl::getInjectedTemplateArgs() {
+  TemplateParameterList *Params = getTemplateParameters();
+  auto *First = cast<ClassTemplatePartialSpecializationDecl>(getFirstDecl());
+  if (!First->InjectedArgs) {
+    auto &Context = getASTContext();
+    SmallVector<TemplateArgument, 16> TemplateArgs;
+    Context.getInjectedTemplateArgs(Params, TemplateArgs);
+    First->InjectedArgs = new (Context) TemplateArgument[TemplateArgs.size()];
+    std::copy(TemplateArgs.begin(), TemplateArgs.end(), First->InjectedArgs);
+  }
+  return llvm::ArrayRef(First->InjectedArgs, Params->size());
+}
+
 //===----------------------------------------------------------------------===//
 // FriendTemplateDecl Implementation
 //===----------------------------------------------------------------------===//
@@ -1535,6 +1549,20 @@ SourceRange VarTemplatePartialSpecializationDecl::getSourceRange() const {
   return Range;
 }
 
+ArrayRef<TemplateArgument>
+VarTemplatePartialSpecializationDecl::getInjectedTemplateArgs() {
+  TemplateParameterList *Params = getTemplateParameters();
+  auto *First = cast<VarTemplatePartialSpecializationDecl>(getFirstDecl());
+  if (!First->InjectedArgs) {
+    auto &Context = getASTContext();
+    SmallVector<TemplateArgument, 16> TemplateArgs;
+    Context.getInjectedTemplateArgs(Params, TemplateArgs);
+    First->InjectedArgs = new (Context) TemplateArgument[TemplateArgs.size()];
+    std::copy(TemplateArgs.begin(), TemplateArgs.end(), First->InjectedArgs);
+  }
+  return llvm::ArrayRef(First->InjectedArgs, Params->size());
+}
+
 static TemplateParameterList *
 createMakeIntegerSeqParameterList(const ASTContext &C, DeclContext *DC) {
   // typename T

diff  --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index 8c7f694c09042e..8665c099903dc3 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -237,7 +237,7 @@ struct TemplateInstantiationArgumentCollecter
     if (Innermost)
       AddInnermostTemplateArguments(VTPSD);
     else if (ForConstraintInstantiation)
-      AddOuterTemplateArguments(VTPSD, VTPSD->getTemplateArgs().asArray(),
+      AddOuterTemplateArguments(VTPSD, VTPSD->getInjectedTemplateArgs(),
                                 /*Final=*/false);
 
     if (VTPSD->isMemberSpecialization())
@@ -274,7 +274,7 @@ struct TemplateInstantiationArgumentCollecter
     if (Innermost)
       AddInnermostTemplateArguments(CTPSD);
     else if (ForConstraintInstantiation)
-      AddOuterTemplateArguments(CTPSD, CTPSD->getTemplateArgs().asArray(),
+      AddOuterTemplateArguments(CTPSD, CTPSD->getInjectedTemplateArgs(),
                                 /*Final=*/false);
 
     if (CTPSD->isMemberSpecialization())

diff  --git a/clang/test/CXX/temp/temp.constr/temp.constr.decl/p4.cpp b/clang/test/CXX/temp/temp.constr/temp.constr.decl/p4.cpp
index 70064f867e18e3..f144e14cd122f9 100644
--- a/clang/test/CXX/temp/temp.constr/temp.constr.decl/p4.cpp
+++ b/clang/test/CXX/temp/temp.constr/temp.constr.decl/p4.cpp
@@ -1,175 +1,219 @@
 // RUN: %clang_cc1 -std=c++20 -verify %s
 // expected-no-diagnostics
 
-template<typename T>
-concept D = true;
+namespace Primary {
+  template<typename T>
+  concept D = true;
 
-template<typename T>
-struct A {
-  template<typename U, bool V>
-  void f() requires V;
+  template<typename T>
+  struct A {
+    template<typename U, bool V>
+    void f() requires V;
 
-  template<>
-  void f<short, true>();
+    template<>
+    void f<short, true>();
+
+    template<D U>
+    void g();
+
+    template<typename U, bool V> requires V
+    struct B;
+
+    template<typename U, bool V> requires V
+    struct B<U*, V>;
+
+    template<>
+    struct B<short, true>;
+
+    template<D U>
+    struct C;
+
+    template<D U>
+    struct C<U*>;
 
+    template<typename U, bool V> requires V
+    static int x;
+
+    template<typename U, bool V> requires V
+    static int x<U*, V>;
+
+    template<>
+    int x<short, true>;
+
+    template<D U>
+    static int y;
+
+    template<D U>
+    static int y<U*>;
+  };
+
+  template<typename T>
+  template<typename U, bool V>
+  void A<T>::f() requires V { }
+
+  template<typename T>
   template<D U>
-  void g();
+  void A<T>::g() { }
 
+  template<typename T>
   template<typename U, bool V> requires V
-  struct B;
+  struct A<T>::B { };
 
+  template<typename T>
   template<typename U, bool V> requires V
-  struct B<U*, V>;
+  struct A<T>::B<U*, V> { };
 
-  template<>
-  struct B<short, true>;
+  template<typename T>
+  template<typename U, bool V> requires V
+  struct A<T>::B<U&, V> { };
 
+  template<typename T>
   template<D U>
-  struct C;
+  struct A<T>::C { };
 
+  template<typename T>
   template<D U>
-  struct C<U*>;
+  struct A<T>::C<U*> { };
 
+  template<typename T>
   template<typename U, bool V> requires V
-  static int x;
+  int A<T>::x = 0;
 
+  template<typename T>
   template<typename U, bool V> requires V
-  static int x<U*, V>;
+  int A<T>::x<U*, V> = 0;
 
-  template<>
-  int x<short, true>;
+  template<typename T>
+  template<typename U, bool V> requires V
+  int A<T>::x<U&, V> = 0;
 
+  template<typename T>
   template<D U>
-  static int y;
+  int A<T>::y = 0;
 
+  template<typename T>
   template<D U>
-  static int y<U*>;
-};
-
-template<typename T>
-template<typename U, bool V>
-void A<T>::f() requires V { }
+  int A<T>::y<U*> = 0;
 
-template<typename T>
-template<D U>
-void A<T>::g() { }
-
-template<typename T>
-template<typename U, bool V> requires V
-struct A<T>::B { };
+  template<>
+  template<typename U, bool V>
+  void A<short>::f() requires V;
 
-template<typename T>
-template<typename U, bool V> requires V
-struct A<T>::B<U*, V> { };
+  template<>
+  template<>
+  void A<short>::f<int, true>();
 
-template<typename T>
-template<typename U, bool V> requires V
-struct A<T>::B<U&, V> { };
+  template<>
+  template<>
+  void A<void>::f<int, true>();
 
-template<typename T>
-template<D U>
-struct A<T>::C { };
+  template<>
+  template<D U>
+  void A<short>::g();
 
-template<typename T>
-template<D U>
-struct A<T>::C<U*> { };
+  template<>
+  template<typename U, bool V> requires V
+  struct A<int>::B;
 
-template<typename T>
-template<typename U, bool V> requires V
-int A<T>::x = 0;
+  template<>
+  template<>
+  struct A<int>::B<int, true>;
 
-template<typename T>
-template<typename U, bool V> requires V
-int A<T>::x<U*, V> = 0;
+  template<>
+  template<>
+  struct A<void>::B<int, true>;
 
-template<typename T>
-template<typename U, bool V> requires V
-int A<T>::x<U&, V> = 0;
+  template<>
+  template<typename U, bool V> requires V
+  struct A<int>::B<U*, V>;
 
-template<typename T>
-template<D U>
-int A<T>::y = 0;
+  template<>
+  template<typename U, bool V> requires V
+  struct A<int>::B<U&, V>;
 
-template<typename T>
-template<D U>
-int A<T>::y<U*> = 0;
+  template<>
+  template<D U>
+  struct A<int>::C;
 
-template<>
-template<typename U, bool V>
-void A<short>::f() requires V;
+  template<>
+  template<D U>
+  struct A<int>::C<U*>;
 
-template<>
-template<>
-void A<short>::f<int, true>();
+  template<>
+  template<D U>
+  struct A<int>::C<U&>;
 
-template<>
-template<>
-void A<void>::f<int, true>();
+  template<>
+  template<typename U, bool V> requires V
+  int A<long>::x;
 
-template<>
-template<D U>
-void A<short>::g();
+  template<>
+  template<>
+  int A<long>::x<int, true>;
 
-template<>
-template<typename U, bool V> requires V
-struct A<int>::B;
+  template<>
+  template<>
+  int A<void>::x<int, true>;
 
-template<>
-template<>
-struct A<int>::B<int, true>;
+  template<>
+  template<typename U, bool V> requires V
+  int A<long>::x<U*, V>;
 
-template<>
-template<>
-struct A<void>::B<int, true>;
+  template<>
+  template<typename U, bool V> requires V
+  int A<long>::x<U&, V>;
 
-template<>
-template<typename U, bool V> requires V
-struct A<int>::B<U*, V>;
+  template<>
+  template<D U>
+  int A<long>::y;
 
-template<>
-template<typename U, bool V> requires V
-struct A<int>::B<U&, V>;
+  template<>
+  template<D U>
+  int A<long>::y<U*>;
 
-template<>
-template<D U>
-struct A<int>::C;
+  template<>
+  template<D U>
+  int A<long>::y<U&>;
+} // namespace Primary
 
-template<>
-template<D U>
-struct A<int>::C<U*>;
+namespace Partial {
+  template<typename T, bool B>
+  struct A;
 
-template<>
-template<D U>
-struct A<int>::C<U&>;
+  template<bool U>
+  struct A<int, U>
+  {
+      template<typename V> requires U
+      void f();
 
-template<>
-template<typename U, bool V> requires V
-int A<long>::x;
+      template<typename V> requires U
+      static const int x;
 
-template<>
-template<>
-int A<long>::x<int, true>;
+      template<typename V> requires U
+      struct B;
+  };
 
-template<>
-template<>
-int A<void>::x<int, true>;
+  template<bool U>
+  template<typename V> requires U
+  void A<int, U>::f() { }
 
-template<>
-template<typename U, bool V> requires V
-int A<long>::x<U*, V>;
+  template<bool U>
+  template<typename V> requires U
+  constexpr int A<int, U>::x = 0;
 
-template<>
-template<typename U, bool V> requires V
-int A<long>::x<U&, V>;
+  template<bool U>
+  template<typename V> requires U
+  struct A<int, U>::B { };
 
-template<>
-template<D U>
-int A<long>::y;
+  template<>
+  template<typename V> requires true
+  void A<int, true>::f() { }
 
-template<>
-template<D U>
-int A<long>::y<U*>;
+  template<>
+  template<typename V> requires true
+  constexpr int A<int, true>::x = 1;
 
-template<>
-template<D U>
-int A<long>::y<U&>;
+  template<>
+  template<typename V> requires true
+  struct A<int, true>::B { };
+} // namespace Partial


        


More information about the cfe-commits mailing list