[clang] [Clang][NFCI] Cleanup the fix for default function argument substitution (PR #104911)
via cfe-commits
cfe-commits at lists.llvm.org
Tue Aug 20 03:55:05 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Younan Zhang (zyn0217)
<details>
<summary>Changes</summary>
(This is one step towards tweaking `getTemplateInstantiationArgs()` as discussed in https://github.com/llvm/llvm-project/pull/102922)
We don't always substitute into default arguments while transforming a function parameter. In that case, we would preserve the uninstantiated expression until after, e.g. building up a CXXDefaultArgExpr and instantiate the expression there.
For member function instantiation, this algorithm used to cause a problem in that the default argument of an out-of-line member function specialization couldn't get properly instantiated. This is because, in `getTemplateInstantiationArgs()`, we would give up visiting a function's declaration context if the function is a specialization of a member template. For example,
```cpp
template <class T>
struct S {
template <class U>
void f(T = sizeof(T));
};
template <> template <class U>
void S<int>::f(int) {}
```
The default argument `sizeof(U)` that lexically appears inside the declaration would be copied to the function declaration in the class template specialization `S<int>`, as well as to the function's out-of-line definition. We use template arguments collected from the out-of-line function definition when substituting into the default arguments. We would therefore give up the traversal after the function, resulting in a single-level template argument of the f itself. However the default argument here could still reference the template parameters of the primary template, hence the error.
In fact, this is similar to constraint checking in some respects: we actually want the "whole" template arguments relative to the primary template, not those relative to the function definition. So this patch adds another flag to indicate `getTemplateInstantiationArgs()` for that.
This patch also consolidates the tests for default arguments and removes some unnecessary tests.
---
Full diff: https://github.com/llvm/llvm-project/pull/104911.diff
5 Files Affected:
- (modified) clang/include/clang/Sema/Sema.h (+8-1)
- (modified) clang/lib/Sema/SemaTemplateInstantiate.cpp (+8-15)
- (modified) clang/lib/Sema/SemaTemplateInstantiateDecl.cpp (+6-4)
- (modified) clang/test/SemaTemplate/default-arguments.cpp (+55)
- (modified) clang/test/SemaTemplate/default-parm-init.cpp (-186)
``````````diff
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 2ec6367eccea01..84df847726e6d2 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13053,12 +13053,19 @@ class Sema final : public SemaBase {
/// ForConstraintInstantiation indicates we should continue looking when
/// encountering a lambda generic call operator, and continue looking for
/// arguments on an enclosing class template.
+ ///
+ /// \param SkipForSpecialization when specified, any template specializations
+ /// in a traversal would be ignored.
+ /// \param ForDefaultArgumentSubstitution indicates we should continue looking
+ /// when encountering a specialized member function template, rather than
+ /// returning immediately.
MultiLevelTemplateArgumentList getTemplateInstantiationArgs(
const NamedDecl *D, const DeclContext *DC = nullptr, bool Final = false,
std::optional<ArrayRef<TemplateArgument>> Innermost = std::nullopt,
bool RelativeToPrimary = false, const FunctionDecl *Pattern = nullptr,
bool ForConstraintInstantiation = false,
- bool SkipForSpecialization = false);
+ bool SkipForSpecialization = false,
+ bool ForDefaultArgumentSubstitution = false);
/// RAII object to handle the state changes required to synthesize
/// a function body.
diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index de470739ab78e7..feed797de838dd 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -255,7 +255,8 @@ HandleClassTemplateSpec(const ClassTemplateSpecializationDecl *ClassTemplSpec,
Response HandleFunction(Sema &SemaRef, const FunctionDecl *Function,
MultiLevelTemplateArgumentList &Result,
const FunctionDecl *Pattern, bool RelativeToPrimary,
- bool ForConstraintInstantiation) {
+ bool ForConstraintInstantiation,
+ bool ForDefaultArgumentSubstitution) {
// Add template arguments from a function template specialization.
if (!RelativeToPrimary &&
Function->getTemplateSpecializationKindForInstantiation() ==
@@ -285,7 +286,8 @@ Response HandleFunction(Sema &SemaRef, const FunctionDecl *Function,
// If this function was instantiated from a specialized member that is
// a function template, we're done.
assert(Function->getPrimaryTemplate() && "No function template?");
- if (Function->getPrimaryTemplate()->isMemberSpecialization())
+ if (!ForDefaultArgumentSubstitution &&
+ Function->getPrimaryTemplate()->isMemberSpecialization())
return Response::Done();
// If this function is a generic lambda specialization, we are done.
@@ -467,7 +469,7 @@ MultiLevelTemplateArgumentList Sema::getTemplateInstantiationArgs(
const NamedDecl *ND, const DeclContext *DC, bool Final,
std::optional<ArrayRef<TemplateArgument>> Innermost, bool RelativeToPrimary,
const FunctionDecl *Pattern, bool ForConstraintInstantiation,
- bool SkipForSpecialization) {
+ bool SkipForSpecialization, bool ForDefaultArgumentSubstitution) {
assert((ND || DC) && "Can't find arguments for a decl if one isn't provided");
// Accumulate the set of template argument lists in this structure.
MultiLevelTemplateArgumentList Result;
@@ -509,7 +511,8 @@ MultiLevelTemplateArgumentList Sema::getTemplateInstantiationArgs(
SkipForSpecialization);
} else if (const auto *Function = dyn_cast<FunctionDecl>(CurDecl)) {
R = HandleFunction(*this, Function, Result, Pattern, RelativeToPrimary,
- ForConstraintInstantiation);
+ ForConstraintInstantiation,
+ ForDefaultArgumentSubstitution);
} else if (const auto *Rec = dyn_cast<CXXRecordDecl>(CurDecl)) {
R = HandleRecordDecl(*this, Rec, Result, Context,
ForConstraintInstantiation);
@@ -3229,7 +3232,6 @@ bool Sema::SubstDefaultArgument(
// default argument expression appears.
ContextRAII SavedContext(*this, FD);
std::unique_ptr<LocalInstantiationScope> LIS;
- MultiLevelTemplateArgumentList NewTemplateArgs = TemplateArgs;
if (ForCallExpr) {
// When instantiating a default argument due to use in a call expression,
@@ -3242,19 +3244,10 @@ bool Sema::SubstDefaultArgument(
/*ForDefinition*/ false);
if (addInstantiatedParametersToScope(FD, PatternFD, *LIS, TemplateArgs))
return true;
- const FunctionTemplateDecl *PrimaryTemplate = FD->getPrimaryTemplate();
- if (PrimaryTemplate && PrimaryTemplate->isOutOfLine()) {
- TemplateArgumentList *CurrentTemplateArgumentList =
- TemplateArgumentList::CreateCopy(getASTContext(),
- TemplateArgs.getInnermost());
- NewTemplateArgs = getTemplateInstantiationArgs(
- FD, FD->getDeclContext(), /*Final=*/false,
- CurrentTemplateArgumentList->asArray(), /*RelativeToPrimary=*/true);
- }
}
runWithSufficientStackSpace(Loc, [&] {
- Result = SubstInitializer(PatternExpr, NewTemplateArgs,
+ Result = SubstInitializer(PatternExpr, TemplateArgs,
/*DirectInit*/ false);
});
}
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index f93cd113988ae4..ad2ad3b1d1a790 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -4659,10 +4659,12 @@ bool Sema::InstantiateDefaultArgument(SourceLocation CallLoc, FunctionDecl *FD,
//
// template<typename T>
// A<T> Foo(int a = A<T>::FooImpl());
- MultiLevelTemplateArgumentList TemplateArgs =
- getTemplateInstantiationArgs(FD, FD->getLexicalDeclContext(),
- /*Final=*/false, /*Innermost=*/std::nullopt,
- /*RelativeToPrimary=*/true);
+ MultiLevelTemplateArgumentList TemplateArgs = getTemplateInstantiationArgs(
+ FD, FD->getLexicalDeclContext(),
+ /*Final=*/false, /*Innermost=*/std::nullopt,
+ /*RelativeToPrimary=*/true, /*Pattern=*/nullptr,
+ /*ForConstraintInstantiation=*/false, /*SkipForSpecialization=*/false,
+ /*ForDefaultArgumentSubstitution=*/true);
if (SubstDefaultArgument(CallLoc, Param, TemplateArgs, /*ForCallExpr*/ true))
return true;
diff --git a/clang/test/SemaTemplate/default-arguments.cpp b/clang/test/SemaTemplate/default-arguments.cpp
index d5d9687cc90f49..c90787c4255a4a 100644
--- a/clang/test/SemaTemplate/default-arguments.cpp
+++ b/clang/test/SemaTemplate/default-arguments.cpp
@@ -229,3 +229,58 @@ namespace unevaluated {
template<int = 0> int f(int = a); // expected-warning 0-1{{extension}}
int k = sizeof(f());
}
+
+#if __cplusplus >= 201103L
+namespace GH68490 {
+
+template <typename T> struct Problem {
+ template <typename U>
+ constexpr int UseAlignOf(int param = alignof(U)) const;
+
+ template <typename U>
+ constexpr int UseSizeOf(int param = sizeof(T)) const;
+};
+
+template <typename T> struct Problem<T *> {
+ template <typename U>
+ constexpr int UseAlignOf(int param = alignof(U)) const;
+
+ template <typename U>
+ constexpr int UseSizeOf(int param = sizeof(T)) const;
+};
+
+template <typename T>
+template <typename U>
+constexpr int Problem<T *>::UseAlignOf(int param) const {
+ return 2 * param;
+}
+
+template <typename T>
+template <typename U>
+constexpr int Problem<T *>::UseSizeOf(int param) const {
+ return 2 * param;
+}
+
+template <>
+template <typename T>
+constexpr int Problem<int>::UseAlignOf(int param) const {
+ return param;
+}
+
+template <>
+template <typename T>
+constexpr int Problem<int>::UseSizeOf(int param) const {
+ return param;
+}
+
+void foo() {
+ static_assert(Problem<int>().UseAlignOf<char>() == alignof(char), "");
+ static_assert(Problem<int>().UseSizeOf<char>() == sizeof(char), "");
+ // expected-error at -1 {{failed}} expected-note at -1 {{evaluates to '4 == 1'}}
+ static_assert(Problem<short *>().UseAlignOf<char>() == 2U * alignof(char), "");
+ static_assert(Problem<short *>().UseSizeOf<char>() == 2U * sizeof(char), "");
+ // expected-error at -1 {{failed}} expected-note at -1 {{evaluates to '4 == 2'}}
+}
+
+} // namespace GH68490
+#endif
diff --git a/clang/test/SemaTemplate/default-parm-init.cpp b/clang/test/SemaTemplate/default-parm-init.cpp
index 73ba8998df6a98..d1f407ad15c677 100644
--- a/clang/test/SemaTemplate/default-parm-init.cpp
+++ b/clang/test/SemaTemplate/default-parm-init.cpp
@@ -2,189 +2,3 @@
// RUN: %clang_cc1 -fsyntax-only -std=c++20 -verify %s
// expected-no-diagnostics
-namespace std {
-
-template<typename Signature> class function;
-
-template<typename R, typename... Args> class invoker_base {
-public:
- virtual ~invoker_base() { }
- virtual R invoke(Args...) = 0;
- virtual invoker_base* clone() = 0;
-};
-
-template<typename F, typename R, typename... Args>
-class functor_invoker : public invoker_base<R, Args...> {
-public:
- explicit functor_invoker(const F& f) : f(f) { }
- R invoke(Args... args) { return f(args...); }
- functor_invoker* clone() { return new functor_invoker(f); }
-
-private:
- F f;
-};
-
-template<typename R, typename... Args>
-class function<R (Args...)> {
-public:
- typedef R result_type;
- function() : invoker (0) { }
- function(const function& other) : invoker(0) {
- if (other.invoker)
- invoker = other.invoker->clone();
- }
-
- template<typename F> function(const F& f) : invoker(0) {
- invoker = new functor_invoker<F, R, Args...>(f);
- }
-
- ~function() {
- if (invoker)
- delete invoker;
- }
-
- function& operator=(const function& other) {
- function(other).swap(*this);
- return *this;
- }
-
- template<typename F>
- function& operator=(const F& f) {
- function(f).swap(*this);
- return *this;
- }
-
- void swap(function& other) {
- invoker_base<R, Args...>* tmp = invoker;
- invoker = other.invoker;
- other.invoker = tmp;
- }
-
- result_type operator()(Args... args) const {
- return invoker->invoke(args...);
- }
-
-private:
- invoker_base<R, Args...>* invoker;
-};
-
-}
-
-template<typename TemplateParam>
-struct Problem {
- template<typename FunctionTemplateParam>
- constexpr int FuncAlign(int param = alignof(FunctionTemplateParam));
-
- template<typename FunctionTemplateParam>
- constexpr int FuncSizeof(int param = sizeof(FunctionTemplateParam));
-
- template<typename FunctionTemplateParam>
- constexpr int FuncAlign2(int param = alignof(TemplateParam));
-
- template<typename FunctionTemplateParam>
- constexpr int FuncSizeof2(int param = sizeof(TemplateParam));
-};
-
-template<typename TemplateParam>
-struct Problem<TemplateParam*> {
- template<typename FunctionTemplateParam>
- constexpr int FuncAlign(int param = alignof(FunctionTemplateParam));
-
- template<typename FunctionTemplateParam>
- constexpr int FuncSizeof(int param = sizeof(FunctionTemplateParam));
-
- template<typename FunctionTemplateParam>
- constexpr int FuncAlign2(int param = alignof(TemplateParam));
-
- template<typename FunctionTemplateParam>
- constexpr int FuncSizeof2(int param = sizeof(TemplateParam));
-};
-
-template<typename TemplateParam>
-template<typename FunctionTemplateParam>
-constexpr int Problem<TemplateParam*>::FuncAlign(int param) {
- return 2U*param;
-}
-
-template<typename TemplateParam>
-template<typename FunctionTemplateParam>
-constexpr int Problem<TemplateParam*>::FuncSizeof(int param) {
- return 2U*param;
-}
-
-template<typename TemplateParam>
-template<typename FunctionTemplateParam>
-constexpr int Problem<TemplateParam*>::FuncAlign2(int param) {
- return 2U*param;
-}
-
-template<typename TemplateParam>
-template<typename FunctionTemplateParam>
-constexpr int Problem<TemplateParam*>::FuncSizeof2(int param) {
- return 2U*param;
-}
-
-template <>
-template<typename FunctionTemplateParam>
-constexpr int Problem<int>::FuncAlign(int param) {
- return param;
-}
-
-template <>
-template<typename FunctionTemplateParam>
-constexpr int Problem<int>::FuncSizeof(int param) {
- return param;
-}
-
-template <>
-template<typename FunctionTemplateParam>
-constexpr int Problem<int>::FuncAlign2(int param) {
- return param;
-}
-
-template <>
-template<typename FunctionTemplateParam>
-constexpr int Problem<int>::FuncSizeof2(int param) {
- return param;
-}
-
-void foo() {
- Problem<int> p = {};
- static_assert(p.FuncAlign<char>() == alignof(char));
- static_assert(p.FuncSizeof<char>() == sizeof(char));
- static_assert(p.FuncAlign2<char>() == alignof(int));
- static_assert(p.FuncSizeof2<char>() == sizeof(int));
- Problem<short*> q = {};
- static_assert(q.FuncAlign<char>() == 2U * alignof(char));
- static_assert(q.FuncSizeof<char>() == 2U * sizeof(char));
- static_assert(q.FuncAlign2<char>() == 2U *alignof(short));
- static_assert(q.FuncSizeof2<char>() == 2U * sizeof(short));
-}
-
-template <typename T>
-class A {
- public:
- void run(
- std::function<void(T&)> f1 = [](auto&&) {},
- std::function<void(T&)> f2 = [](auto&&) {});
- private:
- class Helper {
- public:
- explicit Helper(std::function<void(T&)> f2) : f2_(f2) {}
- std::function<void(T&)> f2_;
- };
-};
-
-template <typename T>
-void A<T>::run(std::function<void(T&)> f1,
- std::function<void(T&)> f2) {
- Helper h(f2);
-}
-
-struct B {};
-
-int main() {
- A<B> a;
- a.run([&](auto& l) {});
- return 0;
-}
``````````
</details>
https://github.com/llvm/llvm-project/pull/104911
More information about the cfe-commits
mailing list