[clang] [clang] WIP: Implement CTAD for type alias template. (PR #77890)

Haojian Wu via cfe-commits cfe-commits at lists.llvm.org
Fri Jan 12 00:18:29 PST 2024


https://github.com/hokein created https://github.com/llvm/llvm-project/pull/77890

Fixes #54051

This is a preliminary,  WIP and messy implementation. While it is still missing many pieces (see FIXMEs), it works for simple cases.

CC @ilya-biryukov, @sam-mccall , @usx95 

>From ccf08bb5e209c98bddcb39b28256cd85cecd93cd Mon Sep 17 00:00:00 2001
From: Haojian Wu <hokein.wu at gmail.com>
Date: Wed, 26 Jul 2023 09:40:12 +0200
Subject: [PATCH] [clang] WIP: Implement CTAD for type alias template.

This is a preliminary,  WIP and messy implementation.

While it is still missing many pieces (see FIXMEs), it works for happy
cases.
---
 clang/include/clang/Sema/Sema.h               |  16 +-
 clang/lib/Sema/SemaInit.cpp                   | 513 +++++++++++++++++-
 clang/lib/Sema/SemaTemplateDeduction.cpp      |   9 +
 clang/lib/Sema/SemaTemplateInstantiate.cpp    |  26 +-
 .../lib/Sema/SemaTemplateInstantiateDecl.cpp  |  17 +-
 ...xx1z-class-template-argument-deduction.cpp |   6 +-
 clang/test/SemaCXX/cxx20-ctad-type-alias.cpp  |  32 ++
 7 files changed, 604 insertions(+), 15 deletions(-)
 create mode 100644 clang/test/SemaCXX/cxx20-ctad-type-alias.cpp

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 4c464a1ae4c67f..6a6a30fc52c782 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -9251,6 +9251,14 @@ class Sema final {
                           const TemplateArgumentList &TemplateArgs,
                           sema::TemplateDeductionInfo &Info);
 
+  TemplateDeductionResult
+  DeduceTemplateArguments(TemplateParameterList *TemplateParams,
+                          ArrayRef<TemplateArgument> Ps,
+                          ArrayRef<TemplateArgument> As,
+                          sema::TemplateDeductionInfo &Info,
+                          SmallVectorImpl<DeducedTemplateArgument> &Deduced,
+                          bool NumberOfArgumentsMustMatch);
+
   TemplateDeductionResult SubstituteExplicitTemplateArguments(
       FunctionTemplateDecl *FunctionTemplate,
       TemplateArgumentListInfo &ExplicitTemplateArgs,
@@ -10403,9 +10411,11 @@ class Sema final {
       SourceLocation PointOfInstantiation, FunctionDecl *Decl,
       ArrayRef<TemplateArgument> TemplateArgs,
       ConstraintSatisfaction &Satisfaction);
-  FunctionDecl *InstantiateFunctionDeclaration(FunctionTemplateDecl *FTD,
-                                               const TemplateArgumentList *Args,
-                                               SourceLocation Loc);
+  FunctionDecl *InstantiateFunctionDeclaration(
+      FunctionTemplateDecl *FTD, const TemplateArgumentList *Args,
+      SourceLocation Loc,
+      CodeSynthesisContext::SynthesisKind CSC =
+          CodeSynthesisContext::ExplicitTemplateArgumentSubstitution);
   void InstantiateFunctionDefinition(SourceLocation PointOfInstantiation,
                                      FunctionDecl *Function,
                                      bool Recursive = false,
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 408ee5f775804b..d247fa74d553e1 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -10,12 +10,18 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "TreeTransform.h"
+#include "TypeLocBuilder.h"
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclAccessPair.h"
+#include "clang/AST/DeclCXX.h"
 #include "clang/AST/DeclObjC.h"
+#include "clang/AST/DeclTemplate.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExprObjC.h"
 #include "clang/AST/ExprOpenMP.h"
 #include "clang/AST/IgnoreExpr.h"
+#include "clang/AST/Type.h"
 #include "clang/AST/TypeLoc.h"
 #include "clang/Basic/CharInfo.h"
 #include "clang/Basic/SourceManager.h"
@@ -27,6 +33,7 @@
 #include "clang/Sema/Lookup.h"
 #include "clang/Sema/Ownership.h"
 #include "clang/Sema/SemaInternal.h"
+#include "clang/Sema/Template.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/FoldingSet.h"
 #include "llvm/ADT/PointerIntPair.h"
@@ -10570,6 +10577,381 @@ static bool isOrIsDerivedFromSpecializationOf(CXXRecordDecl *RD,
   return !(NotSpecialization(RD) && RD->forallBases(NotSpecialization));
 }
 
+
+/// FIXME: this is a copy-paste from SemaTemplate.cpp
+/// Tree transform to "extract" a transformed type from a class template's
+/// constructor to a deduction guide.
+class ExtractTypeForDeductionGuide
+  : public TreeTransform<ExtractTypeForDeductionGuide> {
+  llvm::SmallVectorImpl<TypedefNameDecl *> &MaterializedTypedefs;
+
+public:
+  typedef TreeTransform<ExtractTypeForDeductionGuide> Base;
+  ExtractTypeForDeductionGuide(
+      Sema &SemaRef,
+      llvm::SmallVectorImpl<TypedefNameDecl *> &MaterializedTypedefs)
+      : Base(SemaRef), MaterializedTypedefs(MaterializedTypedefs) {}
+
+  TypeSourceInfo *transform(TypeSourceInfo *TSI) { return TransformType(TSI); }
+
+  QualType TransformTypedefType(TypeLocBuilder &TLB, TypedefTypeLoc TL) {
+    ASTContext &Context = SemaRef.getASTContext();
+    TypedefNameDecl *OrigDecl = TL.getTypedefNameDecl();
+    TypedefNameDecl *Decl = OrigDecl;
+    // Transform the underlying type of the typedef and clone the Decl only if
+    // the typedef has a dependent context.
+    if (OrigDecl->getDeclContext()->isDependentContext()) {
+      TypeLocBuilder InnerTLB;
+      QualType Transformed =
+          TransformType(InnerTLB, OrigDecl->getTypeSourceInfo()->getTypeLoc());
+      TypeSourceInfo *TSI = InnerTLB.getTypeSourceInfo(Context, Transformed);
+      if (isa<TypeAliasDecl>(OrigDecl))
+        Decl = TypeAliasDecl::Create(
+            Context, Context.getTranslationUnitDecl(), OrigDecl->getBeginLoc(),
+            OrigDecl->getLocation(), OrigDecl->getIdentifier(), TSI);
+      else {
+        assert(isa<TypedefDecl>(OrigDecl) && "Not a Type alias or typedef");
+        Decl = TypedefDecl::Create(
+            Context, Context.getTranslationUnitDecl(), OrigDecl->getBeginLoc(),
+            OrigDecl->getLocation(), OrigDecl->getIdentifier(), TSI);
+      }
+      MaterializedTypedefs.push_back(Decl);
+    }
+
+    QualType TDTy = Context.getTypedefType(Decl);
+    TypedefTypeLoc TypedefTL = TLB.push<TypedefTypeLoc>(TDTy);
+    TypedefTL.setNameLoc(TL.getNameLoc());
+
+    return TDTy;
+  }
+};
+
+// Transform to form a corresponding deduction guide for type alias template decl.
+//
+// This class implements the C++ [over.match.class.deduct]p3:
+//   ... Let g denote the result of substituting these deductions into f. If
+//   substitution succeeds, form a function or function template f' with the
+//   following properties and add it to the set of guides of A...
+//
+// FIXME: This is a messy copy of ConvertConstructorToDeductionGuideTransform
+// in SemaTemplate.cpp with some adjustments.
+struct AliasTemplateDeductionGuideTransform {
+  AliasTemplateDeductionGuideTransform(Sema &S, TypeAliasTemplateDecl *Alias)
+      : SemaRef(S), AliasTemplate(Alias) {
+    DC = AliasTemplate->getDeclContext();
+  }
+  Sema &SemaRef;
+  TypeAliasTemplateDecl *AliasTemplate = nullptr;
+
+  DeclContext *DC = nullptr;
+
+  // FIXME: is this really needed?
+  ClassTemplateDecl *NestedPattern = nullptr;
+  // Instantiation arguments for the outermost depth-1 templates
+  // when the template is nested
+  MultiLevelTemplateArgumentList OuterInstantiationArgs;
+
+  // Returns the result of substituting the deduced template arguments into f.
+  NamedDecl *transformUnderlyingFunctionTemplate(
+      CXXDeductionGuideDecl *UnderlyingCDGD,
+      SmallVector<TemplateArgument> DeducedArgs) {
+    SmallVector<TemplateTypeParmDecl *> DeducedTemplateTypeParDecls;
+    // Add all parameters that appear in the deductions.
+    // FIXME: add template parameters of f that were not deduced.
+    for (auto D : DeducedArgs) {
+      // FIXME: fix other template argument kind.
+      if (D.getKind() != TemplateArgument::Type)
+        continue;
+      if (auto *TT =
+              dyn_cast<TemplateTypeParmType>(D.getAsType().getTypePtr())) {
+        if (TemplateTypeParmDecl *DD = TT->getDecl()) {
+          DeducedTemplateTypeParDecls.push_back(DD);
+        }
+      }
+    }
+
+    // C++ [over.match.class.deduct]p3.2:
+    //   If f is a function template, f' is a function template whose template
+    //   parameter list consists of all the template parameters of A (including
+    //   their default template arguments) that appear in the above deductions
+    //   or (recursively) in their default template arguments, followed by the
+    //   template parameters of f that were not deduced (including their default
+    //   template arguments)...
+    LocalInstantiationScope Scope(SemaRef);
+    SmallVector<TemplateArgument, 16> Depth1Args;
+    SmallVector<NamedDecl *, 16> AllParams;
+    SmallVector<TemplateArgument, 16> SubstArgs;
+    int Index = 0;
+    TemplateParameterList *TemplateParams = nullptr;
+
+    for (TemplateTypeParmDecl* Param : DeducedTemplateTypeParDecls) {
+      MultiLevelTemplateArgumentList Args;
+
+      Args.setKind(TemplateSubstitutionKind::Rewrite);
+      Args.addOuterTemplateArguments(Depth1Args);
+      Args.addOuterRetainedLevel();
+      NamedDecl *NewParam = transformTemplateParameter(Param, Args, Index++);
+      if (!NewParam) {
+        llvm::errs() << "Faile to generate new param!\n";
+        return nullptr;
+      }
+      auto NewArgumentForNewParam =
+          SemaRef.Context.getCanonicalTemplateArgument(
+              SemaRef.Context.getInjectedTemplateArg(NewParam));
+      Depth1Args.push_back(NewArgumentForNewParam);
+      AllParams.push_back(NewParam);
+      SubstArgs.push_back(NewArgumentForNewParam);
+    }
+    // FIXME: substitute new template parameters into the requires-clause
+
+    TemplateParams = TemplateParameterList::Create(
+        SemaRef.Context, SourceLocation(), SourceLocation(), AllParams,
+        SourceLocation(), /*requiresClause */nullptr);
+
+    MultiLevelTemplateArgumentList Args;
+    Args.setKind(TemplateSubstitutionKind::Rewrite);
+    Args.addOuterTemplateArguments(SubstArgs);
+    Args.addOuterRetainedLevel();
+
+    FunctionProtoTypeLoc FPTL = UnderlyingCDGD->getTypeSourceInfo()
+                                    ->getTypeLoc()
+                                    .getAsAdjusted<FunctionProtoTypeLoc>();
+    assert(FPTL && "no prototype for underlying deduction guides");
+
+    // Transform the type of the function, adjusting the return type and
+    // replacing references to the old parameters with references to the
+    // new ones.
+    TypeLocBuilder TLB;
+    SmallVector<ParmVarDecl*, 8> Params;
+    SmallVector<TypedefNameDecl *, 4> MaterializedTypedefs;
+    QualType NewType = transformFunctionProtoType(
+        TLB, FPTL, Params, Args, UnderlyingCDGD->getReturnType(),
+        MaterializedTypedefs);
+    if (NewType.isNull())
+      return nullptr;
+    TypeSourceInfo *NewTInfo = TLB.getTypeSourceInfo(SemaRef.Context, NewType);
+
+    return buildDeductionGuide(
+        TemplateParams, UnderlyingCDGD->getCorrespondingConstructor(),
+        UnderlyingCDGD->getExplicitSpecifier(), NewTInfo,
+        UnderlyingCDGD->getBeginLoc(), UnderlyingCDGD->getLocation(),
+        UnderlyingCDGD->getEndLoc(), UnderlyingCDGD->isImplicit(),
+        MaterializedTypedefs);
+  }
+
+private:
+  /// Transform a template parameter of underlying deduction guide into
+  /// a deduction guide template parameter, rebuilding any internal references
+  /// to earlier parameters and renumbering as we go.
+  NamedDecl *transformTemplateParameter(NamedDecl *TemplateParam,
+                                        MultiLevelTemplateArgumentList &Args,
+                                        int Index ) {
+    if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam)) {
+      // TemplateTypeParmDecl's index cannot be changed after creation, so
+      // substitute it directly.
+      auto *NewTTP = TemplateTypeParmDecl::Create(
+          SemaRef.Context, DC, TTP->getBeginLoc(), TTP->getLocation(),
+          // FIXME: is the depth/index right?
+          TTP->getDepth(),  Index,
+          TTP->getIdentifier(), TTP->wasDeclaredWithTypename(),
+          TTP->isParameterPack(), TTP->hasTypeConstraint(),
+          TTP->isExpandedParameterPack()
+              ? std::optional<unsigned>(TTP->getNumExpansionParameters())
+              : std::nullopt);
+      if (const auto *TC = TTP->getTypeConstraint())
+        SemaRef.SubstTypeConstraint(NewTTP, TC, Args,
+                                    /*EvaluateConstraint*/ true);
+      if (TTP->hasDefaultArgument()) {
+        TypeSourceInfo *InstantiatedDefaultArg =
+            SemaRef.SubstType(TTP->getDefaultArgumentInfo(), Args,
+                              TTP->getDefaultArgumentLoc(), TTP->getDeclName());
+        if (InstantiatedDefaultArg)
+          NewTTP->setDefaultArgument(InstantiatedDefaultArg);
+      }
+      SemaRef.CurrentInstantiationScope->InstantiatedLocal(TemplateParam,
+                                                           NewTTP);
+      return NewTTP;
+    }
+
+    if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
+      return transformTemplateParameterImpl(TTP, Args);
+
+    return transformTemplateParameterImpl(
+        cast<NonTypeTemplateParmDecl>(TemplateParam), Args);
+  }
+  template<typename TemplateParmDecl>
+  TemplateParmDecl *
+  transformTemplateParameterImpl(TemplateParmDecl *OldParam,
+                                 MultiLevelTemplateArgumentList &Args) {
+    // Ask the template instantiator to do the heavy lifting for us, then adjust
+    // the index of the parameter once it's done.
+    auto *NewParam =
+        cast<TemplateParmDecl>(SemaRef.SubstDecl(OldParam, DC, Args));
+    // assert(NewParam->getDepth() == OldParam->getDepth() - 1 &&
+    //        "unexpected template param depth");
+    // FIXME
+    // NewParam->setPosition(NewParam->getPosition() + Depth1IndexAdjustment);
+    return NewParam;
+  }
+
+
+  QualType transformFunctionProtoType(
+      TypeLocBuilder &TLB, FunctionProtoTypeLoc TL,
+      SmallVectorImpl<ParmVarDecl *> &Params,
+      MultiLevelTemplateArgumentList &Args,
+      QualType ReturnType,
+      SmallVectorImpl<TypedefNameDecl *> &MaterializedTypedefs) {
+    SmallVector<QualType, 4> ParamTypes;
+    const FunctionProtoType *T = TL.getTypePtr();
+
+    //    -- The types of the function parameters are those of the constructor.
+    for (auto *OldParam : TL.getParams()) {
+      ParmVarDecl *NewParam =
+          transformFunctionTypeParam(OldParam, Args, MaterializedTypedefs);
+      if (NestedPattern && NewParam)
+        NewParam = transformFunctionTypeParam(NewParam, OuterInstantiationArgs,
+                                              MaterializedTypedefs);
+      if (!NewParam)
+        return QualType();
+      ParamTypes.push_back(NewParam->getType());
+      Params.push_back(NewParam);
+    }
+
+    // The return type of the deduction guide f is InjectedClassNameType, transform
+    // it to a TemplateSpecializationType.
+    if (const auto *ET = ReturnType->getAs<InjectedClassNameType>()) {
+      ReturnType = ET->getInjectedSpecializationType();
+    }
+    auto DeductionGuideName =
+        SemaRef.Context.DeclarationNames.getCXXDeductionGuideName(
+            AliasTemplate);
+    ReturnType = SemaRef.SubstType(ReturnType, Args, SourceLocation(),
+                                   DeductionGuideName);
+
+    // Resolving a wording defect, we also inherit the variadicness of the
+    // constructor.
+    FunctionProtoType::ExtProtoInfo EPI;
+    EPI.Variadic = T->isVariadic();
+    EPI.HasTrailingReturn = true;
+
+    QualType FunctionTy = SemaRef.BuildFunctionType(
+        ReturnType, ParamTypes, TL.getBeginLoc(), DeductionGuideName, EPI);
+    if (FunctionTy.isNull())
+      return QualType();
+    assert(FunctionTy->getTypeClass() == Type::FunctionProto);
+    // Pushes spaces for the new FunctionProtoTypeLoc.
+    TLB.pushTrivial(SemaRef.Context,
+                    TypeLoc(FunctionTy, nullptr).getNextTypeLoc().getType(),
+                    SourceLocation());
+    FunctionProtoTypeLoc TargetTL = TLB.push<FunctionProtoTypeLoc>(FunctionTy);
+    TargetTL.setLocalRangeBegin(TL.getLocalRangeBegin());
+    TargetTL.setLParenLoc(TL.getLParenLoc());
+    TargetTL.setRParenLoc(TL.getRParenLoc());
+    TargetTL.setExceptionSpecRange(SourceRange());
+    TargetTL.setLocalRangeEnd(TL.getLocalRangeEnd());
+    for (unsigned I = 0, E = TargetTL.getNumParams(); I != E; ++I)
+      TargetTL.setParam(I, Params[I]);
+    return FunctionTy;
+  }
+
+  ParmVarDecl *transformFunctionTypeParam(
+      ParmVarDecl *OldParam, MultiLevelTemplateArgumentList &Args,
+      llvm::SmallVectorImpl<TypedefNameDecl *> &MaterializedTypedefs) {
+    TypeSourceInfo *OldDI = OldParam->getTypeSourceInfo();
+    TypeSourceInfo *NewDI;
+    if (auto PackTL = OldDI->getTypeLoc().getAs<PackExpansionTypeLoc>()) {
+      // Expand out the one and only element in each inner pack.
+      Sema::ArgumentPackSubstitutionIndexRAII SubstIndex(SemaRef, 0);
+      NewDI =
+          SemaRef.SubstType(PackTL.getPatternLoc(), Args,
+                            OldParam->getLocation(), OldParam->getDeclName());
+      if (!NewDI)
+        return nullptr;
+      NewDI =
+          SemaRef.CheckPackExpansion(NewDI, PackTL.getEllipsisLoc(),
+                                     PackTL.getTypePtr()->getNumExpansions());
+    } else
+      NewDI = SemaRef.SubstType(OldDI, Args, OldParam->getLocation(),
+                                OldParam->getDeclName());
+    if (!NewDI)
+      return nullptr;
+
+    // Extract the type. This (for instance) replaces references to typedef
+    // members of the current instantiations with the definitions of those
+    // typedefs, avoiding triggering instantiation of the deduced type during
+    // deduction.
+    NewDI = ExtractTypeForDeductionGuide(SemaRef, MaterializedTypedefs)
+                .transform(NewDI);
+
+    // Resolving a wording defect, we also inherit default arguments from the
+    // constructor.
+    ExprResult NewDefArg;
+    if (OldParam->hasDefaultArg()) {
+      // We don't care what the value is (we won't use it); just create a
+      // placeholder to indicate there is a default argument.
+      QualType ParamTy = NewDI->getType();
+      NewDefArg = new (SemaRef.Context)
+          OpaqueValueExpr(OldParam->getDefaultArg()->getBeginLoc(),
+                          ParamTy.getNonLValueExprType(SemaRef.Context),
+                          ParamTy->isLValueReferenceType()   ? VK_LValue
+                          : ParamTy->isRValueReferenceType() ? VK_XValue
+                                                             : VK_PRValue);
+    }
+
+    ParmVarDecl *NewParam = ParmVarDecl::Create(SemaRef.Context, DC,
+                                                OldParam->getInnerLocStart(),
+                                                OldParam->getLocation(),
+                                                OldParam->getIdentifier(),
+                                                NewDI->getType(),
+                                                NewDI,
+                                                OldParam->getStorageClass(),
+                                                NewDefArg.get());
+    NewParam->setScopeInfo(OldParam->getFunctionScopeDepth(),
+                           OldParam->getFunctionScopeIndex());
+    SemaRef.CurrentInstantiationScope->InstantiatedLocal(OldParam, NewParam);
+    return NewParam;
+  }
+
+  FunctionTemplateDecl *buildDeductionGuide(
+      TemplateParameterList *TemplateParams, CXXConstructorDecl *Ctor,
+      ExplicitSpecifier ES, TypeSourceInfo *TInfo, SourceLocation LocStart,
+      SourceLocation Loc, SourceLocation LocEnd, bool IsImplicit,
+      llvm::ArrayRef<TypedefNameDecl *> MaterializedTypedefs = {}) {
+    auto DeductionGuideName =
+        SemaRef.Context.DeclarationNames.getCXXDeductionGuideName(
+            AliasTemplate);
+
+    DeclarationNameInfo Name(DeductionGuideName, Loc);
+    ArrayRef<ParmVarDecl *> Params =
+        TInfo->getTypeLoc().castAs<FunctionProtoTypeLoc>().getParams();
+
+    // Build the implicit deduction guide template.
+    auto *Guide =
+        CXXDeductionGuideDecl::Create(SemaRef.Context, DC, LocStart, ES, Name,
+                                      TInfo->getType(), TInfo, LocEnd, Ctor);
+    Guide->setImplicit(IsImplicit);
+    Guide->setParams(Params);
+
+    for (auto *Param : Params)
+      Param->setDeclContext(Guide);
+    for (auto *TD : MaterializedTypedefs)
+      TD->setDeclContext(Guide);
+
+    auto *GuideTemplate = FunctionTemplateDecl::Create(
+        SemaRef.Context, DC, Loc, DeductionGuideName, TemplateParams, Guide);
+    GuideTemplate->setImplicit(IsImplicit);
+    Guide->setDescribedFunctionTemplate(GuideTemplate);
+
+    if (isa<CXXRecordDecl>(DC)) {
+      Guide->setAccess(AS_public);
+      GuideTemplate->setAccess(AS_public);
+    }
+
+    DC->addDecl(GuideTemplate);
+    return GuideTemplate;
+  }
+};
+
 QualType Sema::DeduceTemplateSpecializationFromInitializer(
     TypeSourceInfo *TSInfo, const InitializedEntity &Entity,
     const InitializationKind &Kind, MultiExprArg Inits) {
@@ -10584,6 +10966,21 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
   // We can only perform deduction for class templates.
   auto *Template =
       dyn_cast_or_null<ClassTemplateDecl>(TemplateName.getAsTemplateDecl());
+
+  TypeAliasTemplateDecl* AliasTemplate = nullptr;
+  if (!Template) {
+    if ((AliasTemplate = dyn_cast_or_null<TypeAliasTemplateDecl>(
+             TemplateName.getAsTemplateDecl()))) {
+      auto UnderlyingType = AliasTemplate->getTemplatedDecl()
+                                ->getUnderlyingType()
+                                .getDesugaredType(Context);
+      if (const auto *TST =
+              UnderlyingType->getAs<TemplateSpecializationType>()) {
+        Template = dyn_cast_or_null<ClassTemplateDecl>(
+            TST->getTemplateName().getAsTemplateDecl());
+      }
+    }
+  }
   if (!Template) {
     Diag(Kind.getLocation(),
          diag::err_deduced_non_class_template_specialization_type)
@@ -10623,6 +11020,112 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
   // clear on this, but they're not found by name so access does not apply.
   Guides.suppressDiagnostics();
 
+  SmallVector<DeclAccessPair> GuidesCandidates;
+  if (AliasTemplate) {
+    for (auto* Guide : Guides) {
+      if (!dyn_cast_or_null<FunctionTemplateDecl>(Guide))
+        continue;
+      auto RType = dyn_cast<FunctionTemplateDecl>(Guide)
+                       ->getTemplatedDecl()
+                       ->getReturnType();
+      // The (trailing) return type of the deduction guide.
+      const TemplateSpecializationType * ReturnTST = nullptr;
+      if (const auto *InjectedCNT = RType->getAs<InjectedClassNameType>()) {
+        // for implicitly-generated deduction guide.
+        ReturnTST = InjectedCNT->getInjectedTST();
+      } else if (const auto *ET = RType->getAs<ElaboratedType>()) {
+        // For explicit deduction guide.
+        ReturnTST = ET->getNamedType()->getAs<TemplateSpecializationType>();
+      }
+      assert(ReturnTST);
+      if (ReturnTST) {
+          SmallVector<DeducedTemplateArgument> DeduceResults;
+          SmallVector<TemplateArgument> DeducedArgs;
+          DeduceResults.resize(ReturnTST->template_arguments().size());
+          sema::TemplateDeductionInfo TDeduceInfo({});
+          // Deduce template arguments of the deduction guide from the RHS of
+          // the alias.
+          //
+          // C++ [over.match.class.deduct]p3: ...For each function or function
+          // template f in the guides of the template named by the
+          // simple-template-id of the defining-type-id, the template arguments
+          // of the return type of f are deduced from the defining-type-id of A
+          // according to the process in [temp.deduct.type] with the exception
+          // that deduction does not fail if not all template arguments are
+          // deduced.
+          //
+          //
+          //  template<typename X, typename Y>
+          //  f(X, Y) -> f<Y, X>
+          //
+          //  template<typename U>
+          //  using alias = f<int, U>;
+          //
+          // The RHS of alias is f<int, U>, we deduced the template arguments of
+          // the return type of the deduction guide from it: Y->int, X -> U
+          const auto* AliasRhsTST = AliasTemplate->getTemplatedDecl()
+                         ->getUnderlyingType()
+                         .getDesugaredType(this->Context)
+                         ->getAs<TemplateSpecializationType>();
+          assert(AliasRhsTST);
+
+          if (DeduceTemplateArguments(AliasTemplate->getTemplateParameters(),
+                                      ReturnTST->template_arguments(),
+                                      AliasRhsTST->template_arguments(),
+                                      TDeduceInfo, DeduceResults,
+                                      /*NumberOfArgumentsMustMatch*/ false)) {
+            // FIXME: not all template arguments are deduced, we should continue
+            // to proceed with all deduced results.
+          } else {
+            // Happy case, all template arguments are deduced.
+            for (auto D : DeduceResults)
+              DeducedArgs.push_back(D);
+
+            auto *DeducedArgList =
+                TemplateArgumentList::CreateCopy(this->Context, DeducedArgs);
+
+            AliasTemplateDeductionGuideTransform Transform(
+                *this, AliasTemplate);
+
+            // Substitute all above deduced template arguments into the
+            // deduction guide f.
+            //
+            // FIXME: is using the InstantiateFunctionDeclaration API a right
+            // implement choice? It has some side effects which creates a
+            // specialization for the deduction guide function template, and
+            // the specialization is added to the the FunctionTemplateDecl, this
+            // is not specified by C++ standard.
+            //
+            // FIXME: Should we cache the result?
+            if (auto *K = InstantiateFunctionDeclaration(
+                    dyn_cast<FunctionTemplateDecl>(Guide), DeducedArgList,
+                    SourceLocation(),
+                    Sema::CodeSynthesisContext::BuildingDeductionGuides)) {
+              InstantiatingTemplate BuildingDeductionGuides(
+                  *this, SourceLocation(), AliasTemplate,
+                  Sema::InstantiatingTemplate::BuildingDeductionGuidesTag{});
+              auto *D = Transform.transformUnderlyingFunctionTemplate(
+                  dyn_cast<CXXDeductionGuideDecl>(K), DeducedArgs);
+              // FIXME: implement the assoicated constraint per C++
+              // [over.match.class.deduct]p3.3:
+              //    The associated constraints ([temp.constr.decl]) are the
+              //    conjunction of the associated constraints of g and a
+              //    constraint that is satisfied if and only if the arguments of
+              //    A are deducible (see below) from the return type.
+              // This could be implemented as part of function overload
+              // resolution below.
+              GuidesCandidates.push_back(
+                  DeclAccessPair::make(D, AccessSpecifier::AS_public));
+            }
+          }
+      }
+    }
+  }
+  else {
+    for (auto I = Guides.begin(), E = Guides.end(); I != E; ++I) {
+      GuidesCandidates.push_back(I.getPair());
+    }
+  }
   // Figure out if this is list-initialization.
   InitListExpr *ListInit =
       (Inits.size() == 1 && Kind.getKind() != InitializationKind::IK_Direct)
@@ -10773,9 +11276,8 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
         HasAnyDeductionGuide = true;
       }
     };
-
-    for (auto I = Guides.begin(), E = Guides.end(); I != E; ++I) {
-      NamedDecl *D = (*I)->getUnderlyingDecl();
+    for (auto I : GuidesCandidates) {
+      NamedDecl *D = (I)->getUnderlyingDecl();
       if (D->isInvalidDecl())
         continue;
 
@@ -10788,7 +11290,7 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
       if (!GD->isImplicit())
         HasAnyDeductionGuide = true;
 
-      addDeductionCandidate(TD, GD, I.getPair(), OnlyListConstructors,
+      addDeductionCandidate(TD, GD, I, OnlyListConstructors,
                             /*AllowAggregateDeductionCandidate=*/false);
     }
 
@@ -10828,7 +11330,8 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
     // Try list constructors unless the list is empty and the class has one or
     // more default constructors, in which case those constructors win.
     if (!ListInit->getNumInits()) {
-      for (NamedDecl *D : Guides) {
+      for (auto D : GuidesCandidates) {
+
         auto *FD = dyn_cast<FunctionDecl>(D->getUnderlyingDecl());
         if (FD && FD->getMinRequiredArguments() == 0) {
           TryListConstructors = false;
diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp
index 015b0abaf0e5ee..aa9941a955eba3 100644
--- a/clang/lib/Sema/SemaTemplateDeduction.cpp
+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -2456,6 +2456,15 @@ DeduceTemplateArguments(Sema &S, TemplateParameterList *TemplateParams,
   return Sema::TDK_Success;
 }
 
+Sema::TemplateDeductionResult Sema::DeduceTemplateArguments(
+    TemplateParameterList *TemplateParams, ArrayRef<TemplateArgument> Ps,
+    ArrayRef<TemplateArgument> As, sema::TemplateDeductionInfo &Info,
+    SmallVectorImpl<DeducedTemplateArgument> &Deduced,
+    bool NumberOfArgumentsMustMatch) {
+  return ::DeduceTemplateArguments(*this, TemplateParams, Ps, As, Info, Deduced,
+                                   NumberOfArgumentsMustMatch);
+}
+
 static Sema::TemplateDeductionResult
 DeduceTemplateArguments(Sema &S, TemplateParameterList *TemplateParams,
                         const TemplateArgumentList &ParamList,
diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index 7f20413c104e97..2f548b1a9ff167 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -21,6 +21,7 @@
 #include "clang/AST/ExprConcepts.h"
 #include "clang/AST/PrettyDeclStackTrace.h"
 #include "clang/AST/Type.h"
+#include "clang/AST/TypeLoc.h"
 #include "clang/AST/TypeVisitor.h"
 #include "clang/Basic/LangOptions.h"
 #include "clang/Basic/Stack.h"
@@ -518,7 +519,8 @@ Sema::InstantiatingTemplate::InstantiatingTemplate(
                             TemplateArgs, &DeductionInfo) {
   assert(
     Kind == CodeSynthesisContext::ExplicitTemplateArgumentSubstitution ||
-    Kind == CodeSynthesisContext::DeducedTemplateArgumentSubstitution);
+    Kind == CodeSynthesisContext::DeducedTemplateArgumentSubstitution ||
+    Kind == CodeSynthesisContext::BuildingDeductionGuides);
 }
 
 Sema::InstantiatingTemplate::InstantiatingTemplate(
@@ -1415,6 +1417,28 @@ namespace {
       return inherited::TransformFunctionProtoType(TLB, TL);
     }
 
+    QualType TransformInjectedClassNameType(TypeLocBuilder &TLB,
+                                            InjectedClassNameTypeLoc TL) {
+      // Return a TemplateSpecializationType for building deduction guides
+      Decl *D = TransformDecl(TL.getNameLoc(),
+                              TL.getTypePtr()->getDecl());
+      if (!D) {
+        if (SemaRef.CodeSynthesisContexts.back().Kind !=
+            Sema::CodeSynthesisContext::BuildingDeductionGuides)
+          return QualType();
+        auto *ICT = TL.getType()->getAs<InjectedClassNameType>();
+        auto TST = SemaRef.Context.getTemplateSpecializationType(
+            ICT->getTemplateName(), TemplateArgs.getOutermost());
+        TLB.pushTrivial(SemaRef.Context, TST, {});
+        return TST;
+      }
+
+      // Default implementation.
+      QualType T = SemaRef.Context.getTypeDeclType(cast<TypeDecl>(D));
+      TLB.pushTypeSpec(T).setNameLoc(TL.getNameLoc());
+      return T;
+    }
+
     template<typename Fn>
     QualType TransformFunctionProtoType(TypeLocBuilder &TLB,
                                         FunctionProtoTypeLoc TL,
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index d768bb72e07c09..71c014b65028bf 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -4852,13 +4852,13 @@ bool TemplateDeclInstantiator::SubstDefaultedFunction(FunctionDecl *New,
 FunctionDecl *
 Sema::InstantiateFunctionDeclaration(FunctionTemplateDecl *FTD,
                                      const TemplateArgumentList *Args,
-                                     SourceLocation Loc) {
+                                     SourceLocation Loc, CodeSynthesisContext::SynthesisKind CSC) {
   FunctionDecl *FD = FTD->getTemplatedDecl();
 
   sema::TemplateDeductionInfo Info(Loc);
   InstantiatingTemplate Inst(
       *this, Loc, FTD, Args->asArray(),
-      CodeSynthesisContext::ExplicitTemplateArgumentSubstitution, Info);
+      CSC, Info);
   if (Inst.isInvalid())
     return nullptr;
 
@@ -6279,7 +6279,18 @@ NamedDecl *Sema::FindInstantiatedDecl(SourceLocation Loc, NamedDecl *D,
           QualType T = CheckTemplateIdType(TemplateName(TD), Loc, Args);
           if (T.isNull())
             return nullptr;
-          auto *SubstRecord = T->getAsCXXRecordDecl();
+          CXXRecordDecl *SubstRecord = T->getAsCXXRecordDecl();
+
+          if (!SubstRecord) {
+            // FIXME: we encounter a new type here as we use the
+            // `InstantiateFunctionDeclaration` API to substitute the deduced
+            // template arguments into deduction guide.
+            if (auto TST = T->getAs<TemplateSpecializationType>())
+              // Return a nullptr as a sentinel value, we handle it properly in
+              // the TemplateInstantiator::TransformInjectedClassNameType
+              // override.
+              return nullptr;
+          }
           assert(SubstRecord && "class template id not a class type?");
           // Check that this template-id names the primary template and not a
           // partial or explicit specialization. (In the latter cases, it's
diff --git a/clang/test/SemaCXX/cxx1z-class-template-argument-deduction.cpp b/clang/test/SemaCXX/cxx1z-class-template-argument-deduction.cpp
index a490d318f54b58..2da89718cec221 100644
--- a/clang/test/SemaCXX/cxx1z-class-template-argument-deduction.cpp
+++ b/clang/test/SemaCXX/cxx1z-class-template-argument-deduction.cpp
@@ -101,13 +101,13 @@ namespace dependent {
   struct B {
     template<typename T> struct X { X(T); };
     X(int) -> X<int>;
-    template<typename T> using Y = X<T>; // expected-note {{template}}
+    template<typename T> using Y = X<T>;
   };
   template<typename T> void f() {
     typename T::X tx = 0;
-    typename T::Y ty = 0; // expected-error {{alias template 'Y' requires template arguments; argument deduction only allowed for class templates}}
+    typename T::Y ty = 0;
   }
-  template void f<B>(); // expected-note {{in instantiation of}}
+  template void f<B>();
 
   template<typename T> struct C { C(T); };
   template<typename T> C(T) -> C<T>;
diff --git a/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp b/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp
new file mode 100644
index 00000000000000..6a78e7aa082954
--- /dev/null
+++ b/clang/test/SemaCXX/cxx20-ctad-type-alias.cpp
@@ -0,0 +1,32 @@
+// RUN: %clang_cc1 -fsyntax-only -Wno-c++11-narrowing -Wno-literal-conversion -std=c++20 -verify %s
+// expected-no-diagnostics
+
+template<typename T>
+struct Foo {
+  T t;
+};
+
+template<typename U>
+using Bar = Foo<U>;
+
+void test1() {
+  Bar s = {1};
+}
+
+template<typename X, typename Y>
+struct XYpair {
+  X x;
+  Y y;
+};
+// A tricky explicit deduction guide that swapping X and Y.
+template<typename X, typename Y>
+XYpair(X, Y) -> XYpair<Y, X>;
+template<typename U, typename V>
+using AliasXYpair = XYpair<U, V>;
+
+void test2() {
+  AliasXYpair xy = {1.1, 2}; // XYpair<int, double>
+
+  static_assert(__is_same(decltype(xy.x), int));
+  static_assert(__is_same(decltype(xy.y), double));
+}



More information about the cfe-commits mailing list