[clang] [OpenMP] Fix atomic compare handling with overloaded operators (PR #141142)
Joseph Huber via cfe-commits
cfe-commits at lists.llvm.org
Wed May 28 13:16:45 PDT 2025
https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/141142
>From f2c18ba64744320a8e2a63938b17137a1b6e74d7 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Thu, 22 May 2025 16:21:34 -0500
Subject: [PATCH] [OpenMP] Fix atomic compare handling with overloaded
operators
Summary:
When there are overloaded C++ operators in the global namespace the AST
node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify
the uses to handle this case, basically just treating it as a binary
expression with two arguments.
---
clang/lib/Sema/SemaOpenMP.cpp | 162 +++++++++++++++-----------
clang/test/OpenMP/atomic_messages.cpp | 31 +++++
2 files changed, 126 insertions(+), 67 deletions(-)
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index f16f841d62edd..8d580d1968238 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -11763,51 +11763,61 @@ bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S,
X = BO->getLHS();
auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
- if (!Cond) {
+ auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond());
+ Expr *LHS = nullptr;
+ Expr *RHS = nullptr;
+ if (Cond) {
+ LHS = Cond->getLHS();
+ RHS = Cond->getRHS();
+ } else if (Call) {
+ LHS = Call->getArg(0);
+ RHS = Call->getArg(1);
+ } else {
ErrorInfo.Error = ErrorTy::NotABinaryOp;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
return false;
}
- switch (Cond->getOpcode()) {
- case BO_EQ: {
- C = Cond;
+ if ((Cond && Cond->getOpcode() == BO_EQ) ||
+ (Call && Call->getOperator() == OverloadedOperatorKind::OO_EqualEqual)) {
+ C = S->getCond();
D = BO->getRHS();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
- E = Cond->getRHS();
- } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- E = Cond->getLHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
+ E = RHS;
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
+ E = LHS;
} else {
ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+ S->getCond()->getSourceRange();
return false;
}
- break;
- }
- case BO_LT:
- case BO_GT: {
+ } else if ((Cond &&
+ (Cond->getOpcode() == BO_LT || Cond->getOpcode() == BO_GT)) ||
+ (Call &&
+ (Call->getOperator() == OverloadedOperatorKind::OO_Less ||
+ Call->getOperator() == OverloadedOperatorKind::OO_Greater))) {
E = BO->getRHS();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
- C = Cond;
- } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- C = Cond;
+ if (checkIfTwoExprsAreSame(ContextRef, X, LHS) &&
+ checkIfTwoExprsAreSame(ContextRef, E, RHS)) {
+ C = S->getCond();
+ } else if (checkIfTwoExprsAreSame(ContextRef, E, LHS) &&
+ checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
+ C = S->getCond();
IsXBinopExpr = false;
} else {
ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+ S->getCond()->getSourceRange();
return false;
}
- break;
- }
- default:
+ } else {
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
return false;
}
@@ -11857,52 +11867,60 @@ bool OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S,
}
auto *Cond = dyn_cast<BinaryOperator>(CO->getCond());
- if (!Cond) {
+ auto *Call = dyn_cast<CXXOperatorCallExpr>(CO->getCond());
+ Expr *LHS = nullptr;
+ Expr *RHS = nullptr;
+ if (Cond) {
+ LHS = Cond->getLHS();
+ RHS = Cond->getRHS();
+ } else if (Call) {
+ LHS = Call->getArg(0);
+ RHS = Call->getArg(1);
+ } else {
ErrorInfo.Error = ErrorTy::NotABinaryOp;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
- CO->getCond()->getSourceRange();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange();
return false;
}
- switch (Cond->getOpcode()) {
- case BO_EQ: {
- C = Cond;
+ if ((Cond && Cond->getOpcode() == BO_EQ) ||
+ (Call && Call->getOperator() == OverloadedOperatorKind::OO_EqualEqual)) {
+ C = CO->getCond();
D = CO->getTrueExpr();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
- E = Cond->getRHS();
- } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- E = Cond->getLHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
+ E = RHS;
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
+ E = LHS;
} else {
ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange();
return false;
}
- break;
- }
- case BO_LT:
- case BO_GT: {
+ } else if ((Cond &&
+ (Cond->getOpcode() == BO_LT || Cond->getOpcode() == BO_GT)) ||
+ (Call &&
+ (Call->getOperator() == OverloadedOperatorKind::OO_Less ||
+ Call->getOperator() == OverloadedOperatorKind::OO_Greater))) {
+
E = CO->getTrueExpr();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
- C = Cond;
- } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- C = Cond;
+ if (checkIfTwoExprsAreSame(ContextRef, X, LHS) &&
+ checkIfTwoExprsAreSame(ContextRef, E, RHS)) {
+ C = CO->getCond();
+ } else if (checkIfTwoExprsAreSame(ContextRef, E, LHS) &&
+ checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
+ C = CO->getCond();
IsXBinopExpr = false;
} else {
ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange();
return false;
}
- break;
- }
- default:
+ } else {
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange();
return false;
}
@@ -12063,31 +12081,41 @@ bool OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S,
D = BO->getRHS();
auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
- if (!Cond) {
+ auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond());
+ Expr *LHS = nullptr;
+ Expr *RHS = nullptr;
+ if (Cond) {
+ LHS = Cond->getLHS();
+ RHS = Cond->getRHS();
+ } else if (Call) {
+ LHS = Call->getArg(0);
+ RHS = Call->getArg(1);
+ } else {
ErrorInfo.Error = ErrorTy::NotABinaryOp;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
return false;
}
- if (Cond->getOpcode() != BO_EQ) {
+ if ((Cond && Cond->getOpcode() != BO_EQ) ||
+ (Call && Call->getOperator() != OverloadedOperatorKind::OO_EqualEqual)) {
ErrorInfo.Error = ErrorTy::NotEQ;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
return false;
}
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
- E = Cond->getRHS();
- } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- E = Cond->getLHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
+ E = RHS;
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
+ E = LHS;
} else {
ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
return false;
}
- C = Cond;
+ C = S->getCond();
if (!S->getElse()) {
ErrorInfo.Error = ErrorTy::NoElse;
diff --git a/clang/test/OpenMP/atomic_messages.cpp b/clang/test/OpenMP/atomic_messages.cpp
index d492f6ee1e896..c4e240a0ebb4e 100644
--- a/clang/test/OpenMP/atomic_messages.cpp
+++ b/clang/test/OpenMP/atomic_messages.cpp
@@ -991,3 +991,34 @@ int mixed() {
// expected-note at +1 {{in instantiation of function template specialization 'mixed<int>' requested here}}
return mixed<int>();
}
+
+#ifdef OMP51
+struct U {};
+struct U operator<(U, U);
+struct U operator>(U, U);
+struct U operator==(U, U);
+
+template <typename T> void templated() {
+ T cx, cv, ce, cd;
+#pragma omp atomic compare capture
+ if (cx == ce) {
+ cx = cd;
+ } else {
+ cv = cx;
+ }
+#pragma omp atomic compare capture
+ {
+ cv = cx;
+ if (ce > cx) {
+ cx = ce;
+ }
+ }
+#pragma omp atomic compare capture
+ {
+ cv = cx;
+ if (cx < ce) {
+ cx = ce;
+ }
+ }
+}
+#endif
More information about the cfe-commits
mailing list