[clang] [Clang] Preserve partially substituted pack indexing type/expressions (PR #116782)

Younan Zhang via cfe-commits cfe-commits at lists.llvm.org
Tue Nov 19 02:55:21 PST 2024


https://github.com/zyn0217 created https://github.com/llvm/llvm-project/pull/116782

Substituting into pack indexing types/expressions can still result in unexpanded types/expressions, such as PackIndexingType or PackIndexingExpr. To handle these cases correctly, we should defer the pack size checks to the next round of transformation, when the patterns can be fully expanded.

To that end, the `FullySubstituted` flag is now necessary for computing the dependencies of `PackIndexingExprs`. Conveniently, this flag can also represent the prior `ExpandsToEmpty` status with an additional emptiness check. Therefore, I converted all stored flags to use FullySubstituted.

Fixes https://github.com/llvm/llvm-project/issues/116105

>From 5c0d947576b692c32febf89703033715b0c51cda Mon Sep 17 00:00:00 2001
From: Younan Zhang <zyn7109 at gmail.com>
Date: Tue, 19 Nov 2024 18:33:34 +0800
Subject: [PATCH] [Clang] Preserve partially substituted pack indexing
 type/expressions

---
 clang/include/clang/AST/ExprCXX.h          | 14 +++++----
 clang/include/clang/AST/Type.h             | 12 ++++----
 clang/include/clang/AST/TypeProperties.td  |  6 ++--
 clang/include/clang/Sema/Sema.h            |  2 +-
 clang/lib/AST/ASTContext.cpp               |  8 ++---
 clang/lib/AST/ComputeDependence.cpp        |  3 +-
 clang/lib/AST/ExprCXX.cpp                  |  6 ++--
 clang/lib/AST/Type.cpp                     |  8 ++---
 clang/lib/Sema/SemaTemplateVariadic.cpp    | 16 +++++-----
 clang/lib/Sema/TreeTransform.h             | 13 +++++----
 clang/lib/Serialization/ASTReaderStmt.cpp  |  2 +-
 clang/lib/Serialization/ASTWriterStmt.cpp  |  2 +-
 clang/test/SemaCXX/cxx2c-pack-indexing.cpp | 34 ++++++++++++++++++++++
 13 files changed, 84 insertions(+), 42 deletions(-)

diff --git a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h
index 696a574833dad2..1a24b8857674ca 100644
--- a/clang/include/clang/AST/ExprCXX.h
+++ b/clang/include/clang/AST/ExprCXX.h
@@ -4390,17 +4390,17 @@ class PackIndexingExpr final
   unsigned TransformedExpressions : 31;
 
   LLVM_PREFERRED_TYPE(bool)
-  unsigned ExpandedToEmptyPack : 1;
+  unsigned FullySubstituted : 1;
 
   PackIndexingExpr(QualType Type, SourceLocation EllipsisLoc,
                    SourceLocation RSquareLoc, Expr *PackIdExpr, Expr *IndexExpr,
                    ArrayRef<Expr *> SubstitutedExprs = {},
-                   bool ExpandedToEmptyPack = false)
+                   bool FullySubstituted = false)
       : Expr(PackIndexingExprClass, Type, VK_LValue, OK_Ordinary),
         EllipsisLoc(EllipsisLoc), RSquareLoc(RSquareLoc),
         SubExprs{PackIdExpr, IndexExpr},
         TransformedExpressions(SubstitutedExprs.size()),
-        ExpandedToEmptyPack(ExpandedToEmptyPack) {
+        FullySubstituted(FullySubstituted) {
 
     auto *Exprs = getTrailingObjects<Expr *>();
     std::uninitialized_copy(SubstitutedExprs.begin(), SubstitutedExprs.end(),
@@ -4424,12 +4424,16 @@ class PackIndexingExpr final
                                   SourceLocation RSquareLoc, Expr *PackIdExpr,
                                   Expr *IndexExpr, std::optional<int64_t> Index,
                                   ArrayRef<Expr *> SubstitutedExprs = {},
-                                  bool ExpandedToEmptyPack = false);
+                                  bool FullySubstituted = false);
   static PackIndexingExpr *CreateDeserialized(ASTContext &Context,
                                               unsigned NumTransformedExprs);
 
+  bool isFullySubstituted() const { return FullySubstituted; }
+
   /// Determine if the expression was expanded to empty.
-  bool expandsToEmptyPack() const { return ExpandedToEmptyPack; }
+  bool expandsToEmptyPack() const {
+    return isFullySubstituted() && TransformedExpressions == 0;
+  }
 
   /// Determine the location of the 'sizeof' keyword.
   SourceLocation getEllipsisLoc() const { return EllipsisLoc; }
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 1ed5c22361ca68..90a52b1dcbf624 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -5922,12 +5922,12 @@ class PackIndexingType final
   unsigned Size : 31;
 
   LLVM_PREFERRED_TYPE(bool)
-  unsigned ExpandsToEmptyPack : 1;
+  unsigned FullySubstituted : 1;
 
 protected:
   friend class ASTContext; // ASTContext creates these.
   PackIndexingType(const ASTContext &Context, QualType Canonical,
-                   QualType Pattern, Expr *IndexExpr, bool ExpandsToEmptyPack,
+                   QualType Pattern, Expr *IndexExpr, bool FullySubstituted,
                    ArrayRef<QualType> Expansions = {});
 
 public:
@@ -5951,7 +5951,9 @@ class PackIndexingType final
 
   bool hasSelectedType() const { return getSelectedIndex() != std::nullopt; }
 
-  bool expandsToEmptyPack() const { return ExpandsToEmptyPack; }
+  bool isFullySubstituted() const { return FullySubstituted; }
+
+  bool expandsToEmptyPack() const { return isFullySubstituted() && Size == 0; }
 
   ArrayRef<QualType> getExpansions() const {
     return {getExpansionsPtr(), Size};
@@ -5965,10 +5967,10 @@ class PackIndexingType final
     if (hasSelectedType())
       getSelectedType().Profile(ID);
     else
-      Profile(ID, Context, getPattern(), getIndexExpr(), expandsToEmptyPack());
+      Profile(ID, Context, getPattern(), getIndexExpr(), isFullySubstituted());
   }
   static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
-                      QualType Pattern, Expr *E, bool ExpandsToEmptyPack);
+                      QualType Pattern, Expr *E, bool FullySubstituted);
 
 private:
   const QualType *getExpansionsPtr() const {
diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index a8b9c920b617c0..6f1a76bd18fb50 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -473,12 +473,12 @@ let Class = PackIndexingType in {
   def : Property<"indexExpression", ExprRef> {
     let Read = [{ node->getIndexExpr() }];
   }
-  def : Property<"expandsToEmptyPack", Bool> {
-    let Read = [{ node->expandsToEmptyPack() }];
+  def : Property<"isFullySubstituted", Bool> {
+    let Read = [{ node->isFullySubstituted() }];
   }
 
   def : Creator<[{
-    return ctx.getPackIndexingType(pattern, indexExpression, expandsToEmptyPack);
+    return ctx.getPackIndexingType(pattern, indexExpression, isFullySubstituted);
   }]>;
 }
 
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index d6f3508a5243f3..91e1a884775d5e 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14251,7 +14251,7 @@ class Sema final : public SemaBase {
                                    SourceLocation EllipsisLoc, Expr *IndexExpr,
                                    SourceLocation RSquareLoc,
                                    ArrayRef<Expr *> ExpandedExprs = {},
-                                   bool EmptyPack = false);
+                                   bool FullySubstituted = false);
 
   /// Handle a C++1z fold-expression: ( expr op ... op expr ).
   ExprResult ActOnCXXFoldExpr(Scope *S, SourceLocation LParenLoc, Expr *LHS,
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 5226ca6f5d0191..d5954d436fd25b 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -6226,13 +6226,11 @@ QualType ASTContext::getPackIndexingType(QualType Pattern, Expr *IndexExpr,
                                          ArrayRef<QualType> Expansions,
                                          int Index) const {
   QualType Canonical;
-  bool ExpandsToEmptyPack = FullySubstituted && Expansions.empty();
   if (FullySubstituted && Index != -1) {
     Canonical = getCanonicalType(Expansions[Index]);
   } else {
     llvm::FoldingSetNodeID ID;
-    PackIndexingType::Profile(ID, *this, Pattern, IndexExpr,
-                              ExpandsToEmptyPack);
+    PackIndexingType::Profile(ID, *this, Pattern, IndexExpr, FullySubstituted);
     void *InsertPos = nullptr;
     PackIndexingType *Canon =
         DependentPackIndexingTypes.FindNodeOrInsertPos(ID, InsertPos);
@@ -6241,7 +6239,7 @@ QualType ASTContext::getPackIndexingType(QualType Pattern, Expr *IndexExpr,
           PackIndexingType::totalSizeToAlloc<QualType>(Expansions.size()),
           TypeAlignment);
       Canon = new (Mem) PackIndexingType(*this, QualType(), Pattern, IndexExpr,
-                                         ExpandsToEmptyPack, Expansions);
+                                         FullySubstituted, Expansions);
       DependentPackIndexingTypes.InsertNode(Canon, InsertPos);
     }
     Canonical = QualType(Canon, 0);
@@ -6251,7 +6249,7 @@ QualType ASTContext::getPackIndexingType(QualType Pattern, Expr *IndexExpr,
       Allocate(PackIndexingType::totalSizeToAlloc<QualType>(Expansions.size()),
                TypeAlignment);
   auto *T = new (Mem) PackIndexingType(*this, Canonical, Pattern, IndexExpr,
-                                       ExpandsToEmptyPack, Expansions);
+                                       FullySubstituted, Expansions);
   Types.push_back(T);
   return QualType(T, 0);
 }
diff --git a/clang/lib/AST/ComputeDependence.cpp b/clang/lib/AST/ComputeDependence.cpp
index e37ebec0851951..07c4419e3cf407 100644
--- a/clang/lib/AST/ComputeDependence.cpp
+++ b/clang/lib/AST/ComputeDependence.cpp
@@ -388,9 +388,8 @@ ExprDependence clang::computeDependence(PackIndexingExpr *E) {
          ExprDependence::Instantiation;
 
   ArrayRef<Expr *> Exprs = E->getExpressions();
-  if (Exprs.empty())
+  if (Exprs.empty() || !E->isFullySubstituted())
     D |= PatternDep | ExprDependence::Instantiation;
-
   else if (!E->getIndexExpr()->isInstantiationDependent()) {
     std::optional<unsigned> Index = E->getSelectedIndex();
     assert(Index && *Index < Exprs.size() && "pack index out of bound");
diff --git a/clang/lib/AST/ExprCXX.cpp b/clang/lib/AST/ExprCXX.cpp
index a2c0c60d43dd14..b7e1d502a3b163 100644
--- a/clang/lib/AST/ExprCXX.cpp
+++ b/clang/lib/AST/ExprCXX.cpp
@@ -1718,9 +1718,9 @@ NonTypeTemplateParmDecl *SubstNonTypeTemplateParmExpr::getParameter() const {
 PackIndexingExpr *PackIndexingExpr::Create(
     ASTContext &Context, SourceLocation EllipsisLoc, SourceLocation RSquareLoc,
     Expr *PackIdExpr, Expr *IndexExpr, std::optional<int64_t> Index,
-    ArrayRef<Expr *> SubstitutedExprs, bool ExpandedToEmptyPack) {
+    ArrayRef<Expr *> SubstitutedExprs, bool FullySubstituted) {
   QualType Type;
-  if (Index && !SubstitutedExprs.empty())
+  if (Index && FullySubstituted && !SubstitutedExprs.empty())
     Type = SubstitutedExprs[*Index]->getType();
   else
     Type = Context.DependentTy;
@@ -1729,7 +1729,7 @@ PackIndexingExpr *PackIndexingExpr::Create(
       Context.Allocate(totalSizeToAlloc<Expr *>(SubstitutedExprs.size()));
   return new (Storage)
       PackIndexingExpr(Type, EllipsisLoc, RSquareLoc, PackIdExpr, IndexExpr,
-                       SubstitutedExprs, ExpandedToEmptyPack);
+                       SubstitutedExprs, FullySubstituted);
 }
 
 NamedDecl *PackIndexingExpr::getPackDecl() const {
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 7ecb986e4dc2b5..f570b322fbef2b 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -4035,12 +4035,12 @@ void DependentDecltypeType::Profile(llvm::FoldingSetNodeID &ID,
 
 PackIndexingType::PackIndexingType(const ASTContext &Context,
                                    QualType Canonical, QualType Pattern,
-                                   Expr *IndexExpr, bool ExpandsToEmptyPack,
+                                   Expr *IndexExpr, bool FullySubstituted,
                                    ArrayRef<QualType> Expansions)
     : Type(PackIndexing, Canonical,
            computeDependence(Pattern, IndexExpr, Expansions)),
       Context(Context), Pattern(Pattern), IndexExpr(IndexExpr),
-      Size(Expansions.size()), ExpandsToEmptyPack(ExpandsToEmptyPack) {
+      Size(Expansions.size()), FullySubstituted(FullySubstituted) {
 
   std::uninitialized_copy(Expansions.begin(), Expansions.end(),
                           getTrailingObjects<QualType>());
@@ -4085,10 +4085,10 @@ PackIndexingType::computeDependence(QualType Pattern, Expr *IndexExpr,
 
 void PackIndexingType::Profile(llvm::FoldingSetNodeID &ID,
                                const ASTContext &Context, QualType Pattern,
-                               Expr *E, bool ExpandsToEmptyPack) {
+                               Expr *E, bool FullySubstituted) {
   Pattern.Profile(ID);
   E->Profile(ID, Context, true);
-  ID.AddBoolean(ExpandsToEmptyPack);
+  ID.AddBoolean(FullySubstituted);
 }
 
 UnaryTransformType::UnaryTransformType(QualType BaseType,
diff --git a/clang/lib/Sema/SemaTemplateVariadic.cpp b/clang/lib/Sema/SemaTemplateVariadic.cpp
index 2ea2a368dd24cf..86d15f6324f4f5 100644
--- a/clang/lib/Sema/SemaTemplateVariadic.cpp
+++ b/clang/lib/Sema/SemaTemplateVariadic.cpp
@@ -1157,10 +1157,12 @@ ExprResult Sema::ActOnPackIndexingExpr(Scope *S, Expr *PackExpression,
   return Res;
 }
 
-ExprResult
-Sema::BuildPackIndexingExpr(Expr *PackExpression, SourceLocation EllipsisLoc,
-                            Expr *IndexExpr, SourceLocation RSquareLoc,
-                            ArrayRef<Expr *> ExpandedExprs, bool EmptyPack) {
+ExprResult Sema::BuildPackIndexingExpr(Expr *PackExpression,
+                                       SourceLocation EllipsisLoc,
+                                       Expr *IndexExpr,
+                                       SourceLocation RSquareLoc,
+                                       ArrayRef<Expr *> ExpandedExprs,
+                                       bool FullySubstituted) {
 
   std::optional<int64_t> Index;
   if (!IndexExpr->isInstantiationDependent()) {
@@ -1174,8 +1176,8 @@ Sema::BuildPackIndexingExpr(Expr *PackExpression, SourceLocation EllipsisLoc,
     IndexExpr = Res.get();
   }
 
-  if (Index && (!ExpandedExprs.empty() || EmptyPack)) {
-    if (*Index < 0 || EmptyPack || *Index >= int64_t(ExpandedExprs.size())) {
+  if (Index && FullySubstituted) {
+    if (*Index < 0 || *Index >= int64_t(ExpandedExprs.size())) {
       Diag(PackExpression->getBeginLoc(), diag::err_pack_index_out_of_bound)
           << *Index << PackExpression << ExpandedExprs.size();
       return ExprError();
@@ -1184,7 +1186,7 @@ Sema::BuildPackIndexingExpr(Expr *PackExpression, SourceLocation EllipsisLoc,
 
   return PackIndexingExpr::Create(getASTContext(), EllipsisLoc, RSquareLoc,
                                   PackExpression, IndexExpr, Index,
-                                  ExpandedExprs, EmptyPack);
+                                  ExpandedExprs, FullySubstituted);
 }
 
 TemplateArgumentLoc Sema::getTemplateArgumentPackExpansionPattern(
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 1465bba87724b9..9cf1b2d073a90f 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -3670,10 +3670,10 @@ class TreeTransform {
                                      SourceLocation RSquareLoc,
                                      Expr *PackIdExpression, Expr *IndexExpr,
                                      ArrayRef<Expr *> ExpandedExprs,
-                                     bool EmptyPack = false) {
+                                     bool FullySubstituted = false) {
     return getSema().BuildPackIndexingExpr(PackIdExpression, EllipsisLoc,
                                            IndexExpr, RSquareLoc, ExpandedExprs,
-                                           EmptyPack);
+                                           FullySubstituted);
   }
 
   /// Build a new expression representing a call to a source location
@@ -6769,6 +6769,7 @@ TreeTransform<Derived>::TransformPackIndexingType(TypeLocBuilder &TLB,
       if (Out.isNull())
         return QualType();
       SubtitutedTypes.push_back(Out);
+      FullySubstituted &= !Out->containsUnexpandedParameterPack();
     }
     // If we're supposed to retain a pack expansion, do so by temporarily
     // forgetting the partially-substituted parameter pack.
@@ -15581,6 +15582,7 @@ TreeTransform<Derived>::TransformPackIndexingExpr(PackIndexingExpr *E) {
   }
 
   SmallVector<Expr *, 5> ExpandedExprs;
+  bool FullySubstituted = true;
   if (!E->expandsToEmptyPack() && E->getExpressions().empty()) {
     Expr *Pattern = E->getPackIdExpression();
     SmallVector<UnexpandedParameterPack, 2> Unexpanded;
@@ -15605,7 +15607,7 @@ TreeTransform<Derived>::TransformPackIndexingExpr(PackIndexingExpr *E) {
         return ExprError();
       return getDerived().RebuildPackIndexingExpr(
           E->getEllipsisLoc(), E->getRSquareLoc(), Pack.get(), IndexExpr.get(),
-          {});
+          {}, /*FullySubstituted=*/false);
     }
     for (unsigned I = 0; I != *NumExpansions; ++I) {
       Sema::ArgumentPackSubstitutionIndexRAII SubstIndex(getSema(), I);
@@ -15617,6 +15619,7 @@ TreeTransform<Derived>::TransformPackIndexingExpr(PackIndexingExpr *E) {
                                                 OrigNumExpansions);
         if (Out.isInvalid())
           return true;
+        FullySubstituted = false;
       }
       ExpandedExprs.push_back(Out.get());
     }
@@ -15633,6 +15636,7 @@ TreeTransform<Derived>::TransformPackIndexingExpr(PackIndexingExpr *E) {
                                               OrigNumExpansions);
       if (Out.isInvalid())
         return true;
+      FullySubstituted = false;
       ExpandedExprs.push_back(Out.get());
     }
   } else if (!E->expandsToEmptyPack()) {
@@ -15644,8 +15648,7 @@ TreeTransform<Derived>::TransformPackIndexingExpr(PackIndexingExpr *E) {
 
   return getDerived().RebuildPackIndexingExpr(
       E->getEllipsisLoc(), E->getRSquareLoc(), E->getPackIdExpression(),
-      IndexExpr.get(), ExpandedExprs,
-      /*EmptyPack=*/ExpandedExprs.size() == 0);
+      IndexExpr.get(), ExpandedExprs, FullySubstituted);
 }
 
 template<typename Derived>
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index c39a1950a6cf24..731ad0b64dc850 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2191,7 +2191,7 @@ void ASTStmtReader::VisitSizeOfPackExpr(SizeOfPackExpr *E) {
 void ASTStmtReader::VisitPackIndexingExpr(PackIndexingExpr *E) {
   VisitExpr(E);
   E->TransformedExpressions = Record.readInt();
-  E->ExpandedToEmptyPack = Record.readInt();
+  E->FullySubstituted = Record.readInt();
   E->EllipsisLoc = readSourceLocation();
   E->RSquareLoc = readSourceLocation();
   E->SubExprs[0] = Record.readStmt();
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index e7f567ff59a8ad..4994047d9fe10f 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2191,7 +2191,7 @@ void ASTStmtWriter::VisitSizeOfPackExpr(SizeOfPackExpr *E) {
 void ASTStmtWriter::VisitPackIndexingExpr(PackIndexingExpr *E) {
   VisitExpr(E);
   Record.push_back(E->TransformedExpressions);
-  Record.push_back(E->ExpandedToEmptyPack);
+  Record.push_back(E->FullySubstituted);
   Record.AddSourceLocation(E->getEllipsisLoc());
   Record.AddSourceLocation(E->getRSquareLoc());
   Record.AddStmt(E->getPackIdExpression());
diff --git a/clang/test/SemaCXX/cxx2c-pack-indexing.cpp b/clang/test/SemaCXX/cxx2c-pack-indexing.cpp
index 962dbb8137f289..cb679a6c3ad879 100644
--- a/clang/test/SemaCXX/cxx2c-pack-indexing.cpp
+++ b/clang/test/SemaCXX/cxx2c-pack-indexing.cpp
@@ -271,3 +271,37 @@ void f() {
 }
 
 } // namespace GH105903
+
+namespace GH116105 {
+
+template <unsigned long Np, class... Ts> using pack_type = Ts...[Np];
+
+template <unsigned long Np, auto... Ts> using pack_expr = decltype(Ts...[Np]);
+
+template <class...> struct types;
+
+template <class, long... Is> struct indices;
+
+template <class> struct repack;
+
+template <long... Idx> struct repack<indices<long, Idx...>> {
+  template <class... Ts>
+  using pack_type_alias = types<pack_type<Idx, Ts...>...>;
+
+  template <class... Ts>
+  using pack_expr_alias = types<pack_expr<Idx, Ts{}...>...>;
+};
+
+template <class... Args> struct mdispatch_ {
+  using Idx = __make_integer_seq<indices, long, sizeof...(Args)>;
+
+  static_assert(__is_same(
+      typename repack<Idx>::template pack_type_alias<Args...>, types<Args...>));
+
+  static_assert(__is_same(
+      typename repack<Idx>::template pack_expr_alias<Args...>, types<Args...>));
+};
+
+mdispatch_<int, int> d;
+
+} // namespace GH116105



More information about the cfe-commits mailing list