[clang] [Clang] make most enums trivially equality comparable (PR #169079)

via cfe-commits cfe-commits at lists.llvm.org
Sat Nov 22 09:39:30 PST 2025


https://github.com/halbi2 updated https://github.com/llvm/llvm-project/pull/169079

>From cf66bc88862b47fd7bc97a71f0c5ca81d3f60a8a Mon Sep 17 00:00:00 2001
From: halbi2 <hehiralbi at gmail.com>
Date: Fri, 21 Nov 2025 13:22:30 -0500
Subject: [PATCH 1/4] [Clang] make most enums trivially equality comparable

std::equal(std::byte) currently has sub-optimal codegen due to enum types
not being recognized as trivially equality comparable. In order to fix this
we make them trivially comparable. In the process I factored out into a
standalone function EqualityComparisonIsDefaulted and refactored the test
cases.

Enum types cannot have operator== which is a hidden friend.

Fixes #132672
---
 clang/lib/Sema/SemaTypeTraits.cpp  | 83 ++++++++++++++++++------------
 clang/test/SemaCXX/type-traits.cpp | 46 +++++++----------
 2 files changed, 70 insertions(+), 59 deletions(-)

diff --git a/clang/lib/Sema/SemaTypeTraits.cpp b/clang/lib/Sema/SemaTypeTraits.cpp
index 38877967af05e..581989e6d0069 100644
--- a/clang/lib/Sema/SemaTypeTraits.cpp
+++ b/clang/lib/Sema/SemaTypeTraits.cpp
@@ -591,6 +591,43 @@ static bool HasNoThrowOperator(CXXRecordDecl *RD, OverloadedOperatorKind Op,
   return false;
 }
 
+static bool EqualityComparisonIsDefaulted(Sema &S, const TypeDecl *Decl,
+                                          SourceLocation KeyLoc) {
+  CanQualType T = S.Context.getCanonicalTagType(Decl);
+
+  EnterExpressionEvaluationContext UnevaluatedContext(
+      S, Sema::ExpressionEvaluationContext::Unevaluated);
+  Sema::SFINAETrap SFINAE(S, /*ForValidityCheck=*/true);
+  Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
+
+  // const ClassT& obj;
+  OpaqueValueExpr Operand(
+      KeyLoc, T.withConst(),
+      ExprValueKind::VK_LValue);
+  UnresolvedSet<16> Functions;
+  // obj == obj;
+  S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
+
+  auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
+                                        Functions, &Operand, &Operand);
+  if (Result.isInvalid() || SFINAE.hasErrorOccurred())
+    return false;
+
+  const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
+  if (!CallExpr)
+    return isa<EnumDecl>(Decl);
+  const auto *Callee = CallExpr->getDirectCallee();
+  auto ParamT = Callee->getParamDecl(0)->getType();
+  if (!Callee->isDefaulted())
+    return false;
+  if (!ParamT->isReferenceType()) {
+    const CXXRecordDecl *RD = dyn_cast<CXXRecordDecl>(Decl);
+    if (!RD->isTriviallyCopyable())
+      return false;
+  }
+  return S.Context.hasSameUnqualifiedType(ParamT.getNonReferenceType(), T);
+}
+
 static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
                                                      const CXXRecordDecl *Decl,
                                                      SourceLocation KeyLoc) {
@@ -599,36 +636,8 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
   if (Decl->isLambda())
     return Decl->isCapturelessLambda();
 
-  CanQualType T = S.Context.getCanonicalTagType(Decl);
-  {
-    EnterExpressionEvaluationContext UnevaluatedContext(
-        S, Sema::ExpressionEvaluationContext::Unevaluated);
-    Sema::SFINAETrap SFINAE(S, /*ForValidityCheck=*/true);
-    Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
-
-    // const ClassT& obj;
-    OpaqueValueExpr Operand(KeyLoc, T.withConst(), ExprValueKind::VK_LValue);
-    UnresolvedSet<16> Functions;
-    // obj == obj;
-    S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
-
-    auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
-                                          Functions, &Operand, &Operand);
-    if (Result.isInvalid() || SFINAE.hasErrorOccurred())
-      return false;
-
-    const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
-    if (!CallExpr)
-      return false;
-    const auto *Callee = CallExpr->getDirectCallee();
-    auto ParamT = Callee->getParamDecl(0)->getType();
-    if (!Callee->isDefaulted())
-      return false;
-    if (!ParamT->isReferenceType() && !Decl->isTriviallyCopyable())
-      return false;
-    if (!S.Context.hasSameUnqualifiedType(ParamT.getNonReferenceType(), T))
-      return false;
-  }
+  if (!EqualityComparisonIsDefaulted(S, Decl, KeyLoc))
+    return false;
 
   return llvm::all_of(Decl->bases(),
                       [&](const CXXBaseSpecifier &BS) {
@@ -643,9 +652,12 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
              Type = Type->getBaseElementTypeUnsafe()
                         ->getCanonicalTypeUnqualified();
 
-           if (Type->isReferenceType() || Type->isEnumeralType())
+           if (Type->isReferenceType())
              return false;
-           if (const auto *RD = Type->getAsCXXRecordDecl())
+           if (Type->isEnumeralType()) {
+             EnumDecl *ED = Type->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
+             return EqualityComparisonIsDefaulted(S, ED, KeyLoc);
+           } else if (const auto *RD = Type->getAsCXXRecordDecl())
              return HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc);
            return true;
          });
@@ -655,9 +667,14 @@ static bool isTriviallyEqualityComparableType(Sema &S, QualType Type,
                                               SourceLocation KeyLoc) {
   QualType CanonicalType = Type.getCanonicalType();
   if (CanonicalType->isIncompleteType() || CanonicalType->isDependentType() ||
-      CanonicalType->isEnumeralType() || CanonicalType->isArrayType())
+      CanonicalType->isArrayType())
     return false;
 
+  if (CanonicalType->isEnumeralType()) {
+    EnumDecl *ED = CanonicalType->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
+    return EqualityComparisonIsDefaulted(S, ED, KeyLoc);
+  }
+
   if (const auto *RD = CanonicalType->getAsCXXRecordDecl()) {
     if (!HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc))
       return false;
diff --git a/clang/test/SemaCXX/type-traits.cpp b/clang/test/SemaCXX/type-traits.cpp
index 9ef44d0346b48..76fa4a9c2b936 100644
--- a/clang/test/SemaCXX/type-traits.cpp
+++ b/clang/test/SemaCXX/type-traits.cpp
@@ -3993,6 +3993,10 @@ namespace is_trivially_equality_comparable {
 struct ForwardDeclared; // expected-note {{forward declaration of 'is_trivially_equality_comparable::ForwardDeclared'}}
 static_assert(!__is_trivially_equality_comparable(ForwardDeclared)); // expected-error {{incomplete type 'ForwardDeclared' used in type trait expression}}
 
+enum Enum {};
+enum EnumWithOpEq {};
+bool operator==(EnumWithOpEq, EnumWithOpEq);
+
 static_assert(!__is_trivially_equality_comparable(void));
 static_assert(__is_trivially_equality_comparable(int));
 static_assert(!__is_trivially_equality_comparable(int[]));
@@ -4000,6 +4004,8 @@ static_assert(!__is_trivially_equality_comparable(int[3]));
 static_assert(!__is_trivially_equality_comparable(float));
 static_assert(!__is_trivially_equality_comparable(double));
 static_assert(!__is_trivially_equality_comparable(long double));
+static_assert(__is_trivially_equality_comparable(Enum));
+static_assert(!__is_trivially_equality_comparable(EnumWithOpEq));
 
 struct NonTriviallyEqualityComparableNoComparator {
   int i;
@@ -4033,19 +4039,21 @@ struct TriviallyEqualityComparable {
 };
 static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparable));
 
-struct TriviallyEqualityComparableContainsArray {
-  int a[4];
+template <class T>
+struct TriviallyEqualityComparableContains {
+  T t;
 
-  bool operator==(const TriviallyEqualityComparableContainsArray&) const = default;
+  bool operator==(const TriviallyEqualityComparableContains&) const = default;
 };
-static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsArray));
 
-struct TriviallyEqualityComparableContainsMultiDimensionArray {
-  int a[4][4];
-
-  bool operator==(const TriviallyEqualityComparableContainsMultiDimensionArray&) const = default;
-};
-static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsMultiDimensionArray));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<float>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<double>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<long double>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4][4]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq>));
 
 auto GetNonCapturingLambda() { return [](){ return 42; }; }
 
@@ -4196,13 +4204,6 @@ struct NotTriviallyEqualityComparableNonTriviallyComparableBase : NotTriviallyEq
 };
 static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableNonTriviallyComparableBase));
 
-enum E {
-  a,
-  b
-};
-bool operator==(E, E) { return false; }
-static_assert(!__is_trivially_equality_comparable(E));
-
 struct NotTriviallyEqualityComparableHasEnum {
   E e;
   bool operator==(const NotTriviallyEqualityComparableHasEnum&) const = default;
@@ -4434,15 +4435,8 @@ struct NotTriviallyEqualityComparableHasReferenceMember {
 };
 static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableHasReferenceMember));
 
-enum E {
-  a,
-  b
-};
-bool operator==(E, E) { return false; }
-static_assert(!__is_trivially_equality_comparable(E));
-
 struct NotTriviallyEqualityComparableHasEnum {
-  E e;
+  Enum e;
   friend bool operator==(const NotTriviallyEqualityComparableHasEnum&, const NotTriviallyEqualityComparableHasEnum&) = default;
 };
 static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableHasEnum));
@@ -4465,7 +4459,7 @@ static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableRefC
 }
 
 #endif // __cplusplus >= 202002L
-};
+}
 
 namespace can_pass_in_regs {
 

>From 702699243051f408e0615de24f19169cb445e5d7 Mon Sep 17 00:00:00 2001
From: halbi2 <hehiralbi at gmail.com>
Date: Sat, 22 Nov 2025 11:45:45 -0500
Subject: [PATCH 2/4] fix clang-format and build error

---
 clang/lib/Sema/SemaTypeTraits.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/clang/lib/Sema/SemaTypeTraits.cpp b/clang/lib/Sema/SemaTypeTraits.cpp
index 581989e6d0069..c816b900bef1d 100644
--- a/clang/lib/Sema/SemaTypeTraits.cpp
+++ b/clang/lib/Sema/SemaTypeTraits.cpp
@@ -591,19 +591,17 @@ static bool HasNoThrowOperator(CXXRecordDecl *RD, OverloadedOperatorKind Op,
   return false;
 }
 
-static bool EqualityComparisonIsDefaulted(Sema &S, const TypeDecl *Decl,
+static bool EqualityComparisonIsDefaulted(Sema &S, const TagDecl *Decl,
                                           SourceLocation KeyLoc) {
   CanQualType T = S.Context.getCanonicalTagType(Decl);
 
   EnterExpressionEvaluationContext UnevaluatedContext(
       S, Sema::ExpressionEvaluationContext::Unevaluated);
-  Sema::SFINAETrap SFINAE(S, /*ForValidityCheck=*/true);
+  Sema::SFINAETrap SFINAE(S, /*WithAccessChecking=*/true);
   Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
 
   // const ClassT& obj;
-  OpaqueValueExpr Operand(
-      KeyLoc, T.withConst(),
-      ExprValueKind::VK_LValue);
+  OpaqueValueExpr Operand(KeyLoc, T.withConst(), ExprValueKind::VK_LValue);
   UnresolvedSet<16> Functions;
   // obj == obj;
   S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
@@ -671,7 +669,9 @@ static bool isTriviallyEqualityComparableType(Sema &S, QualType Type,
     return false;
 
   if (CanonicalType->isEnumeralType()) {
-    EnumDecl *ED = CanonicalType->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
+    EnumDecl *ED = CanonicalType->castAs<EnumType>()
+                       ->getOriginalDecl()
+                       ->getDefinitionOrSelf();
     return EqualityComparisonIsDefaulted(S, ED, KeyLoc);
   }
 

>From a76689f7c21d1fb7283df8e7df587d1251f755d3 Mon Sep 17 00:00:00 2001
From: halbi2 <hehiralbi at gmail.com>
Date: Sat, 22 Nov 2025 12:07:51 -0500
Subject: [PATCH 3/4] fix clang-format and build error

---
 clang/lib/Sema/SemaTypeTraits.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/clang/lib/Sema/SemaTypeTraits.cpp b/clang/lib/Sema/SemaTypeTraits.cpp
index c816b900bef1d..d861181fea164 100644
--- a/clang/lib/Sema/SemaTypeTraits.cpp
+++ b/clang/lib/Sema/SemaTypeTraits.cpp
@@ -653,7 +653,8 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
            if (Type->isReferenceType())
              return false;
            if (Type->isEnumeralType()) {
-             EnumDecl *ED = Type->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
+             EnumDecl *ED =
+                 Type->castAs<EnumType>()->getDecl()->getDefinitionOrSelf();
              return EqualityComparisonIsDefaulted(S, ED, KeyLoc);
            } else if (const auto *RD = Type->getAsCXXRecordDecl())
              return HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc);
@@ -669,9 +670,8 @@ static bool isTriviallyEqualityComparableType(Sema &S, QualType Type,
     return false;
 
   if (CanonicalType->isEnumeralType()) {
-    EnumDecl *ED = CanonicalType->castAs<EnumType>()
-                       ->getOriginalDecl()
-                       ->getDefinitionOrSelf();
+    EnumDecl *ED =
+        CanonicalType->castAs<EnumType>()->getDecl()->getDefinitionOrSelf();
     return EqualityComparisonIsDefaulted(S, ED, KeyLoc);
   }
 

>From 5e95710693d541addcf13c33250c824b5343f332 Mon Sep 17 00:00:00 2001
From: halbi2 <hehiralbi at gmail.com>
Date: Sat, 22 Nov 2025 12:39:04 -0500
Subject: [PATCH 4/4] fix more tests

---
 clang/test/SemaCXX/type-traits.cpp | 56 +++++++++++++++++++-----------
 1 file changed, 35 insertions(+), 21 deletions(-)

diff --git a/clang/test/SemaCXX/type-traits.cpp b/clang/test/SemaCXX/type-traits.cpp
index 76fa4a9c2b936..15c3bea18afc8 100644
--- a/clang/test/SemaCXX/type-traits.cpp
+++ b/clang/test/SemaCXX/type-traits.cpp
@@ -4053,7 +4053,11 @@ static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableCon
 static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4]>));
 static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4][4]>));
 static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum[2]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum[2][2]>));
 static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq[2]>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq[2][2]>));
 
 auto GetNonCapturingLambda() { return [](){ return 42; }; }
 
@@ -4204,27 +4208,6 @@ struct NotTriviallyEqualityComparableNonTriviallyComparableBase : NotTriviallyEq
 };
 static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableNonTriviallyComparableBase));
 
-struct NotTriviallyEqualityComparableHasEnum {
-  E e;
-  bool operator==(const NotTriviallyEqualityComparableHasEnum&) const = default;
-};
-static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableHasEnum));
-
-struct NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs {
-  E e[1];
-
-  bool operator==(const NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs&) const = default;
-};
-static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs));
-
-struct NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs2 {
-  E e[1][1];
-
-  bool operator==(const NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs2&) const = default;
-};
-
-static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs2));
-
 struct NotTriviallyEqualityComparablePrivateComparison {
   int i;
 
@@ -4312,6 +4295,37 @@ struct TriviallyEqualityComparable {
 };
 static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparable));
 
+template <class T>
+struct TriviallyEqualityComparableContains {
+  T t;
+
+  friend bool operator==(const TriviallyEqualityComparableContains&, const TriviallyEqualityComparableContains&) = default;
+};
+
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<float>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<double>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<long double>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4][4]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum[2]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum[2][2]>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq[2]>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq[2][2]>));
+
+auto GetNonCapturingLambda() { return [](){ return 42; }; }
+
+struct TriviallyEqualityComparableContainsLambda {
+  [[no_unique_address]] decltype(GetNonCapturingLambda()) l;
+  int i;
+
+  friend bool operator==(const TriviallyEqualityComparableContainsLambda&, const TriviallyEqualityComparableContainsLambda&) = default;
+};
+static_assert(!__is_trivially_equality_comparable(decltype(GetNonCapturingLambda()))); // padding
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsLambda));
+
 struct TriviallyEqualityComparableNonTriviallyCopyable {
   TriviallyEqualityComparableNonTriviallyCopyable(const TriviallyEqualityComparableNonTriviallyCopyable&);
   ~TriviallyEqualityComparableNonTriviallyCopyable();



More information about the cfe-commits mailing list