[clang] [Clang] C++ Templates: Refactor and fix `TransformLambdaExpr`'s mishandling of TypeLocs (PR #78801)

Yuxuan Chen via cfe-commits cfe-commits at lists.llvm.org
Wed Jan 24 11:27:18 PST 2024


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/78801

>From 157e7f107cc7d275f4b3919654d8c0dde9d341df Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Fri, 19 Jan 2024 13:42:46 -0800
Subject: [PATCH] Refactor and fix Lambda prototype instantiation when it's
 attributed or macro qualified

---
 clang/lib/Sema/SemaTemplateInstantiate.cpp    | 19 +++-
 clang/lib/Sema/TreeTransform.h                | 91 +++++++------------
 clang/test/SemaCXX/template-instantiation.cpp | 16 +++-
 3 files changed, 60 insertions(+), 66 deletions(-)

diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index e12186d7d82f8d..8990e345b9d1c6 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -1203,6 +1203,9 @@ namespace {
     // Whether to evaluate the C++20 constraints or simply substitute into them.
     bool EvaluateConstraints = true;
 
+    // Whether we are in the middle of transforming a lambda expression.
+    bool TransformingLambda = false;
+
   public:
     typedef TreeTransform<TemplateInstantiator> inherited;
 
@@ -1454,7 +1457,7 @@ namespace {
     ExprResult TransformLambdaExpr(LambdaExpr *E) {
       LocalInstantiationScope Scope(SemaRef, /*CombineWithOuterScope=*/true);
       Sema::ConstraintEvalRAII<TemplateInstantiator> RAII(*this);
-
+      TransformingLambda = true;
       ExprResult Result = inherited::TransformLambdaExpr(E);
       if (Result.isInvalid())
         return Result;
@@ -2178,10 +2181,16 @@ QualType TemplateInstantiator::TransformFunctionProtoType(TypeLocBuilder &TLB,
                                  CXXRecordDecl *ThisContext,
                                  Qualifiers ThisTypeQuals,
                                  Fn TransformExceptionSpec) {
-  // We need a local instantiation scope for this function prototype.
-  LocalInstantiationScope Scope(SemaRef, /*CombineWithOuterScope=*/true);
-  return inherited::TransformFunctionProtoType(
-      TLB, TL, ThisContext, ThisTypeQuals, TransformExceptionSpec);
+  if (TransformingLambda) {
+    TransformingLambda = false;
+    return inherited::TransformFunctionProtoType(
+        TLB, TL, ThisContext, ThisTypeQuals, TransformExceptionSpec);
+  } else {
+    // We need a local instantiation scope for this function prototype.
+    LocalInstantiationScope Scope(SemaRef, /*CombineWithOuterScope=*/true);
+    return inherited::TransformFunctionProtoType(
+        TLB, TL, ThisContext, ThisTypeQuals, TransformExceptionSpec);
+  }
 }
 
 ParmVarDecl *TemplateInstantiator::TransformFunctionTypeParam(
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index e55e752b9cc354..7a6edd079b9351 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -674,10 +674,6 @@ class TreeTransform {
                                       Qualifiers ThisTypeQuals,
                                       Fn TransformExceptionSpec);
 
-  template <typename Fn>
-  QualType TransformAttributedType(TypeLocBuilder &TLB, AttributedTypeLoc TL,
-                                   Fn TransformModifiedType);
-
   bool TransformExceptionSpec(SourceLocation Loc,
                               FunctionProtoType::ExceptionSpecInfo &ESI,
                               SmallVectorImpl<QualType> &Exceptions,
@@ -7069,11 +7065,11 @@ TreeTransform<Derived>::TransformElaboratedType(TypeLocBuilder &TLB,
 }
 
 template <typename Derived>
-template <typename Fn>
-QualType TreeTransform<Derived>::TransformAttributedType(
-    TypeLocBuilder &TLB, AttributedTypeLoc TL, Fn TransformModifiedTypeFn) {
+QualType TreeTransform<Derived>::TransformAttributedType(TypeLocBuilder &TLB,
+                                                         AttributedTypeLoc TL) {
+
   const AttributedType *oldType = TL.getTypePtr();
-  QualType modifiedType = TransformModifiedTypeFn(TLB, TL.getModifiedLoc());
+  QualType modifiedType = getDerived().TransformType(TLB, TL.getModifiedLoc());
   if (modifiedType.isNull())
     return QualType();
 
@@ -7117,15 +7113,6 @@ QualType TreeTransform<Derived>::TransformAttributedType(
   return result;
 }
 
-template <typename Derived>
-QualType TreeTransform<Derived>::TransformAttributedType(TypeLocBuilder &TLB,
-                                                         AttributedTypeLoc TL) {
-  return getDerived().TransformAttributedType(
-      TLB, TL, [&](TypeLocBuilder &TLB, TypeLoc ModifiedLoc) -> QualType {
-        return getDerived().TransformType(TLB, ModifiedLoc);
-      });
-}
-
 template <typename Derived>
 QualType TreeTransform<Derived>::TransformBTFTagAttributedType(
     TypeLocBuilder &TLB, BTFTagAttributedTypeLoc TL) {
@@ -13636,58 +13623,42 @@ TreeTransform<Derived>::TransformLambdaExpr(LambdaExpr *E) {
   // The transformation MUST be done in the CurrentInstantiationScope since
   // it introduces a mapping of the original to the newly created
   // transformed parameters.
-  TypeSourceInfo *NewCallOpTSI = nullptr;
-  {
-    auto OldCallOpTypeLoc =
-        E->getCallOperator()->getTypeSourceInfo()->getTypeLoc();
-
-    auto TransformFunctionProtoTypeLoc =
-        [this](TypeLocBuilder &TLB, FunctionProtoTypeLoc FPTL) -> QualType {
-      SmallVector<QualType, 4> ExceptionStorage;
-      return this->TransformFunctionProtoType(
-          TLB, FPTL, nullptr, Qualifiers(),
-          [&](FunctionProtoType::ExceptionSpecInfo &ESI, bool &Changed) {
-            return TransformExceptionSpec(FPTL.getBeginLoc(), ESI,
-                                          ExceptionStorage, Changed);
-          });
-    };
+  auto OldCallOpTypeLoc =
+      E->getCallOperator()->getTypeSourceInfo()->getTypeLoc();
 
-    QualType NewCallOpType;
-    TypeLocBuilder NewCallOpTLBuilder;
+  TypeLocBuilder NewCallOpTLBuilder;
+  QualType NewCallOpType =
+      getDerived().TransformType(NewCallOpTLBuilder, OldCallOpTypeLoc);
 
-    if (auto ATL = OldCallOpTypeLoc.getAs<AttributedTypeLoc>()) {
-      NewCallOpType = this->TransformAttributedType(
-          NewCallOpTLBuilder, ATL,
-          [&](TypeLocBuilder &TLB, TypeLoc TL) -> QualType {
-            return TransformFunctionProtoTypeLoc(
-                TLB, TL.castAs<FunctionProtoTypeLoc>());
-          });
-    } else {
-      auto FPTL = OldCallOpTypeLoc.castAs<FunctionProtoTypeLoc>();
-      NewCallOpType = TransformFunctionProtoTypeLoc(NewCallOpTLBuilder, FPTL);
-    }
-
-    if (NewCallOpType.isNull())
-      return ExprError();
-    NewCallOpTSI =
-        NewCallOpTLBuilder.getTypeSourceInfo(getSema().Context, NewCallOpType);
-  }
+  if (NewCallOpType.isNull())
+    return ExprError();
 
-  ArrayRef<ParmVarDecl *> Params;
-  if (auto ATL = NewCallOpTSI->getTypeLoc().getAs<AttributedTypeLoc>()) {
-    Params = ATL.getModifiedLoc().castAs<FunctionProtoTypeLoc>().getParams();
-  } else {
-    auto FPTL = NewCallOpTSI->getTypeLoc().castAs<FunctionProtoTypeLoc>();
-    Params = FPTL.getParams();
-  }
+  TypeSourceInfo *NewCallOpTSI =
+      NewCallOpTLBuilder.getTypeSourceInfo(getSema().Context, NewCallOpType);
+
+  auto ExtractParams = [](TypeLoc TL) {
+    auto ExtractParamsImpl = [](auto Self,
+                                TypeLoc TL) -> ArrayRef<ParmVarDecl *> {
+      if (auto FPTL = TL.getAs<FunctionProtoTypeLoc>()) {
+        return FPTL.getParams();
+      } else if (auto ATL = TL.getAs<AttributedTypeLoc>()) {
+        return Self(Self, ATL.getModifiedLoc());
+      } else if (auto MQTL = TL.getAs<MacroQualifiedTypeLoc>()) {
+        return Self(Self, MQTL.getInnerLoc());
+      } else {
+        llvm_unreachable("Unhandled TypeLoc");
+      }
+    };
+    return ExtractParamsImpl(ExtractParamsImpl, TL);
+  };
 
   getSema().CompleteLambdaCallOperator(
       NewCallOperator, E->getCallOperator()->getLocation(),
       E->getCallOperator()->getInnerLocStart(),
       E->getCallOperator()->getTrailingRequiresClause(), NewCallOpTSI,
       E->getCallOperator()->getConstexprKind(),
-      E->getCallOperator()->getStorageClass(), Params,
-      E->hasExplicitResultType());
+      E->getCallOperator()->getStorageClass(),
+      ExtractParams(NewCallOpTSI->getTypeLoc()), E->hasExplicitResultType());
 
   getDerived().transformAttrs(E->getCallOperator(), NewCallOperator);
   getDerived().transformedLocalDecl(E->getCallOperator(), {NewCallOperator});
diff --git a/clang/test/SemaCXX/template-instantiation.cpp b/clang/test/SemaCXX/template-instantiation.cpp
index 8543af0d5428d0..42be88fef461ee 100644
--- a/clang/test/SemaCXX/template-instantiation.cpp
+++ b/clang/test/SemaCXX/template-instantiation.cpp
@@ -1,15 +1,29 @@
-// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -verify -fsyntax-only %s
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -ast-dump %s | FileCheck %s
 // expected-no-diagnostics
 
 namespace GH76521 {
 
+#define MYATTR __attribute__((preserve_most))
+
 template <typename T>
 void foo() {
+  // CHECK: FunctionDecl {{.*}} foo 'void ()'
   auto l = []() __attribute__((preserve_most)) {};
+  // CHECK: CXXMethodDecl {{.*}} operator() 'auto () __attribute__((preserve_most)) const' inline
+  auto l2 = [](T t) __attribute__((preserve_most)) -> T { return t; };
+  // CHECK: CXXMethodDecl {{.*}} operator() 'auto (int) const -> int __attribute__((preserve_most))':'auto (int) __attribute__((preserve_most)) const -> int' implicit_instantiation inline instantiated_fro
 }
 
+template <typename T>
 void bar() {
+  // CHECK: FunctionDecl {{.*}} bar 'void ()'
+  auto l = []() MYATTR {};
+  // CHECK: CXXMethodDecl {{.*}} operator() 'auto () __attribute__((preserve_most)) const' inline
+}
+
+int main() {
   foo<int>();
+  bar<int>();
 }
 
 }



More information about the cfe-commits mailing list