[clang] [clang] CTAD: build aggregate deduction guides for alias templates. (PR #85904)

Haojian Wu via cfe-commits cfe-commits at lists.llvm.org
Wed Mar 20 01:29:37 PDT 2024


https://github.com/hokein created https://github.com/llvm/llvm-project/pull/85904

Fixes https://github.com/llvm/llvm-project/issues/85767.



>From ff697e2a159a4959deb11031111d9d8442ef544a Mon Sep 17 00:00:00 2001
From: Haojian Wu <hokein.wu at gmail.com>
Date: Fri, 15 Mar 2024 10:47:09 +0100
Subject: [PATCH 1/2] [clang] Move the aggreate deduction guide cache login to
 SemaTemplate.cpp, NFC.

this is a NFC refactoring change, which is needed for the upcoming fix
for alias templates.
---
 clang/lib/Sema/SemaInit.cpp     | 27 +++++----------------------
 clang/lib/Sema/SemaTemplate.cpp | 24 ++++++++++++++++++++++--
 2 files changed, 27 insertions(+), 24 deletions(-)

diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index aa470adb30b47f..763cc5dd17129c 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -10918,32 +10918,15 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
                   Context.getLValueReferenceType(ElementTypes[I].withConst());
           }
 
-        llvm::FoldingSetNodeID ID;
-        ID.AddPointer(Template);
-        for (auto &T : ElementTypes)
-          T.getCanonicalType().Profile(ID);
-        unsigned Hash = ID.ComputeHash();
-        if (AggregateDeductionCandidates.count(Hash) == 0) {
-          if (FunctionTemplateDecl *TD =
-                  DeclareImplicitDeductionGuideFromInitList(
-                      Template, ElementTypes,
-                      TSInfo->getTypeLoc().getEndLoc())) {
-            auto *GD = cast<CXXDeductionGuideDecl>(TD->getTemplatedDecl());
-            GD->setDeductionCandidateKind(DeductionCandidate::Aggregate);
-            AggregateDeductionCandidates[Hash] = GD;
-            addDeductionCandidate(TD, GD, DeclAccessPair::make(TD, AS_public),
-                                  OnlyListConstructors,
-                                  /*AllowAggregateDeductionCandidate=*/true);
-          }
-        } else {
-          CXXDeductionGuideDecl *GD = AggregateDeductionCandidates[Hash];
-          FunctionTemplateDecl *TD = GD->getDescribedFunctionTemplate();
-          assert(TD && "aggregate deduction candidate is function template");
+        if (FunctionTemplateDecl *TD =
+                DeclareImplicitDeductionGuideFromInitList(
+                    Template, ElementTypes, TSInfo->getTypeLoc().getEndLoc())) {
+          auto *GD = cast<CXXDeductionGuideDecl>(TD->getTemplatedDecl());
           addDeductionCandidate(TD, GD, DeclAccessPair::make(TD, AS_public),
                                 OnlyListConstructors,
                                 /*AllowAggregateDeductionCandidate=*/true);
+          HasAnyDeductionGuide = true;
         }
-        HasAnyDeductionGuide = true;
       }
     };
 
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 51e8db2dfbaac8..0a4c16f2f09f09 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -2987,6 +2987,23 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
 FunctionTemplateDecl *Sema::DeclareImplicitDeductionGuideFromInitList(
     TemplateDecl *Template, MutableArrayRef<QualType> ParamTypes,
     SourceLocation Loc) {
+  llvm::FoldingSetNodeID ID;
+  ID.AddPointer(Template);
+  for (auto &T : ParamTypes)
+    T.getCanonicalType().Profile(ID);
+  unsigned Hash = ID.ComputeHash();
+  
+  auto Found = AggregateDeductionCandidates.find(Hash);
+  if (Found != AggregateDeductionCandidates.end()) {
+    CXXDeductionGuideDecl *GD = Found->getSecond();
+    return GD->getDescribedFunctionTemplate();
+  }
+  
+  // if (auto *AliasTemplate = llvm::dyn_cast<TypeAliasTemplateDecl>(Template)) {
+  //   DeclareImplicitDeductionGuidesForTypeAlias(*this, AliasTemplate, Loc);
+  //   return;
+  // }
+
   if (CXXRecordDecl *DefRecord =
           cast<CXXRecordDecl>(Template->getTemplatedDecl())->getDefinition()) {
     if (TemplateDecl *DescribedTemplate =
@@ -3019,10 +3036,13 @@ FunctionTemplateDecl *Sema::DeclareImplicitDeductionGuideFromInitList(
       Transform.NestedPattern ? Transform.NestedPattern : Transform.Template;
   ContextRAII SavedContext(*this, Pattern->getTemplatedDecl());
 
-  auto *DG = cast<FunctionTemplateDecl>(
+  auto *FTD = cast<FunctionTemplateDecl>(
       Transform.buildSimpleDeductionGuide(ParamTypes));
   SavedContext.pop();
-  return DG;
+  auto *GD = cast<CXXDeductionGuideDecl>(FTD->getTemplatedDecl());
+  GD->setDeductionCandidateKind(DeductionCandidate::Aggregate);
+  AggregateDeductionCandidates[Hash] = GD;
+  return FTD;
 }
 
 void Sema::DeclareImplicitDeductionGuides(TemplateDecl *Template,

>From 9355ff3d75221659a33ca5e7571849ff3135353a Mon Sep 17 00:00:00 2001
From: Haojian Wu <hokein.wu at gmail.com>
Date: Tue, 19 Mar 2024 23:07:09 +0100
Subject: [PATCH 2/2] [clang] CTAD: build aggregate deduction guides for alias
 templates.

Fixes https://github.com/llvm/llvm-project/issues/85767.
---
 clang/lib/Sema/SemaInit.cpp                 |   2 +-
 clang/lib/Sema/SemaTemplate.cpp             | 183 +++++++++++++++-----
 clang/test/SemaTemplate/deduction-guide.cpp |  15 ++
 3 files changed, 153 insertions(+), 47 deletions(-)

diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 763cc5dd17129c..a2d5bad1300d1f 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -10920,7 +10920,7 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
 
         if (FunctionTemplateDecl *TD =
                 DeclareImplicitDeductionGuideFromInitList(
-                    Template, ElementTypes, TSInfo->getTypeLoc().getEndLoc())) {
+                    LookupTemplateDecl, ElementTypes, TSInfo->getTypeLoc().getEndLoc())) {
           auto *GD = cast<CXXDeductionGuideDecl>(TD->getTemplatedDecl());
           addDeductionCandidate(TD, GD, DeclAccessPair::make(TD, AS_public),
                                 OnlyListConstructors,
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 0a4c16f2f09f09..86d07dd3a2187a 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -2725,21 +2725,43 @@ bool hasDeclaredDeductionGuides(DeclarationName Name, DeclContext *DC) {
   return false;
 }
 
-// Build deduction guides for a type alias template.
-void DeclareImplicitDeductionGuidesForTypeAlias(
-    Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate, SourceLocation Loc) {
-  auto &Context = SemaRef.Context;
-  // FIXME: if there is an explicit deduction guide after the first use of the
-  // type alias usage, we will not cover this explicit deduction guide. fix this
-  // case.
-  if (hasDeclaredDeductionGuides(
-          Context.DeclarationNames.getCXXDeductionGuideName(AliasTemplate),
-          AliasTemplate->getDeclContext()))
-    return;
+NamedDecl *TransformTemplateParameter(Sema &SemaRef, DeclContext *DC,
+                                      NamedDecl *TemplateParam,
+                                      MultiLevelTemplateArgumentList &Args,
+                                      unsigned NewIndex) {
+  if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
+    return transformTemplateTypeParam(SemaRef, DC, TTP, Args, TTP->getDepth(),
+                                      NewIndex);
+  if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
+    return transformTemplateParam(SemaRef, DC, TTP, Args, NewIndex,
+                                  TTP->getDepth());
+  if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
+    return transformTemplateParam(SemaRef, DC, NTTP, Args, NewIndex,
+                                  NTTP->getDepth());
+  return nullptr;
+}
+
+Expr *TransformRequireClause(Sema &SemaRef, FunctionTemplateDecl *FTD,
+                             llvm::ArrayRef<TemplateArgument> TransformedArgs) {
+  Expr *RC =
+          FTD->getTemplateParameters()->getRequiresClause();
+  if (!RC)
+    return nullptr;
+  MultiLevelTemplateArgumentList Args;
+  Args.setKind(TemplateSubstitutionKind::Rewrite);
+  Args.addOuterTemplateArguments(TransformedArgs);
+  ExprResult E = SemaRef.SubstExpr(RC, Args);
+  if (E.isInvalid())
+    return nullptr;
+  return E.getAs<Expr>();
+}
+
+std::pair<TemplateDecl *, llvm::ArrayRef<TemplateArgument>>
+GetRHSTemplateDeclAndArgs(Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate) {
   // Unwrap the sugared ElaboratedType.
   auto RhsType = AliasTemplate->getTemplatedDecl()
                      ->getUnderlyingType()
-                     .getSingleStepDesugaredType(Context);
+                     .getSingleStepDesugaredType(SemaRef.Context);
   TemplateDecl *Template = nullptr;
   llvm::ArrayRef<TemplateArgument> AliasRhsTemplateArgs;
   if (const auto *TST = RhsType->getAs<TemplateSpecializationType>()) {
@@ -2760,6 +2782,22 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
   } else {
     assert(false && "unhandled RHS type of the alias");
   }
+  return {Template, AliasRhsTemplateArgs};
+}
+
+// Build deduction guides for a type alias template.
+void DeclareImplicitDeductionGuidesForTypeAlias(
+    Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate, SourceLocation Loc) {
+  auto &Context = SemaRef.Context;
+  // FIXME: if there is an explicit deduction guide after the first use of the
+  // type alias usage, we will not cover this explicit deduction guide. fix this
+  // case.
+  if (hasDeclaredDeductionGuides(
+          Context.DeclarationNames.getCXXDeductionGuideName(AliasTemplate),
+          AliasTemplate->getDeclContext()))
+    return;
+  auto [Template, AliasRhsTemplateArgs] =
+      GetRHSTemplateDeclAndArgs(SemaRef, AliasTemplate);
   if (!Template)
     return;
   DeclarationNameInfo NameInfo(
@@ -2854,21 +2892,6 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
     // parameters, used for building `TemplateArgsForBuildingFPrime`.
     SmallVector<TemplateArgument, 16> TransformedDeducedAliasArgs(
         AliasTemplate->getTemplateParameters()->size());
-    auto TransformTemplateParameter =
-        [&SemaRef](DeclContext *DC, NamedDecl *TemplateParam,
-                   MultiLevelTemplateArgumentList &Args,
-                   unsigned NewIndex) -> NamedDecl * {
-      if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
-        return transformTemplateTypeParam(SemaRef, DC, TTP, Args,
-                                          TTP->getDepth(), NewIndex);
-      if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
-        return transformTemplateParam(SemaRef, DC, TTP, Args, NewIndex,
-                                      TTP->getDepth());
-      if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
-        return transformTemplateParam(SemaRef, DC, NTTP, Args, NewIndex,
-                                      NTTP->getDepth());
-      return nullptr;
-    };
 
     for (unsigned AliasTemplateParamIdx : DeducedAliasTemplateParams) {
       auto *TP = AliasTemplate->getTemplateParameters()->getParam(
@@ -2878,9 +2901,9 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
       MultiLevelTemplateArgumentList Args;
       Args.setKind(TemplateSubstitutionKind::Rewrite);
       Args.addOuterTemplateArguments(TransformedDeducedAliasArgs);
-      NamedDecl *NewParam =
-          TransformTemplateParameter(AliasTemplate->getDeclContext(), TP, Args,
-                                     /*NewIndex*/ FPrimeTemplateParams.size());
+      NamedDecl *NewParam = TransformTemplateParameter(
+          SemaRef, AliasTemplate->getDeclContext(), TP, Args,
+          /*NewIndex*/ FPrimeTemplateParams.size());
       FPrimeTemplateParams.push_back(NewParam);
 
       auto NewTemplateArgument = Context.getCanonicalTemplateArgument(
@@ -2897,7 +2920,7 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
       // TemplateArgsForBuildingFPrime.
       Args.addOuterTemplateArguments(TemplateArgsForBuildingFPrime);
       NamedDecl *NewParam = TransformTemplateParameter(
-          F->getDeclContext(), TP, Args, FPrimeTemplateParams.size());
+          SemaRef, F->getDeclContext(), TP, Args, FPrimeTemplateParams.size());
       FPrimeTemplateParams.push_back(NewParam);
 
       assert(TemplateArgsForBuildingFPrime[FTemplateParamIdx].isNull() &&
@@ -2907,16 +2930,8 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
               Context.getInjectedTemplateArg(NewParam));
     }
     // Substitute new template parameters into requires-clause if present.
-    Expr *RequiresClause = nullptr;
-    if (Expr *InnerRC = F->getTemplateParameters()->getRequiresClause()) {
-      MultiLevelTemplateArgumentList Args;
-      Args.setKind(TemplateSubstitutionKind::Rewrite);
-      Args.addOuterTemplateArguments(TemplateArgsForBuildingFPrime);
-      ExprResult E = SemaRef.SubstExpr(InnerRC, Args);
-      if (E.isInvalid())
-        return;
-      RequiresClause = E.getAs<Expr>();
-    }
+    Expr *RequiresClause =
+        TransformRequireClause(SemaRef, F, TemplateArgsForBuildingFPrime);
     // FIXME: implement the is_deducible constraint per C++
     // [over.match.class.deduct]p3.3:
     //    ... and a constraint that is satisfied if and only if the arguments
@@ -2982,8 +2997,79 @@ void DeclareImplicitDeductionGuidesForTypeAlias(
   }
 }
 
+// Build an aggregate deduction guide for a type alias template.
+FunctionTemplateDecl *DeclareAggrecateDeductionGuideForTypeAlias(
+    Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate,
+    MutableArrayRef<QualType> ParamTypes, SourceLocation Loc) {
+  TemplateDecl *RHSTemplate =
+      GetRHSTemplateDeclAndArgs(SemaRef, AliasTemplate).first;
+  if (!RHSTemplate)
+    return nullptr;
+  auto *RHSDeductionGuide =
+      SemaRef.DeclareImplicitDeductionGuideFromInitList(RHSTemplate, ParamTypes,
+                                                        Loc);
+  if (!RHSDeductionGuide)
+    return nullptr;
+
+  LocalInstantiationScope Scope(SemaRef);
+
+  // Build a new template parameter list for the synthesized aggregate deduction
+  // guide by transforming the one from RHSDeductionGuide.
+  SmallVector<NamedDecl *> TransformedTemplateParams;
+  // Template args that refers to the rebuilt template parameters.
+  // All template arguments must be initialized in advance.
+  SmallVector<TemplateArgument> TransformedTemplateArgs(
+      RHSDeductionGuide->getTemplateParameters()->size());
+  for (auto *TP : *RHSDeductionGuide->getTemplateParameters()) {
+    // Rebuild any internal references to earlier parameters and reindex as
+    // we go.
+    MultiLevelTemplateArgumentList Args;
+    Args.setKind(TemplateSubstitutionKind::Rewrite);
+    Args.addOuterTemplateArguments(TransformedTemplateArgs);
+    NamedDecl *NewParam = TransformTemplateParameter(
+        SemaRef, AliasTemplate->getDeclContext(), TP, Args,
+        /*NewIndex=*/ TransformedTemplateParams.size());
+
+    TransformedTemplateArgs[TransformedTemplateParams.size()] =
+        SemaRef.Context.getCanonicalTemplateArgument(
+            SemaRef.Context.getInjectedTemplateArg(NewParam));
+    TransformedTemplateParams.push_back(NewParam);
+  }
+  // FIXME: implement the is_deducible constraint per C++
+  // [over.match.class.deduct]p3.3.
+  Expr *TransformedRequiresClause = TransformRequireClause(
+      SemaRef, RHSDeductionGuide, TransformedTemplateArgs);
+  auto *TransformedTemplateParameterList = TemplateParameterList::Create(
+      SemaRef.Context, AliasTemplate->getTemplateParameters()->getTemplateLoc(),
+      AliasTemplate->getTemplateParameters()->getLAngleLoc(),
+      TransformedTemplateParams,
+      AliasTemplate->getTemplateParameters()->getRAngleLoc(),
+      TransformedRequiresClause);
+  auto *TransformedTemplateArgList = TemplateArgumentList::CreateCopy(
+      SemaRef.Context, TransformedTemplateArgs);
+
+  if (auto *TransformedDeductionGuide = SemaRef.InstantiateFunctionDeclaration(
+          RHSDeductionGuide, TransformedTemplateArgList,
+          AliasTemplate->getLocation(),
+          Sema::CodeSynthesisContext::BuildingDeductionGuides)) {
+    auto *GD =
+        llvm::dyn_cast<clang::CXXDeductionGuideDecl>(TransformedDeductionGuide);
+    FunctionTemplateDecl *Result = buildDeductionGuide(
+        SemaRef, AliasTemplate, TransformedTemplateParameterList,
+        GD->getCorrespondingConstructor(), GD->getExplicitSpecifier(),
+        GD->getTypeSourceInfo(), AliasTemplate->getBeginLoc(),
+        AliasTemplate->getLocation(), AliasTemplate->getEndLoc(),
+        GD->isImplicit());
+    cast<CXXDeductionGuideDecl>(Result->getTemplatedDecl())
+        ->setDeductionCandidateKind(DeductionCandidate::Aggregate);
+    return Result;
+  }
+  return nullptr;
+}
+
 } // namespace
 
+// FIXME: rename to DeclareAggrecateDeductionGuide.
 FunctionTemplateDecl *Sema::DeclareImplicitDeductionGuideFromInitList(
     TemplateDecl *Template, MutableArrayRef<QualType> ParamTypes,
     SourceLocation Loc) {
@@ -2998,11 +3084,16 @@ FunctionTemplateDecl *Sema::DeclareImplicitDeductionGuideFromInitList(
     CXXDeductionGuideDecl *GD = Found->getSecond();
     return GD->getDescribedFunctionTemplate();
   }
-  
-  // if (auto *AliasTemplate = llvm::dyn_cast<TypeAliasTemplateDecl>(Template)) {
-  //   DeclareImplicitDeductionGuidesForTypeAlias(*this, AliasTemplate, Loc);
-  //   return;
-  // }
+
+  if (auto *AliasTemplate = llvm::dyn_cast<TypeAliasTemplateDecl>(Template)) {
+    if (auto *FTD = DeclareAggrecateDeductionGuideForTypeAlias(
+            *this, AliasTemplate, ParamTypes, Loc)) {
+      auto *GD = cast<CXXDeductionGuideDecl>(FTD->getTemplatedDecl());
+      GD->setDeductionCandidateKind(DeductionCandidate::Aggregate);
+      AggregateDeductionCandidates[Hash] = GD;
+      return FTD;
+    }
+  }
 
   if (CXXRecordDecl *DefRecord =
           cast<CXXRecordDecl>(Template->getTemplatedDecl())->getDefinition()) {
diff --git a/clang/test/SemaTemplate/deduction-guide.cpp b/clang/test/SemaTemplate/deduction-guide.cpp
index 16c7083df29d0c..bf1a6939203f4e 100644
--- a/clang/test/SemaTemplate/deduction-guide.cpp
+++ b/clang/test/SemaTemplate/deduction-guide.cpp
@@ -239,3 +239,18 @@ F s(0);
 // CHECK: |-InjectedClassNameType {{.*}} 'F<>' dependent
 // CHECK: | `-CXXRecord {{.*}} 'F'
 // CHECK: `-TemplateTypeParmType {{.*}} 'type-parameter-0-1' dependent depth 0 index 1
+
+template <typename T>
+struct G { T t1; };
+template<typename X>
+using AG = G<X>;
+
+AG ag = {1};
+// Verify that the aggregate deduction guide for alias templates is built.
+// CHECK-LABEL: Dumping <deduction guide for AG>
+// CHECK: FunctionTemplateDecl
+// CHECK: |-CXXDeductionGuideDecl {{.*}} 'auto (type-parameter-0-0) -> G<type-parameter-0-0>'
+// CHECK: `-CXXDeductionGuideDecl {{.*}} 'auto (int) -> G<int>' implicit_instantiation
+// CHECK:   |-TemplateArgument type 'int'
+// CHECK:   | `-BuiltinType {{.*}} 'int'
+// CHECK:   `-ParmVarDecl {{.*}} 'int'



More information about the cfe-commits mailing list