[clang] [clang] Implement CTAD for type alias template. (PR #77890)

via cfe-commits cfe-commits at lists.llvm.org
Wed Mar 6 01:18:06 PST 2024


================
@@ -2612,44 +2669,312 @@ struct ConvertConstructorToDeductionGuideTransform {
     SemaRef.CurrentInstantiationScope->InstantiatedLocal(OldParam, NewParam);
     return NewParam;
   }
+};
 
-  FunctionTemplateDecl *buildDeductionGuide(
-      TemplateParameterList *TemplateParams, CXXConstructorDecl *Ctor,
-      ExplicitSpecifier ES, TypeSourceInfo *TInfo, SourceLocation LocStart,
-      SourceLocation Loc, SourceLocation LocEnd,
-      llvm::ArrayRef<TypedefNameDecl *> MaterializedTypedefs = {}) {
-    DeclarationNameInfo Name(DeductionGuideName, Loc);
-    ArrayRef<ParmVarDecl *> Params =
-        TInfo->getTypeLoc().castAs<FunctionProtoTypeLoc>().getParams();
+// Find all template parameters that appear in the given DeducedArgs.
+// Return the indices of the template parameters in the TemplateParams.
+SmallVector<unsigned> TemplateParamsReferencedInTemplateArgumentList(
+    ArrayRef<NamedDecl *> TemplateParams,
+    ArrayRef<TemplateArgument> DeducedArgs) {
+  struct FindAppearedTemplateParams
+      : public RecursiveASTVisitor<FindAppearedTemplateParams> {
+    llvm::DenseSet<NamedDecl *> TemplateParams;
+    llvm::DenseSet<const NamedDecl *> AppearedTemplateParams;
+
+    FindAppearedTemplateParams(ArrayRef<NamedDecl *> TemplateParams)
+        : TemplateParams(TemplateParams.begin(), TemplateParams.end()) {}
+
+    bool VisitTemplateTypeParmType(TemplateTypeParmType *TTP) {
+      TTP->getIndex();
+      MarkAppeared(TTP->getDecl());
+      return true;
+    }
+    bool VisitDeclRefExpr(DeclRefExpr *DRE) {
+      MarkAppeared(DRE->getFoundDecl());
+      return true;
+    }
 
-    // Build the implicit deduction guide template.
-    auto *Guide =
-        CXXDeductionGuideDecl::Create(SemaRef.Context, DC, LocStart, ES, Name,
-                                      TInfo->getType(), TInfo, LocEnd, Ctor);
-    Guide->setImplicit();
-    Guide->setParams(Params);
+    void MarkAppeared(NamedDecl *ND) {
+      if (TemplateParams.contains(ND))
+        AppearedTemplateParams.insert(ND);
+    }
+  };
+  FindAppearedTemplateParams MarkAppeared(TemplateParams);
+  MarkAppeared.TraverseTemplateArguments(DeducedArgs);
 
-    for (auto *Param : Params)
-      Param->setDeclContext(Guide);
-    for (auto *TD : MaterializedTypedefs)
-      TD->setDeclContext(Guide);
+  SmallVector<unsigned> Results;
+  for (unsigned Index = 0; Index < TemplateParams.size(); ++Index) {
+    if (MarkAppeared.AppearedTemplateParams.contains(
+            TemplateParams[Index]))
+      Results.push_back(Index);
+  }
+  return Results;
+}
+
+bool hasDeclaredDeductionGuides(DeclarationName Name, DeclContext *DC) {
+  // Check whether we've already declared deduction guides for this template.
+  // FIXME: Consider storing a flag on the template to indicate this.
+  assert(Name.getNameKind() ==
+             DeclarationName::NameKind::CXXDeductionGuideName &&
+         "name must be a deduction guide name");
+  auto Existing = DC->lookup(Name);
+  for (auto *D : Existing)
+    if (D->isImplicit())
+      return true;
+  return false;
+}
+
+// Build deduction guides for a type alias template.
+void DeclareImplicitDeductionGuidesForTypeAlias(
+    Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate, SourceLocation Loc) {
+  auto &Context = SemaRef.Context;
+  // FIXME: if there is an explicit deduction guide after the first use of the
+  // type alias usage, we will not cover this explicit deduction guide. fix this
+  // case.
+  if (hasDeclaredDeductionGuides(
+          Context.DeclarationNames.getCXXDeductionGuideName(AliasTemplate),
+          AliasTemplate->getDeclContext()))
+    return;
+  // Unwrap the sugared ElaboratedType.
+  auto RhsType = AliasTemplate->getTemplatedDecl()
+                     ->getUnderlyingType()
+                     .getSingleStepDesugaredType(Context);
+  TemplateDecl *Template = nullptr;
+  llvm::ArrayRef<TemplateArgument> AliasRhsTemplateArgs;
+  if (const auto *TST = RhsType->getAs<TemplateSpecializationType>()) {
+    // Cases where the RHS of the alias is dependent. e.g.
+    //   template<typename T>
+    //   using AliasFoo1 = Foo<T>; // a class/type alias template specialization
+    Template = TST->getTemplateName().getAsTemplateDecl();
+    AliasRhsTemplateArgs = TST->template_arguments();
+  } else if (const auto *RT = RhsType->getAs<RecordType>()) {
+    // Cases where template arguments in the RHS of the alias are not
+    // dependent. e.g.
+    //   using AliasFoo = Foo<bool>;
+    if (const auto *CTSD = llvm::dyn_cast<ClassTemplateSpecializationDecl>(
+            RT->getAsCXXRecordDecl())) {
+      Template = CTSD->getSpecializedTemplate();
+      AliasRhsTemplateArgs = CTSD->getTemplateArgs().asArray();
+    }
+  } else {
+    assert(false && "unhandled RHS type of the alias");
+  }
+  if (!Template)
+    return;
+  DeclarationNameInfo NameInfo(
+      Context.DeclarationNames.getCXXDeductionGuideName(Template), Loc);
+  LookupResult Guides(SemaRef, NameInfo, clang::Sema::LookupOrdinaryName);
+  SemaRef.LookupQualifiedName(Guides, Template->getDeclContext());
+  Guides.suppressDiagnostics();
+
+  for (auto *G : Guides) {
+    FunctionTemplateDecl *F = dyn_cast<FunctionTemplateDecl>(G);
+    if (!F)
+      continue;
+    auto RType = F->getTemplatedDecl()->getReturnType();
+    // The (trailing) return type of the deduction guide.
+    const TemplateSpecializationType *FReturnType =
+        RType->getAs<TemplateSpecializationType>();
+    if (const auto *InjectedCNT = RType->getAs<InjectedClassNameType>())
+      // implicitly-generated deduction guide.
+      FReturnType = InjectedCNT->getInjectedTST();
+    else if (const auto *ET = RType->getAs<ElaboratedType>())
+      // explicit deduction guide.
+      FReturnType = ET->getNamedType()->getAs<TemplateSpecializationType>();
+    assert(FReturnType && "expected to see a return type");
+    // Deduce template arguments of the deduction guide f from the RHS of
+    // the alias.
+    //
+    // C++ [over.match.class.deduct]p3: ...For each function or function
+    // template f in the guides of the template named by the
+    // simple-template-id of the defining-type-id, the template arguments
+    // of the return type of f are deduced from the defining-type-id of A
+    // according to the process in [temp.deduct.type] with the exception
+    // that deduction does not fail if not all template arguments are
+    // deduced.
+    //
+    //
+    //  template<typename X, typename Y>
+    //  f(X, Y) -> f<Y, X>;
+    //
+    //  template<typename U>
+    //  using alias = f<int, U>;
+    //
+    // The RHS of alias is f<int, U>, we deduced the template arguments of
+    // the return type of the deduction guide from it: Y->int, X->U
+    sema::TemplateDeductionInfo TDeduceInfo(Loc);
+    // Must initialize n elements, this is required by DeduceTemplateArguments.
+    SmallVector<DeducedTemplateArgument> DeduceResults(
+        F->getTemplateParameters()->size());
+
+    // FIXME: DeduceTemplateArguments stops immediately at the first
+    // non-deducible template argument. However, this doesn't seem to casue
+    // issues for practice cases, we probably need to extend it to continue
+    // performing deduction for rest of arguments to align with the C++
+    // standard.
+    SemaRef.DeduceTemplateArguments(
+        F->getTemplateParameters(), FReturnType->template_arguments(),
+        AliasRhsTemplateArgs, TDeduceInfo, DeduceResults,
+        /*NumberOfArgumentsMustMatch=*/false);
+
+    SmallVector<TemplateArgument> DeducedArgs;
+    SmallVector<unsigned> NonDeducedTemplateParamsInFIndex;
+    // !!NOTE: DeduceResults respects the sequence of template parameters of
+    // the deduction guide f.
+    for (unsigned Index = 0; Index < DeduceResults.size(); ++Index) {
+      if (const auto &D = DeduceResults[Index]; !D.isNull()) // Deduced
+        DeducedArgs.push_back(D);
+      else
+        NonDeducedTemplateParamsInFIndex.push_back(Index);
+    }
+    auto DeducedAliasTemplateParams =
+        TemplateParamsReferencedInTemplateArgumentList(
+            AliasTemplate->getTemplateParameters()->asArray(), DeducedArgs);
+    // All template arguments null by default.
+    SmallVector<TemplateArgument> TemplateArgsForBuildingFPrime(
+        F->getTemplateParameters()->size());
+
+    Sema::InstantiatingTemplate BuildingDeductionGuides(
+        SemaRef, AliasTemplate->getLocation(), F,
+        Sema::InstantiatingTemplate::BuildingDeductionGuidesTag{});
+    if (BuildingDeductionGuides.isInvalid())
+      return;
+    LocalInstantiationScope Scope(SemaRef);
 
-    auto *GuideTemplate = FunctionTemplateDecl::Create(
-        SemaRef.Context, DC, Loc, DeductionGuideName, TemplateParams, Guide);
-    GuideTemplate->setImplicit();
-    Guide->setDescribedFunctionTemplate(GuideTemplate);
+    // Create a template parameter list for the synthesized deduction guide f'.
+    //
+    // C++ [over.match.class.deduct]p3.2:
+    //   If f is a function template, f' is a function template whose template
+    //   parameter list consists of all the template parameters of A
+    //   (including their default template arguments) that appear in the above
+    //   deductions or (recursively) in their default template arguments
+    SmallVector<NamedDecl *> FPrimeTemplateParams;
+    // Store template arguments that refer to the newly-created template
+    // parameters, used for building `TemplateArgsForBuildingFPrime`.
+    SmallVector<TemplateArgument, 16> TransformedDeducedAliasArgs(
+        AliasTemplate->getTemplateParameters()->size());
+    auto TransformTemplateParameter =
+        [&SemaRef](DeclContext *DC, NamedDecl *TemplateParam,
+                   MultiLevelTemplateArgumentList &Args,
+                   unsigned NewIndex) -> NamedDecl * {
+      if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
+        return transformTemplateTypeParam(SemaRef, DC, TTP, Args,
+                                          TTP->getDepth(), NewIndex);
+      if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
+        return transformTemplateParam(SemaRef, DC, TTP, Args, NewIndex,
+                                      TTP->getDepth());
+      if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
+        return transformTemplateParam(SemaRef, DC, NTTP, Args, NewIndex,
+                                      NTTP->getDepth());
+      return nullptr;
+    };
 
-    if (isa<CXXRecordDecl>(DC)) {
-      Guide->setAccess(AS_public);
-      GuideTemplate->setAccess(AS_public);
+    for (unsigned AliasTemplateParamIdx : DeducedAliasTemplateParams) {
+      auto *TP = AliasTemplate->getTemplateParameters()->getParam(
+          AliasTemplateParamIdx);
+      // Rebuild any internal references to earlier parameters and reindex as
+      // we go.
+      MultiLevelTemplateArgumentList Args;
+      Args.setKind(TemplateSubstitutionKind::Rewrite);
+      Args.addOuterTemplateArguments(TransformedDeducedAliasArgs);
+      NamedDecl *NewParam =
+          TransformTemplateParameter(AliasTemplate->getDeclContext(), TP, Args,
+                                     /*NewIndex*/ FPrimeTemplateParams.size());
+      FPrimeTemplateParams.push_back(NewParam);
+
+      auto NewTemplateArgument = Context.getCanonicalTemplateArgument(
+          Context.getInjectedTemplateArg(NewParam));
+      TransformedDeducedAliasArgs[AliasTemplateParamIdx] = NewTemplateArgument;
+    }
+    //   ...followed by the template parameters of f that were not deduced
+    //   (including their default template arguments)
+    for (unsigned FTemplateParamIdx : NonDeducedTemplateParamsInFIndex) {
+      auto *TP = F->getTemplateParameters()->getParam(FTemplateParamIdx);
+      MultiLevelTemplateArgumentList Args;
+      Args.setKind(TemplateSubstitutionKind::Rewrite);
+      // We take a shortcut here, it is ok to reuse the
+      // TemplateArgsForBuildingFPrime.
+      Args.addOuterTemplateArguments(TemplateArgsForBuildingFPrime);
+      NamedDecl *NewParam = TransformTemplateParameter(
+          F->getDeclContext(), TP, Args, FPrimeTemplateParams.size());
+      FPrimeTemplateParams.push_back(NewParam);
+
+      assert(TemplateArgsForBuildingFPrime[FTemplateParamIdx].isNull() &&
+             "The argument must be null before setting");
+      TemplateArgsForBuildingFPrime[FTemplateParamIdx] =
+          Context.getCanonicalTemplateArgument(
+              Context.getInjectedTemplateArg(NewParam));
+    }
+    // FIXME: support require clause.
----------------
cor3ntin wrote:

So this is missing https://eel.is/c++draft/over.match.class.deduct#3.3
I think we need to create a Binary (and) op expression with the requires clause of F and, a literal bool expression set to whether the arguments are deducible (do we implement the deducible part?)

https://github.com/llvm/llvm-project/pull/77890


More information about the cfe-commits mailing list