[clang] 5e77dfe - [clang] CTAD: build aggregate deduction guides for alias templates. (#85904)

via cfe-commits cfe-commits at lists.llvm.org
Fri Apr 5 12:07:18 PDT 2024


Author: Haojian Wu
Date: 2024-04-05T21:07:14+02:00
New Revision: 5e77dfecd2e274117cc0a75de69c832d00130e6a

URL: https://github.com/llvm/llvm-project/commit/5e77dfecd2e274117cc0a75de69c832d00130e6a
DIFF: https://github.com/llvm/llvm-project/commit/5e77dfecd2e274117cc0a75de69c832d00130e6a.diff

LOG: [clang] CTAD: build aggregate deduction guides for alias templates. (#85904)

Fixes https://github.com/llvm/llvm-project/issues/85767.

The aggregate deduction guides are handled in a separate code path. We
don't generate dedicated aggregate deduction guides for alias templates
(we just reuse the ones from the underlying template decl by accident).
The patch fixes this incorrect issue.

Note: there is a small refactoring change in this PR, where we move the
cache logic from `Sema::DeduceTemplateSpecializationFromInitializer` to
`Sema::DeclareImplicitDeductionGuideFromInitList`

Added: 
    

Modified: 
    clang/include/clang/Sema/Sema.h
    clang/lib/Sema/SemaInit.cpp
    clang/lib/Sema/SemaTemplate.cpp
    clang/test/SemaTemplate/deduction-guide.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 149b268311bc5c..f52816e12efe47 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -9713,7 +9713,8 @@ class Sema final {
   /// not already done so.
   void DeclareImplicitDeductionGuides(TemplateDecl *Template,
                                       SourceLocation Loc);
-  FunctionTemplateDecl *DeclareImplicitDeductionGuideFromInitList(
+
+  FunctionTemplateDecl *DeclareAggregateDeductionGuideFromInitList(
       TemplateDecl *Template, MutableArrayRef<QualType> ParamTypes,
       SourceLocation Loc);
 

diff  --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index e2a1951f1062cb..a75e9925a43146 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -10930,32 +10930,16 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
                   Context.getLValueReferenceType(ElementTypes[I].withConst());
           }
 
-        llvm::FoldingSetNodeID ID;
-        ID.AddPointer(Template);
-        for (auto &T : ElementTypes)
-          T.getCanonicalType().Profile(ID);
-        unsigned Hash = ID.ComputeHash();
-        if (AggregateDeductionCandidates.count(Hash) == 0) {
-          if (FunctionTemplateDecl *TD =
-                  DeclareImplicitDeductionGuideFromInitList(
-                      Template, ElementTypes,
-                      TSInfo->getTypeLoc().getEndLoc())) {
-            auto *GD = cast<CXXDeductionGuideDecl>(TD->getTemplatedDecl());
-            GD->setDeductionCandidateKind(DeductionCandidate::Aggregate);
-            AggregateDeductionCandidates[Hash] = GD;
-            addDeductionCandidate(TD, GD, DeclAccessPair::make(TD, AS_public),
-                                  OnlyListConstructors,
-                                  /*AllowAggregateDeductionCandidate=*/true);
-          }
-        } else {
-          CXXDeductionGuideDecl *GD = AggregateDeductionCandidates[Hash];
-          FunctionTemplateDecl *TD = GD->getDescribedFunctionTemplate();
-          assert(TD && "aggregate deduction candidate is function template");
+        if (FunctionTemplateDecl *TD =
+                DeclareAggregateDeductionGuideFromInitList(
+                    LookupTemplateDecl, ElementTypes,
+                    TSInfo->getTypeLoc().getEndLoc())) {
+          auto *GD = cast<CXXDeductionGuideDecl>(TD->getTemplatedDecl());
           addDeductionCandidate(TD, GD, DeclAccessPair::make(TD, AS_public),
                                 OnlyListConstructors,
                                 /*AllowAggregateDeductionCandidate=*/true);
+          HasAnyDeductionGuide = true;
         }
-        HasAnyDeductionGuide = true;
       }
     };
 

diff  --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 0b75f4fb401e63..73030cf4b72203 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -2754,23 +2754,42 @@ bool hasDeclaredDeductionGuides(DeclarationName Name, DeclContext *DC) {
   return false;
 }
 
-// Build deduction guides for a type alias template.
-void DeclareImplicitDeductionGuidesForTypeAlias(
-    Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate, SourceLocation Loc) {
-  if (AliasTemplate->isInvalidDecl())
-    return;
-  auto &Context = SemaRef.Context;
-  // FIXME: if there is an explicit deduction guide after the first use of the
-  // type alias usage, we will not cover this explicit deduction guide. fix this
-  // case.
-  if (hasDeclaredDeductionGuides(
-          Context.DeclarationNames.getCXXDeductionGuideName(AliasTemplate),
-          AliasTemplate->getDeclContext()))
-    return;
+NamedDecl *transformTemplateParameter(Sema &SemaRef, DeclContext *DC,
+                                      NamedDecl *TemplateParam,
+                                      MultiLevelTemplateArgumentList &Args,
+                                      unsigned NewIndex) {
+  if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
+    return transformTemplateTypeParam(SemaRef, DC, TTP, Args, TTP->getDepth(),
+                                      NewIndex);
+  if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
+    return transformTemplateParam(SemaRef, DC, TTP, Args, NewIndex,
+                                  TTP->getDepth());
+  if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
+    return transformTemplateParam(SemaRef, DC, NTTP, Args, NewIndex,
+                                  NTTP->getDepth());
+  llvm_unreachable("Unhandled template parameter types");
+}
+
+Expr *transformRequireClause(Sema &SemaRef, FunctionTemplateDecl *FTD,
+                             llvm::ArrayRef<TemplateArgument> TransformedArgs) {
+  Expr *RC = FTD->getTemplateParameters()->getRequiresClause();
+  if (!RC)
+    return nullptr;
+  MultiLevelTemplateArgumentList Args;
+  Args.setKind(TemplateSubstitutionKind::Rewrite);
+  Args.addOuterTemplateArguments(TransformedArgs);
+  ExprResult E = SemaRef.SubstExpr(RC, Args);
+  if (E.isInvalid())
+    return nullptr;
+  return E.getAs<Expr>();
+}
+
+std::pair<TemplateDecl *, llvm::ArrayRef<TemplateArgument>>
+getRHSTemplateDeclAndArgs(Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate) {
   // Unwrap the sugared ElaboratedType.
   auto RhsType = AliasTemplate->getTemplatedDecl()
                      ->getUnderlyingType()
-                     .getSingleStepDesugaredType(Context);
+                     .getSingleStepDesugaredType(SemaRef.Context);
   TemplateDecl *Template = nullptr;
   llvm::ArrayRef<TemplateArgument> AliasRhsTemplateArgs;
   if (const auto *TST = RhsType->getAs<TemplateSpecializationType>()) {
@@ -2791,6 +2810,24 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
   } else {
     assert(false && "unhandled RHS type of the alias");
   }
+  return {Template, AliasRhsTemplateArgs};
+}
+
+// Build deduction guides for a type alias template.
+void DeclareImplicitDeductionGuidesForTypeAlias(
+    Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate, SourceLocation Loc) {
+  if (AliasTemplate->isInvalidDecl())
+    return;
+  auto &Context = SemaRef.Context;
+  // FIXME: if there is an explicit deduction guide after the first use of the
+  // type alias usage, we will not cover this explicit deduction guide. fix this
+  // case.
+  if (hasDeclaredDeductionGuides(
+          Context.DeclarationNames.getCXXDeductionGuideName(AliasTemplate),
+          AliasTemplate->getDeclContext()))
+    return;
+  auto [Template, AliasRhsTemplateArgs] =
+      getRHSTemplateDeclAndArgs(SemaRef, AliasTemplate);
   if (!Template)
     return;
   DeclarationNameInfo NameInfo(
@@ -2803,6 +2840,13 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
     FunctionTemplateDecl *F = dyn_cast<FunctionTemplateDecl>(G);
     if (!F)
       continue;
+    // The **aggregate** deduction guides are handled in a 
diff erent code path
+    // (DeclareImplicitDeductionGuideFromInitList), which involves the tricky
+    // cache.
+    if (cast<CXXDeductionGuideDecl>(F->getTemplatedDecl())
+            ->getDeductionCandidateKind() == DeductionCandidate::Aggregate)
+      continue;
+
     auto RType = F->getTemplatedDecl()->getReturnType();
     // The (trailing) return type of the deduction guide.
     const TemplateSpecializationType *FReturnType =
@@ -2885,21 +2929,6 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
     // parameters, used for building `TemplateArgsForBuildingFPrime`.
     SmallVector<TemplateArgument, 16> TransformedDeducedAliasArgs(
         AliasTemplate->getTemplateParameters()->size());
-    auto TransformTemplateParameter =
-        [&SemaRef](DeclContext *DC, NamedDecl *TemplateParam,
-                   MultiLevelTemplateArgumentList &Args,
-                   unsigned NewIndex) -> NamedDecl * {
-      if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
-        return transformTemplateTypeParam(SemaRef, DC, TTP, Args,
-                                          TTP->getDepth(), NewIndex);
-      if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
-        return transformTemplateParam(SemaRef, DC, TTP, Args, NewIndex,
-                                      TTP->getDepth());
-      if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
-        return transformTemplateParam(SemaRef, DC, NTTP, Args, NewIndex,
-                                      NTTP->getDepth());
-      return nullptr;
-    };
 
     for (unsigned AliasTemplateParamIdx : DeducedAliasTemplateParams) {
       auto *TP = AliasTemplate->getTemplateParameters()->getParam(
@@ -2909,9 +2938,9 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
       MultiLevelTemplateArgumentList Args;
       Args.setKind(TemplateSubstitutionKind::Rewrite);
       Args.addOuterTemplateArguments(TransformedDeducedAliasArgs);
-      NamedDecl *NewParam =
-          TransformTemplateParameter(AliasTemplate->getDeclContext(), TP, Args,
-                                     /*NewIndex*/ FPrimeTemplateParams.size());
+      NamedDecl *NewParam = transformTemplateParameter(
+          SemaRef, AliasTemplate->getDeclContext(), TP, Args,
+          /*NewIndex*/ FPrimeTemplateParams.size());
       FPrimeTemplateParams.push_back(NewParam);
 
       auto NewTemplateArgument = Context.getCanonicalTemplateArgument(
@@ -2927,8 +2956,8 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
       // We take a shortcut here, it is ok to reuse the
       // TemplateArgsForBuildingFPrime.
       Args.addOuterTemplateArguments(TemplateArgsForBuildingFPrime);
-      NamedDecl *NewParam = TransformTemplateParameter(
-          F->getDeclContext(), TP, Args, FPrimeTemplateParams.size());
+      NamedDecl *NewParam = transformTemplateParameter(
+          SemaRef, F->getDeclContext(), TP, Args, FPrimeTemplateParams.size());
       FPrimeTemplateParams.push_back(NewParam);
 
       assert(TemplateArgsForBuildingFPrime[FTemplateParamIdx].isNull() &&
@@ -2938,16 +2967,8 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
               Context.getInjectedTemplateArg(NewParam));
     }
     // Substitute new template parameters into requires-clause if present.
-    Expr *RequiresClause = nullptr;
-    if (Expr *InnerRC = F->getTemplateParameters()->getRequiresClause()) {
-      MultiLevelTemplateArgumentList Args;
-      Args.setKind(TemplateSubstitutionKind::Rewrite);
-      Args.addOuterTemplateArguments(TemplateArgsForBuildingFPrime);
-      ExprResult E = SemaRef.SubstExpr(InnerRC, Args);
-      if (E.isInvalid())
-        return;
-      RequiresClause = E.getAs<Expr>();
-    }
+    Expr *RequiresClause =
+        transformRequireClause(SemaRef, F, TemplateArgsForBuildingFPrime);
     // FIXME: implement the is_deducible constraint per C++
     // [over.match.class.deduct]p3.3:
     //    ... and a constraint that is satisfied if and only if the arguments
@@ -3013,11 +3034,102 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
   }
 }
 
+// Build an aggregate deduction guide for a type alias template.
+FunctionTemplateDecl *DeclareAggregateDeductionGuideForTypeAlias(
+    Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate,
+    MutableArrayRef<QualType> ParamTypes, SourceLocation Loc) {
+  TemplateDecl *RHSTemplate =
+      getRHSTemplateDeclAndArgs(SemaRef, AliasTemplate).first;
+  if (!RHSTemplate)
+    return nullptr;
+  auto *RHSDeductionGuide = SemaRef.DeclareAggregateDeductionGuideFromInitList(
+      RHSTemplate, ParamTypes, Loc);
+  if (!RHSDeductionGuide)
+    return nullptr;
+
+  LocalInstantiationScope Scope(SemaRef);
+
+  // Build a new template parameter list for the synthesized aggregate deduction
+  // guide by transforming the one from RHSDeductionGuide.
+  SmallVector<NamedDecl *> TransformedTemplateParams;
+  // Template args that refer to the rebuilt template parameters.
+  // All template arguments must be initialized in advance.
+  SmallVector<TemplateArgument> TransformedTemplateArgs(
+      RHSDeductionGuide->getTemplateParameters()->size());
+  for (auto *TP : *RHSDeductionGuide->getTemplateParameters()) {
+    // Rebuild any internal references to earlier parameters and reindex as
+    // we go.
+    MultiLevelTemplateArgumentList Args;
+    Args.setKind(TemplateSubstitutionKind::Rewrite);
+    Args.addOuterTemplateArguments(TransformedTemplateArgs);
+    NamedDecl *NewParam = transformTemplateParameter(
+        SemaRef, AliasTemplate->getDeclContext(), TP, Args,
+        /*NewIndex=*/TransformedTemplateParams.size());
+
+    TransformedTemplateArgs[TransformedTemplateParams.size()] =
+        SemaRef.Context.getCanonicalTemplateArgument(
+            SemaRef.Context.getInjectedTemplateArg(NewParam));
+    TransformedTemplateParams.push_back(NewParam);
+  }
+  // FIXME: implement the is_deducible constraint per C++
+  // [over.match.class.deduct]p3.3.
+  Expr *TransformedRequiresClause = transformRequireClause(
+      SemaRef, RHSDeductionGuide, TransformedTemplateArgs);
+  auto *TransformedTemplateParameterList = TemplateParameterList::Create(
+      SemaRef.Context, AliasTemplate->getTemplateParameters()->getTemplateLoc(),
+      AliasTemplate->getTemplateParameters()->getLAngleLoc(),
+      TransformedTemplateParams,
+      AliasTemplate->getTemplateParameters()->getRAngleLoc(),
+      TransformedRequiresClause);
+  auto *TransformedTemplateArgList = TemplateArgumentList::CreateCopy(
+      SemaRef.Context, TransformedTemplateArgs);
+
+  if (auto *TransformedDeductionGuide = SemaRef.InstantiateFunctionDeclaration(
+          RHSDeductionGuide, TransformedTemplateArgList,
+          AliasTemplate->getLocation(),
+          Sema::CodeSynthesisContext::BuildingDeductionGuides)) {
+    auto *GD =
+        llvm::dyn_cast<clang::CXXDeductionGuideDecl>(TransformedDeductionGuide);
+    FunctionTemplateDecl *Result = buildDeductionGuide(
+        SemaRef, AliasTemplate, TransformedTemplateParameterList,
+        GD->getCorrespondingConstructor(), GD->getExplicitSpecifier(),
+        GD->getTypeSourceInfo(), AliasTemplate->getBeginLoc(),
+        AliasTemplate->getLocation(), AliasTemplate->getEndLoc(),
+        GD->isImplicit());
+    cast<CXXDeductionGuideDecl>(Result->getTemplatedDecl())
+        ->setDeductionCandidateKind(DeductionCandidate::Aggregate);
+    return Result;
+  }
+  return nullptr;
+}
+
 } // namespace
 
-FunctionTemplateDecl *Sema::DeclareImplicitDeductionGuideFromInitList(
+FunctionTemplateDecl *Sema::DeclareAggregateDeductionGuideFromInitList(
     TemplateDecl *Template, MutableArrayRef<QualType> ParamTypes,
     SourceLocation Loc) {
+  llvm::FoldingSetNodeID ID;
+  ID.AddPointer(Template);
+  for (auto &T : ParamTypes)
+    T.getCanonicalType().Profile(ID);
+  unsigned Hash = ID.ComputeHash();
+
+  auto Found = AggregateDeductionCandidates.find(Hash);
+  if (Found != AggregateDeductionCandidates.end()) {
+    CXXDeductionGuideDecl *GD = Found->getSecond();
+    return GD->getDescribedFunctionTemplate();
+  }
+
+  if (auto *AliasTemplate = llvm::dyn_cast<TypeAliasTemplateDecl>(Template)) {
+    if (auto *FTD = DeclareAggregateDeductionGuideForTypeAlias(
+            *this, AliasTemplate, ParamTypes, Loc)) {
+      auto *GD = cast<CXXDeductionGuideDecl>(FTD->getTemplatedDecl());
+      GD->setDeductionCandidateKind(DeductionCandidate::Aggregate);
+      AggregateDeductionCandidates[Hash] = GD;
+      return FTD;
+    }
+  }
+
   if (CXXRecordDecl *DefRecord =
           cast<CXXRecordDecl>(Template->getTemplatedDecl())->getDefinition()) {
     if (TemplateDecl *DescribedTemplate =
@@ -3050,10 +3162,13 @@ FunctionTemplateDecl *Sema::DeclareImplicitDeductionGuideFromInitList(
       Transform.NestedPattern ? Transform.NestedPattern : Transform.Template;
   ContextRAII SavedContext(*this, Pattern->getTemplatedDecl());
 
-  auto *DG = cast<FunctionTemplateDecl>(
+  auto *FTD = cast<FunctionTemplateDecl>(
       Transform.buildSimpleDeductionGuide(ParamTypes));
   SavedContext.pop();
-  return DG;
+  auto *GD = cast<CXXDeductionGuideDecl>(FTD->getTemplatedDecl());
+  GD->setDeductionCandidateKind(DeductionCandidate::Aggregate);
+  AggregateDeductionCandidates[Hash] = GD;
+  return FTD;
 }
 
 void Sema::DeclareImplicitDeductionGuides(TemplateDecl *Template,

diff  --git a/clang/test/SemaTemplate/deduction-guide.cpp b/clang/test/SemaTemplate/deduction-guide.cpp
index 0caef78fedbfd9..58f08aa1eed650 100644
--- a/clang/test/SemaTemplate/deduction-guide.cpp
+++ b/clang/test/SemaTemplate/deduction-guide.cpp
@@ -248,3 +248,15 @@ G g = {1};
 // CHECK: FunctionTemplateDecl
 // CHECK: |-CXXDeductionGuideDecl {{.*}} implicit <deduction guide for G> 'auto (T) -> G<T>' aggregate
 // CHECK: `-CXXDeductionGuideDecl {{.*}} implicit used <deduction guide for G> 'auto (int) -> G<int>' implicit_instantiation aggregate
+
+template<typename X>
+using AG = G<X>;
+AG ag = {1};
+// Verify that the aggregate deduction guide for alias templates is built.
+// CHECK-LABEL: Dumping <deduction guide for AG>
+// CHECK: FunctionTemplateDecl
+// CHECK: |-CXXDeductionGuideDecl {{.*}} 'auto (type-parameter-0-0) -> G<type-parameter-0-0>'
+// CHECK: `-CXXDeductionGuideDecl {{.*}} 'auto (int) -> G<int>' implicit_instantiation
+// CHECK:   |-TemplateArgument type 'int'
+// CHECK:   | `-BuiltinType {{.*}} 'int'
+// CHECK:   `-ParmVarDecl {{.*}} 'int'


        


More information about the cfe-commits mailing list