[clang] [Clang] Distinguish expanding-packs-in-place cases for SubstTemplateTypeParmTypes (PR #114220)

Younan Zhang via cfe-commits cfe-commits at lists.llvm.org
Wed Nov 6 19:20:24 PST 2024


https://github.com/zyn0217 updated https://github.com/llvm/llvm-project/pull/114220

>From 782caa155a746e7170f1794972b07d28fcf80692 Mon Sep 17 00:00:00 2001
From: Younan Zhang <zyn7109 at gmail.com>
Date: Wed, 30 Oct 2024 20:35:33 +0800
Subject: [PATCH 1/2] [Clang] Distinguish expanding-packs-in-place cases for
 SubstTemplateTypeParmTypes

---
 clang/include/clang/AST/ASTContext.h         |  8 ++---
 clang/include/clang/AST/Type.h               | 16 +++++++--
 clang/include/clang/AST/TypeProperties.td    |  5 ++-
 clang/lib/AST/ASTContext.cpp                 |  8 ++---
 clang/lib/AST/ASTImporter.cpp                |  4 +--
 clang/lib/AST/Type.cpp                       |  3 +-
 clang/lib/Sema/SemaTemplateInstantiate.cpp   | 34 +++++++++++++-------
 clang/test/SemaCXX/cxx20-ctad-type-alias.cpp | 19 +++++++++++
 8 files changed, 71 insertions(+), 26 deletions(-)

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index a4d36f2eacd5d1..929e0f73064872 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -1728,10 +1728,10 @@ class ASTContext : public RefCountedBase<ASTContext> {
       QualType Wrapped, QualType Contained,
       const HLSLAttributedResourceType::Attributes &Attrs);
 
-  QualType
-  getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl,
-                               unsigned Index,
-                               std::optional<unsigned> PackIndex) const;
+  QualType getSubstTemplateTypeParmType(QualType Replacement,
+                                        Decl *AssociatedDecl, unsigned Index,
+                                        std::optional<unsigned> PackIndex,
+                                        bool ExpandPacksInPlace = false) const;
   QualType getSubstTemplateTypeParmPackType(Decl *AssociatedDecl,
                                             unsigned Index, bool Final,
                                             const TemplateArgument &ArgPack);
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 40e617bf8f3b8d..fe17ccf690b9d9 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2170,6 +2170,9 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
     LLVM_PREFERRED_TYPE(bool)
     unsigned HasNonCanonicalUnderlyingType : 1;
 
+    LLVM_PREFERRED_TYPE(bool)
+    unsigned ExpandPacksInPlace : 1;
+
     // The index of the template parameter this substitution represents.
     unsigned Index : 15;
 
@@ -6393,7 +6396,8 @@ class SubstTemplateTypeParmType final
   Decl *AssociatedDecl;
 
   SubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl,
-                            unsigned Index, std::optional<unsigned> PackIndex);
+                            unsigned Index, std::optional<unsigned> PackIndex,
+                            bool ExpandPacksInPlace);
 
 public:
   /// Gets the type that was substituted for the template
@@ -6422,21 +6426,27 @@ class SubstTemplateTypeParmType final
     return SubstTemplateTypeParmTypeBits.PackIndex - 1;
   }
 
+  bool expandPacksInPlace() const {
+    return SubstTemplateTypeParmTypeBits.ExpandPacksInPlace;
+  }
+
   bool isSugared() const { return true; }
   QualType desugar() const { return getReplacementType(); }
 
   void Profile(llvm::FoldingSetNodeID &ID) {
     Profile(ID, getReplacementType(), getAssociatedDecl(), getIndex(),
-            getPackIndex());
+            getPackIndex(), expandPacksInPlace());
   }
 
   static void Profile(llvm::FoldingSetNodeID &ID, QualType Replacement,
                       const Decl *AssociatedDecl, unsigned Index,
-                      std::optional<unsigned> PackIndex) {
+                      std::optional<unsigned> PackIndex,
+                      bool ExpandPacksInPlace) {
     Replacement.Profile(ID);
     ID.AddPointer(AssociatedDecl);
     ID.AddInteger(Index);
     ID.AddInteger(PackIndex ? *PackIndex - 1 : 0);
+    ID.AddInteger(ExpandPacksInPlace);
   }
 
   static bool classof(const Type *T) {
diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index d05072607e949c..f572a768b539b3 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -820,11 +820,14 @@ let Class = SubstTemplateTypeParmType in {
   def : Property<"PackIndex", Optional<UInt32>> {
     let Read = [{ node->getPackIndex() }];
   }
+  def : Property<"ExpandPacksInPlace", Bool> {
+    let Read = [{ node->expandPacksInPlace() }];
+  }
 
   // The call to getCanonicalType here existed in ASTReader.cpp, too.
   def : Creator<[{
     return ctx.getSubstTemplateTypeParmType(
-        replacementType, associatedDecl, Index, PackIndex);
+        replacementType, associatedDecl, Index, PackIndex, ExpandPacksInPlace);
   }]>;
 }
 
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 4bf8ddd762e9a5..a7c797b56c3dd1 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -5248,10 +5248,10 @@ QualType ASTContext::getHLSLAttributedResourceType(
 /// Retrieve a substitution-result type.
 QualType ASTContext::getSubstTemplateTypeParmType(
     QualType Replacement, Decl *AssociatedDecl, unsigned Index,
-    std::optional<unsigned> PackIndex) const {
+    std::optional<unsigned> PackIndex, bool ExpandPacksInPlace) const {
   llvm::FoldingSetNodeID ID;
   SubstTemplateTypeParmType::Profile(ID, Replacement, AssociatedDecl, Index,
-                                     PackIndex);
+                                     PackIndex, ExpandPacksInPlace);
   void *InsertPos = nullptr;
   SubstTemplateTypeParmType *SubstParm =
       SubstTemplateTypeParmTypes.FindNodeOrInsertPos(ID, InsertPos);
@@ -5260,8 +5260,8 @@ QualType ASTContext::getSubstTemplateTypeParmType(
     void *Mem = Allocate(SubstTemplateTypeParmType::totalSizeToAlloc<QualType>(
                              !Replacement.isCanonical()),
                          alignof(SubstTemplateTypeParmType));
-    SubstParm = new (Mem) SubstTemplateTypeParmType(Replacement, AssociatedDecl,
-                                                    Index, PackIndex);
+    SubstParm = new (Mem) SubstTemplateTypeParmType(
+        Replacement, AssociatedDecl, Index, PackIndex, ExpandPacksInPlace);
     Types.push_back(SubstParm);
     SubstTemplateTypeParmTypes.InsertNode(SubstParm, InsertPos);
   }
diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index e7a6509167f0a0..8bc584c1b42569 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -1627,8 +1627,8 @@ ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmType(
     return ToReplacementTypeOrErr.takeError();
 
   return Importer.getToContext().getSubstTemplateTypeParmType(
-      *ToReplacementTypeOrErr, *ReplacedOrErr, T->getIndex(),
-      T->getPackIndex());
+      *ToReplacementTypeOrErr, *ReplacedOrErr, T->getIndex(), T->getPackIndex(),
+      T->expandPacksInPlace());
 }
 
 ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmPackType(
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 5232efae4e3630..748a1d3a6e1164 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -4194,7 +4194,7 @@ static const TemplateTypeParmDecl *getReplacedParameter(Decl *D,
 
 SubstTemplateTypeParmType::SubstTemplateTypeParmType(
     QualType Replacement, Decl *AssociatedDecl, unsigned Index,
-    std::optional<unsigned> PackIndex)
+    std::optional<unsigned> PackIndex, bool ExpandPacksInPlace)
     : Type(SubstTemplateTypeParm, Replacement.getCanonicalType(),
            Replacement->getDependence()),
       AssociatedDecl(AssociatedDecl) {
@@ -4205,6 +4205,7 @@ SubstTemplateTypeParmType::SubstTemplateTypeParmType(
 
   SubstTemplateTypeParmTypeBits.Index = Index;
   SubstTemplateTypeParmTypeBits.PackIndex = PackIndex ? *PackIndex + 1 : 0;
+  SubstTemplateTypeParmTypeBits.ExpandPacksInPlace = ExpandPacksInPlace;
   assert(AssociatedDecl != nullptr);
 }
 
diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index 457a9968c32a4a..e74175bae330b8 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -1648,14 +1648,16 @@ namespace {
     QualType
     TransformSubstTemplateTypeParmType(TypeLocBuilder &TLB,
                                        SubstTemplateTypeParmTypeLoc TL) {
-      if (SemaRef.CodeSynthesisContexts.back().Kind !=
-          Sema::CodeSynthesisContext::ConstraintSubstitution)
+      const SubstTemplateTypeParmType *Type = TL.getTypePtr();
+      if (!Type->expandPacksInPlace())
         return inherited::TransformSubstTemplateTypeParmType(TLB, TL);
 
-      auto PackIndex = TL.getTypePtr()->getPackIndex();
-      std::optional<Sema::ArgumentPackSubstitutionIndexRAII> SubstIndex;
-      if (SemaRef.ArgumentPackSubstitutionIndex == -1 && PackIndex)
-        SubstIndex.emplace(SemaRef, *PackIndex);
+      assert(Type->getPackIndex());
+      TemplateArgument TA = TemplateArgs(
+          Type->getReplacedParameter()->getDepth(), Type->getIndex());
+      assert(*Type->getPackIndex() + 1 <= TA.pack_size());
+      Sema::ArgumentPackSubstitutionIndexRAII SubstIndex(
+          SemaRef, TA.pack_size() - 1 - *Type->getPackIndex());
 
       return inherited::TransformSubstTemplateTypeParmType(TLB, TL);
     }
@@ -3133,7 +3135,11 @@ struct ExpandPackedTypeConstraints
 
   using inherited = TreeTransform<ExpandPackedTypeConstraints>;
 
-  ExpandPackedTypeConstraints(Sema &SemaRef) : inherited(SemaRef) {}
+  const MultiLevelTemplateArgumentList &TemplateArgs;
+
+  ExpandPackedTypeConstraints(
+      Sema &SemaRef, const MultiLevelTemplateArgumentList &TemplateArgs)
+      : inherited(SemaRef), TemplateArgs(TemplateArgs) {}
 
   using inherited::TransformTemplateTypeParmType;
 
@@ -3149,9 +3155,15 @@ struct ExpandPackedTypeConstraints
 
     assert(SemaRef.ArgumentPackSubstitutionIndex != -1);
 
+    TemplateArgument Arg = TemplateArgs(T->getDepth(), T->getIndex());
+
+    std::optional<unsigned> PackIndex;
+    if (Arg.getKind() == TemplateArgument::Pack)
+      PackIndex = Arg.pack_size() - 1 - SemaRef.ArgumentPackSubstitutionIndex;
+
     QualType Result = SemaRef.Context.getSubstTemplateTypeParmType(
-        TL.getType(), T->getDecl(), T->getIndex(),
-        SemaRef.ArgumentPackSubstitutionIndex);
+        TL.getType(), T->getDecl(), T->getIndex(), PackIndex,
+        /*ExpandPacksInPlace=*/true);
     SubstTemplateTypeParmTypeLoc NewTL =
         TLB.push<SubstTemplateTypeParmTypeLoc>(Result);
     NewTL.setNameLoc(TL.getNameLoc());
@@ -3210,8 +3222,8 @@ bool Sema::SubstTypeConstraint(
       TemplateArgumentListInfo InstArgs;
       InstArgs.setLAngleLoc(TemplArgInfo->LAngleLoc);
       InstArgs.setRAngleLoc(TemplArgInfo->RAngleLoc);
-      if (ExpandPackedTypeConstraints(*this).SubstTemplateArguments(
-              TemplArgInfo->arguments(), InstArgs))
+      if (ExpandPackedTypeConstraints(*this, TemplateArgs)
+              .SubstTemplateArguments(TemplArgInfo->arguments(), InstArgs))
         return true;
 
       // The type of the original parameter.
diff --git a/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp b/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp
index 675c32a81f1ae8..2d43e46b9e3d76 100644
--- a/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp
+++ b/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp
@@ -494,3 +494,22 @@ template <typename V> using Alias = S<V>;
 Alias A(42);
 
 } // namespace GH111508
+
+namespace GH113518 {
+
+template <class T, unsigned N> struct array {
+  T value[N];
+};
+
+template <typename Tp, typename... Up>
+array(Tp, Up...) -> array<Tp, 1 + sizeof...(Up)>;
+
+template <typename T> struct ArrayType {
+  template <unsigned size> using Array = array<T, size>;
+};
+
+template <ArrayType<int>::Array array> void test() {}
+
+void foo() { test<{1, 2, 3}>(); }
+
+} // namespace GH113518

>From 85cfdcf83fa1882f28641cd3657e555140d628aa Mon Sep 17 00:00:00 2001
From: Younan Zhang <zyn7109 at gmail.com>
Date: Thu, 7 Nov 2024 11:19:06 +0800
Subject: [PATCH 2/2] Apply review suggestions

---
 clang/include/clang/AST/ASTContext.h       |  9 +++----
 clang/include/clang/AST/PropertiesBase.td  |  1 +
 clang/include/clang/AST/Type.h             | 28 +++++++++++++++-------
 clang/include/clang/AST/TypeProperties.td  |  6 ++---
 clang/lib/AST/ASTContext.cpp               |  9 +++----
 clang/lib/AST/ASTImporter.cpp              |  2 +-
 clang/lib/AST/Type.cpp                     |  6 +++--
 clang/lib/Sema/SemaTemplateInstantiate.cpp |  5 ++--
 8 files changed, 42 insertions(+), 24 deletions(-)

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index 929e0f73064872..f4b7cb2f3b77a8 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -1728,10 +1728,11 @@ class ASTContext : public RefCountedBase<ASTContext> {
       QualType Wrapped, QualType Contained,
       const HLSLAttributedResourceType::Attributes &Attrs);
 
-  QualType getSubstTemplateTypeParmType(QualType Replacement,
-                                        Decl *AssociatedDecl, unsigned Index,
-                                        std::optional<unsigned> PackIndex,
-                                        bool ExpandPacksInPlace = false) const;
+  QualType getSubstTemplateTypeParmType(
+      QualType Replacement, Decl *AssociatedDecl, unsigned Index,
+      std::optional<unsigned> PackIndex,
+      SubstTemplateTypeParmTypeFlag Flag =
+          SubstTemplateTypeParmTypeFlag::ExpandPacksInPlace) const;
   QualType getSubstTemplateTypeParmPackType(Decl *AssociatedDecl,
                                             unsigned Index, bool Final,
                                             const TemplateArgument &ArgPack);
diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td
index 3057669e3758b5..46e5c75086875c 100644
--- a/clang/include/clang/AST/PropertiesBase.td
+++ b/clang/include/clang/AST/PropertiesBase.td
@@ -136,6 +136,7 @@ def Selector : PropertyType;
 def SourceLocation : PropertyType;
 def StmtRef : RefPropertyType<"Stmt"> { let ConstWhenWriting = 1; }
   def ExprRef : SubclassPropertyType<"Expr", StmtRef>;
+def SubstTemplateTypeParmTypeFlag : EnumPropertyType;
 def TemplateArgument : PropertyType;
 def TemplateArgumentKind : EnumPropertyType<"TemplateArgument::ArgKind">;
 def TemplateName : DefaultValuePropertyType;
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index fe17ccf690b9d9..14241072b2ecdf 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -1801,6 +1801,15 @@ enum class AutoTypeKeyword {
   GNUAutoType
 };
 
+enum class SubstTemplateTypeParmTypeFlag {
+  None,
+
+  /// Whether to expand the pack using the stored PackIndex in place. This is
+  /// useful for e.g. substituting into an atomic constraint expression, where
+  /// that expression is part of an unexpanded pack.
+  ExpandPacksInPlace,
+};
+
 enum class ArraySizeModifier;
 enum class ElaboratedTypeKeyword;
 enum class VectorKind;
@@ -2170,8 +2179,8 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
     LLVM_PREFERRED_TYPE(bool)
     unsigned HasNonCanonicalUnderlyingType : 1;
 
-    LLVM_PREFERRED_TYPE(bool)
-    unsigned ExpandPacksInPlace : 1;
+    LLVM_PREFERRED_TYPE(SubstTemplateTypeParmTypeFlag)
+    unsigned SubstitutionFlag : 1;
 
     // The index of the template parameter this substitution represents.
     unsigned Index : 15;
@@ -6397,7 +6406,7 @@ class SubstTemplateTypeParmType final
 
   SubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl,
                             unsigned Index, std::optional<unsigned> PackIndex,
-                            bool ExpandPacksInPlace);
+                            SubstTemplateTypeParmTypeFlag Flag);
 
 public:
   /// Gets the type that was substituted for the template
@@ -6426,8 +6435,9 @@ class SubstTemplateTypeParmType final
     return SubstTemplateTypeParmTypeBits.PackIndex - 1;
   }
 
-  bool expandPacksInPlace() const {
-    return SubstTemplateTypeParmTypeBits.ExpandPacksInPlace;
+  SubstTemplateTypeParmTypeFlag getSubstitutionFlag() const {
+    return static_cast<SubstTemplateTypeParmTypeFlag>(
+        SubstTemplateTypeParmTypeBits.SubstitutionFlag);
   }
 
   bool isSugared() const { return true; }
@@ -6435,18 +6445,20 @@ class SubstTemplateTypeParmType final
 
   void Profile(llvm::FoldingSetNodeID &ID) {
     Profile(ID, getReplacementType(), getAssociatedDecl(), getIndex(),
-            getPackIndex(), expandPacksInPlace());
+            getPackIndex(), getSubstitutionFlag());
   }
 
   static void Profile(llvm::FoldingSetNodeID &ID, QualType Replacement,
                       const Decl *AssociatedDecl, unsigned Index,
                       std::optional<unsigned> PackIndex,
-                      bool ExpandPacksInPlace) {
+                      SubstTemplateTypeParmTypeFlag Flag) {
     Replacement.Profile(ID);
     ID.AddPointer(AssociatedDecl);
     ID.AddInteger(Index);
     ID.AddInteger(PackIndex ? *PackIndex - 1 : 0);
-    ID.AddInteger(ExpandPacksInPlace);
+    ID.AddInteger(llvm::to_underlying(Flag));
+    assert(Flag != SubstTemplateTypeParmTypeFlag::ExpandPacksInPlace ||
+           PackIndex && "ExpandPacksInPlace needs a valid PackIndex");
   }
 
   static bool classof(const Type *T) {
diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index f572a768b539b3..80c3c5352606c8 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -820,14 +820,14 @@ let Class = SubstTemplateTypeParmType in {
   def : Property<"PackIndex", Optional<UInt32>> {
     let Read = [{ node->getPackIndex() }];
   }
-  def : Property<"ExpandPacksInPlace", Bool> {
-    let Read = [{ node->expandPacksInPlace() }];
+  def : Property<"SubstitutionFlag", SubstTemplateTypeParmTypeFlag> {
+    let Read = [{ node->getSubstitutionFlag() }];
   }
 
   // The call to getCanonicalType here existed in ASTReader.cpp, too.
   def : Creator<[{
     return ctx.getSubstTemplateTypeParmType(
-        replacementType, associatedDecl, Index, PackIndex, ExpandPacksInPlace);
+        replacementType, associatedDecl, Index, PackIndex, SubstitutionFlag);
   }]>;
 }
 
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index a7c797b56c3dd1..cc8115604e8035 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -5248,10 +5248,11 @@ QualType ASTContext::getHLSLAttributedResourceType(
 /// Retrieve a substitution-result type.
 QualType ASTContext::getSubstTemplateTypeParmType(
     QualType Replacement, Decl *AssociatedDecl, unsigned Index,
-    std::optional<unsigned> PackIndex, bool ExpandPacksInPlace) const {
+    std::optional<unsigned> PackIndex,
+    SubstTemplateTypeParmTypeFlag Flag) const {
   llvm::FoldingSetNodeID ID;
   SubstTemplateTypeParmType::Profile(ID, Replacement, AssociatedDecl, Index,
-                                     PackIndex, ExpandPacksInPlace);
+                                     PackIndex, Flag);
   void *InsertPos = nullptr;
   SubstTemplateTypeParmType *SubstParm =
       SubstTemplateTypeParmTypes.FindNodeOrInsertPos(ID, InsertPos);
@@ -5260,8 +5261,8 @@ QualType ASTContext::getSubstTemplateTypeParmType(
     void *Mem = Allocate(SubstTemplateTypeParmType::totalSizeToAlloc<QualType>(
                              !Replacement.isCanonical()),
                          alignof(SubstTemplateTypeParmType));
-    SubstParm = new (Mem) SubstTemplateTypeParmType(
-        Replacement, AssociatedDecl, Index, PackIndex, ExpandPacksInPlace);
+    SubstParm = new (Mem) SubstTemplateTypeParmType(Replacement, AssociatedDecl,
+                                                    Index, PackIndex, Flag);
     Types.push_back(SubstParm);
     SubstTemplateTypeParmTypes.InsertNode(SubstParm, InsertPos);
   }
diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index 8bc584c1b42569..730f077137aa4a 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -1628,7 +1628,7 @@ ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmType(
 
   return Importer.getToContext().getSubstTemplateTypeParmType(
       *ToReplacementTypeOrErr, *ReplacedOrErr, T->getIndex(), T->getPackIndex(),
-      T->expandPacksInPlace());
+      T->getSubstitutionFlag());
 }
 
 ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmPackType(
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 748a1d3a6e1164..5b57ab439263c6 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -4194,7 +4194,7 @@ static const TemplateTypeParmDecl *getReplacedParameter(Decl *D,
 
 SubstTemplateTypeParmType::SubstTemplateTypeParmType(
     QualType Replacement, Decl *AssociatedDecl, unsigned Index,
-    std::optional<unsigned> PackIndex, bool ExpandPacksInPlace)
+    std::optional<unsigned> PackIndex, SubstTemplateTypeParmTypeFlag Flag)
     : Type(SubstTemplateTypeParm, Replacement.getCanonicalType(),
            Replacement->getDependence()),
       AssociatedDecl(AssociatedDecl) {
@@ -4205,7 +4205,9 @@ SubstTemplateTypeParmType::SubstTemplateTypeParmType(
 
   SubstTemplateTypeParmTypeBits.Index = Index;
   SubstTemplateTypeParmTypeBits.PackIndex = PackIndex ? *PackIndex + 1 : 0;
-  SubstTemplateTypeParmTypeBits.ExpandPacksInPlace = ExpandPacksInPlace;
+  SubstTemplateTypeParmTypeBits.SubstitutionFlag = llvm::to_underlying(Flag);
+  assert(Flag != SubstTemplateTypeParmTypeFlag::ExpandPacksInPlace ||
+         PackIndex && "ExpandPacksInPlace needs a valid PackIndex");
   assert(AssociatedDecl != nullptr);
 }
 
diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index e74175bae330b8..7e565b20135ff9 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -1649,7 +1649,8 @@ namespace {
     TransformSubstTemplateTypeParmType(TypeLocBuilder &TLB,
                                        SubstTemplateTypeParmTypeLoc TL) {
       const SubstTemplateTypeParmType *Type = TL.getTypePtr();
-      if (!Type->expandPacksInPlace())
+      if (Type->getSubstitutionFlag() !=
+          SubstTemplateTypeParmTypeFlag::ExpandPacksInPlace)
         return inherited::TransformSubstTemplateTypeParmType(TLB, TL);
 
       assert(Type->getPackIndex());
@@ -3163,7 +3164,7 @@ struct ExpandPackedTypeConstraints
 
     QualType Result = SemaRef.Context.getSubstTemplateTypeParmType(
         TL.getType(), T->getDecl(), T->getIndex(), PackIndex,
-        /*ExpandPacksInPlace=*/true);
+        SubstTemplateTypeParmTypeFlag::ExpandPacksInPlace);
     SubstTemplateTypeParmTypeLoc NewTL =
         TLB.push<SubstTemplateTypeParmTypeLoc>(Result);
     NewTL.setNameLoc(TL.getNameLoc());



More information about the cfe-commits mailing list