[clang] [Clang][Sema] Improve support for explicit specializations of constrained member functions & member function templates (PR #88963)

Krystian Stasiowski via cfe-commits cfe-commits at lists.llvm.org
Tue May 7 08:18:20 PDT 2024


https://github.com/sdkrystian updated https://github.com/llvm/llvm-project/pull/88963

>From 01db101ca28f26181dfedeaef1ec49a5ae42ee99 Mon Sep 17 00:00:00 2001
From: Krystian Stasiowski <sdkrystian at gmail.com>
Date: Tue, 16 Apr 2024 13:36:11 -0400
Subject: [PATCH 1/7] [Clang][Sema] Improve support for explicit
 specializations of constrained member functions & member function templates

---
 clang/lib/Sema/SemaConcept.cpp                |  2 +-
 clang/lib/Sema/SemaOverload.cpp               | 11 ++-
 clang/lib/Sema/SemaTemplate.cpp               | 44 +++++++++--
 clang/lib/Sema/SemaTemplateInstantiate.cpp    |  4 +
 .../CXX/temp/temp.spec/temp.expl.spec/p8.cpp  | 74 +++++++++++++++++++
 5 files changed, 124 insertions(+), 11 deletions(-)
 create mode 100644 clang/test/CXX/temp/temp.spec/temp.expl.spec/p8.cpp

diff --git a/clang/lib/Sema/SemaConcept.cpp b/clang/lib/Sema/SemaConcept.cpp
index e00c972602829..7bfec4e11f7aa 100644
--- a/clang/lib/Sema/SemaConcept.cpp
+++ b/clang/lib/Sema/SemaConcept.cpp
@@ -811,7 +811,7 @@ static const Expr *SubstituteConstraintExpressionWithoutSatisfaction(
   // this may happen while we're comparing two templates' constraint
   // equivalence.
   LocalInstantiationScope ScopeForParameters(S);
-  if (auto *FD = llvm::dyn_cast<FunctionDecl>(DeclInfo.getDecl()))
+  if (auto *FD = DeclInfo.getDecl()->getAsFunction())
     for (auto *PVD : FD->parameters())
       ScopeForParameters.InstantiatedLocal(PVD, PVD);
 
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index a416df2e97c43..5c70588bd28b8 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -1303,6 +1303,8 @@ static bool IsOverloadOrOverrideImpl(Sema &SemaRef, FunctionDecl *New,
   if (New->isMSVCRTEntryPoint())
     return false;
 
+  NamedDecl *OldDecl = Old;
+  NamedDecl *NewDecl = New;
   FunctionTemplateDecl *OldTemplate = Old->getDescribedFunctionTemplate();
   FunctionTemplateDecl *NewTemplate = New->getDescribedFunctionTemplate();
 
@@ -1347,6 +1349,8 @@ static bool IsOverloadOrOverrideImpl(Sema &SemaRef, FunctionDecl *New,
   // references to non-instantiated entities during constraint substitution.
   // GH78101.
   if (NewTemplate) {
+    OldDecl = OldTemplate;
+    NewDecl = NewTemplate;
     // C++ [temp.over.link]p4:
     //   The signature of a function template consists of its function
     //   signature, its return type and its template parameter list. The names
@@ -1506,13 +1510,14 @@ static bool IsOverloadOrOverrideImpl(Sema &SemaRef, FunctionDecl *New,
     }
   }
 
-  if (!UseOverrideRules) {
+  if (!UseOverrideRules &&
+      New->getTemplateSpecializationKind() != TSK_ExplicitSpecialization) {
     Expr *NewRC = New->getTrailingRequiresClause(),
          *OldRC = Old->getTrailingRequiresClause();
     if ((NewRC != nullptr) != (OldRC != nullptr))
       return true;
-
-    if (NewRC && !SemaRef.AreConstraintExpressionsEqual(Old, OldRC, New, NewRC))
+    if (NewRC &&
+        !SemaRef.AreConstraintExpressionsEqual(OldDecl, OldRC, NewDecl, NewRC))
       return true;
   }
 
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index e647ac267ab39..eeadc47e99069 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -10354,6 +10354,25 @@ bool Sema::CheckFunctionTemplateSpecialization(
   return false;
 }
 
+static bool IsMoreConstrainedFunction(Sema &S, FunctionDecl *FD1,
+                                      FunctionDecl *FD2) {
+  if (FunctionDecl *MF = FD1->getInstantiatedFromMemberFunction())
+    FD1 = MF;
+  if (FunctionDecl *MF = FD2->getInstantiatedFromMemberFunction())
+    FD2 = MF;
+  llvm::SmallVector<const Expr *, 3> AC1, AC2;
+  FD1->getAssociatedConstraints(AC1);
+  FD2->getAssociatedConstraints(AC2);
+  bool AtLeastAsConstrained1, AtLeastAsConstrained2;
+  if (S.IsAtLeastAsConstrained(FD1, AC1, FD2, AC2, AtLeastAsConstrained1))
+    return false;
+  if (S.IsAtLeastAsConstrained(FD2, AC2, FD1, AC1, AtLeastAsConstrained2))
+    return false;
+  if (AtLeastAsConstrained1 == AtLeastAsConstrained2)
+    return false;
+  return AtLeastAsConstrained1;
+}
+
 /// Perform semantic analysis for the given non-template member
 /// specialization.
 ///
@@ -10388,15 +10407,26 @@ Sema::CheckMemberSpecialization(NamedDecl *Member, LookupResult &Previous) {
         QualType Adjusted = Function->getType();
         if (!hasExplicitCallingConv(Adjusted))
           Adjusted = adjustCCAndNoReturn(Adjusted, Method->getType());
+        if (!Context.hasSameType(Adjusted, Method->getType()))
+          continue;
+        if (Method->getTrailingRequiresClause()) {
+          ConstraintSatisfaction Satisfaction;
+          if (CheckFunctionConstraints(Method, Satisfaction,
+                                       /*UsageLoc=*/Member->getLocation(),
+                                       /*ForOverloadResolution=*/true) ||
+              !Satisfaction.IsSatisfied)
+            continue;
+          if (Instantiation &&
+              !IsMoreConstrainedFunction(*this, Method,
+                                         cast<CXXMethodDecl>(Instantiation)))
+            continue;
+        }
         // This doesn't handle deduced return types, but both function
         // declarations should be undeduced at this point.
-        if (Context.hasSameType(Adjusted, Method->getType())) {
-          FoundInstantiation = *I;
-          Instantiation = Method;
-          InstantiatedFrom = Method->getInstantiatedFromMemberFunction();
-          MSInfo = Method->getMemberSpecializationInfo();
-          break;
-        }
+        FoundInstantiation = *I;
+        Instantiation = Method;
+        InstantiatedFrom = Method->getInstantiatedFromMemberFunction();
+        MSInfo = Method->getMemberSpecializationInfo();
       }
     }
   } else if (isa<VarDecl>(Member)) {
diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index 3a9fd906b7af8..c5e832c97d8a0 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -275,6 +275,10 @@ Response HandleFunction(Sema &SemaRef, const FunctionDecl *Function,
                                      TemplateArgs->asArray(),
                                      /*Final=*/false);
 
+    if (RelativeToPrimary &&
+        Function->getTemplateSpecializationKind() == TSK_ExplicitSpecialization)
+      return Response::UseNextDecl(Function);
+
     // If this function was instantiated from a specialized member that is
     // a function template, we're done.
     assert(Function->getPrimaryTemplate() && "No function template?");
diff --git a/clang/test/CXX/temp/temp.spec/temp.expl.spec/p8.cpp b/clang/test/CXX/temp/temp.spec/temp.expl.spec/p8.cpp
new file mode 100644
index 0000000000000..87e10d10e4b45
--- /dev/null
+++ b/clang/test/CXX/temp/temp.spec/temp.expl.spec/p8.cpp
@@ -0,0 +1,74 @@
+// RUN: %clang_cc1 -std=c++20 -fsyntax-only -verify %s
+// expected-no-diagnostics
+
+template<typename T>
+concept C = sizeof(T) <= sizeof(long);
+
+template<typename T>
+struct A {
+  template<typename U>
+  void f(U) requires C<U>;
+
+  void g() requires C<T>;
+
+  template<typename U>
+  void h(U) requires C<T>;
+
+  constexpr int i() requires C<T> {
+    return 0;
+  }
+
+  constexpr int i() requires C<T> && true {
+    return 1;
+  }
+
+  template<>
+  void f(char);
+};
+
+template<>
+template<typename U>
+void A<short>::f(U) requires C<U>;
+
+template<>
+template<typename U>
+void A<short>::h(U) requires C<short>;
+
+template<>
+template<>
+void A<int>::f(int);
+
+template<>
+void A<long>::g();
+
+template<>
+constexpr int A<long>::i() {
+  return 2;
+}
+
+static_assert(A<long>().i() == 2);
+
+template<typename T>
+struct D {
+  template<typename U>
+  static constexpr int f(U);
+
+  template<typename U>
+  static constexpr int f(U) requires (sizeof(T) == 1);
+
+  template<>
+  constexpr int f(int) {
+    return 1;
+  }
+};
+
+template<>
+template<typename U>
+constexpr int D<signed char>::f(U) requires (sizeof(signed char) == 1) {
+  return 0;
+}
+
+static_assert(D<char>::f(0) == 1);
+static_assert(D<char[2]>::f(0) == 1);
+static_assert(D<signed char>::f(0) == 1);
+static_assert(D<signed char>::f(0.0) == 0);

>From f47395a64c115c499b706017b4b86ce3dfa60afe Mon Sep 17 00:00:00 2001
From: Krystian Stasiowski <sdkrystian at gmail.com>
Date: Tue, 30 Apr 2024 16:14:47 -0400
Subject: [PATCH 2/7] [FOLD] handle friend function template specializations

---
 clang/lib/Sema/SemaTemplateInstantiate.cpp | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp
index c5e832c97d8a0..07626058c7977 100644
--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -276,7 +276,10 @@ Response HandleFunction(Sema &SemaRef, const FunctionDecl *Function,
                                      /*Final=*/false);
 
     if (RelativeToPrimary &&
-        Function->getTemplateSpecializationKind() == TSK_ExplicitSpecialization)
+        (Function->getTemplateSpecializationKind() ==
+             TSK_ExplicitSpecialization ||
+         (Function->getFriendObjectKind() &&
+          !Function->getPrimaryTemplate()->getFriendObjectKind())))
       return Response::UseNextDecl(Function);
 
     // If this function was instantiated from a specialized member that is

>From 342b06c7d00d1f349ef93f5fb37db4fe86cf6762 Mon Sep 17 00:00:00 2001
From: Krystian Stasiowski <sdkrystian at gmail.com>
Date: Thu, 2 May 2024 16:16:07 -0400
Subject: [PATCH 3/7] [FOLD] add more tests, diagnose ambiguous member function
 specializations

---
 .../clang/Basic/DiagnosticSemaKinds.td        |  5 ++
 clang/include/clang/Sema/Sema.h               |  3 +
 clang/lib/Sema/SemaTemplate.cpp               | 89 ++++++++++++-------
 .../temp/temp.spec/temp.expl.spec/p14-23.cpp  | 58 ++++++++++++
 4 files changed, 122 insertions(+), 33 deletions(-)
 create mode 100644 clang/test/CXX/temp/temp.spec/temp.expl.spec/p14-23.cpp

diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 9a0bae9c216de..2d577e432cc28 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -5437,6 +5437,11 @@ def note_function_template_spec_matched : Note<
 def err_function_template_partial_spec : Error<
     "function template partial specialization is not allowed">;
 
+def err_function_member_spec_ambiguous : Error<
+    "ambiguous member function specialization of %q0">;
+def note_function_member_spec_matched : Note<
+    "member function specialization matches %0">;
+
 // C++ Template Instantiation
 def err_template_recursion_depth_exceeded : Error<
   "recursive template instantiation exceeded maximum depth of %0">,
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index a80ac6dbc7613..ddb3de2b66023 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -9739,6 +9739,9 @@ class Sema final : public SemaBase {
                      const PartialDiagnostic &CandidateDiag,
                      bool Complain = true, QualType TargetType = QualType());
 
+  FunctionDecl *getMoreConstrainedFunction(FunctionDecl *FD1,
+                                           FunctionDecl *FD2);
+
   ///@}
 
   //
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index eeadc47e99069..08096fefdc82d 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -10354,23 +10354,27 @@ bool Sema::CheckFunctionTemplateSpecialization(
   return false;
 }
 
-static bool IsMoreConstrainedFunction(Sema &S, FunctionDecl *FD1,
-                                      FunctionDecl *FD2) {
+FunctionDecl *Sema::getMoreConstrainedFunction(FunctionDecl *FD1,
+                                               FunctionDecl *FD2) {
+  assert(!FD1->getDescribedTemplate() && !FD2->getDescribedTemplate() &&
+         "not for function templates");
+  FunctionDecl *F1 = FD1;
   if (FunctionDecl *MF = FD1->getInstantiatedFromMemberFunction())
-    FD1 = MF;
+    F1 = MF;
+  FunctionDecl *F2 = FD2;
   if (FunctionDecl *MF = FD2->getInstantiatedFromMemberFunction())
-    FD2 = MF;
+    F2 = MF;
   llvm::SmallVector<const Expr *, 3> AC1, AC2;
-  FD1->getAssociatedConstraints(AC1);
-  FD2->getAssociatedConstraints(AC2);
+  F1->getAssociatedConstraints(AC1);
+  F2->getAssociatedConstraints(AC2);
   bool AtLeastAsConstrained1, AtLeastAsConstrained2;
-  if (S.IsAtLeastAsConstrained(FD1, AC1, FD2, AC2, AtLeastAsConstrained1))
-    return false;
-  if (S.IsAtLeastAsConstrained(FD2, AC2, FD1, AC1, AtLeastAsConstrained2))
-    return false;
+  if (IsAtLeastAsConstrained(F1, AC1, F2, AC2, AtLeastAsConstrained1))
+    return nullptr;
+  if (IsAtLeastAsConstrained(F2, AC2, F1, AC1, AtLeastAsConstrained2))
+    return nullptr;
   if (AtLeastAsConstrained1 == AtLeastAsConstrained2)
-    return false;
-  return AtLeastAsConstrained1;
+    return nullptr;
+  return AtLeastAsConstrained1 ? FD1 : FD2;
 }
 
 /// Perform semantic analysis for the given non-template member
@@ -10400,35 +10404,54 @@ Sema::CheckMemberSpecialization(NamedDecl *Member, LookupResult &Previous) {
   if (Previous.empty()) {
     // Nowhere to look anyway.
   } else if (FunctionDecl *Function = dyn_cast<FunctionDecl>(Member)) {
+    SmallVector<FunctionDecl *> Candidates;
+    bool Ambiguous = false;
     for (LookupResult::iterator I = Previous.begin(), E = Previous.end();
            I != E; ++I) {
-      NamedDecl *D = (*I)->getUnderlyingDecl();
-      if (CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(D)) {
-        QualType Adjusted = Function->getType();
-        if (!hasExplicitCallingConv(Adjusted))
-          Adjusted = adjustCCAndNoReturn(Adjusted, Method->getType());
-        if (!Context.hasSameType(Adjusted, Method->getType()))
-          continue;
-        if (Method->getTrailingRequiresClause()) {
-          ConstraintSatisfaction Satisfaction;
-          if (CheckFunctionConstraints(Method, Satisfaction,
-                                       /*UsageLoc=*/Member->getLocation(),
-                                       /*ForOverloadResolution=*/true) ||
-              !Satisfaction.IsSatisfied)
-            continue;
-          if (Instantiation &&
-              !IsMoreConstrainedFunction(*this, Method,
-                                         cast<CXXMethodDecl>(Instantiation)))
-            continue;
-        }
-        // This doesn't handle deduced return types, but both function
-        // declarations should be undeduced at this point.
+      CXXMethodDecl *Method =
+          dyn_cast<CXXMethodDecl>((*I)->getUnderlyingDecl());
+      if (!Method)
+        continue;
+      QualType Adjusted = Function->getType();
+      if (!hasExplicitCallingConv(Adjusted))
+        Adjusted = adjustCCAndNoReturn(Adjusted, Method->getType());
+      // This doesn't handle deduced return types, but both function
+      // declarations should be undeduced at this point.
+      if (!Context.hasSameType(Adjusted, Function->getType()))
+        continue;
+      // FIXME: What if neither function is more constrained than the other?
+      if (ConstraintSatisfaction Satisfaction;
+          Method->getTrailingRequiresClause() &&
+          (CheckFunctionConstraints(Method, Satisfaction,
+                                    /*UsageLoc=*/Member->getLocation(),
+                                    /*ForOverloadResolution=*/true) ||
+           !Satisfaction.IsSatisfied))
+        continue;
+      Candidates.push_back(Method);
+      FunctionDecl *MoreConstrained =
+          Instantiation ? getMoreConstrainedFunction(
+                              Method, cast<FunctionDecl>(Instantiation))
+                        : Method;
+      if (!MoreConstrained) {
+        Ambiguous = true;
+        continue;
+      }
+      if (MoreConstrained == Method) {
+        Ambiguous = false;
         FoundInstantiation = *I;
         Instantiation = Method;
         InstantiatedFrom = Method->getInstantiatedFromMemberFunction();
         MSInfo = Method->getMemberSpecializationInfo();
       }
     }
+    if (Ambiguous) {
+      Diag(Member->getLocation(), diag::err_function_member_spec_ambiguous)
+          << Member;
+      for (FunctionDecl *Candidate : Candidates)
+        Diag(Candidate->getLocation(), diag::note_function_member_spec_matched)
+            << Candidate;
+      return true;
+    }
   } else if (isa<VarDecl>(Member)) {
     VarDecl *PrevVar;
     if (Previous.isSingleResult() &&
diff --git a/clang/test/CXX/temp/temp.spec/temp.expl.spec/p14-23.cpp b/clang/test/CXX/temp/temp.spec/temp.expl.spec/p14-23.cpp
new file mode 100644
index 0000000000000..21974fdf358c2
--- /dev/null
+++ b/clang/test/CXX/temp/temp.spec/temp.expl.spec/p14-23.cpp
@@ -0,0 +1,58 @@
+// RUN: %clang_cc1 -std=c++20 -verify %s
+
+template<int I>
+concept C = I >= 4;
+
+template<int I>
+concept D = I < 8;
+
+template<int I>
+struct A {
+  constexpr static int f() { return 0; }
+  constexpr static int f() requires C<I> && D<I> { return 1; }
+  constexpr static int f() requires C<I> { return 2; }
+
+  constexpr static int g() requires C<I> { return 0; } // expected-note {{member function specialization matches 'g'}}
+  constexpr static int g() requires D<I> { return 1; } // expected-note {{member function specialization matches 'g'}}
+
+  constexpr static int h() requires C<I> { return 0; } // expected-note {{member declaration nearly matches}}
+};
+
+template<>
+constexpr int A<2>::f() { return 3; }
+
+template<>
+constexpr int A<4>::f() { return 4; }
+
+template<>
+constexpr int A<8>::f() { return 5; }
+
+static_assert(A<3>::f() == 0);
+static_assert(A<5>::f() == 1);
+static_assert(A<9>::f() == 2);
+static_assert(A<2>::f() == 3);
+static_assert(A<4>::f() == 4);
+static_assert(A<8>::f() == 5);
+
+template<>
+constexpr int A<0>::g() { return 2; }
+
+template<>
+constexpr int A<8>::g() { return 3; }
+
+template<>
+constexpr int A<6>::g() { return 4; } // expected-error {{ambiguous member function specialization of 'A<6>::g'}}
+
+static_assert(A<9>::g() == 0);
+static_assert(A<1>::g() == 1);
+static_assert(A<0>::g() == 2);
+static_assert(A<8>::g() == 3);
+
+template<>
+constexpr int A<4>::h() { return 1; }
+
+template<>
+constexpr int A<0>::h() { return 2; } // expected-error {{out-of-line definition of 'h' does not match any declaration in 'A<0>'}}
+
+static_assert(A<5>::h() == 0);
+static_assert(A<4>::h() == 1);

>From b88d161b67c20806da3682214c8abff095772e06 Mon Sep 17 00:00:00 2001
From: Krystian Stasiowski <sdkrystian at gmail.com>
Date: Thu, 2 May 2024 16:29:07 -0400
Subject: [PATCH 4/7] [FOLD] remove outdated comment

---
 clang/lib/Sema/SemaTemplate.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 08096fefdc82d..f104ee4aeb72f 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -10419,7 +10419,6 @@ Sema::CheckMemberSpecialization(NamedDecl *Member, LookupResult &Previous) {
       // declarations should be undeduced at this point.
       if (!Context.hasSameType(Adjusted, Function->getType()))
         continue;
-      // FIXME: What if neither function is more constrained than the other?
       if (ConstraintSatisfaction Satisfaction;
           Method->getTrailingRequiresClause() &&
           (CheckFunctionConstraints(Method, Satisfaction,

>From 1249b562ee5b8ea06e05d5592013c76499e520e5 Mon Sep 17 00:00:00 2001
From: Krystian Stasiowski <sdkrystian at gmail.com>
Date: Mon, 6 May 2024 07:12:37 -0400
Subject: [PATCH 5/7] [FOLD] use correct function type for comparison

---
 clang/lib/Sema/SemaTemplate.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index f104ee4aeb72f..d6209857d8b85 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -10417,7 +10417,7 @@ Sema::CheckMemberSpecialization(NamedDecl *Member, LookupResult &Previous) {
         Adjusted = adjustCCAndNoReturn(Adjusted, Method->getType());
       // This doesn't handle deduced return types, but both function
       // declarations should be undeduced at this point.
-      if (!Context.hasSameType(Adjusted, Function->getType()))
+      if (!Context.hasSameType(Adjusted, Method->getType()))
         continue;
       if (ConstraintSatisfaction Satisfaction;
           Method->getTrailingRequiresClause() &&

>From 6c14b675e601f0f46d9388438ef57f7b29f3a49d Mon Sep 17 00:00:00 2001
From: Krystian Stasiowski <sdkrystian at gmail.com>
Date: Mon, 6 May 2024 08:44:35 -0400
Subject: [PATCH 6/7] [FOLD] move getMoreConstrainedFunction to
 SemaTemplateDeduction.cpp

---
 clang/lib/Sema/SemaTemplate.cpp          | 23 -----------------
 clang/lib/Sema/SemaTemplateDeduction.cpp | 32 ++++++++++++++++++++++++
 2 files changed, 32 insertions(+), 23 deletions(-)

diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index d6209857d8b85..d617ee5b9d8b1 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -10354,29 +10354,6 @@ bool Sema::CheckFunctionTemplateSpecialization(
   return false;
 }
 
-FunctionDecl *Sema::getMoreConstrainedFunction(FunctionDecl *FD1,
-                                               FunctionDecl *FD2) {
-  assert(!FD1->getDescribedTemplate() && !FD2->getDescribedTemplate() &&
-         "not for function templates");
-  FunctionDecl *F1 = FD1;
-  if (FunctionDecl *MF = FD1->getInstantiatedFromMemberFunction())
-    F1 = MF;
-  FunctionDecl *F2 = FD2;
-  if (FunctionDecl *MF = FD2->getInstantiatedFromMemberFunction())
-    F2 = MF;
-  llvm::SmallVector<const Expr *, 3> AC1, AC2;
-  F1->getAssociatedConstraints(AC1);
-  F2->getAssociatedConstraints(AC2);
-  bool AtLeastAsConstrained1, AtLeastAsConstrained2;
-  if (IsAtLeastAsConstrained(F1, AC1, F2, AC2, AtLeastAsConstrained1))
-    return nullptr;
-  if (IsAtLeastAsConstrained(F2, AC2, F1, AC1, AtLeastAsConstrained2))
-    return nullptr;
-  if (AtLeastAsConstrained1 == AtLeastAsConstrained2)
-    return nullptr;
-  return AtLeastAsConstrained1 ? FD1 : FD2;
-}
-
 /// Perform semantic analysis for the given non-template member
 /// specialization.
 ///
diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp
index 9f9e442282717..04998b564e3f8 100644
--- a/clang/lib/Sema/SemaTemplateDeduction.cpp
+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -5878,6 +5878,38 @@ UnresolvedSetIterator Sema::getMostSpecialized(
   return SpecEnd;
 }
 
+/// Returns the more constrained function according to the rules of
+/// partial ordering by constraints (C++ [temp.constr.order]).
+///
+/// \param FD1 the first function
+///
+/// \param FD2 the second function
+///
+/// \returns the more constrained function. If neither function is
+/// more constrained, returns NULL.
+FunctionDecl *Sema::getMoreConstrainedFunction(FunctionDecl *FD1,
+                                               FunctionDecl *FD2) {
+  assert(!FD1->getDescribedTemplate() && !FD2->getDescribedTemplate() &&
+         "not for function templates");
+  FunctionDecl *F1 = FD1;
+  if (FunctionDecl *MF = FD1->getInstantiatedFromMemberFunction())
+    F1 = MF;
+  FunctionDecl *F2 = FD2;
+  if (FunctionDecl *MF = FD2->getInstantiatedFromMemberFunction())
+    F2 = MF;
+  llvm::SmallVector<const Expr *, 3> AC1, AC2;
+  F1->getAssociatedConstraints(AC1);
+  F2->getAssociatedConstraints(AC2);
+  bool AtLeastAsConstrained1, AtLeastAsConstrained2;
+  if (IsAtLeastAsConstrained(F1, AC1, F2, AC2, AtLeastAsConstrained1))
+    return nullptr;
+  if (IsAtLeastAsConstrained(F2, AC2, F1, AC1, AtLeastAsConstrained2))
+    return nullptr;
+  if (AtLeastAsConstrained1 == AtLeastAsConstrained2)
+    return nullptr;
+  return AtLeastAsConstrained1 ? FD1 : FD2;
+}
+
 /// Determine whether one partial specialization, P1, is at least as
 /// specialized than another, P2.
 ///

>From 4e49baecfa86817e0822fc79504e1019f936eda8 Mon Sep 17 00:00:00 2001
From: Krystian Stasiowski <sdkrystian at gmail.com>
Date: Tue, 7 May 2024 08:31:02 -0400
Subject: [PATCH 7/7] [FOLD] use getMoreConstrainedFunction in more places

---
 clang/lib/Sema/SemaOverload.cpp          | 61 +++++-------------------
 clang/lib/Sema/SemaTemplateDeduction.cpp |  2 +-
 2 files changed, 12 insertions(+), 51 deletions(-)

diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 5c70588bd28b8..f173300b5c96c 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -10700,29 +10700,10 @@ bool clang::isBetterOverloadCandidate(
   //   -— F1 and F2 are non-template functions with the same
   //      parameter-type-lists, and F1 is more constrained than F2 [...],
   if (!Cand1IsSpecialization && !Cand2IsSpecialization &&
-      sameFunctionParameterTypeLists(S, Cand1, Cand2)) {
-    FunctionDecl *Function1 = Cand1.Function;
-    FunctionDecl *Function2 = Cand2.Function;
-    if (FunctionDecl *MF = Function1->getInstantiatedFromMemberFunction())
-      Function1 = MF;
-    if (FunctionDecl *MF = Function2->getInstantiatedFromMemberFunction())
-      Function2 = MF;
-
-    const Expr *RC1 = Function1->getTrailingRequiresClause();
-    const Expr *RC2 = Function2->getTrailingRequiresClause();
-    if (RC1 && RC2) {
-      bool AtLeastAsConstrained1, AtLeastAsConstrained2;
-      if (S.IsAtLeastAsConstrained(Function1, RC1, Function2, RC2,
-                                   AtLeastAsConstrained1) ||
-          S.IsAtLeastAsConstrained(Function2, RC2, Function1, RC1,
-                                   AtLeastAsConstrained2))
-        return false;
-      if (AtLeastAsConstrained1 != AtLeastAsConstrained2)
-        return AtLeastAsConstrained1;
-    } else if (RC1 || RC2) {
-      return RC1 != nullptr;
-    }
-  }
+      sameFunctionParameterTypeLists(S, Cand1, Cand2) &&
+      S.getMoreConstrainedFunction(Cand1.Function, Cand2.Function) ==
+          Cand1.Function)
+    return true;
 
   //   -- F1 is a constructor for a class D, F2 is a constructor for a base
   //      class B of D, and for all arguments the corresponding parameters of
@@ -13390,25 +13371,6 @@ Sema::resolveAddressOfSingleOverloadCandidate(Expr *E, DeclAccessPair &Pair) {
            static_cast<int>(CUDA().IdentifyPreference(Caller, FD2));
   };
 
-  auto CheckMoreConstrained = [&](FunctionDecl *FD1,
-                                  FunctionDecl *FD2) -> std::optional<bool> {
-    if (FunctionDecl *MF = FD1->getInstantiatedFromMemberFunction())
-      FD1 = MF;
-    if (FunctionDecl *MF = FD2->getInstantiatedFromMemberFunction())
-      FD2 = MF;
-    SmallVector<const Expr *, 1> AC1, AC2;
-    FD1->getAssociatedConstraints(AC1);
-    FD2->getAssociatedConstraints(AC2);
-    bool AtLeastAsConstrained1, AtLeastAsConstrained2;
-    if (IsAtLeastAsConstrained(FD1, AC1, FD2, AC2, AtLeastAsConstrained1))
-      return std::nullopt;
-    if (IsAtLeastAsConstrained(FD2, AC2, FD1, AC1, AtLeastAsConstrained2))
-      return std::nullopt;
-    if (AtLeastAsConstrained1 == AtLeastAsConstrained2)
-      return std::nullopt;
-    return AtLeastAsConstrained1;
-  };
-
   // Don't use the AddressOfResolver because we're specifically looking for
   // cases where we have one overload candidate that lacks
   // enable_if/pass_object_size/...
@@ -13445,15 +13407,14 @@ Sema::resolveAddressOfSingleOverloadCandidate(Expr *E, DeclAccessPair &Pair) {
       }
       // FD has the same CUDA prefernece than Result. Continue check
       // constraints.
-      std::optional<bool> MoreConstrainedThanPrevious =
-          CheckMoreConstrained(FD, Result);
-      if (!MoreConstrainedThanPrevious) {
-        IsResultAmbiguous = true;
-        AmbiguousDecls.push_back(FD);
+      FunctionDecl *MoreConstrained = getMoreConstrainedFunction(FD, Result);
+      if (MoreConstrained != FD) {
+        if (!MoreConstrained) {
+          IsResultAmbiguous = true;
+          AmbiguousDecls.push_back(FD);
+        }
         continue;
       }
-      if (!*MoreConstrainedThanPrevious)
-        continue;
       // FD is more constrained - replace Result with it.
     }
     FoundBetter();
@@ -13472,7 +13433,7 @@ Sema::resolveAddressOfSingleOverloadCandidate(Expr *E, DeclAccessPair &Pair) {
       // constraints.
       if (getLangOpts().CUDA && CheckCUDAPreference(Skipped, Result) != 0)
         continue;
-      if (!CheckMoreConstrained(Skipped, Result))
+      if (!getMoreConstrainedFunction(Skipped, Result))
         return nullptr;
     }
     Pair = DAP;
diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp
index 04998b564e3f8..237ce27713376 100644
--- a/clang/lib/Sema/SemaTemplateDeduction.cpp
+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -5897,7 +5897,7 @@ FunctionDecl *Sema::getMoreConstrainedFunction(FunctionDecl *FD1,
   FunctionDecl *F2 = FD2;
   if (FunctionDecl *MF = FD2->getInstantiatedFromMemberFunction())
     F2 = MF;
-  llvm::SmallVector<const Expr *, 3> AC1, AC2;
+  llvm::SmallVector<const Expr *, 1> AC1, AC2;
   F1->getAssociatedConstraints(AC1);
   F2->getAssociatedConstraints(AC2);
   bool AtLeastAsConstrained1, AtLeastAsConstrained2;



More information about the cfe-commits mailing list