[clang] [OpenMP] Fix atomic compare handling with overloaded operators (PR #141142)
Joseph Huber via cfe-commits
cfe-commits at lists.llvm.org
Tue May 27 09:15:28 PDT 2025
https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/141142
>From a45dc43315631f28ced9cf5a14890e46e011e6d2 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Thu, 22 May 2025 16:21:34 -0500
Subject: [PATCH] [OpenMP] Fix atomic compare handling with overloaded
operators
Summary:
When there are overloaded C++ operators in the global namespace the AST
node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify
the uses to handle this case, basically just treating it as a binary
expression with two arguments.
---
clang/lib/Sema/SemaOpenMP.cpp | 320 ++++++++++++++++++--------
clang/test/OpenMP/atomic_messages.cpp | 31 +++
2 files changed, 249 insertions(+), 102 deletions(-)
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index f16f841d62edd..a0ad814c366d8 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -11762,52 +11762,98 @@ bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S,
X = BO->getLHS();
- auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
- if (!Cond) {
- ErrorInfo.Error = ErrorTy::NotABinaryOp;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
- return false;
- }
-
- switch (Cond->getOpcode()) {
- case BO_EQ: {
- C = Cond;
- D = BO->getRHS();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
- E = Cond->getRHS();
- } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- E = Cond->getLHS();
- } else {
- ErrorInfo.Error = ErrorTy::InvalidComparison;
+ if (auto *Cond = dyn_cast<BinaryOperator>(S->getCond())) {
+ switch (Cond->getOpcode()) {
+ case BO_EQ: {
+ C = Cond;
+ D = BO->getRHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+ E = Cond->getRHS();
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ E = Cond->getLHS();
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ case BO_LT:
+ case BO_GT: {
+ E = BO->getRHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
+ checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
+ C = Cond;
+ } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
+ checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ C = Cond;
+ IsXBinopExpr = false;
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ default:
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
return false;
}
- break;
- }
- case BO_LT:
- case BO_GT: {
- E = BO->getRHS();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
- C = Cond;
- } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- C = Cond;
- IsXBinopExpr = false;
- } else {
- ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) {
+ if (Call->getNumArgs() != 2) {
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
return false;
}
- break;
- }
- default:
- ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ switch (Call->getOperator()) {
+ case clang::OverloadedOperatorKind::OO_EqualEqual: {
+ C = Call;
+ D = BO->getLHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
+ E = Call->getArg(1);
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+ E = Call->getArg(0);
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ case clang::OverloadedOperatorKind::OO_Greater:
+ case clang::OverloadedOperatorKind::OO_Less: {
+ E = BO->getRHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0)) &&
+ checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(1))) {
+ C = Call;
+ } else if (checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(0)) &&
+ checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+ C = Call;
+ IsXBinopExpr = false;
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ default:
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ } else {
+ ErrorInfo.Error = ErrorTy::NotABinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
return false;
}
@@ -11856,53 +11902,99 @@ bool OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S,
return false;
}
- auto *Cond = dyn_cast<BinaryOperator>(CO->getCond());
- if (!Cond) {
- ErrorInfo.Error = ErrorTy::NotABinaryOp;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
- CO->getCond()->getSourceRange();
- return false;
- }
-
- switch (Cond->getOpcode()) {
- case BO_EQ: {
- C = Cond;
- D = CO->getTrueExpr();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
- E = Cond->getRHS();
- } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- E = Cond->getLHS();
- } else {
- ErrorInfo.Error = ErrorTy::InvalidComparison;
+ if (auto *Cond = dyn_cast<BinaryOperator>(CO->getCond())) {
+ switch (Cond->getOpcode()) {
+ case BO_EQ: {
+ C = Cond;
+ D = CO->getTrueExpr();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+ E = Cond->getRHS();
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ E = Cond->getLHS();
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ case BO_LT:
+ case BO_GT: {
+ E = CO->getTrueExpr();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
+ checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
+ C = Cond;
+ } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
+ checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ C = Cond;
+ IsXBinopExpr = false;
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ default:
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
return false;
}
- break;
- }
- case BO_LT:
- case BO_GT: {
- E = CO->getTrueExpr();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
- C = Cond;
- } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- C = Cond;
- IsXBinopExpr = false;
- } else {
- ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(CO->getCond())) {
+ if (Call->getNumArgs() != 2) {
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
return false;
}
- break;
- }
- default:
- ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ switch (Call->getOperator()) {
+ case clang::OverloadedOperatorKind::OO_EqualEqual: {
+ C = Call;
+ D = CO->getTrueExpr();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
+ E = Call->getArg(1);
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+ E = Call->getArg(0);
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ case clang::OverloadedOperatorKind::OO_Less:
+ case clang::OverloadedOperatorKind::OO_Greater: {
+ E = CO->getTrueExpr();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0)) &&
+ checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(1))) {
+ C = Call;
+ } else if (checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(0)) &&
+ checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+ C = Call;
+ IsXBinopExpr = false;
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ default:
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ } else {
+ ErrorInfo.Error = ErrorTy::NotABinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+ CO->getCond()->getSourceRange();
return false;
}
@@ -12062,32 +12154,56 @@ bool OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S,
X = BO->getLHS();
D = BO->getRHS();
- auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
- if (!Cond) {
+ if (auto *Cond = dyn_cast<BinaryOperator>(S->getCond())) {
+ C = Cond;
+ if (Cond->getOpcode() != BO_EQ) {
+ ErrorInfo.Error = ErrorTy::NotEQ;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+ E = Cond->getRHS();
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ E = Cond->getLHS();
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+ } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) {
+ C = Call;
+ if (Call->getNumArgs() != 2) {
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ if (Call->getOperator() != clang::OverloadedOperatorKind::OO_EqualEqual) {
+ ErrorInfo.Error = ErrorTy::NotEQ;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+
+ if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
+ E = Call->getArg(1);
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+ E = Call->getArg(0);
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ } else {
ErrorInfo.Error = ErrorTy::NotABinaryOp;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
return false;
}
- if (Cond->getOpcode() != BO_EQ) {
- ErrorInfo.Error = ErrorTy::NotEQ;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
- return false;
- }
-
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
- E = Cond->getRHS();
- } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- E = Cond->getLHS();
- } else {
- ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
- return false;
- }
-
- C = Cond;
if (!S->getElse()) {
ErrorInfo.Error = ErrorTy::NoElse;
diff --git a/clang/test/OpenMP/atomic_messages.cpp b/clang/test/OpenMP/atomic_messages.cpp
index d492f6ee1e896..c4e240a0ebb4e 100644
--- a/clang/test/OpenMP/atomic_messages.cpp
+++ b/clang/test/OpenMP/atomic_messages.cpp
@@ -991,3 +991,34 @@ int mixed() {
// expected-note at +1 {{in instantiation of function template specialization 'mixed<int>' requested here}}
return mixed<int>();
}
+
+#ifdef OMP51
+struct U {};
+struct U operator<(U, U);
+struct U operator>(U, U);
+struct U operator==(U, U);
+
+template <typename T> void templated() {
+ T cx, cv, ce, cd;
+#pragma omp atomic compare capture
+ if (cx == ce) {
+ cx = cd;
+ } else {
+ cv = cx;
+ }
+#pragma omp atomic compare capture
+ {
+ cv = cx;
+ if (ce > cx) {
+ cx = ce;
+ }
+ }
+#pragma omp atomic compare capture
+ {
+ cv = cx;
+ if (cx < ce) {
+ cx = ce;
+ }
+ }
+}
+#endif
More information about the cfe-commits
mailing list