[clang] [Clang] Distinguish expanding-packs-in-place cases for SubstTemplateTypeParmTypes (PR #114220)

via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 30 18:58:43 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-clang-modules

Author: Younan Zhang (zyn0217)

<details>
<summary>Changes</summary>

In 50e5411e4, we preserved the pack substitution index within SubstTemplateTypeParmType nodes and performed in-place expansions of packs such that type constraints on a lambda that serve as a pattern of a fold expression could be evaluated if the type constraints contain any packs that are expanded by the fold expression.

However, we made an incorrect assumption of the condition under which in-place expansion should occur. For example, a SizeOfPackExpr case relies on SubstTemplateTypeParmType nodes being transformed to SubstTemplateTypeParmPackTypes rather than expanding them immediately in place.

This fixes that by adding a flag to SubstTemplateTypeParmType to discriminate such in-place expansion situations.

Fixes https://github.com/llvm/llvm-project/issues/113518

---
Full diff: https://github.com/llvm/llvm-project/pull/114220.diff


8 Files Affected:

- (modified) clang/include/clang/AST/ASTContext.h (+4-4) 
- (modified) clang/include/clang/AST/Type.h (+13-3) 
- (modified) clang/include/clang/AST/TypeProperties.td (+4-1) 
- (modified) clang/lib/AST/ASTContext.cpp (+4-4) 
- (modified) clang/lib/AST/ASTImporter.cpp (+2-2) 
- (modified) clang/lib/AST/Type.cpp (+2-1) 
- (modified) clang/lib/Sema/SemaTemplateInstantiate.cpp (+23-11) 
- (modified) clang/test/SemaCXX/cxx20-ctad-type-alias.cpp (+19) 


``````````diff
diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index a4d36f2eacd5d1..929e0f73064872 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -1728,10 +1728,10 @@ class ASTContext : public RefCountedBase<ASTContext> {
       QualType Wrapped, QualType Contained,
       const HLSLAttributedResourceType::Attributes &Attrs);
 
-  QualType
-  getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl,
-                               unsigned Index,
-                               std::optional<unsigned> PackIndex) const;
+  QualType getSubstTemplateTypeParmType(QualType Replacement,
+                                        Decl *AssociatedDecl, unsigned Index,
+                                        std::optional<unsigned> PackIndex,
+                                        bool ExpandPacksInPlace = false) const;
   QualType getSubstTemplateTypeParmPackType(Decl *AssociatedDecl,
                                             unsigned Index, bool Final,
                                             const TemplateArgument &ArgPack);
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 40e617bf8f3b8d..fe17ccf690b9d9 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2170,6 +2170,9 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
     LLVM_PREFERRED_TYPE(bool)
     unsigned HasNonCanonicalUnderlyingType : 1;
 
+    LLVM_PREFERRED_TYPE(bool)
+    unsigned ExpandPacksInPlace : 1;
+
     // The index of the template parameter this substitution represents.
     unsigned Index : 15;
 
@@ -6393,7 +6396,8 @@ class SubstTemplateTypeParmType final
   Decl *AssociatedDecl;
 
   SubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl,
-                            unsigned Index, std::optional<unsigned> PackIndex);
+                            unsigned Index, std::optional<unsigned> PackIndex,
+                            bool ExpandPacksInPlace);
 
 public:
   /// Gets the type that was substituted for the template
@@ -6422,21 +6426,27 @@ class SubstTemplateTypeParmType final
     return SubstTemplateTypeParmTypeBits.PackIndex - 1;
   }
 
+  bool expandPacksInPlace() const {
+    return SubstTemplateTypeParmTypeBits.ExpandPacksInPlace;
+  }
+
   bool isSugared() const { return true; }
   QualType desugar() const { return getReplacementType(); }
 
   void Profile(llvm::FoldingSetNodeID &ID) {
     Profile(ID, getReplacementType(), getAssociatedDecl(), getIndex(),
-            getPackIndex());
+            getPackIndex(), expandPacksInPlace());
   }
 
   static void Profile(llvm::FoldingSetNodeID &ID, QualType Replacement,
                       const Decl *AssociatedDecl, unsigned Index,
-                      std::optional<unsigned> PackIndex) {
+                      std::optional<unsigned> PackIndex,
+                      bool ExpandPacksInPlace) {
     Replacement.Profile(ID);
     ID.AddPointer(AssociatedDecl);
     ID.AddInteger(Index);
     ID.AddInteger(PackIndex ? *PackIndex - 1 : 0);
+    ID.AddInteger(ExpandPacksInPlace);
   }
 
   static bool classof(const Type *T) {
diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index d05072607e949c..f572a768b539b3 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -820,11 +820,14 @@ let Class = SubstTemplateTypeParmType in {
   def : Property<"PackIndex", Optional<UInt32>> {
     let Read = [{ node->getPackIndex() }];
   }
+  def : Property<"ExpandPacksInPlace", Bool> {
+    let Read = [{ node->expandPacksInPlace() }];
+  }
 
   // The call to getCanonicalType here existed in ASTReader.cpp, too.
   def : Creator<[{
     return ctx.getSubstTemplateTypeParmType(
-        replacementType, associatedDecl, Index, PackIndex);
+        replacementType, associatedDecl, Index, PackIndex, ExpandPacksInPlace);
   }]>;
 }
 
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 4bf8ddd762e9a5..a7c797b56c3dd1 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -5248,10 +5248,10 @@ QualType ASTContext::getHLSLAttributedResourceType(
 /// Retrieve a substitution-result type.
 QualType ASTContext::getSubstTemplateTypeParmType(
     QualType Replacement, Decl *AssociatedDecl, unsigned Index,
-    std::optional<unsigned> PackIndex) const {
+    std::optional<unsigned> PackIndex, bool ExpandPacksInPlace) const {
   llvm::FoldingSetNodeID ID;
   SubstTemplateTypeParmType::Profile(ID, Replacement, AssociatedDecl, Index,
-                                     PackIndex);
+                                     PackIndex, ExpandPacksInPlace);
   void *InsertPos = nullptr;
   SubstTemplateTypeParmType *SubstParm =
       SubstTemplateTypeParmTypes.FindNodeOrInsertPos(ID, InsertPos);
@@ -5260,8 +5260,8 @@ QualType ASTContext::getSubstTemplateTypeParmType(
     void *Mem = Allocate(SubstTemplateTypeParmType::totalSizeToAlloc<QualType>(
                              !Replacement.isCanonical()),
                          alignof(SubstTemplateTypeParmType));
-    SubstParm = new (Mem) SubstTemplateTypeParmType(Replacement, AssociatedDecl,
-                                                    Index, PackIndex);
+    SubstParm = new (Mem) SubstTemplateTypeParmType(
+        Replacement, AssociatedDecl, Index, PackIndex, ExpandPacksInPlace);
     Types.push_back(SubstParm);
     SubstTemplateTypeParmTypes.InsertNode(SubstParm, InsertPos);
   }
diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index e7a6509167f0a0..8bc584c1b42569 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -1627,8 +1627,8 @@ ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmType(
     return ToReplacementTypeOrErr.takeError();
 
   return Importer.getToContext().getSubstTemplateTypeParmType(
-      *ToReplacementTypeOrErr, *ReplacedOrErr, T->getIndex(),
-      T->getPackIndex());
+      *ToReplacementTypeOrErr, *ReplacedOrErr, T->getIndex(), T->getPackIndex(),
+      T->expandPacksInPlace());
 }
 
 ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmPackType(
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 5232efae4e3630..748a1d3a6e1164 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -4194,7 +4194,7 @@ static const TemplateTypeParmDecl *getReplacedParameter(Decl *D,
 
 SubstTemplateTypeParmType::SubstTemplateTypeParmType(
     QualType Replacement, Decl *AssociatedDecl, unsigned Index,
-    std::optional<unsigned> PackIndex)
+    std::optional<unsigned> PackIndex, bool ExpandPacksInPlace)
     : Type(SubstTemplateTypeParm, Replacement.getCanonicalType(),
            Replacement->getDependence()),
       AssociatedDecl(AssociatedDecl) {
@@ -4205,6 +4205,7 @@ SubstTemplateTypeParmType::SubstTemplateTypeParmType(
 
   SubstTemplateTypeParmTypeBits.Index = Index;
   SubstTemplateTypeParmTypeBits.PackIndex = PackIndex ? *PackIndex + 1 : 0;
+  SubstTemplateTypeParmTypeBits.ExpandPacksInPlace = ExpandPacksInPlace;
   assert(AssociatedDecl != nullptr);
 }
 
diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index 457a9968c32a4a..e74175bae330b8 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -1648,14 +1648,16 @@ namespace {
     QualType
     TransformSubstTemplateTypeParmType(TypeLocBuilder &TLB,
                                        SubstTemplateTypeParmTypeLoc TL) {
-      if (SemaRef.CodeSynthesisContexts.back().Kind !=
-          Sema::CodeSynthesisContext::ConstraintSubstitution)
+      const SubstTemplateTypeParmType *Type = TL.getTypePtr();
+      if (!Type->expandPacksInPlace())
         return inherited::TransformSubstTemplateTypeParmType(TLB, TL);
 
-      auto PackIndex = TL.getTypePtr()->getPackIndex();
-      std::optional<Sema::ArgumentPackSubstitutionIndexRAII> SubstIndex;
-      if (SemaRef.ArgumentPackSubstitutionIndex == -1 && PackIndex)
-        SubstIndex.emplace(SemaRef, *PackIndex);
+      assert(Type->getPackIndex());
+      TemplateArgument TA = TemplateArgs(
+          Type->getReplacedParameter()->getDepth(), Type->getIndex());
+      assert(*Type->getPackIndex() + 1 <= TA.pack_size());
+      Sema::ArgumentPackSubstitutionIndexRAII SubstIndex(
+          SemaRef, TA.pack_size() - 1 - *Type->getPackIndex());
 
       return inherited::TransformSubstTemplateTypeParmType(TLB, TL);
     }
@@ -3133,7 +3135,11 @@ struct ExpandPackedTypeConstraints
 
   using inherited = TreeTransform<ExpandPackedTypeConstraints>;
 
-  ExpandPackedTypeConstraints(Sema &SemaRef) : inherited(SemaRef) {}
+  const MultiLevelTemplateArgumentList &TemplateArgs;
+
+  ExpandPackedTypeConstraints(
+      Sema &SemaRef, const MultiLevelTemplateArgumentList &TemplateArgs)
+      : inherited(SemaRef), TemplateArgs(TemplateArgs) {}
 
   using inherited::TransformTemplateTypeParmType;
 
@@ -3149,9 +3155,15 @@ struct ExpandPackedTypeConstraints
 
     assert(SemaRef.ArgumentPackSubstitutionIndex != -1);
 
+    TemplateArgument Arg = TemplateArgs(T->getDepth(), T->getIndex());
+
+    std::optional<unsigned> PackIndex;
+    if (Arg.getKind() == TemplateArgument::Pack)
+      PackIndex = Arg.pack_size() - 1 - SemaRef.ArgumentPackSubstitutionIndex;
+
     QualType Result = SemaRef.Context.getSubstTemplateTypeParmType(
-        TL.getType(), T->getDecl(), T->getIndex(),
-        SemaRef.ArgumentPackSubstitutionIndex);
+        TL.getType(), T->getDecl(), T->getIndex(), PackIndex,
+        /*ExpandPacksInPlace=*/true);
     SubstTemplateTypeParmTypeLoc NewTL =
         TLB.push<SubstTemplateTypeParmTypeLoc>(Result);
     NewTL.setNameLoc(TL.getNameLoc());
@@ -3210,8 +3222,8 @@ bool Sema::SubstTypeConstraint(
       TemplateArgumentListInfo InstArgs;
       InstArgs.setLAngleLoc(TemplArgInfo->LAngleLoc);
       InstArgs.setRAngleLoc(TemplArgInfo->RAngleLoc);
-      if (ExpandPackedTypeConstraints(*this).SubstTemplateArguments(
-              TemplArgInfo->arguments(), InstArgs))
+      if (ExpandPackedTypeConstraints(*this, TemplateArgs)
+              .SubstTemplateArguments(TemplArgInfo->arguments(), InstArgs))
         return true;
 
       // The type of the original parameter.
diff --git a/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp b/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp
index 675c32a81f1ae8..2d43e46b9e3d76 100644
--- a/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp
+++ b/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp
@@ -494,3 +494,22 @@ template <typename V> using Alias = S<V>;
 Alias A(42);
 
 } // namespace GH111508
+
+namespace GH113518 {
+
+template <class T, unsigned N> struct array {
+  T value[N];
+};
+
+template <typename Tp, typename... Up>
+array(Tp, Up...) -> array<Tp, 1 + sizeof...(Up)>;
+
+template <typename T> struct ArrayType {
+  template <unsigned size> using Array = array<T, size>;
+};
+
+template <ArrayType<int>::Array array> void test() {}
+
+void foo() { test<{1, 2, 3}>(); }
+
+} // namespace GH113518

``````````

</details>


https://github.com/llvm/llvm-project/pull/114220


More information about the cfe-commits mailing list