[clang] [clang] Implement CTAD for type alias template. (PR #77890)

Haojian Wu via cfe-commits cfe-commits at lists.llvm.org
Wed Mar 6 02:28:02 PST 2024


================
@@ -2612,44 +2669,312 @@ struct ConvertConstructorToDeductionGuideTransform {
     SemaRef.CurrentInstantiationScope->InstantiatedLocal(OldParam, NewParam);
     return NewParam;
   }
+};
 
-  FunctionTemplateDecl *buildDeductionGuide(
-      TemplateParameterList *TemplateParams, CXXConstructorDecl *Ctor,
-      ExplicitSpecifier ES, TypeSourceInfo *TInfo, SourceLocation LocStart,
-      SourceLocation Loc, SourceLocation LocEnd,
-      llvm::ArrayRef<TypedefNameDecl *> MaterializedTypedefs = {}) {
-    DeclarationNameInfo Name(DeductionGuideName, Loc);
-    ArrayRef<ParmVarDecl *> Params =
-        TInfo->getTypeLoc().castAs<FunctionProtoTypeLoc>().getParams();
+// Find all template parameters that appear in the given DeducedArgs.
+// Return the indices of the template parameters in the TemplateParams.
+SmallVector<unsigned> TemplateParamsReferencedInTemplateArgumentList(
+    ArrayRef<NamedDecl *> TemplateParams,
+    ArrayRef<TemplateArgument> DeducedArgs) {
+  struct FindAppearedTemplateParams
+      : public RecursiveASTVisitor<FindAppearedTemplateParams> {
+    llvm::DenseSet<NamedDecl *> TemplateParams;
+    llvm::DenseSet<const NamedDecl *> AppearedTemplateParams;
+
+    FindAppearedTemplateParams(ArrayRef<NamedDecl *> TemplateParams)
+        : TemplateParams(TemplateParams.begin(), TemplateParams.end()) {}
+
+    bool VisitTemplateTypeParmType(TemplateTypeParmType *TTP) {
+      TTP->getIndex();
+      MarkAppeared(TTP->getDecl());
+      return true;
+    }
+    bool VisitDeclRefExpr(DeclRefExpr *DRE) {
+      MarkAppeared(DRE->getFoundDecl());
+      return true;
+    }
 
-    // Build the implicit deduction guide template.
-    auto *Guide =
-        CXXDeductionGuideDecl::Create(SemaRef.Context, DC, LocStart, ES, Name,
-                                      TInfo->getType(), TInfo, LocEnd, Ctor);
-    Guide->setImplicit();
-    Guide->setParams(Params);
+    void MarkAppeared(NamedDecl *ND) {
+      if (TemplateParams.contains(ND))
+        AppearedTemplateParams.insert(ND);
+    }
+  };
+  FindAppearedTemplateParams MarkAppeared(TemplateParams);
+  MarkAppeared.TraverseTemplateArguments(DeducedArgs);
 
-    for (auto *Param : Params)
-      Param->setDeclContext(Guide);
-    for (auto *TD : MaterializedTypedefs)
-      TD->setDeclContext(Guide);
+  SmallVector<unsigned> Results;
+  for (unsigned Index = 0; Index < TemplateParams.size(); ++Index) {
+    if (MarkAppeared.AppearedTemplateParams.contains(
+            TemplateParams[Index]))
+      Results.push_back(Index);
+  }
+  return Results;
+}
+
+bool hasDeclaredDeductionGuides(DeclarationName Name, DeclContext *DC) {
+  // Check whether we've already declared deduction guides for this template.
+  // FIXME: Consider storing a flag on the template to indicate this.
+  assert(Name.getNameKind() ==
+             DeclarationName::NameKind::CXXDeductionGuideName &&
+         "name must be a deduction guide name");
+  auto Existing = DC->lookup(Name);
+  for (auto *D : Existing)
+    if (D->isImplicit())
+      return true;
+  return false;
+}
+
+// Build deduction guides for a type alias template.
+void DeclareImplicitDeductionGuidesForTypeAlias(
+    Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate, SourceLocation Loc) {
+  auto &Context = SemaRef.Context;
+  // FIXME: if there is an explicit deduction guide after the first use of the
+  // type alias usage, we will not cover this explicit deduction guide. fix this
+  // case.
+  if (hasDeclaredDeductionGuides(
+          Context.DeclarationNames.getCXXDeductionGuideName(AliasTemplate),
+          AliasTemplate->getDeclContext()))
+    return;
+  // Unwrap the sugared ElaboratedType.
+  auto RhsType = AliasTemplate->getTemplatedDecl()
+                     ->getUnderlyingType()
+                     .getSingleStepDesugaredType(Context);
+  TemplateDecl *Template = nullptr;
+  llvm::ArrayRef<TemplateArgument> AliasRhsTemplateArgs;
+  if (const auto *TST = RhsType->getAs<TemplateSpecializationType>()) {
+    // Cases where the RHS of the alias is dependent. e.g.
+    //   template<typename T>
+    //   using AliasFoo1 = Foo<T>; // a class/type alias template specialization
+    Template = TST->getTemplateName().getAsTemplateDecl();
+    AliasRhsTemplateArgs = TST->template_arguments();
+  } else if (const auto *RT = RhsType->getAs<RecordType>()) {
+    // Cases where template arguments in the RHS of the alias are not
+    // dependent. e.g.
+    //   using AliasFoo = Foo<bool>;
+    if (const auto *CTSD = llvm::dyn_cast<ClassTemplateSpecializationDecl>(
+            RT->getAsCXXRecordDecl())) {
+      Template = CTSD->getSpecializedTemplate();
+      AliasRhsTemplateArgs = CTSD->getTemplateArgs().asArray();
+    }
+  } else {
+    assert(false && "unhandled RHS type of the alias");
+  }
+  if (!Template)
+    return;
+  DeclarationNameInfo NameInfo(
+      Context.DeclarationNames.getCXXDeductionGuideName(Template), Loc);
+  LookupResult Guides(SemaRef, NameInfo, clang::Sema::LookupOrdinaryName);
+  SemaRef.LookupQualifiedName(Guides, Template->getDeclContext());
+  Guides.suppressDiagnostics();
+
+  for (auto *G : Guides) {
+    FunctionTemplateDecl *F = dyn_cast<FunctionTemplateDecl>(G);
+    if (!F)
+      continue;
+    auto RType = F->getTemplatedDecl()->getReturnType();
+    // The (trailing) return type of the deduction guide.
+    const TemplateSpecializationType *FReturnType =
+        RType->getAs<TemplateSpecializationType>();
+    if (const auto *InjectedCNT = RType->getAs<InjectedClassNameType>())
+      // implicitly-generated deduction guide.
+      FReturnType = InjectedCNT->getInjectedTST();
+    else if (const auto *ET = RType->getAs<ElaboratedType>())
+      // explicit deduction guide.
+      FReturnType = ET->getNamedType()->getAs<TemplateSpecializationType>();
+    assert(FReturnType && "expected to see a return type");
+    // Deduce template arguments of the deduction guide f from the RHS of
+    // the alias.
+    //
+    // C++ [over.match.class.deduct]p3: ...For each function or function
+    // template f in the guides of the template named by the
+    // simple-template-id of the defining-type-id, the template arguments
+    // of the return type of f are deduced from the defining-type-id of A
+    // according to the process in [temp.deduct.type] with the exception
+    // that deduction does not fail if not all template arguments are
+    // deduced.
+    //
+    //
+    //  template<typename X, typename Y>
+    //  f(X, Y) -> f<Y, X>;
+    //
+    //  template<typename U>
+    //  using alias = f<int, U>;
+    //
+    // The RHS of alias is f<int, U>, we deduced the template arguments of
+    // the return type of the deduction guide from it: Y->int, X->U
+    sema::TemplateDeductionInfo TDeduceInfo(Loc);
+    // Must initialize n elements, this is required by DeduceTemplateArguments.
+    SmallVector<DeducedTemplateArgument> DeduceResults(
+        F->getTemplateParameters()->size());
+
+    // FIXME: DeduceTemplateArguments stops immediately at the first
+    // non-deducible template argument. However, this doesn't seem to casue
+    // issues for practice cases, we probably need to extend it to continue
+    // performing deduction for rest of arguments to align with the C++
+    // standard.
+    SemaRef.DeduceTemplateArguments(
+        F->getTemplateParameters(), FReturnType->template_arguments(),
+        AliasRhsTemplateArgs, TDeduceInfo, DeduceResults,
+        /*NumberOfArgumentsMustMatch=*/false);
+
+    SmallVector<TemplateArgument> DeducedArgs;
+    SmallVector<unsigned> NonDeducedTemplateParamsInFIndex;
+    // !!NOTE: DeduceResults respects the sequence of template parameters of
+    // the deduction guide f.
+    for (unsigned Index = 0; Index < DeduceResults.size(); ++Index) {
+      if (const auto &D = DeduceResults[Index]; !D.isNull()) // Deduced
+        DeducedArgs.push_back(D);
+      else
+        NonDeducedTemplateParamsInFIndex.push_back(Index);
+    }
+    auto DeducedAliasTemplateParams =
+        TemplateParamsReferencedInTemplateArgumentList(
+            AliasTemplate->getTemplateParameters()->asArray(), DeducedArgs);
+    // All template arguments null by default.
+    SmallVector<TemplateArgument> TemplateArgsForBuildingFPrime(
+        F->getTemplateParameters()->size());
+
+    Sema::InstantiatingTemplate BuildingDeductionGuides(
+        SemaRef, AliasTemplate->getLocation(), F,
+        Sema::InstantiatingTemplate::BuildingDeductionGuidesTag{});
+    if (BuildingDeductionGuides.isInvalid())
+      return;
+    LocalInstantiationScope Scope(SemaRef);
 
-    auto *GuideTemplate = FunctionTemplateDecl::Create(
-        SemaRef.Context, DC, Loc, DeductionGuideName, TemplateParams, Guide);
-    GuideTemplate->setImplicit();
-    Guide->setDescribedFunctionTemplate(GuideTemplate);
+    // Create a template parameter list for the synthesized deduction guide f'.
+    //
+    // C++ [over.match.class.deduct]p3.2:
+    //   If f is a function template, f' is a function template whose template
+    //   parameter list consists of all the template parameters of A
+    //   (including their default template arguments) that appear in the above
+    //   deductions or (recursively) in their default template arguments
+    SmallVector<NamedDecl *> FPrimeTemplateParams;
+    // Store template arguments that refer to the newly-created template
+    // parameters, used for building `TemplateArgsForBuildingFPrime`.
+    SmallVector<TemplateArgument, 16> TransformedDeducedAliasArgs(
+        AliasTemplate->getTemplateParameters()->size());
+    auto TransformTemplateParameter =
+        [&SemaRef](DeclContext *DC, NamedDecl *TemplateParam,
+                   MultiLevelTemplateArgumentList &Args,
+                   unsigned NewIndex) -> NamedDecl * {
+      if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
+        return transformTemplateTypeParam(SemaRef, DC, TTP, Args,
+                                          TTP->getDepth(), NewIndex);
+      if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
+        return transformTemplateParam(SemaRef, DC, TTP, Args, NewIndex,
+                                      TTP->getDepth());
+      if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
+        return transformTemplateParam(SemaRef, DC, NTTP, Args, NewIndex,
+                                      NTTP->getDepth());
+      return nullptr;
+    };
 
-    if (isa<CXXRecordDecl>(DC)) {
-      Guide->setAccess(AS_public);
-      GuideTemplate->setAccess(AS_public);
+    for (unsigned AliasTemplateParamIdx : DeducedAliasTemplateParams) {
+      auto *TP = AliasTemplate->getTemplateParameters()->getParam(
+          AliasTemplateParamIdx);
+      // Rebuild any internal references to earlier parameters and reindex as
+      // we go.
+      MultiLevelTemplateArgumentList Args;
+      Args.setKind(TemplateSubstitutionKind::Rewrite);
+      Args.addOuterTemplateArguments(TransformedDeducedAliasArgs);
+      NamedDecl *NewParam =
+          TransformTemplateParameter(AliasTemplate->getDeclContext(), TP, Args,
+                                     /*NewIndex*/ FPrimeTemplateParams.size());
+      FPrimeTemplateParams.push_back(NewParam);
+
+      auto NewTemplateArgument = Context.getCanonicalTemplateArgument(
+          Context.getInjectedTemplateArg(NewParam));
+      TransformedDeducedAliasArgs[AliasTemplateParamIdx] = NewTemplateArgument;
+    }
+    //   ...followed by the template parameters of f that were not deduced
+    //   (including their default template arguments)
+    for (unsigned FTemplateParamIdx : NonDeducedTemplateParamsInFIndex) {
+      auto *TP = F->getTemplateParameters()->getParam(FTemplateParamIdx);
+      MultiLevelTemplateArgumentList Args;
+      Args.setKind(TemplateSubstitutionKind::Rewrite);
+      // We take a shortcut here, it is ok to reuse the
+      // TemplateArgsForBuildingFPrime.
+      Args.addOuterTemplateArguments(TemplateArgsForBuildingFPrime);
+      NamedDecl *NewParam = TransformTemplateParameter(
+          F->getDeclContext(), TP, Args, FPrimeTemplateParams.size());
+      FPrimeTemplateParams.push_back(NewParam);
+
+      assert(TemplateArgsForBuildingFPrime[FTemplateParamIdx].isNull() &&
+             "The argument must be null before setting");
+      TemplateArgsForBuildingFPrime[FTemplateParamIdx] =
+          Context.getCanonicalTemplateArgument(
+              Context.getInjectedTemplateArg(NewParam));
+    }
+    // FIXME: support require clause.
----------------
hokein wrote:

I think the FIXME is mostly about the rewriting the requires clause of `F` and attaching it to the rewritten template parameter list, it is just about the `the associated constraints of g` part of https://eel.is/c++draft/over.match.class.deduct#3.3. Merged the FIXME below with this one.

We haven't implemented the deducible part in this patch yet. Re the implementation, we might not follow the standard way, we could perform the `is deducible` check as part of the overload resolution we do for CTAD which is more efficient.

https://github.com/llvm/llvm-project/pull/77890


More information about the cfe-commits mailing list