[clang] [Sema] Fix computations of "unexpanded packs" in substituted lambdas (PR #99882)

via cfe-commits cfe-commits at lists.llvm.org
Mon Jul 22 07:29:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Ilya Biryukov (ilya-biryukov)

<details>
<summary>Changes</summary>

This addresses a crash in https://github.com/llvm/llvm-project/issues/99877 that happens on nested lambdas that use
template parameter packs, see the added test.

Before this patch, the code computing the 'ContainsUnexpandedPacks' was
relying on a flag from `LambdaScopeInfo`, that was updated differently
in parsing and template substitution.
This led to a discrepancy in how template parameters and captures that
contain unexpanded packs were handled. In particular, the substitution
missed some cases where it was supposed to mark the lambda as containing
unexpanded packs.
In turn, this lead to invalid state when the corresponding lambda was used
to create a substituted `CXXFoldExpr` as `CXXFoldExpr` relies on this
particular flag to distinguish left and right folds, and the code using
it relies on `getPattern` to have unexpanded packs. The latter was
causing assertion failures and subsequent crashes in non-assertion
builds.

This commit addresses those issues by moving the code computing
`ContainsUnexpandedPack` value into `ComputeDependence.cpp` and relying
on structurally checking the dependencies of various components that
substitute a lambda expression instead of relying on the code parity
between parsing and template substitution.
This should help to avoid issues like this in the future, as checking
the correctness of structural code is much easier than ensuring that
global state manipulation is correct across two versions of the code.
It's also all in one file, as opposed to being spread to multiple files.

The only exception to this rule right now is the lambda body, as
statements do not carry around any dependency flags and we have to rely
on the global state for them.

Another tricky case is attributes, one of which was actually checked in
the tests (`diagnose_if`). The new code path handles only `diagnose_if`,
but the long-term solution there would likely be propagating the
`ContainsUnexpandedPack` flag into `Attr`, similar to the
`Attr::IsPackExpansion` flag we already have. Previously, the code
examples like the one in the added test would crash on all attributes.
However, simpler examples used to work and will now report an error if
used inside fold expressions. This is a potential regression, but it
also seems very rare and I think we can address it in a follow up.

Lastly, this commit adds an asssertion that exactly one of the arguments
to `CXXFoldExpr` contains an unexpanded parameter pack. It should help
catch issues like this.

---
Full diff: https://github.com/llvm/llvm-project/pull/99882.diff


9 Files Affected:

- (modified) clang/include/clang/AST/ComputeDependence.h (+1-1) 
- (modified) clang/include/clang/AST/ExprCXX.h (+4-11) 
- (modified) clang/include/clang/Sema/ScopeInfo.h (+3-2) 
- (modified) clang/lib/AST/ComputeDependence.cpp (+33-2) 
- (modified) clang/lib/AST/DeclTemplate.cpp (+27-11) 
- (modified) clang/lib/AST/ExprCXX.cpp (+22-4) 
- (modified) clang/lib/Sema/SemaLambda.cpp (+7-14) 
- (modified) clang/lib/Sema/SemaTemplateVariadic.cpp (+5-1) 
- (added) clang/test/SemaCXX/fold_lambda_with_variadics.cpp (+78) 


``````````diff
diff --git a/clang/include/clang/AST/ComputeDependence.h b/clang/include/clang/AST/ComputeDependence.h
index 6d3a51c379f9d..848c1f203c296 100644
--- a/clang/include/clang/AST/ComputeDependence.h
+++ b/clang/include/clang/AST/ComputeDependence.h
@@ -166,7 +166,7 @@ ExprDependence computeDependence(CXXTemporaryObjectExpr *E);
 ExprDependence computeDependence(CXXDefaultInitExpr *E);
 ExprDependence computeDependence(CXXDefaultArgExpr *E);
 ExprDependence computeDependence(LambdaExpr *E,
-                                 bool ContainsUnexpandedParameterPack);
+                                 bool BodyContainsUnexpandedPacks);
 ExprDependence computeDependence(CXXUnresolvedConstructExpr *E);
 ExprDependence computeDependence(CXXDependentScopeMemberExpr *E);
 ExprDependence computeDependence(MaterializeTemporaryExpr *E);
diff --git a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h
index c2feac525c1ea..c0d58cab58b9b 100644
--- a/clang/include/clang/AST/ExprCXX.h
+++ b/clang/include/clang/AST/ExprCXX.h
@@ -1975,7 +1975,8 @@ class LambdaExpr final : public Expr,
              LambdaCaptureDefault CaptureDefault,
              SourceLocation CaptureDefaultLoc, bool ExplicitParams,
              bool ExplicitResultType, ArrayRef<Expr *> CaptureInits,
-             SourceLocation ClosingBrace, bool ContainsUnexpandedParameterPack);
+             SourceLocation ClosingBrace,
+             bool BodyContainsUnexpandedParameterPack);
 
   /// Construct an empty lambda expression.
   LambdaExpr(EmptyShell Empty, unsigned NumCaptures);
@@ -1996,7 +1997,7 @@ class LambdaExpr final : public Expr,
          LambdaCaptureDefault CaptureDefault, SourceLocation CaptureDefaultLoc,
          bool ExplicitParams, bool ExplicitResultType,
          ArrayRef<Expr *> CaptureInits, SourceLocation ClosingBrace,
-         bool ContainsUnexpandedParameterPack);
+         bool BodyContainsUnexpandedParameterPack);
 
   /// Construct a new lambda expression that will be deserialized from
   /// an external source.
@@ -4854,15 +4855,7 @@ class CXXFoldExpr : public Expr {
   CXXFoldExpr(QualType T, UnresolvedLookupExpr *Callee,
               SourceLocation LParenLoc, Expr *LHS, BinaryOperatorKind Opcode,
               SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc,
-              std::optional<unsigned> NumExpansions)
-      : Expr(CXXFoldExprClass, T, VK_PRValue, OK_Ordinary),
-        LParenLoc(LParenLoc), EllipsisLoc(EllipsisLoc), RParenLoc(RParenLoc),
-        NumExpansions(NumExpansions ? *NumExpansions + 1 : 0), Opcode(Opcode) {
-    SubExprs[SubExpr::Callee] = Callee;
-    SubExprs[SubExpr::LHS] = LHS;
-    SubExprs[SubExpr::RHS] = RHS;
-    setDependence(computeDependence(this));
-  }
+              std::optional<unsigned> NumExpansions);
 
   CXXFoldExpr(EmptyShell Empty) : Expr(CXXFoldExprClass, Empty) {}
 
diff --git a/clang/include/clang/Sema/ScopeInfo.h b/clang/include/clang/Sema/ScopeInfo.h
index 700e361ef83f1..c998670377208 100644
--- a/clang/include/clang/Sema/ScopeInfo.h
+++ b/clang/include/clang/Sema/ScopeInfo.h
@@ -895,8 +895,9 @@ class LambdaScopeInfo final :
   /// Whether any of the capture expressions requires cleanups.
   CleanupInfo Cleanup;
 
-  /// Whether the lambda contains an unexpanded parameter pack.
-  bool ContainsUnexpandedParameterPack = false;
+  /// Whether the lambda body contains an unexpanded parameter pack.
+  /// Note that the captures and template paramters are handled separately.
+  bool BodyContainsUnexpandedParameterPack = false;
 
   /// Packs introduced by this lambda, if any.
   SmallVector<NamedDecl*, 4> LocalPacks;
diff --git a/clang/lib/AST/ComputeDependence.cpp b/clang/lib/AST/ComputeDependence.cpp
index 62ca15ea398f5..fbb3c72438a56 100644
--- a/clang/lib/AST/ComputeDependence.cpp
+++ b/clang/lib/AST/ComputeDependence.cpp
@@ -850,9 +850,40 @@ ExprDependence clang::computeDependence(CXXDefaultArgExpr *E) {
 }
 
 ExprDependence clang::computeDependence(LambdaExpr *E,
-                                        bool ContainsUnexpandedParameterPack) {
+                                        bool BodyContainsUnexpandedPacks) {
   auto D = toExprDependenceForImpliedType(E->getType()->getDependence());
-  if (ContainsUnexpandedParameterPack)
+
+  // Record the presence of unexpanded packs.
+  bool ContainsUnexpandedPack =
+      BodyContainsUnexpandedPacks ||
+      (E->getTemplateParameterList() &&
+       E->getTemplateParameterList()->containsUnexpandedParameterPack());
+  if (!ContainsUnexpandedPack) {
+    // Also look at captures.
+    for (const auto &C : E->explicit_captures()) {
+      if (!C.capturesVariable() || C.isPackExpansion())
+        continue;
+      auto *Var = C.getCapturedVar();
+      if ((!Var->isInitCapture() && Var->isParameterPack()) ||
+          (Var->isInitCapture() && !Var->isParameterPack() &&
+           cast<VarDecl>(Var)->getInit()->containsUnexpandedParameterPack())) {
+        ContainsUnexpandedPack = true;
+        break;
+      }
+    }
+  }
+  // FIXME: Support other attributes, e.g. by storing corresponding flag inside
+  // Attr (similar to Attr::IsPackExpansion).
+  if (!ContainsUnexpandedPack) {
+    for (auto *A : E->getCallOperator()->specific_attrs<DiagnoseIfAttr>()) {
+      if (A->getCond() && A->getCond()->containsUnexpandedParameterPack()) {
+        ContainsUnexpandedPack = true;
+        break;
+      }
+    }
+  }
+
+  if (ContainsUnexpandedPack)
     D |= ExprDependence::UnexpandedPack;
   return D;
 }
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index 722c7fcf0b0df..f95be88e6c087 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -61,27 +61,43 @@ TemplateParameterList::TemplateParameterList(const ASTContext& C,
 
     bool IsPack = P->isTemplateParameterPack();
     if (const auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(P)) {
-      if (!IsPack && NTTP->getType()->containsUnexpandedParameterPack())
-        ContainsUnexpandedParameterPack = true;
+      if (!IsPack) {
+        if (NTTP->getType()->containsUnexpandedParameterPack())
+          ContainsUnexpandedParameterPack = true;
+        else if (NTTP->hasDefaultArgument() &&
+                 NTTP->getDefaultArgument()
+                     .getArgument()
+                     .containsUnexpandedParameterPack())
+          ContainsUnexpandedParameterPack = true;
+      }
       if (NTTP->hasPlaceholderTypeConstraint())
         HasConstrainedParameters = true;
     } else if (const auto *TTP = dyn_cast<TemplateTemplateParmDecl>(P)) {
-      if (!IsPack &&
-          TTP->getTemplateParameters()->containsUnexpandedParameterPack())
-        ContainsUnexpandedParameterPack = true;
-    } else if (const auto *TTP = dyn_cast<TemplateTypeParmDecl>(P)) {
-      if (const TypeConstraint *TC = TTP->getTypeConstraint()) {
-        if (TC->getImmediatelyDeclaredConstraint()
-            ->containsUnexpandedParameterPack())
+      if (!IsPack) {
+        if (TTP->getTemplateParameters()->containsUnexpandedParameterPack())
           ContainsUnexpandedParameterPack = true;
+        else if (TTP->hasDefaultArgument() &&
+                 TTP->getDefaultArgument()
+                     .getArgument()
+                     .containsUnexpandedParameterPack())
+          ContainsUnexpandedParameterPack = true;
+      }
+    } else if (const auto *TTP = dyn_cast<TemplateTypeParmDecl>(P)) {
+      if (!IsPack && TTP->hasDefaultArgument() &&
+          TTP->getDefaultArgument()
+              .getArgument()
+              .containsUnexpandedParameterPack()) {
+        ContainsUnexpandedParameterPack = true;
+      } else if (const TypeConstraint *TC = TTP->getTypeConstraint();
+                 TC && TC->getImmediatelyDeclaredConstraint()
+                           ->containsUnexpandedParameterPack()) {
+        ContainsUnexpandedParameterPack = true;
       }
       if (TTP->hasTypeConstraint())
         HasConstrainedParameters = true;
     } else {
       llvm_unreachable("unexpected template parameter type");
     }
-    // FIXME: If a default argument contains an unexpanded parameter pack, the
-    // template parameter list does too.
   }
 
   if (HasRequiresClause) {
diff --git a/clang/lib/AST/ExprCXX.cpp b/clang/lib/AST/ExprCXX.cpp
index 8d2a1b5611ccc..a8f7aee7df1b3 100644
--- a/clang/lib/AST/ExprCXX.cpp
+++ b/clang/lib/AST/ExprCXX.cpp
@@ -1254,7 +1254,7 @@ LambdaExpr::LambdaExpr(QualType T, SourceRange IntroducerRange,
                        SourceLocation CaptureDefaultLoc, bool ExplicitParams,
                        bool ExplicitResultType, ArrayRef<Expr *> CaptureInits,
                        SourceLocation ClosingBrace,
-                       bool ContainsUnexpandedParameterPack)
+                       bool BodyContainsUnexpandedParameterPack)
     : Expr(LambdaExprClass, T, VK_PRValue, OK_Ordinary),
       IntroducerRange(IntroducerRange), CaptureDefaultLoc(CaptureDefaultLoc),
       ClosingBrace(ClosingBrace) {
@@ -1276,7 +1276,7 @@ LambdaExpr::LambdaExpr(QualType T, SourceRange IntroducerRange,
   // Copy the body of the lambda.
   *Stored++ = getCallOperator()->getBody();
 
-  setDependence(computeDependence(this, ContainsUnexpandedParameterPack));
+  setDependence(computeDependence(this, BodyContainsUnexpandedParameterPack));
 }
 
 LambdaExpr::LambdaExpr(EmptyShell Empty, unsigned NumCaptures)
@@ -1295,7 +1295,7 @@ LambdaExpr *LambdaExpr::Create(const ASTContext &Context, CXXRecordDecl *Class,
                                bool ExplicitParams, bool ExplicitResultType,
                                ArrayRef<Expr *> CaptureInits,
                                SourceLocation ClosingBrace,
-                               bool ContainsUnexpandedParameterPack) {
+                               bool BodyContainsUnexpandedParameterPack) {
   // Determine the type of the expression (i.e., the type of the
   // function object we're creating).
   QualType T = Context.getTypeDeclType(Class);
@@ -1305,7 +1305,7 @@ LambdaExpr *LambdaExpr::Create(const ASTContext &Context, CXXRecordDecl *Class,
   return new (Mem)
       LambdaExpr(T, IntroducerRange, CaptureDefault, CaptureDefaultLoc,
                  ExplicitParams, ExplicitResultType, CaptureInits, ClosingBrace,
-                 ContainsUnexpandedParameterPack);
+                 BodyContainsUnexpandedParameterPack);
 }
 
 LambdaExpr *LambdaExpr::CreateDeserialized(const ASTContext &C,
@@ -1944,3 +1944,21 @@ CXXParenListInitExpr *CXXParenListInitExpr::CreateEmpty(ASTContext &C,
                          alignof(CXXParenListInitExpr));
   return new (Mem) CXXParenListInitExpr(Empty, NumExprs);
 }
+
+CXXFoldExpr::CXXFoldExpr(QualType T, UnresolvedLookupExpr *Callee,
+                         SourceLocation LParenLoc, Expr *LHS,
+                         BinaryOperatorKind Opcode, SourceLocation EllipsisLoc,
+                         Expr *RHS, SourceLocation RParenLoc,
+                         std::optional<unsigned> NumExpansions)
+    : Expr(CXXFoldExprClass, T, VK_PRValue, OK_Ordinary), LParenLoc(LParenLoc),
+      EllipsisLoc(EllipsisLoc), RParenLoc(RParenLoc),
+      NumExpansions(NumExpansions ? *NumExpansions + 1 : 0), Opcode(Opcode) {
+  // We rely on asserted invariant to distnguish left and right folds.
+  assert(((LHS && LHS->containsUnexpandedParameterPack()) !=
+          (RHS && RHS->containsUnexpandedParameterPack())) &&
+         "Exactly one of LHS or RHS should contain an unexpanded pack");
+  SubExprs[SubExpr::Callee] = Callee;
+  SubExprs[SubExpr::LHS] = LHS;
+  SubExprs[SubExpr::RHS] = RHS;
+  setDependence(computeDependence(this));
+}
diff --git a/clang/lib/Sema/SemaLambda.cpp b/clang/lib/Sema/SemaLambda.cpp
index 601077e9f3334..b4cb38c133a54 100644
--- a/clang/lib/Sema/SemaLambda.cpp
+++ b/clang/lib/Sema/SemaLambda.cpp
@@ -1109,8 +1109,6 @@ void Sema::ActOnLambdaExpressionAfterIntroducer(LambdaIntroducer &Intro,
 
   PushDeclContext(CurScope, Method);
 
-  bool ContainsUnexpandedParameterPack = false;
-
   // Distinct capture names, for diagnostics.
   llvm::DenseMap<IdentifierInfo *, ValueDecl *> CaptureNames;
 
@@ -1312,8 +1310,6 @@ void Sema::ActOnLambdaExpressionAfterIntroducer(LambdaIntroducer &Intro,
 
         // Just ignore the ellipsis.
       }
-    } else if (Var->isParameterPack()) {
-      ContainsUnexpandedParameterPack = true;
     }
 
     if (C->Init.isUsable()) {
@@ -1328,7 +1324,6 @@ void Sema::ActOnLambdaExpressionAfterIntroducer(LambdaIntroducer &Intro,
       LSI->ExplicitCaptureRanges[LSI->Captures.size() - 1] = C->ExplicitRange;
   }
   finishLambdaExplicitCaptures(LSI);
-  LSI->ContainsUnexpandedParameterPack |= ContainsUnexpandedParameterPack;
   PopDeclContext();
 }
 
@@ -1380,8 +1375,6 @@ void Sema::ActOnLambdaClosureParameters(
     AddTemplateParametersToLambdaCallOperator(LSI->CallOperator, LSI->Lambda,
                                               TemplateParams);
     LSI->Lambda->setLambdaIsGeneric(true);
-    LSI->ContainsUnexpandedParameterPack |=
-        TemplateParams->containsUnexpandedParameterPack();
   }
   LSI->AfterParameterList = true;
 }
@@ -2079,7 +2072,7 @@ ExprResult Sema::BuildLambdaExpr(SourceLocation StartLoc, SourceLocation EndLoc,
   bool ExplicitParams;
   bool ExplicitResultType;
   CleanupInfo LambdaCleanup;
-  bool ContainsUnexpandedParameterPack;
+  bool BodyContainsUnexpandedParameterPack;
   bool IsGenericLambda;
   {
     CallOperator = LSI->CallOperator;
@@ -2088,7 +2081,8 @@ ExprResult Sema::BuildLambdaExpr(SourceLocation StartLoc, SourceLocation EndLoc,
     ExplicitParams = LSI->ExplicitParams;
     ExplicitResultType = !LSI->HasImplicitReturnType;
     LambdaCleanup = LSI->Cleanup;
-    ContainsUnexpandedParameterPack = LSI->ContainsUnexpandedParameterPack;
+    BodyContainsUnexpandedParameterPack =
+        LSI->BodyContainsUnexpandedParameterPack;
     IsGenericLambda = Class->isGenericLambda();
 
     CallOperator->setLexicalDeclContext(Class);
@@ -2227,11 +2221,10 @@ ExprResult Sema::BuildLambdaExpr(SourceLocation StartLoc, SourceLocation EndLoc,
 
   Cleanup.mergeFrom(LambdaCleanup);
 
-  LambdaExpr *Lambda = LambdaExpr::Create(Context, Class, IntroducerRange,
-                                          CaptureDefault, CaptureDefaultLoc,
-                                          ExplicitParams, ExplicitResultType,
-                                          CaptureInits, EndLoc,
-                                          ContainsUnexpandedParameterPack);
+  LambdaExpr *Lambda = LambdaExpr::Create(
+      Context, Class, IntroducerRange, CaptureDefault, CaptureDefaultLoc,
+      ExplicitParams, ExplicitResultType, CaptureInits, EndLoc,
+      BodyContainsUnexpandedParameterPack);
   // If the lambda expression's call operator is not explicitly marked constexpr
   // and we are not in a dependent context, analyze the call operator to infer
   // its constexpr-ness, suppressing diagnostics while doing so.
diff --git a/clang/lib/Sema/SemaTemplateVariadic.cpp b/clang/lib/Sema/SemaTemplateVariadic.cpp
index 6df7f2223d267..8e47e9ee339f2 100644
--- a/clang/lib/Sema/SemaTemplateVariadic.cpp
+++ b/clang/lib/Sema/SemaTemplateVariadic.cpp
@@ -353,7 +353,11 @@ Sema::DiagnoseUnexpandedParameterPacks(SourceLocation Loc,
       }
 
       if (!EnclosingStmtExpr) {
-        LSI->ContainsUnexpandedParameterPack = true;
+        // It is ok to have unexpanded packs in captures, template parameters
+        // and parameters too, but only the body statement does not store this
+        // flag, so we have to propagate it through LamdaScopeInfo.
+        if (LSI->AfterParameterList)
+          LSI->BodyContainsUnexpandedParameterPack = true;
         return false;
       }
     } else {
diff --git a/clang/test/SemaCXX/fold_lambda_with_variadics.cpp b/clang/test/SemaCXX/fold_lambda_with_variadics.cpp
new file mode 100644
index 0000000000000..45505d4f4c434
--- /dev/null
+++ b/clang/test/SemaCXX/fold_lambda_with_variadics.cpp
@@ -0,0 +1,78 @@
+// RUN: %clang_cc1 -fsyntax-only -std=c++20 -verify %s
+// expected-no-diagnostics
+struct tuple {
+    int x[3];
+};
+
+template <class F>
+int apply(F f, tuple v) {
+    return f(v.x[0], v.x[1], v.x[2]);
+}
+
+int Cartesian1(auto x, auto y) {
+    return apply([&](auto... xs) {
+        return (apply([xs](auto... ys) {
+            return (ys + ...);
+        }, y) + ...);
+    }, x);
+}
+
+int Cartesian2(auto x, auto y) {
+    return apply([&](auto... xs) {
+        return (apply([zs = xs](auto... ys) {
+            return (ys + ...);
+        }, y) + ...);
+    }, x);
+}
+
+template <int ...> struct Ints{};
+template <int> struct Choose {
+  template<class> struct Templ;
+};
+template <int ...x>
+int Cartesian3(auto y) {
+    return [&]<int ...xs>(Ints<xs...>) {
+        // check in default template arguments for
+        // - type template parameters,
+        (void)(apply([]<class = decltype(xs)>(auto... ys) {
+          return (ys + ...);
+        }, y) + ...);
+        // - template template parameters.
+        (void)(apply([]<template<class> class = Choose<xs>::template Templ>(auto... ys) {
+          return (ys + ...);
+        }, y) + ...);
+        // - non-type template parameters,
+        return (apply([]<int = xs>(auto... ys) {
+            return (ys + ...);
+        }, y) + ...);
+
+    }(Ints<x...>());
+}
+
+template <int ...x>
+int Cartesian4(auto y) {
+    return [&]<int ...xs>(Ints<xs...>) {
+        return (apply([]<decltype(xs) xx = 1>(auto... ys) {
+            return (ys + ...);
+        }, y) + ...);
+    }(Ints<x...>());
+}
+
+int Cartesian5(auto x, auto y) {
+    return apply([&](auto... xs) {
+        return (apply([](auto... ys) __attribute__((diagnose_if(!__is_same(decltype(xs), int), "message", "error"))) {
+            return (ys + ...);
+        }, y) + ...);
+    }, x);
+}
+
+
+int main() {
+    auto x = tuple({1, 2, 3});
+    auto y = tuple({4, 5, 6});
+    Cartesian1(x, y);
+    Cartesian2(x, y);
+    Cartesian3<1,2,3>(y);
+    Cartesian4<1,2,3>(y);
+    Cartesian5(x, y);
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/99882


More information about the cfe-commits mailing list