[clang] [Clang] Make enums trivially equality comparable (PR #133587)
via cfe-commits
cfe-commits at lists.llvm.org
Sun Mar 30 01:59:34 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Nikolas Klauser (philnik777)
<details>
<summary>Changes</summary>
Fixes #<!-- -->132672
---
Full diff: https://github.com/llvm/llvm-project/pull/133587.diff
2 Files Affected:
- (modified) clang/lib/Sema/SemaExprCXX.cpp (+48-35)
- (modified) clang/test/SemaCXX/type-traits.cpp (+12)
``````````diff
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 46895db4a0756..d4a9900d3fa8a 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -5174,6 +5174,43 @@ static bool HasNoThrowOperator(const RecordType *RT, OverloadedOperatorKind Op,
return false;
}
+static bool EqualityComparisonIsDefaulted(Sema &S, const TypeDecl *Decl,
+ SourceLocation KeyLoc) {
+ EnterExpressionEvaluationContext UnevaluatedContext(
+ S, Sema::ExpressionEvaluationContext::Unevaluated);
+ Sema::SFINAETrap SFINAE(S, /*AccessCheckingSFINAE=*/true);
+ Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
+
+ // const ClassT& obj;
+ OpaqueValueExpr Operand(
+ KeyLoc, Decl->getTypeForDecl()->getCanonicalTypeUnqualified().withConst(),
+ ExprValueKind::VK_LValue);
+ UnresolvedSet<16> Functions;
+ // obj == obj;
+ S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
+
+ auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
+ Functions, &Operand, &Operand);
+ if (Result.isInvalid() || SFINAE.hasErrorOccurred())
+ return false;
+
+ const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
+ if (!CallExpr)
+ return isa<EnumDecl>(Decl);
+ const auto *Callee = CallExpr->getDirectCallee();
+ auto ParamT = Callee->getParamDecl(0)->getType();
+ if (!Callee->isDefaulted())
+ return false;
+ if (!ParamT->isReferenceType()) {
+ if (const CXXRecordDecl * RD = dyn_cast<CXXRecordDecl>(Decl); !RD->isTriviallyCopyable())
+ return false;
+ }
+ if (ParamT.getNonReferenceType()->getUnqualifiedDesugaredType() !=
+ Decl->getTypeForDecl())
+ return false;
+ return true;
+}
+
static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
const CXXRecordDecl *Decl,
SourceLocation KeyLoc) {
@@ -5182,39 +5219,8 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
if (Decl->isLambda())
return Decl->isCapturelessLambda();
- {
- EnterExpressionEvaluationContext UnevaluatedContext(
- S, Sema::ExpressionEvaluationContext::Unevaluated);
- Sema::SFINAETrap SFINAE(S, /*AccessCheckingSFINAE=*/true);
- Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
-
- // const ClassT& obj;
- OpaqueValueExpr Operand(
- KeyLoc,
- Decl->getTypeForDecl()->getCanonicalTypeUnqualified().withConst(),
- ExprValueKind::VK_LValue);
- UnresolvedSet<16> Functions;
- // obj == obj;
- S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
-
- auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
- Functions, &Operand, &Operand);
- if (Result.isInvalid() || SFINAE.hasErrorOccurred())
- return false;
-
- const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
- if (!CallExpr)
- return false;
- const auto *Callee = CallExpr->getDirectCallee();
- auto ParamT = Callee->getParamDecl(0)->getType();
- if (!Callee->isDefaulted())
- return false;
- if (!ParamT->isReferenceType() && !Decl->isTriviallyCopyable())
- return false;
- if (ParamT.getNonReferenceType()->getUnqualifiedDesugaredType() !=
- Decl->getTypeForDecl())
- return false;
- }
+ if (!EqualityComparisonIsDefaulted(S, Decl, KeyLoc))
+ return false;
return llvm::all_of(Decl->bases(),
[&](const CXXBaseSpecifier &BS) {
@@ -5229,7 +5235,10 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
Type = Type->getBaseElementTypeUnsafe()
->getCanonicalTypeUnqualified();
- if (Type->isReferenceType() || Type->isEnumeralType())
+ if (Type->isReferenceType() ||
+ (Type->isEnumeralType() &&
+ !EqualityComparisonIsDefaulted(
+ S, cast<EnumDecl>(Type->getAsTagDecl()), KeyLoc)))
return false;
if (const auto *RD = Type->getAsCXXRecordDecl())
return HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc);
@@ -5240,9 +5249,13 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
static bool isTriviallyEqualityComparableType(Sema &S, QualType Type, SourceLocation KeyLoc) {
QualType CanonicalType = Type.getCanonicalType();
if (CanonicalType->isIncompleteType() || CanonicalType->isDependentType() ||
- CanonicalType->isEnumeralType() || CanonicalType->isArrayType())
+ CanonicalType->isArrayType())
return false;
+ if (CanonicalType->isEnumeralType())
+ return EqualityComparisonIsDefaulted(
+ S, cast<EnumDecl>(CanonicalType->getAsTagDecl()), KeyLoc);
+
if (const auto *RD = CanonicalType->getAsCXXRecordDecl()) {
if (!HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc))
return false;
diff --git a/clang/test/SemaCXX/type-traits.cpp b/clang/test/SemaCXX/type-traits.cpp
index b130024503101..657d5bcf07343 100644
--- a/clang/test/SemaCXX/type-traits.cpp
+++ b/clang/test/SemaCXX/type-traits.cpp
@@ -3873,6 +3873,11 @@ static_assert(!__is_trivially_equality_comparable(NonTriviallyEqualityComparable
#if __cplusplus >= 202002L
+enum TriviallyEqualityComparableEnum {
+ x, y
+};
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableEnum));
+
struct TriviallyEqualityComparable {
int i;
int j;
@@ -3891,6 +3896,13 @@ struct TriviallyEqualityComparableContainsArray {
};
static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsArray));
+struct TriviallyEqualityComparableContainsEnum {
+ TriviallyEqualityComparableEnum e;
+
+ bool operator==(const TriviallyEqualityComparableContainsEnum&) const = default;
+};
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsEnum));
+
struct TriviallyEqualityComparableContainsMultiDimensionArray {
int a[4][4];
``````````
</details>
https://github.com/llvm/llvm-project/pull/133587
More information about the cfe-commits
mailing list