[clang] [Clang] [CodeGen] Perform derived-to-base conversion on explicit object parameter in lambda (PR #89828)

via cfe-commits cfe-commits at lists.llvm.org
Wed Apr 24 16:19:20 PDT 2024


https://github.com/Sirraide updated https://github.com/llvm/llvm-project/pull/89828

>From b5422012a65165f27bb31be7e9490892f663acfe Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Tue, 23 Apr 2024 22:45:29 +0200
Subject: [PATCH 1/2] [Clang] [CodeGen] Perform derived-to-base conversion on
 explicit object parameter in lambda

---
 clang/docs/ReleaseNotes.rst                   |  3 +
 clang/lib/CodeGen/CGExpr.cpp                  | 23 +++++++
 clang/test/CodeGenCXX/cxx2b-deducing-this.cpp | 63 +++++++++++++++++++
 3 files changed, 89 insertions(+)

diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index d1f7293a842bb6..34aad4abf39619 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -562,6 +562,9 @@ Bug Fixes to C++ Support
 - Fixed a crash when trying to evaluate a user-defined ``static_assert`` message whose ``size()``
   function returns a large or negative value. Fixes (#GH89407).
 - Fixed a use-after-free bug in parsing of type constraints with default arguments that involve lambdas. (#GH67235)
+- Fixed a crash when trying to emit captures in a lambda call operator with an explicit object
+  parameter that is called on a derived type of the lambda.
+  Fixes (#GH87210), (GH89541).
 
 Bug Fixes to AST Handling
 ^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 931cb391342ea2..33795d7d4d1921 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -4684,6 +4684,29 @@ LValue CodeGenFunction::EmitLValueForLambdaField(const FieldDecl *Field,
     else
       LambdaLV = MakeAddrLValue(AddrOfExplicitObject,
                                 D->getType().getNonReferenceType());
+
+    // Make sure we have an lvalue to the lambda itself and not a derived class.
+    auto *ThisTy = D->getType().getNonReferenceType()->getAsCXXRecordDecl();
+    auto *LambdaTy = cast<CXXRecordDecl>(Field->getParent());
+    if (ThisTy != LambdaTy) {
+      CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/true,
+                         /*DetectVirtual=*/false);
+
+      [[maybe_unused]] bool Derived = ThisTy->isDerivedFrom(LambdaTy, Paths);
+      assert(Derived && "Type not derived from lambda type?");
+
+      const CXXBasePath *Path = &Paths.front();
+      CXXCastPath BasePathArray;
+      for (unsigned I = 0, E = Path->size(); I != E; ++I)
+        BasePathArray.push_back(
+            const_cast<CXXBaseSpecifier *>((*Path)[I].Base));
+
+      Address Base = GetAddressOfBaseClass(
+          LambdaLV.getAddress(*this), ThisTy, BasePathArray.begin(),
+          BasePathArray.end(), /*NullCheckValue=*/false, SourceLocation());
+
+      LambdaLV = MakeAddrLValue(Base, QualType{LambdaTy->getTypeForDecl(), 0});
+    }
   } else {
     QualType LambdaTagType = getContext().getTagDeclType(Field->getParent());
     LambdaLV = MakeNaturalAlignAddrLValue(ThisValue, LambdaTagType);
diff --git a/clang/test/CodeGenCXX/cxx2b-deducing-this.cpp b/clang/test/CodeGenCXX/cxx2b-deducing-this.cpp
index b755e80db35a12..649fe2afbf4e91 100644
--- a/clang/test/CodeGenCXX/cxx2b-deducing-this.cpp
+++ b/clang/test/CodeGenCXX/cxx2b-deducing-this.cpp
@@ -182,3 +182,66 @@ auto dothing(int num)
   fun();
 }
 }
+
+namespace GH87210 {
+template <typename... Ts>
+struct Overloaded : Ts... {
+  using Ts::operator()...;
+};
+
+template <typename... Ts>
+Overloaded(Ts...) -> Overloaded<Ts...>;
+
+// CHECK-LABEL: define dso_local void @_ZN7GH872101fEv()
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[X:%.*]] = alloca i32
+// CHECK-NEXT:    [[Over:%.*]] = alloca %"{{.*}}Overloaded"
+// CHECK:         call noundef ptr @"_ZZN7GH872101fEvENH3$_0clINS_10OverloadedIJS0_EEEEEDaRT_"(ptr {{.*}} [[Over]])
+void f() {
+  int x;
+  Overloaded o {
+    // CHECK: define internal noundef ptr @"_ZZN7GH872101fEvENH3$_0clINS_10OverloadedIJS0_EEEEEDaRT_"(ptr {{.*}} [[Self:%.*]])
+    // CHECK-NEXT:  entry:
+    // CHECK-NEXT:    [[SelfAddr:%.*]] = alloca ptr
+    // CHECK-NEXT:    store ptr [[Self]], ptr [[SelfAddr]]
+    // CHECK-NEXT:    [[SelfPtr:%.*]] = load ptr, ptr [[SelfAddr]]
+    // CHECK-NEXT:    [[XRef:%.*]] = getelementptr inbounds %{{.*}}, ptr [[SelfPtr]], i32 0, i32 0
+    // CHECK-NEXT:    [[X:%.*]] = load ptr, ptr [[XRef]]
+    // CHECK-NEXT:    ret ptr [[X]]
+    [&](this auto& self) {
+      return &x;
+    }
+  };
+  o();
+}
+
+void g() {
+  int x;
+  Overloaded o {
+    [=](this auto& self) {
+      return x;
+    }
+  };
+  o();
+}
+}
+
+namespace GH89541 {
+// Same as above; just check that this doesn't crash.
+int one = 1;
+auto factory(int& x = one) {
+  return [&](this auto self) {
+    x;
+  };
+};
+
+using Base = decltype(factory());
+struct Derived : Base {
+  Derived() : Base(factory()) {}
+};
+
+void f() {
+  Derived d;
+  d();
+}
+}

>From 90d73ea88016307532bb38c4b2e8fa8f082bea75 Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Thu, 25 Apr 2024 01:01:18 +0200
Subject: [PATCH 2/2] [Clang] Tentative implementation of CWG 2881

---
 clang/include/clang/AST/ASTContext.h          |  9 +++
 .../clang/Basic/DiagnosticSemaKinds.td        |  5 ++
 clang/include/clang/Sema/Sema.h               |  4 +-
 clang/lib/CodeGen/CGExpr.cpp                  | 17 +----
 clang/lib/Sema/SemaLambda.cpp                 | 61 ++++++++++++----
 clang/lib/Sema/SemaOverload.cpp               | 15 ++--
 clang/test/CXX/drs/dr28xx.cpp                 | 71 +++++++++++++++++++
 7 files changed, 147 insertions(+), 35 deletions(-)

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index d5ed20ff50157d..3210ef2dfe12be 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -110,6 +110,9 @@ class VarTemplateDecl;
 class VTableContextBase;
 class XRayFunctionFilter;
 
+/// A simple array of base specifiers.
+typedef SmallVector<CXXBaseSpecifier *, 4> CXXCastPath;
+
 namespace Builtin {
 
 class Context;
@@ -1168,6 +1171,12 @@ class ASTContext : public RefCountedBase<ASTContext> {
   /// in device compilation.
   llvm::DenseSet<const FunctionDecl *> CUDAImplicitHostDeviceFunUsedByDevice;
 
+  /// For capturing lambdas with an explicit object parameter whose type is
+  /// derived from the lambda type, we need to perform derived-to-base
+  /// conversion so we can access the captures; the cast paths for that
+  /// are stored here.
+  llvm::DenseMap<const CXXMethodDecl *, CXXCastPath> LambdaCastPaths;
+
   ASTContext(LangOptions &LOpts, SourceManager &SM, IdentifierTable &idents,
              SelectorTable &sels, Builtin::Context &builtins,
              TranslationUnitKind TUKind);
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 63e951daec7477..5e04ec82ea152b 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -7497,6 +7497,11 @@ def err_explicit_object_parameter_mutable: Error<
 def err_invalid_explicit_object_type_in_lambda: Error<
   "invalid explicit object parameter type %0 in lambda with capture; "
   "the type must be the same as, or derived from, the lambda">;
+def err_explicit_object_lambda_ambiguous_base : Error<
+  "lambda %0 is inaccessible due to ambiguity:%1">;
+def err_explicit_object_lambda_inaccessible_base : Error<
+  "invalid explicit object parameter type %0 in lambda with capture; "
+  "the type must derive publicly from the lambda">;
 
 def err_ref_qualifier_overload : Error<
   "cannot overload a member function %select{without a ref-qualifier|with "
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 1ca523ec88c2f9..fa450a7868282a 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -7144,7 +7144,9 @@ class Sema final : public SemaBase {
       StorageClass SC, ArrayRef<ParmVarDecl *> Params,
       bool HasExplicitResultType);
 
-  void DiagnoseInvalidExplicitObjectParameterInLambda(CXXMethodDecl *Method);
+  /// Returns true if the explicit object parameter was invalid.
+  bool DiagnoseInvalidExplicitObjectParameterInLambda(CXXMethodDecl *Method,
+                                                      SourceLocation CallLoc);
 
   /// Perform initialization analysis of the init-capture and perform
   /// any implicit conversions such as an lvalue-to-rvalue conversion if
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 33795d7d4d1921..73f8a67c10fe82 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -4667,7 +4667,8 @@ LValue CodeGenFunction::EmitMemberExpr(const MemberExpr *E) {
 LValue CodeGenFunction::EmitLValueForLambdaField(const FieldDecl *Field,
                                                  llvm::Value *ThisValue) {
   bool HasExplicitObjectParameter = false;
-  if (const auto *MD = dyn_cast_if_present<CXXMethodDecl>(CurCodeDecl)) {
+  const auto *MD = dyn_cast_if_present<CXXMethodDecl>(CurCodeDecl);
+  if (MD) {
     HasExplicitObjectParameter = MD->isExplicitObjectMemberFunction();
     assert(MD->getParent()->isLambda());
     assert(MD->getParent() == Field->getParent());
@@ -4689,22 +4690,10 @@ LValue CodeGenFunction::EmitLValueForLambdaField(const FieldDecl *Field,
     auto *ThisTy = D->getType().getNonReferenceType()->getAsCXXRecordDecl();
     auto *LambdaTy = cast<CXXRecordDecl>(Field->getParent());
     if (ThisTy != LambdaTy) {
-      CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/true,
-                         /*DetectVirtual=*/false);
-
-      [[maybe_unused]] bool Derived = ThisTy->isDerivedFrom(LambdaTy, Paths);
-      assert(Derived && "Type not derived from lambda type?");
-
-      const CXXBasePath *Path = &Paths.front();
-      CXXCastPath BasePathArray;
-      for (unsigned I = 0, E = Path->size(); I != E; ++I)
-        BasePathArray.push_back(
-            const_cast<CXXBaseSpecifier *>((*Path)[I].Base));
-
+      const CXXCastPath &BasePathArray = getContext().LambdaCastPaths.at(MD);
       Address Base = GetAddressOfBaseClass(
           LambdaLV.getAddress(*this), ThisTy, BasePathArray.begin(),
           BasePathArray.end(), /*NullCheckValue=*/false, SourceLocation());
-
       LambdaLV = MakeAddrLValue(Base, QualType{LambdaTy->getTypeForDecl(), 0});
     }
   } else {
diff --git a/clang/lib/Sema/SemaLambda.cpp b/clang/lib/Sema/SemaLambda.cpp
index 1743afaf15287f..c96f376d1e2bea 100644
--- a/clang/lib/Sema/SemaLambda.cpp
+++ b/clang/lib/Sema/SemaLambda.cpp
@@ -12,6 +12,7 @@
 #include "clang/Sema/SemaLambda.h"
 #include "TypeLocBuilder.h"
 #include "clang/AST/ASTLambda.h"
+#include "clang/AST/CXXInheritance.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Sema/DeclSpec.h"
@@ -386,30 +387,62 @@ buildTypeForLambdaCallOperator(Sema &S, clang::CXXRecordDecl *Class,
 //  parameter, if any, of the lambda's function call operator (possibly
 //  instantiated from a function call operator template) shall be either:
 //  - the closure type,
-//  - class type derived from the closure type, or
+//  - class type publicly and unambiguously derived from the closure type, or
 //  - a reference to a possibly cv-qualified such type.
-void Sema::DiagnoseInvalidExplicitObjectParameterInLambda(
-    CXXMethodDecl *Method) {
+bool Sema::DiagnoseInvalidExplicitObjectParameterInLambda(
+    CXXMethodDecl *Method, SourceLocation CallLoc) {
   if (!isLambdaCallWithExplicitObjectParameter(Method))
-    return;
+    return false;
   CXXRecordDecl *RD = Method->getParent();
   if (Method->getType()->isDependentType())
-    return;
+    return false;
   if (RD->isCapturelessLambda())
-    return;
-  QualType ExplicitObjectParameterType = Method->getParamDecl(0)
-                                             ->getType()
+    return false;
+
+  ParmVarDecl *Param = Method->getParamDecl(0);
+  QualType ExplicitObjectParameterType = Param->getType()
                                              .getNonReferenceType()
                                              .getUnqualifiedType()
                                              .getDesugaredType(getASTContext());
   QualType LambdaType = getASTContext().getRecordType(RD);
   if (LambdaType == ExplicitObjectParameterType)
-    return;
-  if (IsDerivedFrom(RD->getLocation(), ExplicitObjectParameterType, LambdaType))
-    return;
-  Diag(Method->getParamDecl(0)->getLocation(),
-       diag::err_invalid_explicit_object_type_in_lambda)
-      << ExplicitObjectParameterType;
+    return false;
+
+  // Don't check the same instantiation twice.
+  //
+  // If this call operator is ill-formed, there is no point in issuing
+  // a diagnostic every time it is called because the problem is in the
+  // definition of the derived type, not at the call site.
+  //
+  // FIXME: Move this check to where we instantiate the method?
+  if (auto It = Context.LambdaCastPaths.find(Method);
+      It != Context.LambdaCastPaths.end())
+    return It->second.empty();
+
+  CXXCastPath &Path = Context.LambdaCastPaths[Method];
+  CXXBasePaths Paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true,
+                     /*DetectVirtual=*/false);
+  if (!IsDerivedFrom(RD->getLocation(), ExplicitObjectParameterType, LambdaType,
+                     Paths)) {
+    Diag(Param->getLocation(), diag::err_invalid_explicit_object_type_in_lambda)
+        << ExplicitObjectParameterType;
+    return true;
+  }
+
+  if (Paths.isAmbiguous(LambdaType->getCanonicalTypeUnqualified())) {
+    std::string PathsDisplay = getAmbiguousPathsDisplayString(Paths);
+    Diag(CallLoc, diag::err_explicit_object_lambda_ambiguous_base)
+        << LambdaType << PathsDisplay;
+    return true;
+  }
+
+  if (CheckBaseClassAccess(CallLoc, LambdaType, ExplicitObjectParameterType,
+                           Paths.front(),
+                           diag::err_explicit_object_lambda_inaccessible_base))
+    return true;
+
+  BuildBasePathArray(Paths, Path);
+  return false;
 }
 
 void Sema::handleLambdaNumbering(
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 04cd9e78739d20..88782cf64c95df 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -6492,17 +6492,20 @@ ExprResult Sema::InitializeExplicitObjectArgument(Sema &S, Expr *Obj,
       Obj->getExprLoc(), Obj);
 }
 
-static void PrepareExplicitObjectArgument(Sema &S, CXXMethodDecl *Method,
+static bool PrepareExplicitObjectArgument(Sema &S, CXXMethodDecl *Method,
                                           Expr *Object, MultiExprArg &Args,
                                           SmallVectorImpl<Expr *> &NewArgs) {
   assert(Method->isExplicitObjectMemberFunction() &&
          "Method is not an explicit member function");
   assert(NewArgs.empty() && "NewArgs should be empty");
+
   NewArgs.reserve(Args.size() + 1);
   Expr *This = GetExplicitObjectExpr(S, Object, Method);
   NewArgs.push_back(This);
   NewArgs.append(Args.begin(), Args.end());
   Args = NewArgs;
+  return S.DiagnoseInvalidExplicitObjectParameterInLambda(
+      Method, Object->getBeginLoc());
 }
 
 /// Determine whether the provided type is an integral type, or an enumeration
@@ -15671,8 +15674,10 @@ ExprResult Sema::BuildCallToMemberFunction(Scope *S, Expr *MemExprE,
   CallExpr *TheCall = nullptr;
   llvm::SmallVector<Expr *, 8> NewArgs;
   if (Method->isExplicitObjectMemberFunction()) {
-    PrepareExplicitObjectArgument(*this, Method, MemExpr->getBase(), Args,
-                                  NewArgs);
+    if (PrepareExplicitObjectArgument(*this, Method, MemExpr->getBase(), Args,
+                                      NewArgs))
+      return ExprError();
+
     // Build the actual expression node.
     ExprResult FnExpr =
         CreateFunctionRefExpr(*this, Method, FoundDecl, MemExpr,
@@ -15986,9 +15991,7 @@ Sema::BuildCallToObjectOfClassType(Scope *S, Expr *Obj,
   // Initialize the object parameter.
   llvm::SmallVector<Expr *, 8> NewArgs;
   if (Method->isExplicitObjectMemberFunction()) {
-    // FIXME: we should do that during the definition of the lambda when we can.
-    DiagnoseInvalidExplicitObjectParameterInLambda(Method);
-    PrepareExplicitObjectArgument(*this, Method, Obj, Args, NewArgs);
+    IsError |= PrepareExplicitObjectArgument(*this, Method, Obj, Args, NewArgs);
   } else {
     ExprResult ObjRes = PerformImplicitObjectArgumentInitialization(
         Object.get(), /*Qualifier=*/nullptr, Best->FoundDecl, Method);
diff --git a/clang/test/CXX/drs/dr28xx.cpp b/clang/test/CXX/drs/dr28xx.cpp
index 4d9b0c76758d53..ef140a91b9c494 100644
--- a/clang/test/CXX/drs/dr28xx.cpp
+++ b/clang/test/CXX/drs/dr28xx.cpp
@@ -81,3 +81,74 @@ struct A {
 #endif
 
 } // namespace cwg2858
+
+namespace cwg2881 { // cwg2881: 19
+
+#if __cplusplus >= 202302L
+
+template <typename T> struct A : T {};
+template <typename T> struct B : T {};
+template <typename T> struct C : virtual T { C(T t) : T(t) {} };
+template <typename T> struct D : virtual T { D(T t) : T(t) {} };
+
+template <typename Ts>
+struct O1 : A<Ts>, B<Ts> {
+  using A<Ts>::operator();
+  using B<Ts>::operator();
+};
+
+template <typename Ts> struct O2 : protected Ts { // expected-note {{declared protected here}}
+  using Ts::operator();
+  O2(Ts ts) : Ts(ts) {}
+};
+
+template <typename Ts> struct O3 : private Ts { // expected-note {{declared private here}}
+  using Ts::operator();
+  O3(Ts ts) : Ts(ts) {}
+};
+
+// Not ambiguous because of virtual inheritance.
+template <typename Ts>
+struct O4 : C<Ts>, D<Ts> {
+  using C<Ts>::operator();
+  using D<Ts>::operator();
+  O4(Ts t) : Ts(t), C<Ts>(t), D<Ts>(t) {}
+};
+
+// This still has a public path to the lambda, and it's also not
+// ambiguous because of virtual inheritance.
+template <typename Ts>
+struct O5 : private C<Ts>, D<Ts> {
+  using C<Ts>::operator();
+  using D<Ts>::operator();
+  O5(Ts t) : Ts(t), C<Ts>(t), D<Ts>(t) {}
+};
+
+// This is only invalid if we call T's call operator.
+template <typename T, typename U>
+struct O6 : private T, U { // expected-note {{declared private here}}
+  using T::operator();
+  using U::operator();
+  O6(T t, U u) : T(t), U(u) {}
+};
+
+void f() {
+  int x;
+  auto L1 = [=](this auto&& self) { (void) &x; };
+  auto L2 = [&](this auto&& self) { (void) &x; };
+  O1<decltype(L1)>{L1, L1}(); // expected-error {{inaccessible due to ambiguity}}
+  O1<decltype(L2)>{L2, L2}(); // expected-error {{inaccessible due to ambiguity}}
+  O2{L1}(); // expected-error {{must derive publicly from the lambda}}
+  O3{L1}(); // expected-error {{must derive publicly from the lambda}}
+  O4{L1}();
+  O5{L1}();
+  O6 o{L1, L2};
+  o.decltype(L1)::operator()(); // expected-error {{must derive publicly from the lambda}}
+  o.decltype(L1)::operator()(); // No error here because we've already diagnosed this method.
+  o.decltype(L2)::operator()();
+}
+
+#endif
+
+} // namespace cwg2881
+



More information about the cfe-commits mailing list