[clang] [Clang][NFCI] Cleanup the fix for default function substitution (PR #104911)

Younan Zhang via cfe-commits cfe-commits at lists.llvm.org
Tue Aug 20 03:02:22 PDT 2024


https://github.com/zyn0217 created https://github.com/llvm/llvm-project/pull/104911

(This is one step towards tweaking getTemplateInstantiationArgs() as discussed in https://github.com/llvm/llvm-project/pull/102922)

We don't always substitute into default arguments while transforming a function parameter. In that case, we would preserve the uninstantiated expression until after, e.g. building up a CXXDefaultArgExpr and instantiating the expression there.

For member function instantiation, this algorithm used to cause a problem in that the default argument of an out-of-line member function specialization couldn't get properly instantiated. This is because, in getTemplateInstantiationArgs(), we would give up visiting a function's declaration context if the function is a specialization of a member template. For example,

```cpp
template <class T>
struct S {
  template <class U>
  void f(T = sizeof(T));
};

template <> template <class U>
void S<int>::f(int) {}
```

The default argument `sizeof(U)` that lexically appears inside the declaration would be copied to the function declaration in the class template specialization S<int>, as well as to the function's out-of-line definition. We use template arguments collected from the out-of-line function definition when substituting into the default arguments. We would therefore give up the traversal after the function, resulting in a single-level template argument of the f itself. However the default argument here could still reference the template parameters of the primary template, hence the error.

In fact, this is similar to constraint checking in some respects: we actually want the "whole" template arguments relative to the primary template, not those relative to the function definition. So this patch adds another flag to indicate getTemplateInstantiationArgs() for that.

This patch also consolidates the tests for default arguments and removes some unnecessary tests.

>From c8b8360fe046d38452f71479368c21e217468ddb Mon Sep 17 00:00:00 2001
From: Younan Zhang <zyn7109 at gmail.com>
Date: Tue, 20 Aug 2024 17:18:35 +0800
Subject: [PATCH] [Clang][NFCI] Cleanup the fix for default function
 substitution

---
 clang/include/clang/Sema/Sema.h               |   9 +-
 clang/lib/Sema/SemaTemplateInstantiate.cpp    |  23 +--
 .../lib/Sema/SemaTemplateInstantiateDecl.cpp  |  10 +-
 clang/test/SemaTemplate/default-arguments.cpp |  55 ++++++
 clang/test/SemaTemplate/default-parm-init.cpp | 186 ------------------
 5 files changed, 77 insertions(+), 206 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 2ec6367eccea01..84df847726e6d2 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13053,12 +13053,19 @@ class Sema final : public SemaBase {
   /// ForConstraintInstantiation indicates we should continue looking when
   /// encountering a lambda generic call operator, and continue looking for
   /// arguments on an enclosing class template.
+  ///
+  /// \param SkipForSpecialization when specified, any template specializations
+  /// in a traversal would be ignored.
+  /// \param ForDefaultArgumentSubstitution indicates we should continue looking
+  /// when encountering a specialized member function template, rather than
+  /// returning immediately.
   MultiLevelTemplateArgumentList getTemplateInstantiationArgs(
       const NamedDecl *D, const DeclContext *DC = nullptr, bool Final = false,
       std::optional<ArrayRef<TemplateArgument>> Innermost = std::nullopt,
       bool RelativeToPrimary = false, const FunctionDecl *Pattern = nullptr,
       bool ForConstraintInstantiation = false,
-      bool SkipForSpecialization = false);
+      bool SkipForSpecialization = false,
+      bool ForDefaultArgumentSubstitution = false);
 
   /// RAII object to handle the state changes required to synthesize
   /// a function body.
diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index de470739ab78e7..feed797de838dd 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -255,7 +255,8 @@ HandleClassTemplateSpec(const ClassTemplateSpecializationDecl *ClassTemplSpec,
 Response HandleFunction(Sema &SemaRef, const FunctionDecl *Function,
                         MultiLevelTemplateArgumentList &Result,
                         const FunctionDecl *Pattern, bool RelativeToPrimary,
-                        bool ForConstraintInstantiation) {
+                        bool ForConstraintInstantiation,
+                        bool ForDefaultArgumentSubstitution) {
   // Add template arguments from a function template specialization.
   if (!RelativeToPrimary &&
       Function->getTemplateSpecializationKindForInstantiation() ==
@@ -285,7 +286,8 @@ Response HandleFunction(Sema &SemaRef, const FunctionDecl *Function,
     // If this function was instantiated from a specialized member that is
     // a function template, we're done.
     assert(Function->getPrimaryTemplate() && "No function template?");
-    if (Function->getPrimaryTemplate()->isMemberSpecialization())
+    if (!ForDefaultArgumentSubstitution &&
+        Function->getPrimaryTemplate()->isMemberSpecialization())
       return Response::Done();
 
     // If this function is a generic lambda specialization, we are done.
@@ -467,7 +469,7 @@ MultiLevelTemplateArgumentList Sema::getTemplateInstantiationArgs(
     const NamedDecl *ND, const DeclContext *DC, bool Final,
     std::optional<ArrayRef<TemplateArgument>> Innermost, bool RelativeToPrimary,
     const FunctionDecl *Pattern, bool ForConstraintInstantiation,
-    bool SkipForSpecialization) {
+    bool SkipForSpecialization, bool ForDefaultArgumentSubstitution) {
   assert((ND || DC) && "Can't find arguments for a decl if one isn't provided");
   // Accumulate the set of template argument lists in this structure.
   MultiLevelTemplateArgumentList Result;
@@ -509,7 +511,8 @@ MultiLevelTemplateArgumentList Sema::getTemplateInstantiationArgs(
                                   SkipForSpecialization);
     } else if (const auto *Function = dyn_cast<FunctionDecl>(CurDecl)) {
       R = HandleFunction(*this, Function, Result, Pattern, RelativeToPrimary,
-                         ForConstraintInstantiation);
+                         ForConstraintInstantiation,
+                         ForDefaultArgumentSubstitution);
     } else if (const auto *Rec = dyn_cast<CXXRecordDecl>(CurDecl)) {
       R = HandleRecordDecl(*this, Rec, Result, Context,
                            ForConstraintInstantiation);
@@ -3229,7 +3232,6 @@ bool Sema::SubstDefaultArgument(
     //   default argument expression appears.
     ContextRAII SavedContext(*this, FD);
     std::unique_ptr<LocalInstantiationScope> LIS;
-    MultiLevelTemplateArgumentList NewTemplateArgs = TemplateArgs;
 
     if (ForCallExpr) {
       // When instantiating a default argument due to use in a call expression,
@@ -3242,19 +3244,10 @@ bool Sema::SubstDefaultArgument(
           /*ForDefinition*/ false);
       if (addInstantiatedParametersToScope(FD, PatternFD, *LIS, TemplateArgs))
         return true;
-      const FunctionTemplateDecl *PrimaryTemplate = FD->getPrimaryTemplate();
-      if (PrimaryTemplate && PrimaryTemplate->isOutOfLine()) {
-        TemplateArgumentList *CurrentTemplateArgumentList =
-            TemplateArgumentList::CreateCopy(getASTContext(),
-                                             TemplateArgs.getInnermost());
-        NewTemplateArgs = getTemplateInstantiationArgs(
-            FD, FD->getDeclContext(), /*Final=*/false,
-            CurrentTemplateArgumentList->asArray(), /*RelativeToPrimary=*/true);
-      }
     }
 
     runWithSufficientStackSpace(Loc, [&] {
-      Result = SubstInitializer(PatternExpr, NewTemplateArgs,
+      Result = SubstInitializer(PatternExpr, TemplateArgs,
                                 /*DirectInit*/ false);
     });
   }
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index f93cd113988ae4..ad2ad3b1d1a790 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -4659,10 +4659,12 @@ bool Sema::InstantiateDefaultArgument(SourceLocation CallLoc, FunctionDecl *FD,
   //
   // template<typename T>
   // A<T> Foo(int a = A<T>::FooImpl());
-  MultiLevelTemplateArgumentList TemplateArgs =
-      getTemplateInstantiationArgs(FD, FD->getLexicalDeclContext(),
-                                   /*Final=*/false, /*Innermost=*/std::nullopt,
-                                   /*RelativeToPrimary=*/true);
+  MultiLevelTemplateArgumentList TemplateArgs = getTemplateInstantiationArgs(
+      FD, FD->getLexicalDeclContext(),
+      /*Final=*/false, /*Innermost=*/std::nullopt,
+      /*RelativeToPrimary=*/true, /*Pattern=*/nullptr,
+      /*ForConstraintInstantiation=*/false, /*SkipForSpecialization=*/false,
+      /*ForDefaultArgumentSubstitution=*/true);
 
   if (SubstDefaultArgument(CallLoc, Param, TemplateArgs, /*ForCallExpr*/ true))
     return true;
diff --git a/clang/test/SemaTemplate/default-arguments.cpp b/clang/test/SemaTemplate/default-arguments.cpp
index d5d9687cc90f49..53f7a8d2f40adf 100644
--- a/clang/test/SemaTemplate/default-arguments.cpp
+++ b/clang/test/SemaTemplate/default-arguments.cpp
@@ -229,3 +229,58 @@ namespace unevaluated {
   template<int = 0> int f(int = a); // expected-warning 0-1{{extension}}
   int k = sizeof(f());
 }
+
+#if __cplusplus >= 201103L
+namespace GH68490 {
+
+template <typename T> struct Problem {
+  template <typename U>
+  constexpr int UseAlign(int param = alignof(U)) const;
+
+  template <typename U>
+  constexpr int UseSizeof(int param = sizeof(T)) const;
+};
+
+template <typename T> struct Problem<T *> {
+  template <typename U>
+  constexpr int UseAlign(int param = alignof(U)) const;
+
+  template <typename U>
+  constexpr int UseSizeof(int param = sizeof(T)) const;
+};
+
+template <typename T>
+template <typename U>
+constexpr int Problem<T *>::UseAlign(int param) const {
+  return 2 * param;
+}
+
+template <typename T>
+template <typename U>
+constexpr int Problem<T *>::UseSizeof(int param) const {
+  return 2 * param;
+}
+
+template <>
+template <typename T>
+constexpr int Problem<int>::UseAlign(int param) const {
+  return param;
+}
+
+template <>
+template <typename T>
+constexpr int Problem<int>::UseSizeof(int param) const {
+  return param;
+}
+
+void foo() {
+  static_assert(Problem<int>().UseAlign<char>() == alignof(char), "");
+  static_assert(Problem<int>().UseSizeof<char>() == sizeof(char), "");
+  // expected-error at -1 {{failed}} expected-note at -1 {{evaluates to '4 == 1'}}
+  static_assert(Problem<short *>().UseAlign<char>() == 2U * alignof(char), "");
+  static_assert(Problem<short *>().UseSizeof<char>() == 2U * sizeof(char), "");
+  // expected-error at -1 {{failed}} expected-note at -1 {{evaluates to '4 == 2'}}
+}
+
+} // namespace GH68490
+#endif
diff --git a/clang/test/SemaTemplate/default-parm-init.cpp b/clang/test/SemaTemplate/default-parm-init.cpp
index 73ba8998df6a98..d1f407ad15c677 100644
--- a/clang/test/SemaTemplate/default-parm-init.cpp
+++ b/clang/test/SemaTemplate/default-parm-init.cpp
@@ -2,189 +2,3 @@
 // RUN: %clang_cc1 -fsyntax-only -std=c++20 -verify %s
 // expected-no-diagnostics
 
-namespace std {
-
-template<typename Signature> class function;
-
-template<typename R, typename... Args> class invoker_base {
-public: 
-  virtual ~invoker_base() { } 
-  virtual R invoke(Args...) = 0; 
-  virtual invoker_base* clone() = 0;
-};
-
-template<typename F, typename R, typename... Args> 
-class functor_invoker : public invoker_base<R, Args...> {
-public: 
-  explicit functor_invoker(const F& f) : f(f) { } 
-  R invoke(Args... args) { return f(args...); } 
-  functor_invoker* clone() { return new functor_invoker(f); }
-
-private:
-  F f;
-};
-
-template<typename R, typename... Args>
-class function<R (Args...)> {
-public: 
-  typedef R result_type;
-  function() : invoker (0) { }
-  function(const function& other) : invoker(0) { 
-    if (other.invoker)
-      invoker = other.invoker->clone();
-  }
-
-  template<typename F> function(const F& f) : invoker(0) {
-    invoker = new functor_invoker<F, R, Args...>(f);
-  }
-
-  ~function() { 
-    if (invoker)
-      delete invoker;
-  }
-
-  function& operator=(const function& other) { 
-    function(other).swap(*this); 
-    return *this;
-  }
-
-  template<typename F> 
-  function& operator=(const F& f) {
-    function(f).swap(*this); 
-    return *this;
-  }
-
-  void swap(function& other) { 
-    invoker_base<R, Args...>* tmp = invoker; 
-    invoker = other.invoker; 
-    other.invoker = tmp;
-  }
-
-  result_type operator()(Args... args) const { 
-    return invoker->invoke(args...);
-  }
-
-private: 
-  invoker_base<R, Args...>* invoker;
-};
-
-}
-
-template<typename TemplateParam>
-struct Problem {
-  template<typename FunctionTemplateParam>
-  constexpr int FuncAlign(int param = alignof(FunctionTemplateParam));
-
-  template<typename FunctionTemplateParam>
-  constexpr int FuncSizeof(int param = sizeof(FunctionTemplateParam));
-
-  template<typename FunctionTemplateParam>
-  constexpr int FuncAlign2(int param = alignof(TemplateParam));
-
-  template<typename FunctionTemplateParam>
-  constexpr int FuncSizeof2(int param = sizeof(TemplateParam));
-};
-
-template<typename TemplateParam>
-struct Problem<TemplateParam*> {
-  template<typename FunctionTemplateParam>
-  constexpr int FuncAlign(int param = alignof(FunctionTemplateParam));
-
-  template<typename FunctionTemplateParam>
-  constexpr int FuncSizeof(int param = sizeof(FunctionTemplateParam));
-
-  template<typename FunctionTemplateParam>
-  constexpr int FuncAlign2(int param = alignof(TemplateParam));
-
-  template<typename FunctionTemplateParam>
-  constexpr int FuncSizeof2(int param = sizeof(TemplateParam));
-};
-
-template<typename TemplateParam>
-template<typename FunctionTemplateParam>
-constexpr int Problem<TemplateParam*>::FuncAlign(int param) {
-	return 2U*param;
-}
-
-template<typename TemplateParam>
-template<typename FunctionTemplateParam>
-constexpr int Problem<TemplateParam*>::FuncSizeof(int param) {
-    return 2U*param;
-}
-
-template<typename TemplateParam>
-template<typename FunctionTemplateParam>
-constexpr int Problem<TemplateParam*>::FuncAlign2(int param) {
-	return 2U*param;
-}
-
-template<typename TemplateParam>
-template<typename FunctionTemplateParam>
-constexpr int Problem<TemplateParam*>::FuncSizeof2(int param) {
-	return 2U*param;
-}
-
-template <>
-template<typename FunctionTemplateParam>
-constexpr int Problem<int>::FuncAlign(int param) {
-	return param;
-}
-
-template <>
-template<typename FunctionTemplateParam>
-constexpr int Problem<int>::FuncSizeof(int param) {
-	return param;
-}
-
-template <>
-template<typename FunctionTemplateParam>
-constexpr int Problem<int>::FuncAlign2(int param) {
-	return param;
-}
-
-template <>
-template<typename FunctionTemplateParam>
-constexpr int Problem<int>::FuncSizeof2(int param) {
-	return param;
-}
-
-void foo() {
-    Problem<int> p = {};
-    static_assert(p.FuncAlign<char>() == alignof(char));
-    static_assert(p.FuncSizeof<char>() == sizeof(char));
-    static_assert(p.FuncAlign2<char>() == alignof(int));
-    static_assert(p.FuncSizeof2<char>() == sizeof(int));
-    Problem<short*> q = {};
-    static_assert(q.FuncAlign<char>() == 2U * alignof(char));
-    static_assert(q.FuncSizeof<char>() == 2U * sizeof(char));
-    static_assert(q.FuncAlign2<char>() == 2U *alignof(short));
-    static_assert(q.FuncSizeof2<char>() == 2U * sizeof(short));
-}
-
-template <typename T>
-class A {
- public:
-  void run(
-    std::function<void(T&)> f1 = [](auto&&) {},
-    std::function<void(T&)> f2 = [](auto&&) {});
- private:
-  class Helper {
-   public:
-    explicit Helper(std::function<void(T&)> f2) : f2_(f2) {}
-    std::function<void(T&)> f2_;
-  };
-};
-
-template <typename T>
-void A<T>::run(std::function<void(T&)> f1,
-               std::function<void(T&)> f2) {
-  Helper h(f2);
-}
-
-struct B {};
-
-int main() {
-    A<B> a;
-    a.run([&](auto& l) {});
-    return 0;
-}



More information about the cfe-commits mailing list