[Mlir-commits] [flang] [mlir] [Flang] [OpenMP] atomic compare (PR #184761)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 17 11:12:12 PDT 2026
https://github.com/SunilKuravinakop updated https://github.com/llvm/llvm-project/pull/184761
>From e4ded216ad9d798246f5dd8e609d8ec806dfcea6 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe31.hpc.amslabs.hpecorp.net>
Date: Wed, 4 Mar 2026 13:17:11 -0600
Subject: [PATCH 1/8] Support for "!omp atomic compare".
---
flang/lib/Lower/OpenMP/Atomic.cpp | 150 +++++++++++++++-
.../Integration/OpenMP/atomic-compare.f90 | 116 +++++++++++++
.../Lower/OpenMP/Todo/atomic-compare-fail.f90 | 2 +-
.../test/Lower/OpenMP/Todo/atomic-compare.f90 | 11 --
flang/test/Lower/OpenMP/atomic-compare.f90 | 123 +++++++++++++
.../test/Semantics/OpenMP/atomic-compare.f90 | 71 ++++++++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 49 +++++-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 24 +++
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 162 +++++++++++++++++-
mlir/test/Target/LLVMIR/openmp-llvm.mlir | 67 ++++++++
10 files changed, 758 insertions(+), 17 deletions(-)
create mode 100644 flang/test/Integration/OpenMP/atomic-compare.f90
delete mode 100644 flang/test/Lower/OpenMP/Todo/atomic-compare.f90
create mode 100644 flang/test/Lower/OpenMP/atomic-compare.f90
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index f31de82fc2a5f..87d354d504656 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -485,6 +485,47 @@ genAtomicOperation(lower::AbstractConverter &converter,
}
}
+/// Map a Fortran relational operator to an MLIR integer comparison predicate.
+static mlir::arith::CmpIPredicate
+mapRelationalOpToIntPredicate(Fortran::common::RelationalOperator relOpr) {
+ switch (relOpr) {
+ case Fortran::common::RelationalOperator::EQ:
+ return mlir::arith::CmpIPredicate::eq;
+ case Fortran::common::RelationalOperator::NE:
+ return mlir::arith::CmpIPredicate::ne;
+ case Fortran::common::RelationalOperator::LT:
+ return mlir::arith::CmpIPredicate::slt;
+ case Fortran::common::RelationalOperator::LE:
+ return mlir::arith::CmpIPredicate::sle;
+ case Fortran::common::RelationalOperator::GT:
+ return mlir::arith::CmpIPredicate::sgt;
+ case Fortran::common::RelationalOperator::GE:
+ return mlir::arith::CmpIPredicate::sge;
+ }
+ llvm_unreachable("unexpected relational operator");
+}
+
+/// Map a Fortran relational operator to an MLIR floating-point comparison
+/// predicate (ordered).
+static mlir::arith::CmpFPredicate
+mapRelationalOpToFPPredicate(Fortran::common::RelationalOperator relOpr) {
+ switch (relOpr) {
+ case Fortran::common::RelationalOperator::EQ:
+ return mlir::arith::CmpFPredicate::OEQ;
+ case Fortran::common::RelationalOperator::NE:
+ return mlir::arith::CmpFPredicate::ONE;
+ case Fortran::common::RelationalOperator::LT:
+ return mlir::arith::CmpFPredicate::OLT;
+ case Fortran::common::RelationalOperator::LE:
+ return mlir::arith::CmpFPredicate::OLE;
+ case Fortran::common::RelationalOperator::GT:
+ return mlir::arith::CmpFPredicate::OGT;
+ case Fortran::common::RelationalOperator::GE:
+ return mlir::arith::CmpFPredicate::OGE;
+ }
+ llvm_unreachable("unexpected relational operator");
+}
+
void Fortran::lower::omp::lowerAtomic(
AbstractConverter &converter, SymMap &symTable,
semantics::SemanticsContext &semaCtx, pft::Evaluation &eval,
@@ -521,8 +562,113 @@ void Fortran::lower::omp::lowerAtomic(
memOrder = makeValidForAction(memOrder, action0, action1, version);
if (auto *cond = get(analysis.cond)) {
- (void)cond;
- TODO(loc, "OpenMP ATOMIC COMPARE");
+ // atomic compare: if (x == e) x = d
+ // e : expecteVal
+ // d : desiredVal
+
+ // Check for compound clauses (fail, capture, weak) that are not yet
+ // supported with atomic compare.
+ bool hasCompoundClause = false;
+ for (const omp::Clause &clause : clauses) {
+ if (clause.id == llvm::omp::Clause::OMPC_fail ||
+ clause.id == llvm::omp::Clause::OMPC_capture ||
+ clause.id == llvm::omp::Clause::OMPC_weak) {
+ hasCompoundClause = true;
+ break;
+ }
+ }
+ if (hasCompoundClause)
+ TODO(loc, "Compound clauses of OpenMP ATOMIC COMPARE");
+
+ Fortran::common::RelationalOperator relOpr =
+ Fortran::common::RelationalOperator::EQ;
+ std::optional<semantics::SomeExpr> expectedExprStorage;
+
+ if (const auto *rel = Fortran::evaluate::UnwrapExpr<
+ Fortran::evaluate::Relational<Fortran::evaluate::SomeType>>(
+ *cond)) {
+ std::visit(
+ [&](const auto &relImpl) {
+ relOpr = relImpl.opr;
+ using Operand = typename std::decay_t<decltype(relImpl)>::Operand;
+ expectedExprStorage = Fortran::evaluate::AsGenericExpr(
+ Fortran::evaluate::Expr<Operand>{relImpl.right()});
+ },
+ rel->u);
+ }
+ if (!expectedExprStorage) {
+ // The condition expression exists but isn't a recognized relational form.
+ mlir::emitError(loc, "internal error: atomic compare condition is not a "
+ "recognized relational expression");
+ return;
+ }
+
+ mlir::UnitAttr weakAttr = nullptr;
+ mlir::Operation *atomicOp = mlir::omp::AtomicCompareOp::create(
+ builder, loc, atomAddr, weakAttr, hint,
+ makeMemOrderAttr(converter, memOrder));
+ mlir::Type elemTypeOfX = fir::unwrapRefType(atomAddr.getType());
+ mlir::Block *block = builder.createBlock(&atomicOp->getRegion(0));
+ mlir::Value blockArg = block->addArgument(elemTypeOfX, loc);
+ builder.setInsertionPointToEnd(block);
+
+ mlir::Value expectedVal = fir::getBase(
+ converter.genExprValue(*expectedExprStorage, stmtCtx, &loc));
+ if (expectedVal.getType() != elemTypeOfX) {
+ expectedVal = builder.createConvert(loc, elemTypeOfX, expectedVal);
+ }
+
+ // Generate comparison: e.g. x == e
+ mlir::Value cmpResult;
+ if (mlir::isa<mlir::IntegerType>(elemTypeOfX)) {
+ auto pred = mapRelationalOpToIntPredicate(relOpr);
+ cmpResult = mlir::arith::CmpIOp::create(builder, loc, pred, blockArg,
+ expectedVal);
+ } else if (mlir::isa<mlir::FloatType>(elemTypeOfX)) {
+ auto pred = mapRelationalOpToFPPredicate(relOpr);
+ cmpResult = mlir::arith::CmpFOp::create(builder, loc, pred, blockArg,
+ expectedVal);
+ } else {
+ llvm_unreachable("unsupported type for atomic compare");
+ }
+
+ // Check for presence of Assignment (x = d) and wether it is being invoked
+ // only for IfTrue condition.
+
+ // writeActionCond is a bitmask combining the following flags:
+ // 1) the action type (Read/Write/Update)
+ // 2) condition (IfTrue/IfFalse)
+ int writeActionCond = 0;
+ const evaluate::Assignment *writeAssign = nullptr;
+ if (analysis.op0.what & analysis.Write) {
+ writeAssign = get(analysis.op0.assign);
+ writeActionCond = analysis.op0.what;
+ }
+ if (!writeAssign && (analysis.op1.what & analysis.Write)) {
+ writeAssign = get(analysis.op1.assign);
+ writeActionCond = analysis.op1.what;
+ }
+ if (!writeAssign) {
+ mlir::emitError(loc,
+ "internal error: atomic compare has no write assignment");
+ return;
+ }
+ assert((writeActionCond & analysis.IfTrue) &&
+ "atomic compare write should be conditioned on IfTrue");
+
+ // Generate new/desired value of x e.g. x = d
+ mlir::Value desiredVal =
+ fir::getBase(converter.genExprValue(writeAssign->rhs, stmtCtx, &loc));
+ if (desiredVal.getType() != elemTypeOfX)
+ desiredVal = builder.createConvert(loc, elemTypeOfX, desiredVal);
+ mlir::Value newVal = mlir::arith::SelectOp::create(builder, loc, cmpResult,
+ desiredVal, blockArg);
+
+ // Generate omp.yield
+ mlir::omp::YieldOp::create(builder, loc, newVal);
+ builder.setInsertionPointAfter(atomicOp);
+
+ // END omp atomic compare
} else {
mlir::Operation *captureOp = nullptr;
fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
diff --git a/flang/test/Integration/OpenMP/atomic-compare.f90 b/flang/test/Integration/OpenMP/atomic-compare.f90
new file mode 100644
index 0000000000000..c5be037b7533f
--- /dev/null
+++ b/flang/test/Integration/OpenMP/atomic-compare.f90
@@ -0,0 +1,116 @@
+!===----------------------------------------------------------------------===!
+! This directory can be used to add Integration tests involving multiple
+! stages of the compiler (for eg. from Fortran to LLVM IR). It should not
+! contain executable tests. We should only add tests here sparingly and only
+! if there is no other way to test. Repeat this message in each test that is
+! added to this directory and sub-directories.
+!===----------------------------------------------------------------------===!
+
+!RUN: %flang_fc1 -triple x86_64-unknown-linux-gnu -emit-llvm -fopenmp -fopenmp-version=51 %s -o - | FileCheck %s
+
+! Int "==" → cmpxchg, default (monotonic) ordering
+!CHECK-LABEL: define void @atomic_compare_integer_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]], align 4
+!CHECK: cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+subroutine atomic_compare_integer(x, e, d)
+ integer :: x, e, d
+ !$omp atomic compare
+ if (x == e) x = d
+end
+
+! seq_cst ordering → cmpxchg seq_cst + flush
+!CHECK-LABEL: define void @atomic_compare_seq_cst_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]], align 4
+!CHECK: cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] seq_cst seq_cst
+!CHECK: call void @__kmpc_flush(
+subroutine atomic_compare_seq_cst(x, e, d)
+ integer :: x, e, d
+ !$omp atomic compare seq_cst
+ if (x == e) x = d
+end
+
+! acquire ordering → cmpxchg acquire
+!CHECK-LABEL: define void @atomic_compare_acquire_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]], align 4
+!CHECK: cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] acquire acquire
+subroutine atomic_compare_acquire(x, e, d)
+ integer :: x, e, d
+ !$omp atomic compare acquire
+ if (x == e) x = d
+end
+
+! release ordering → cmpxchg release + flush
+!CHECK-LABEL: define void @atomic_compare_release_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]], align 4
+!CHECK: cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] release monotonic
+!CHECK: call void @__kmpc_flush(
+subroutine atomic_compare_release(x, e, d)
+ integer :: x, e, d
+ !$omp atomic compare release
+ if (x == e) x = d
+end
+
+! relaxed ordering → cmpxchg monotonic
+!CHECK-LABEL: define void @atomic_compare_relaxed_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]], align 4
+!CHECK: cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+subroutine atomic_compare_relaxed(x, e, d)
+ integer :: x, e, d
+ !$omp atomic compare relaxed
+ if (x == e) x = d
+end
+
+! Less-than comparison → atomicrmw umax
+!CHECK-LABEL: define void @atomic_compare_lt_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
+!CHECK: atomicrmw umax ptr %[[X]], i32 %[[EVAL]] monotonic
+subroutine atomic_compare_lt(x, e)
+ integer :: x, e
+ !$omp atomic compare
+ if (x < e) x = e
+end
+
+! Less-than with seq_cst → atomicrmw umax seq_cst + flush
+!CHECK-LABEL: define void @atomic_compare_lt_seq_cst_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
+!CHECK: atomicrmw umax ptr %[[X]], i32 %[[EVAL]] seq_cst
+!CHECK: call void @__kmpc_flush(
+subroutine atomic_compare_lt_seq_cst(x, e)
+ integer :: x, e
+ !$omp atomic compare seq_cst
+ if (x < e) x = e
+end
+
+! Less-than with acquire → atomicrmw umax acquire
+!CHECK-LABEL: define void @atomic_compare_lt_acquire_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
+!CHECK: atomicrmw umax ptr %[[X]], i32 %[[EVAL]] acquire
+subroutine atomic_compare_lt_acquire(x, e)
+ integer :: x, e
+ !$omp atomic compare acquire
+ if (x < e) x = e
+end
+
+! Greater-than comparison → atomicrmw umin
+!CHECK-LABEL: define void @atomic_compare_gt_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
+!CHECK: atomicrmw umin ptr %[[X]], i32 %[[EVAL]] monotonic
+subroutine atomic_compare_gt(x, e)
+ integer :: x, e
+ !$omp atomic compare
+ if (x > e) x = e
+end
diff --git a/flang/test/Lower/OpenMP/Todo/atomic-compare-fail.f90 b/flang/test/Lower/OpenMP/Todo/atomic-compare-fail.f90
index 6f58e0939a787..3369a6223cf73 100644
--- a/flang/test/Lower/OpenMP/Todo/atomic-compare-fail.f90
+++ b/flang/test/Lower/OpenMP/Todo/atomic-compare-fail.f90
@@ -1,6 +1,6 @@
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -fopenmp-version=51 -o - %s 2>&1 | FileCheck %s
-! CHECK: not yet implemented: OpenMP ATOMIC COMPARE
+! CHECK: not yet implemented: Compound clauses of OpenMP ATOMIC COMPARE
program p
integer :: x
logical :: r
diff --git a/flang/test/Lower/OpenMP/Todo/atomic-compare.f90 b/flang/test/Lower/OpenMP/Todo/atomic-compare.f90
deleted file mode 100644
index 6729be6e5cf8b..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/atomic-compare.f90
+++ /dev/null
@@ -1,11 +0,0 @@
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -fopenmp-version=51 -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: OpenMP ATOMIC COMPARE
-program p
- integer :: x
- logical :: r
- !$omp atomic compare
- if (x .eq. 0) then
- x = 2
- end if
-end program p
diff --git a/flang/test/Lower/OpenMP/atomic-compare.f90 b/flang/test/Lower/OpenMP/atomic-compare.f90
new file mode 100644
index 0000000000000..e69c0b55d5351
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-compare.f90
@@ -0,0 +1,123 @@
+! REQUIRES: openmp_runtime
+
+! This test checks lowering of atomic compare constructs.
+! RUN: bbc %openmp_flags -fopenmp-version=51 -emit-hlfir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=51 %s -o - | FileCheck %s
+
+! CHECK-LABEL: func.func @_QPatomic_compare_int_eq(
+! CHECK-SAME: %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME: %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME: %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"})
+! CHECK: %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK: ^bb0(%[[XVAL:.*]]: i32):
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[XVAL]], %[[EVAL]] : i32
+! CHECK: %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK: omp.yield(%[[SEL]] : i32)
+! CHECK: }
+subroutine atomic_compare_int_eq(x, e, d)
+ integer :: x, e, d
+ !$omp atomic compare
+ if (x .eq. e) x = d
+end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_float_eq(
+! CHECK-SAME: %[[X:.*]]: !fir.ref<f32> {fir.bindc_name = "x"},
+! CHECK-SAME: %[[E:.*]]: !fir.ref<f32> {fir.bindc_name = "e"},
+! CHECK-SAME: %[[D:.*]]: !fir.ref<f32> {fir.bindc_name = "d"})
+! CHECK: %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<f32> {
+! CHECK: ^bb0(%[[XVAL:.*]]: f32):
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
+! CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[XVAL]], %[[EVAL]] fastmath<contract> : f32
+! CHECK: %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<f32>
+! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : f32
+! CHECK: omp.yield(%[[SEL]] : f32)
+! CHECK: }
+subroutine atomic_compare_float_eq(x, e, d)
+ real :: x, e, d
+ !$omp atomic compare
+ if (x .eq. e) x = d
+end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_int_lt(
+! CHECK-SAME: %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME: %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"})
+! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK: ^bb0(%[[XVAL:.*]]: i32):
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[XVAL]], %[[EVAL]] : i32
+! CHECK: %[[EVAL2:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[EVAL2]], %[[XVAL]] : i32
+! CHECK: omp.yield(%[[SEL]] : i32)
+! CHECK: }
+subroutine atomic_compare_int_lt(x, e)
+ integer :: x, e
+ !$omp atomic compare
+ if (x .lt. e) x = e
+end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_int_gt(
+! CHECK-SAME: %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME: %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"})
+! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK: ^bb0(%[[XVAL:.*]]: i32):
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[XVAL]], %[[EVAL]] : i32
+! CHECK: %[[EVAL2:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[EVAL2]], %[[XVAL]] : i32
+! CHECK: omp.yield(%[[SEL]] : i32)
+! CHECK: }
+subroutine atomic_compare_int_gt(x, e)
+ integer :: x, e
+ !$omp atomic compare
+ if (x .gt. e) x = e
+end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_float_lt(
+! CHECK-SAME: %[[X:.*]]: !fir.ref<f32> {fir.bindc_name = "x"},
+! CHECK-SAME: %[[E:.*]]: !fir.ref<f32> {fir.bindc_name = "e"})
+! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<f32> {
+! CHECK: ^bb0(%[[XVAL:.*]]: f32):
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
+! CHECK: %[[CMP:.*]] = arith.cmpf olt, %[[XVAL]], %[[EVAL]] fastmath<contract> : f32
+! CHECK: %[[EVAL2:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
+! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[EVAL2]], %[[XVAL]] : f32
+! CHECK: omp.yield(%[[SEL]] : f32)
+! CHECK: }
+subroutine atomic_compare_float_lt(x, e)
+ real :: x, e
+ !$omp atomic compare
+ if (x .lt. e) x = e
+end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_float_gt(
+! CHECK-SAME: %[[X:.*]]: !fir.ref<f32> {fir.bindc_name = "x"},
+! CHECK-SAME: %[[E:.*]]: !fir.ref<f32> {fir.bindc_name = "e"})
+! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<f32> {
+! CHECK: ^bb0(%[[XVAL:.*]]: f32):
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
+! CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[XVAL]], %[[EVAL]] fastmath<contract> : f32
+! CHECK: %[[EVAL2:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
+! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[EVAL2]], %[[XVAL]] : f32
+! CHECK: omp.yield(%[[SEL]] : f32)
+! CHECK: }
+subroutine atomic_compare_float_gt(x, e)
+ real :: x, e
+ !$omp atomic compare
+ if (x .gt. e) x = e
+end
diff --git a/flang/test/Semantics/OpenMP/atomic-compare.f90 b/flang/test/Semantics/OpenMP/atomic-compare.f90
index 6a4fbe7ffe81b..0e53729c2f02a 100644
--- a/flang/test/Semantics/OpenMP/atomic-compare.f90
+++ b/flang/test/Semantics/OpenMP/atomic-compare.f90
@@ -8,6 +8,7 @@
real a, b, c
+ logical :: r, s
a = 1.0
b = 2.0
c = 3.0
@@ -43,6 +44,20 @@
if (c .eq. a) a = b
!$omp end atomic
+ ! Less-than comparison.
+ !$omp atomic compare
+ if (b .lt. a) b = c
+
+ ! Greater-than comparison.
+ !$omp atomic compare
+ if (b .gt. a) b = c
+
+ ! Two-statement form: r = cond; if (r) update.
+ !$omp atomic compare
+ r = b .eq. a
+ if (r) b = c
+ !$omp end atomic
+
! Check for error conditions:
!ERROR: At most one SEQ_CST clause can appear on the ATOMIC directive
!$omp atomic seq_cst seq_cst compare
@@ -79,5 +94,61 @@
if (c .eq. a) a = b
!$omp end atomic
+ ! The /= operator is not valid for atomic compare.
+ !$omp atomic compare
+ !ERROR: The /= operator is not a valid condition for ATOMIC operation
+ if (b .ne. a) b = c
+
+ ! The <= operator is not valid for atomic compare.
+ !$omp atomic compare
+ !ERROR: The <= operator is not a valid condition for ATOMIC operation
+ if (b .le. a) b = c
+
+ ! The >= operator is not valid for atomic compare.
+ !$omp atomic compare
+ !ERROR: The >= operator is not a valid condition for ATOMIC operation
+ if (b .ge. a) b = c
+
+ ! ELSE branch is not allowed.
+ !$omp atomic compare
+ if (b .eq. a) then
+ b = c
+ else
+ !ERROR: In ATOMIC UPDATE COMPARE the update statement should not have an ELSE branch
+ a = b
+ end if
+
+ ! Not a conditional statement.
+ !ERROR: In ATOMIC UPDATE COMPARE the update statement should be a conditional statement
+ !$omp atomic compare
+ b = c
+
+ ! Too many statements.
+ !ERROR: ATOMIC UPDATE COMPARE operation should contain one or two statements
+ !$omp atomic compare
+ r = b .eq. a
+ if (r) b = c
+ a = b
+ !$omp end atomic
+
+ ! Two-statement form with wrong condition variable.
+ !$omp atomic compare
+ r = b .eq. a
+ !ERROR: In ATOMIC UPDATE COMPARE the conditional statement must use r as the condition
+ if (s) b = c
+ !$omp end atomic
+
+ ! Neither argument of the condition is the target of the assignment.
+ !$omp atomic compare
+ !ERROR: An argument of the == operator should be the target of the assignment
+ if (a .eq. c) b = c
+
+ ! First statement is not a comparison, condition uses wrong variable.
+ !$omp atomic compare
+ b = c
+ !ERROR: In ATOMIC UPDATE COMPARE the conditional statement must use b as the condition
+ if (r) b = c
+ !$omp end atomic
+
!$omp end parallel
end
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index dfec6609e1161..494a267077c0d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -821,8 +821,8 @@ def SimdOp : OpenMP_Op<"simd", traits = [
def YieldOp : OpenMP_Op<"yield",
[Pure, ReturnLike, Terminator,
- ParentOneOf<["AtomicUpdateOp", "DeclareReductionOp", "LoopNestOp",
- "PrivateClauseOp"]>]> {
+ ParentOneOf<["AtomicUpdateOp", "AtomicCompareOp", "DeclareReductionOp",
+ "LoopNestOp", "PrivateClauseOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"omp.yield" yields SSA values from the OpenMP dialect op region and
@@ -1814,6 +1814,51 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// [5.1] 2.17.7 atomic Directive - compare clause
+//===----------------------------------------------------------------------===//
+
+def AtomicCompareOp : OpenMP_Op<"atomic.compare", traits = [
+ RecursiveMemoryEffects,
+ SingleBlockImplicitTerminator<"YieldOp">
+ ], clauses = [
+ OpenMP_HintClause, OpenMP_MemoryOrderClause
+ ], singleRegion = 1> {
+ let summary = "performs an atomic compare";
+ let description = [{
+ This operation performs an atomic compare-and-swap.
+
+ The `atomic compare` construct implements atomic conditional update
+ semantics. The operand `x` is the address of the variable that is being
+ compared and potentially updated. The region describes the comparison
+ and update logic.
+
+ The region has the following structure:
+ ```
+ omp.atomic.compare {
+ if (x == d) x = e
+ omp.yield
+ }
+ ```
+ }] # clausesDescription;
+
+ let arguments = !con(
+ (ins Arg<OpenMP_PointerLikeType,
+ "Address of variable to be compared/updated", [MemRead, MemWrite]>:$x,
+ UnitAttr:$weak),
+ clausesArgs);
+
+ // Override region definition.
+ let regions = (region SizedRegion<1>:$region);
+
+ // Override clause-based assemblyFormat.
+ let assemblyFormat = clausesAssemblyFormat #
+ "$x `:` type($x) $region attr-dict";
+
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// [5.1] 2.21.2 threadprivate Directive
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c3916219d1c93..416aecd05ee46 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4042,6 +4042,30 @@ LogicalResult AtomicCaptureOp::verifyRegions() {
return success();
}
+//===----------------------------------------------------------------------===//
+// AtomicCompareOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicCompareOp::verify() {
+ return verifySynchronizationHint(*this, getHint());
+}
+
+LogicalResult AtomicCompareOp::verifyRegions() {
+ Region ®ion = getRegion();
+ if (region.empty())
+ return emitOpError("region for atomic compare must not be empty");
+
+ Block &block = region.front();
+ if (block.empty())
+ return emitOpError("region body for atomic compare must not be empty");
+
+ Operation *terminator = block.getTerminator();
+ if (!terminator || !isa<YieldOp>(terminator))
+ return emitOpError("region must be terminated with omp.yield");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// CancelOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f04d614633965..b40cf8e02e5c1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -445,7 +445,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::SimdOp op) { checkReduction(op, result); })
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
- omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
+ omp::AtomicCaptureOp, omp::AtomicCompareOp>(
+ [&](auto op) { checkHint(op, result); })
.Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
[&](auto op) { checkDepend(op, result); })
.Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); })
@@ -4012,6 +4013,162 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
return success();
}
+/// Helper to extract the OMPAtomicCompareOp from an integer comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) {
+ switch (predicate) {
+ case LLVM::ICmpPredicate::eq:
+ return llvm::omp::OMPAtomicCompareOp::EQ;
+ case LLVM::ICmpPredicate::slt:
+ case LLVM::ICmpPredicate::ult:
+ return llvm::omp::OMPAtomicCompareOp::MIN;
+ case LLVM::ICmpPredicate::sgt:
+ case LLVM::ICmpPredicate::ugt:
+ return llvm::omp::OMPAtomicCompareOp::MAX;
+ default:
+ return std::nullopt;
+ }
+}
+
+/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
+ switch (predicate) {
+ case LLVM::FCmpPredicate::oeq:
+ case LLVM::FCmpPredicate::ueq:
+ return llvm::omp::OMPAtomicCompareOp::EQ;
+ case LLVM::FCmpPredicate::olt:
+ case LLVM::FCmpPredicate::ult:
+ return llvm::omp::OMPAtomicCompareOp::MIN;
+ case LLVM::FCmpPredicate::ogt:
+ case LLVM::FCmpPredicate::ugt:
+ return llvm::omp::OMPAtomicCompareOp::MAX;
+ default:
+ return std::nullopt;
+ }
+}
+
+/// Converts an omp.atomic.compare operation to LLVM IR.
+///
+/// if (x == e) x = d
+/// The region contains a comparison + select pattern:
+/// ^bb0(%xval: T):
+/// %cmp = llvm.icmp/fcmp <pred> %xval, %e : T
+/// %sel = llvm.select %cmp, %d, %xval : i1, T
+/// omp.yield(%sel : T)
+///
+/// This function walks the MLIR region to extract the comparison predicate,
+/// the expected value (e), and the desired value (d), then delegates to
+/// OpenMPIRBuilder::createAtomicCompare which generates the actual
+/// cmpxchg / atomicrmw instruction.
+///
+static LogicalResult
+convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
+ llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ if (failed(checkImplementationStatus(*atomicCompareOp)))
+ return failure();
+
+ Region ®ion = atomicCompareOp.getRegion();
+ Block &block = region.front();
+
+ // Determine element type from the region block argument
+ llvm::Type *llvmXElementType =
+ moduleTranslation.convertType(block.getArgument(0).getType());
+ if (!llvmXElementType)
+ return atomicCompareOp.emitError(
+ "unable to determine element type for atomic compare");
+
+ llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
+ llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
+ false, false};
+
+ llvm::AtomicOrdering atomicOrdering =
+ convertAtomicOrdering(atomicCompareOp.getMemoryOrder());
+
+ // Trace back through load operations and generate load instructions
+ auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
+ if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
+ llvm::Value *loadAddr = moduleTranslation.lookupValue(loadOp.getAddr());
+ llvm::Type *loadType =
+ moduleTranslation.convertType(loadOp.getResult().getType());
+ return builder.CreateLoad(loadType, loadAddr);
+ }
+ return moduleTranslation.lookupValue(val);
+ };
+
+ // Walk the region to extract comparison predicate, eVal, and dVal.
+ // if (x == eVal) x = dVal
+ llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+ llvm::Value *eVal = nullptr;
+ llvm::Value *dVal = nullptr;
+ bool isXBinopExpr = false;
+
+ for (Operation &op : block.getOperations()) {
+ if (auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
+ auto maybeOp =
+ convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate());
+ if (!maybeOp)
+ return atomicCompareOp.emitError(
+ "unsupported comparison predicate in atomic compare");
+ compareOp = *maybeOp;
+
+ // Identify which operand is the block argument (x) and which is e.
+ isXBinopExpr = (icmpOp.getOperand(0) == block.getArgument(0));
+ mlir::Value eOperand =
+ isXBinopExpr ? icmpOp.getOperand(1) : icmpOp.getOperand(0);
+ eVal = materializeValue(eOperand);
+ } else if (auto fcmpOp = dyn_cast<LLVM::FCmpOp>(op)) {
+ auto maybeOp =
+ convertFCmpPredicateToAtomicCompareOp(fcmpOp.getPredicate());
+ if (!maybeOp)
+ return atomicCompareOp.emitError(
+ "unsupported comparison predicate in atomic compare");
+ compareOp = *maybeOp;
+
+ isXBinopExpr = (fcmpOp.getOperand(0) == block.getArgument(0));
+ mlir::Value eOperand =
+ isXBinopExpr ? fcmpOp.getOperand(1) : fcmpOp.getOperand(0);
+ eVal = materializeValue(eOperand);
+ } else if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
+ if (!dVal)
+ dVal = materializeValue(selectOp.getTrueValue());
+ }
+ }
+
+ if (!eVal)
+ return atomicCompareOp.emitError(
+ "failed to extract expected value (e) from atomic compare region");
+ if (!dVal) {
+ // Fall back to the yield operand.
+ auto yieldOp = cast<omp::YieldOp>(block.getTerminator());
+ if (yieldOp.getResults().empty())
+ return atomicCompareOp.emitError(
+ "failed to extract desired value (d) from atomic compare region");
+ dVal = materializeValue(yieldOp.getResults()[0]);
+ }
+
+ llvm::OpenMPIRBuilder::AtomicOpValue vOpVal = {nullptr, nullptr, false,
+ false};
+ llvm::OpenMPIRBuilder::AtomicOpValue rOpVal = {nullptr, nullptr, false,
+ false};
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+ ompBuilder->createAtomicCompare(ompLoc, llvmAtomicX, vOpVal, rOpVal, eVal,
+ dVal, atomicOrdering, compareOp,
+ isXBinopExpr, false, false);
+
+ if (failed(handleError(afterIP, *atomicCompareOp)))
+ return failure();
+
+ builder.restoreIP(*afterIP);
+ return success();
+}
+
static llvm::omp::Directive convertCancellationConstructType(
omp::ClauseCancellationConstructType directive) {
switch (directive) {
@@ -7112,6 +7269,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
.Case([&](omp::AtomicCaptureOp op) {
return convertOmpAtomicCapture(op, builder, moduleTranslation);
})
+ .Case([&](omp::AtomicCompareOp op) {
+ return convertOmpAtomicCompare(op, builder, moduleTranslation);
+ })
.Case([&](omp::CancelOp op) {
return convertOmpCancel(op, builder, moduleTranslation);
})
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 0b8a9765a4b87..bcf0d1f3de954 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2453,6 +2453,73 @@ llvm.func @omp_atomic_capture_misc(
// -----
+// CHECK-LABEL: @omp_atomic_compare
+// CHECK-SAME: (ptr %[[X:.*]], i32 %[[E:.*]], i32 %[[D:.*]], ptr %[[XF:.*]], float %[[EF:.*]], float %[[DF:.*]])
+llvm.func @omp_atomic_compare(
+ %x : !llvm.ptr, %e : i32, %d : i32,
+ %xf : !llvm.ptr, %ef : f32, %df : f32) {
+
+ // Integer equality → cmpxchg
+ // CHECK: cmpxchg ptr %[[X]], i32 %[[E]], i32 %[[D]] monotonic monotonic
+ omp.atomic.compare %x : !llvm.ptr {
+ ^bb0(%xval : i32):
+ %cmp0 = llvm.icmp "eq" %xval, %e : i32
+ %sel0 = llvm.select %cmp0, %d, %xval : i1, i32
+ omp.yield(%sel0 : i32)
+ }
+
+ // Float equality → bitcast + cmpxchg
+ // CHECK: %[[EBC:.*]] = bitcast float %[[EF]] to i32
+ // CHECK: %[[DBC:.*]] = bitcast float %[[DF]] to i32
+ // CHECK: cmpxchg ptr %[[XF]], i32 %[[EBC]], i32 %[[DBC]] monotonic monotonic
+ omp.atomic.compare %xf : !llvm.ptr {
+ ^bb0(%xval : f32):
+ %cmp1 = llvm.fcmp "oeq" %xval, %ef : f32
+ %sel1 = llvm.select %cmp1, %df, %xval : i1, f32
+ omp.yield(%sel1 : f32)
+ }
+
+ // Integer x < e → atomicrmw umax (reversed, unsigned)
+ // CHECK: atomicrmw umax ptr %[[X]], i32 %[[E]] monotonic
+ omp.atomic.compare %x : !llvm.ptr {
+ ^bb0(%xval : i32):
+ %cmp2 = llvm.icmp "slt" %xval, %e : i32
+ %sel2 = llvm.select %cmp2, %e, %xval : i1, i32
+ omp.yield(%sel2 : i32)
+ }
+
+ // Integer x > e → atomicrmw umin (reversed, unsigned)
+ // CHECK: atomicrmw umin ptr %[[X]], i32 %[[E]] monotonic
+ omp.atomic.compare %x : !llvm.ptr {
+ ^bb0(%xval : i32):
+ %cmp3 = llvm.icmp "sgt" %xval, %e : i32
+ %sel3 = llvm.select %cmp3, %e, %xval : i1, i32
+ omp.yield(%sel3 : i32)
+ }
+
+ // Float x < e → atomicrmw fmax (reversed)
+ // CHECK: atomicrmw fmax ptr %[[XF]], float %[[EF]] monotonic, align 4
+ omp.atomic.compare %xf : !llvm.ptr {
+ ^bb0(%xval : f32):
+ %cmp4 = llvm.fcmp "olt" %xval, %ef : f32
+ %sel4 = llvm.select %cmp4, %ef, %xval : i1, f32
+ omp.yield(%sel4 : f32)
+ }
+
+ // Float x > e → atomicrmw fmin (reversed)
+ // CHECK: atomicrmw fmin ptr %[[XF]], float %[[EF]] monotonic, align 4
+ omp.atomic.compare %xf : !llvm.ptr {
+ ^bb0(%xval : f32):
+ %cmp5 = llvm.fcmp "ogt" %xval, %ef : f32
+ %sel5 = llvm.select %cmp5, %ef, %xval : i1, f32
+ omp.yield(%sel5 : f32)
+ }
+
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: @omp_sections_empty
llvm.func @omp_sections_empty() -> () {
omp.sections {
>From 3669073f727b2f3d9c80a80f9d74e3c46e3932c2 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe31.hpc.amslabs.hpecorp.net>
Date: Thu, 5 Mar 2026 03:27:10 -0600
Subject: [PATCH 2/8] Minor changes to OpenMP "atomic compare".
---
flang/lib/Lower/OpenMP/Atomic.cpp | 17 ++++++-----------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 10 ++++++----
2 files changed, 12 insertions(+), 15 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index 87d354d504656..78c2f618fbad4 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -568,17 +568,13 @@ void Fortran::lower::omp::lowerAtomic(
// Check for compound clauses (fail, capture, weak) that are not yet
// supported with atomic compare.
- bool hasCompoundClause = false;
- for (const omp::Clause &clause : clauses) {
- if (clause.id == llvm::omp::Clause::OMPC_fail ||
- clause.id == llvm::omp::Clause::OMPC_capture ||
- clause.id == llvm::omp::Clause::OMPC_weak) {
- hasCompoundClause = true;
- break;
- }
- }
- if (hasCompoundClause)
+ if (llvm::any_of(clauses, [](const omp::Clause &clause) {
+ return clause.id == llvm::omp::Clause::OMPC_fail ||
+ clause.id == llvm::omp::Clause::OMPC_capture ||
+ clause.id == llvm::omp::Clause::OMPC_weak;
+ })) {
TODO(loc, "Compound clauses of OpenMP ATOMIC COMPARE");
+ }
Fortran::common::RelationalOperator relOpr =
Fortran::common::RelationalOperator::EQ;
@@ -597,7 +593,6 @@ void Fortran::lower::omp::lowerAtomic(
rel->u);
}
if (!expectedExprStorage) {
- // The condition expression exists but isn't a recognized relational form.
mlir::emitError(loc, "internal error: atomic compare condition is not a "
"recognized relational expression");
return;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index b40cf8e02e5c1..939530e74a2e5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4059,10 +4059,12 @@ convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
/// %sel = llvm.select %cmp, %d, %xval : i1, T
/// omp.yield(%sel : T)
///
-/// This function walks the MLIR region to extract the comparison predicate,
-/// the expected value (e), and the desired value (d), then delegates to
-/// OpenMPIRBuilder::createAtomicCompare which generates the actual
-/// cmpxchg / atomicrmw instruction.
+/// From MLIR extract:
+/// 1) comparison operator
+/// 2) expected value (e)
+/// 3) desired value (d)
+/// These are passed to OpenMPIRBuilder::createAtomicCompare which generates
+/// the actual cmpxchg / atomicrmw instruction.
///
static LogicalResult
convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
>From d607043b0e5aad65ee1c28bfebaf31062817ce57 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe31.hpc.amslabs.hpecorp.net>
Date: Fri, 13 Mar 2026 15:20:08 -0500
Subject: [PATCH 3/8] Taking care of the feedback from Tom Eccles.
---
flang/include/flang/Lower/ConvertType.h | 8 +
flang/lib/Lower/ConvertExprToHLFIR.cpp | 60 +-
flang/lib/Lower/ConvertType.cpp | 66 +++
flang/lib/Lower/OpenMP/Atomic.cpp | 75 ++-
.../Lower/OpenMP/Todo/atomic-compare-fail.f90 | 19 +-
.../Interfaces/AtomicInterfaces.h | 1 +
.../Interfaces/AtomicInterfaces.td | 156 ++++++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 2 +-
.../OpenACCMPCommon/Interfaces/CMakeLists.txt | 2 +
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 14 +-
mlir/test/Dialect/OpenMP/invalid.mlir | 517 ++++++++++++++++++
11 files changed, 820 insertions(+), 100 deletions(-)
diff --git a/flang/include/flang/Lower/ConvertType.h b/flang/include/flang/Lower/ConvertType.h
index 3c726595c0f76..16b299e2ab319 100644
--- a/flang/include/flang/Lower/ConvertType.h
+++ b/flang/include/flang/Lower/ConvertType.h
@@ -23,6 +23,7 @@
#include "flang/Evaluate/type.h"
#include "flang/Support/Fortran.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinTypes.h"
namespace mlir {
@@ -130,6 +131,13 @@ class ComponentReverseIterator {
name_iterator componentIt{};
name_iterator componentItEnd{};
};
+
+mlir::arith::CmpIPredicate
+translateSignedRelational(Fortran::common::RelationalOperator rop);
+mlir::arith::CmpIPredicate
+translateUnsignedRelational(Fortran::common::RelationalOperator rop);
+mlir::arith::CmpFPredicate
+translateFloatRelational(Fortran::common::RelationalOperator rop);
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 0c015bc9a2f1b..2b5fd0b63e633 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1191,52 +1191,6 @@ translateSignedRelational(Fortran::common::RelationalOperator rop) {
llvm_unreachable("unhandled INTEGER relational operator");
}
-static mlir::arith::CmpIPredicate
-translateUnsignedRelational(Fortran::common::RelationalOperator rop) {
- switch (rop) {
- case Fortran::common::RelationalOperator::LT:
- return mlir::arith::CmpIPredicate::ult;
- case Fortran::common::RelationalOperator::LE:
- return mlir::arith::CmpIPredicate::ule;
- case Fortran::common::RelationalOperator::EQ:
- return mlir::arith::CmpIPredicate::eq;
- case Fortran::common::RelationalOperator::NE:
- return mlir::arith::CmpIPredicate::ne;
- case Fortran::common::RelationalOperator::GT:
- return mlir::arith::CmpIPredicate::ugt;
- case Fortran::common::RelationalOperator::GE:
- return mlir::arith::CmpIPredicate::uge;
- }
- llvm_unreachable("unhandled UNSIGNED relational operator");
-}
-
-/// Convert parser's REAL relational operators to MLIR.
-/// The choice of order (O prefix) vs unorder (U prefix) follows Fortran 2018
-/// requirements in the IEEE context (table 17.1 of F2018). This choice is
-/// also applied in other contexts because it is easier and in line with
-/// other Fortran compilers.
-/// FIXME: The signaling/quiet aspect of the table 17.1 requirement is not
-/// fully enforced. FIR and LLVM `fcmp` instructions do not give any guarantee
-/// whether the comparison will signal or not in case of quiet NaN argument.
-static mlir::arith::CmpFPredicate
-translateFloatRelational(Fortran::common::RelationalOperator rop) {
- switch (rop) {
- case Fortran::common::RelationalOperator::LT:
- return mlir::arith::CmpFPredicate::OLT;
- case Fortran::common::RelationalOperator::LE:
- return mlir::arith::CmpFPredicate::OLE;
- case Fortran::common::RelationalOperator::EQ:
- return mlir::arith::CmpFPredicate::OEQ;
- case Fortran::common::RelationalOperator::NE:
- return mlir::arith::CmpFPredicate::UNE;
- case Fortran::common::RelationalOperator::GT:
- return mlir::arith::CmpFPredicate::OGT;
- case Fortran::common::RelationalOperator::GE:
- return mlir::arith::CmpFPredicate::OGE;
- }
- llvm_unreachable("unhandled REAL relational operator");
-}
-
template <int KIND>
struct BinaryOp<Fortran::evaluate::Relational<
Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer, KIND>>> {
@@ -1247,7 +1201,8 @@ struct BinaryOp<Fortran::evaluate::Relational<
const Op &op, hlfir::Entity lhs,
hlfir::Entity rhs) {
auto cmp = mlir::arith::CmpIOp::create(
- builder, loc, translateSignedRelational(op.opr), lhs, rhs);
+ builder, loc, Fortran::lower::translateSignedRelational(op.opr), lhs,
+ rhs);
return hlfir::EntityWithAttributes{cmp};
}
};
@@ -1269,7 +1224,8 @@ struct BinaryOp<Fortran::evaluate::Relational<
mlir::Value lhsSL = builder.createConvert(loc, signlessType, lhs);
mlir::Value rhsSL = builder.createConvert(loc, signlessType, rhs);
auto cmp = mlir::arith::CmpIOp::create(
- builder, loc, translateUnsignedRelational(op.opr), lhsSL, rhsSL);
+ builder, loc, Fortran::lower::translateUnsignedRelational(op.opr),
+ lhsSL, rhsSL);
return hlfir::EntityWithAttributes{cmp};
}
};
@@ -1284,7 +1240,8 @@ struct BinaryOp<Fortran::evaluate::Relational<
const Op &op, hlfir::Entity lhs,
hlfir::Entity rhs) {
auto cmp = mlir::arith::CmpFOp::create(
- builder, loc, translateFloatRelational(op.opr), lhs, rhs);
+ builder, loc, Fortran::lower::translateFloatRelational(op.opr), lhs,
+ rhs);
return hlfir::EntityWithAttributes{cmp};
}
};
@@ -1298,8 +1255,9 @@ struct BinaryOp<Fortran::evaluate::Relational<
fir::FirOpBuilder &builder,
const Op &op, hlfir::Entity lhs,
hlfir::Entity rhs) {
- auto cmp = fir::CmpcOp::create(builder, loc,
- translateFloatRelational(op.opr), lhs, rhs);
+ auto cmp = fir::CmpcOp::create(
+ builder, loc, Fortran::lower::translateFloatRelational(op.opr), lhs,
+ rhs);
return hlfir::EntityWithAttributes{cmp};
}
};
diff --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp
index 0d343968374f0..a3c978c00769b 100644
--- a/flang/lib/Lower/ConvertType.cpp
+++ b/flang/lib/Lower/ConvertType.cpp
@@ -699,3 +699,69 @@ void Fortran::lower::ComponentReverseIterator::setCurrentType(
using namespace Fortran::evaluate;
using namespace Fortran::common;
FOR_EACH_SPECIFIC_TYPE(template class Fortran::lower::TypeBuilder, )
+
+/// Convert parser's INTEGER relational operators to MLIR.
+mlir::arith::CmpIPredicate Fortran::lower::translateSignedRelational(
+ Fortran::common::RelationalOperator rop) {
+ switch (rop) {
+ case Fortran::common::RelationalOperator::LT:
+ return mlir::arith::CmpIPredicate::slt;
+ case Fortran::common::RelationalOperator::LE:
+ return mlir::arith::CmpIPredicate::sle;
+ case Fortran::common::RelationalOperator::EQ:
+ return mlir::arith::CmpIPredicate::eq;
+ case Fortran::common::RelationalOperator::NE:
+ return mlir::arith::CmpIPredicate::ne;
+ case Fortran::common::RelationalOperator::GT:
+ return mlir::arith::CmpIPredicate::sgt;
+ case Fortran::common::RelationalOperator::GE:
+ return mlir::arith::CmpIPredicate::sge;
+ }
+ llvm_unreachable("unhandled INTEGER relational operator");
+}
+
+mlir::arith::CmpIPredicate Fortran::lower::translateUnsignedRelational(
+ Fortran::common::RelationalOperator rop) {
+ switch (rop) {
+ case Fortran::common::RelationalOperator::LT:
+ return mlir::arith::CmpIPredicate::ult;
+ case Fortran::common::RelationalOperator::LE:
+ return mlir::arith::CmpIPredicate::ule;
+ case Fortran::common::RelationalOperator::EQ:
+ return mlir::arith::CmpIPredicate::eq;
+ case Fortran::common::RelationalOperator::NE:
+ return mlir::arith::CmpIPredicate::ne;
+ case Fortran::common::RelationalOperator::GT:
+ return mlir::arith::CmpIPredicate::ugt;
+ case Fortran::common::RelationalOperator::GE:
+ return mlir::arith::CmpIPredicate::uge;
+ }
+ llvm_unreachable("unhandled UNSIGNED relational operator");
+}
+
+/// Convert parser's REAL relational operators to MLIR.
+/// The choice of order (O prefix) vs unorder (U prefix) follows Fortran 2018
+/// requirements in the IEEE context (table 17.1 of F2018). This choice is
+/// also applied in other contexts because it is easier and in line with
+/// other Fortran compilers.
+/// FIXME: The signaling/quiet aspect of the table 17.1 requirement is not
+/// fully enforced. FIR and LLVM `fcmp` instructions do not give any guarantee
+/// whether the comparison will signal or not in case of quiet NaN argument.
+mlir::arith::CmpFPredicate Fortran::lower::translateFloatRelational(
+ Fortran::common::RelationalOperator rop) {
+ switch (rop) {
+ case Fortran::common::RelationalOperator::LT:
+ return mlir::arith::CmpFPredicate::OLT;
+ case Fortran::common::RelationalOperator::LE:
+ return mlir::arith::CmpFPredicate::OLE;
+ case Fortran::common::RelationalOperator::EQ:
+ return mlir::arith::CmpFPredicate::OEQ;
+ case Fortran::common::RelationalOperator::NE:
+ return mlir::arith::CmpFPredicate::UNE;
+ case Fortran::common::RelationalOperator::GT:
+ return mlir::arith::CmpFPredicate::OGT;
+ case Fortran::common::RelationalOperator::GE:
+ return mlir::arith::CmpFPredicate::OGE;
+ }
+ llvm_unreachable("unhandled REAL relational operator");
+}
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index 78c2f618fbad4..c5195ee088ad2 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -13,6 +13,7 @@
#include "flang/Evaluate/traverse.h"
#include "flang/Evaluate/type.h"
#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/ConvertType.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
@@ -485,45 +486,24 @@ genAtomicOperation(lower::AbstractConverter &converter,
}
}
-/// Map a Fortran relational operator to an MLIR integer comparison predicate.
-static mlir::arith::CmpIPredicate
-mapRelationalOpToIntPredicate(Fortran::common::RelationalOperator relOpr) {
- switch (relOpr) {
- case Fortran::common::RelationalOperator::EQ:
- return mlir::arith::CmpIPredicate::eq;
- case Fortran::common::RelationalOperator::NE:
- return mlir::arith::CmpIPredicate::ne;
- case Fortran::common::RelationalOperator::LT:
- return mlir::arith::CmpIPredicate::slt;
- case Fortran::common::RelationalOperator::LE:
- return mlir::arith::CmpIPredicate::sle;
- case Fortran::common::RelationalOperator::GT:
- return mlir::arith::CmpIPredicate::sgt;
- case Fortran::common::RelationalOperator::GE:
- return mlir::arith::CmpIPredicate::sge;
- }
- llvm_unreachable("unexpected relational operator");
-}
-
-/// Map a Fortran relational operator to an MLIR floating-point comparison
-/// predicate (ordered).
-static mlir::arith::CmpFPredicate
-mapRelationalOpToFPPredicate(Fortran::common::RelationalOperator relOpr) {
- switch (relOpr) {
- case Fortran::common::RelationalOperator::EQ:
- return mlir::arith::CmpFPredicate::OEQ;
- case Fortran::common::RelationalOperator::NE:
- return mlir::arith::CmpFPredicate::ONE;
- case Fortran::common::RelationalOperator::LT:
- return mlir::arith::CmpFPredicate::OLT;
- case Fortran::common::RelationalOperator::LE:
- return mlir::arith::CmpFPredicate::OLE;
- case Fortran::common::RelationalOperator::GT:
- return mlir::arith::CmpFPredicate::OGT;
- case Fortran::common::RelationalOperator::GE:
- return mlir::arith::CmpFPredicate::OGE;
+/// Reverse a relational operator as if the operands were swapped.
+/// e.g. LT becomes GT, LE becomes GE. Symmetric operators (EQ, NE)
+/// are returned unchanged.
+static Fortran::common::RelationalOperator
+reverseRelOp(Fortran::common::RelationalOperator op) {
+ using RO = Fortran::common::RelationalOperator;
+ switch (op) {
+ case RO::LT:
+ return RO::GT;
+ case RO::LE:
+ return RO::GE;
+ case RO::GT:
+ return RO::LT;
+ case RO::GE:
+ return RO::LE;
+ default:
+ return op;
}
- llvm_unreachable("unexpected relational operator");
}
void Fortran::lower::omp::lowerAtomic(
@@ -587,8 +567,21 @@ void Fortran::lower::omp::lowerAtomic(
[&](const auto &relImpl) {
relOpr = relImpl.opr;
using Operand = typename std::decay_t<decltype(relImpl)>::Operand;
- expectedExprStorage = Fortran::evaluate::AsGenericExpr(
+ auto leftExpr = Fortran::evaluate::AsGenericExpr(
+ Fortran::evaluate::Expr<Operand>{relImpl.left()});
+ auto rightExpr = Fortran::evaluate::AsGenericExpr(
Fortran::evaluate::Expr<Operand>{relImpl.right()});
+ if (Fortran::evaluate::IsSameOrConvertOf(rightExpr, atom)) {
+ // e.g. e == x (atom is on the right)
+ // left operand is expected value (e)
+ // reverse the operator so that the comparison becomes
+ // x <reversed-op> e.
+ expectedExprStorage = std::move(leftExpr);
+ relOpr = reverseRelOp(relOpr);
+ } else {
+ // Form: x == e (atom is on the left, or default)
+ expectedExprStorage = std::move(rightExpr);
+ }
},
rel->u);
}
@@ -616,11 +609,11 @@ void Fortran::lower::omp::lowerAtomic(
// Generate comparison: e.g. x == e
mlir::Value cmpResult;
if (mlir::isa<mlir::IntegerType>(elemTypeOfX)) {
- auto pred = mapRelationalOpToIntPredicate(relOpr);
+ auto pred = Fortran::lower::translateSignedRelational(relOpr);
cmpResult = mlir::arith::CmpIOp::create(builder, loc, pred, blockArg,
expectedVal);
} else if (mlir::isa<mlir::FloatType>(elemTypeOfX)) {
- auto pred = mapRelationalOpToFPPredicate(relOpr);
+ auto pred = Fortran::lower::translateFloatRelational(relOpr);
cmpResult = mlir::arith::CmpFOp::create(builder, loc, pred, blockArg,
expectedVal);
} else {
diff --git a/flang/test/Lower/OpenMP/Todo/atomic-compare-fail.f90 b/flang/test/Lower/OpenMP/Todo/atomic-compare-fail.f90
index 3369a6223cf73..27c7b1fb1a432 100644
--- a/flang/test/Lower/OpenMP/Todo/atomic-compare-fail.f90
+++ b/flang/test/Lower/OpenMP/Todo/atomic-compare-fail.f90
@@ -3,9 +3,26 @@
! CHECK: not yet implemented: Compound clauses of OpenMP ATOMIC COMPARE
program p
integer :: x
- logical :: r
+ integer :: r
+ integer :: d
+ integer :: v
!$omp atomic compare fail(relaxed)
if (x .eq. 0) then
x = 2
end if
+ !$omp end atomic
+
+ !$omp atomic compare capture
+ v = x
+ if (x > r) then
+ x = d
+ end if
+ !$omp end atomic
+
+ !$omp atomic compare fail(relaxed)
+ if (x > r) then
+ x = d
+ end if
+ !$omp end atomic
+
end program p
diff --git a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h
index cfe0ec5185bc8..c27ec7cf29c74 100644
--- a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h
+++ b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h
@@ -14,6 +14,7 @@
#ifndef OPENACC_MP_COMMON_INTERFACES_ATOMICINTERFACES_H_
#define OPENACC_MP_COMMON_INTERFACES_ATOMICINTERFACES_H_
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
index 223bee9ab1c27..63f280606126b 100644
--- a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
@@ -317,4 +317,160 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
];
}
+def AtomicCompareOpInterface : OpInterface<"AtomicCompareOpInterface"> {
+ let description = [{
+ This interface is used for OpenMP dialect operation that performs an
+ atomic compare.
+
+ The interface terminology uses `x`, `e`, and `d` like the directive
+ specifications:
+ `if (x == e) x = d`
+ `x` is the address of the variable that is being compared and updated.
+ The region describes the comparison and update logic. It takes
+ the current value of `x` as a single block argument.
+
+ The region has the following structure:
+ ```
+ atomic.compare {
+ ^bb0(%val_x):
+ <compare %val_x with e>
+ <conditionally yield d or %val_x>
+ }
+ ```
+ }];
+ let cppNamespace = "::mlir::accomp";
+
+ let methods = [
+ InterfaceMethod<[{
+ Obtains `x` which is the address of the variable that is being
+ compared and potentially updated.
+ }],
+ /*retTy=*/"::mlir::Value",
+ /*methodName=*/"getX",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<[{
+ Returns the first operation in the atomic compare region.
+ }],
+ /*retTy=*/"::mlir::Operation *",
+ /*methodName=*/"getFirstOp",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return &($_op.getRegion().front().getOperations().front());
+ }]
+ >,
+ InterfaceMethod<[{
+ Common verifier for operation that implements atomic compare interface.
+ }],
+ /*retTy=*/"::llvm::LogicalResult",
+ /*methodName=*/"verifyCommon",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if ($_op.getRegion().getNumArguments() != 1)
+ return $_op.emitError("the region must accept exactly one argument");
+
+ Type elementType = $_op.getX().getType().getElementType();
+ if (elementType && elementType != $_op.getRegion().getArgument(0).getType()) {
+ return $_op.emitError("the type of the operand must be a pointer type whose "
+ "element type is the same as that of the region argument");
+ }
+
+ return mlir::success();
+ }]
+ >,
+ InterfaceMethod<[{
+ Common verifier of the required region for operation that implements
+ atomic compare interface.
+ }],
+ /*retTy=*/"::llvm::LogicalResult",
+ /*methodName=*/"verifyRegionsCommon",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ mlir::Region ®ion = $_op.getRegion();
+ if (region.empty())
+ return $_op.emitError(
+ "region for atomic compare must not be empty");
+
+ mlir::Block &block = region.front();
+ if (block.empty())
+ return $_op.emitError(
+ "region body for atomic compare must not be empty");
+
+ // The region must contain at least a comparison operation and a
+ // terminator. A region with only a terminator is missing the
+ // required comparison logic.
+ if (block.getOperations().size() < 2)
+ return $_op.emitError(
+ "region must contain a comparison operation");
+
+ return mlir::success();
+ }]
+ >,
+ InterfaceMethod<[{
+ Common verifier for operator that implements atomic compare interface.
+ Checks that the comparison operation in the region uses:
+ 1) supported predicate for integer comparison : eq, slt, or sgt
+ 2) supported predicate for float comparison : oeq, oglt or lsgt
+ }],
+ /*retTy=*/"::llvm::LogicalResult",
+ /*methodName=*/"verifyOperator",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ mlir::Region ®ion = $_op.getRegion();
+ if (region.empty())
+ return $_op.emitError(
+ "region for atomic compare must not be empty");
+
+ mlir::Block &block = region.front();
+ bool foundComparison = false;
+ for (mlir::Operation &op : block.getOperations()) {
+ llvm::StringRef opName = op.getName().getStringRef();
+ if (opName == "arith.cmpi" || opName == "llvm.icmp") {
+ foundComparison = true;
+ auto predAttr = op.getAttrOfType<mlir::IntegerAttr>("predicate");
+ if (predAttr) {
+ auto predName = mlir::arith::stringifyCmpIPredicate(
+ static_cast<mlir::arith::CmpIPredicate>(predAttr.getInt()));
+ if (predName != "eq" && predName != "slt" && predName != "sgt") {
+ return $_op.emitError(
+ "unsupported comparison operator '")
+ << predName
+ << "' in atomic compare region, "
+ "supported operators are: eq, slt, sgt";
+ }
+ }
+ break;
+ } else if (opName == "arith.cmpf" || opName == "llvm.fcmp") {
+ foundComparison = true;
+ auto predAttr = op.getAttrOfType<mlir::IntegerAttr>("predicate");
+ if (predAttr) {
+ auto predName = mlir::arith::stringifyCmpFPredicate(
+ static_cast<mlir::arith::CmpFPredicate>(predAttr.getInt()));
+ if (predName != "oeq" && predName != "ogt" && predName != "olt") {
+ return $_op.emitError(
+ "unsupported comparison operator '")
+ << predName
+ << "' in atomic compare region, "
+ "supported operators are: oeq, ogt, olt";
+ }
+ }
+ break;
+ }
+ }
+
+ if (!foundComparison)
+ return $_op.emitError(
+ "atomic compare region must contain a comparison operation "
+ "(arith.cmpi or llvm.icmp)");
+
+ return mlir::success();
+ }]
+ >,
+ ];
+}
+
#endif // OPENACC_MP_COMMON_INTERFACES_ATOMICINTERFACES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index c6ad7872ba17d..e3fa552cf3bef 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1847,7 +1847,7 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
//===----------------------------------------------------------------------===//
def AtomicCompareOp : OpenMP_Op<"atomic.compare", traits = [
- RecursiveMemoryEffects,
+ AtomicCompareOpInterface, RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"YieldOp">
], clauses = [
OpenMP_HintClause, OpenMP_MemoryOrderClause
diff --git a/mlir/lib/Dialect/OpenACCMPCommon/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/OpenACCMPCommon/Interfaces/CMakeLists.txt
index 6da04424231aa..2a82ff2150287 100644
--- a/mlir/lib/Dialect/OpenACCMPCommon/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACCMPCommon/Interfaces/CMakeLists.txt
@@ -5,8 +5,10 @@ ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACCMPCommon/Interfaces
DEPENDS
+MLIRArithOpsIncGen
MLIRAtomicInterfacesIncGen
LINK_LIBS PUBLIC
+MLIRArithDialect
MLIRIR
)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 44976d087e7aa..e5084022f51bb 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4114,17 +4114,19 @@ LogicalResult AtomicCaptureOp::verifyRegions() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicCompareOp::verify() {
+ if (verifyCommon().failed())
+ return mlir::failure();
return verifySynchronizationHint(*this, getHint());
}
LogicalResult AtomicCompareOp::verifyRegions() {
- Region ®ion = getRegion();
- if (region.empty())
- return emitOpError("region for atomic compare must not be empty");
+ if (verifyRegionsCommon().failed())
+ return mlir::failure();
+
+ if (verifyOperator().failed())
+ return mlir::failure();
- Block &block = region.front();
- if (block.empty())
- return emitOpError("region body for atomic compare must not be empty");
+ Block &block = getRegion().front();
Operation *terminator = block.getTerminator();
if (!terminator || !isa<YieldOp>(terminator))
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index bc508d66fbd5f..bdbfa0cc5a334 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1420,6 +1420,523 @@ func.func @omp_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
// -----
+func.func @omp_atomic_compare_no_block_arg(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{the region must accept exactly one argument}}
+ omp.atomic.compare %x : memref<i32> {
+ omp.yield
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_empty_region(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{region must contain a comparison operation}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ omp.yield(%xval : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_hint(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}}
+ omp.atomic.compare hint(contended, uncontended) %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = llvm.icmp "eq" %xval, %e : i32
+ %sel = llvm.select %cmp, %d, %xval : i1, i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_hint2(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{the hints omp_sync_hint_nonspeculative and omp_sync_hint_speculative cannot be combined}}
+ omp.atomic.compare hint(nonspeculative, speculative) %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = llvm.icmp "eq" %xval, %e : i32
+ %sel = llvm.select %cmp, %d, %xval : i1, i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+// float comparison operators mentionend in ArithBase.td not permitted for
+// !omp atomic compare
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'one' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = arith.cmpf one, %xval, %e : f32
+ %sel = arith.select %cmp, %d, %xval : f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f64>, %e: f64, %d: f64) {
+ // expected-error @below {{unsupported comparison operator 'one' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f64> {
+ ^bb0(%xval: f64):
+ %cmp = llvm.fcmp "one" %xval, %e : f64
+ %sel = llvm.select %cmp, %d, %xval : i1, f64
+ omp.yield(%sel : f64)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'oge' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = arith.cmpf oge, %xval, %e : f32
+ %sel = arith.select %cmp, %d, %xval : f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f64>, %e: f64, %d: f64) {
+ // expected-error @below {{unsupported comparison operator 'oge' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f64> {
+ ^bb0(%xval: f64):
+ %cmp = llvm.fcmp "oge" %xval, %e : f64
+ %sel = llvm.select %cmp, %d, %xval : i1, f64
+ omp.yield(%sel : f64)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f64>, %e: f64, %d: f64) {
+ // expected-error @below {{unsupported comparison operator 'ole' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f64> {
+ ^bb0(%xval: f64):
+ %cmp = arith.cmpf ole, %xval, %e : f64
+ %sel = arith.select %cmp, %d, %xval : f64
+ omp.yield(%sel : f64)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'ole' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = llvm.fcmp "ole" %xval, %e : f32
+ %sel = llvm.select %cmp, %d, %xval : i1, f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f64>, %e: f64, %d: f64) {
+ // expected-error @below {{unsupported comparison operator 'ord' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f64> {
+ ^bb0(%xval: f64):
+ %cmp = arith.cmpf ord, %xval, %e : f64
+ %sel = arith.select %cmp, %d, %xval : f64
+ omp.yield(%sel : f64)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'ord' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = llvm.fcmp "ord" %xval, %e : f32
+ %sel = llvm.select %cmp, %d, %xval : i1, f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'ueq' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = arith.cmpf ueq, %xval, %e : f32
+ %sel = arith.select %cmp, %d, %xval : f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f64>, %e: f64, %d: f64) {
+ // expected-error @below {{unsupported comparison operator 'ueq' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f64> {
+ ^bb0(%xval: f64):
+ %cmp = llvm.fcmp "ueq" %xval, %e : f64
+ %sel = llvm.select %cmp, %d, %xval : i1, f64
+ omp.yield(%sel : f64)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'ugt' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = arith.cmpf ugt, %xval, %e : f32
+ %sel = arith.select %cmp, %d, %xval : f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f64>, %e: f64, %d: f64) {
+ // expected-error @below {{unsupported comparison operator 'ugt' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f64> {
+ ^bb0(%xval: f64):
+ %cmp = llvm.fcmp "ugt" %xval, %e : f64
+ %sel = llvm.select %cmp, %d, %xval : i1, f64
+ omp.yield(%sel : f64)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'uge' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = arith.cmpf uge, %xval, %e : f32
+ %sel = arith.select %cmp, %d, %xval : f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'uge' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = llvm.fcmp "uge" %xval, %e : f32
+ %sel = llvm.select %cmp, %d, %xval : i1, f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f64>, %e: f64, %d: f64) {
+ // expected-error @below {{unsupported comparison operator 'ult' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f64> {
+ ^bb0(%xval: f64):
+ %cmp = arith.cmpf ult, %xval, %e : f64
+ %sel = arith.select %cmp, %d, %xval : f64
+ omp.yield(%sel : f64)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'ult' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = llvm.fcmp "ult" %xval, %e : f32
+ %sel = llvm.select %cmp, %d, %xval : i1, f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'ule' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = arith.cmpf ule, %xval, %e : f32
+ %sel = arith.select %cmp, %d, %xval : f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f64>, %e: f64, %d: f64) {
+ // expected-error @below {{unsupported comparison operator 'ule' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f64> {
+ ^bb0(%xval: f64):
+ %cmp = llvm.fcmp "ule" %xval, %e : f64
+ %sel = llvm.select %cmp, %d, %xval : i1, f64
+ omp.yield(%sel : f64)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'une' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = arith.cmpf une, %xval, %e : f32
+ %sel = arith.select %cmp, %d, %xval : f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'une' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = llvm.fcmp "une" %xval, %e : f32
+ %sel = llvm.select %cmp, %d, %xval : i1, f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpf_predicate(%x: memref<f64>, %e: f64, %d: f64) {
+ // expected-error @below {{unsupported comparison operator 'uno' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f64> {
+ ^bb0(%xval: f64):
+ %cmp = arith.cmpf uno, %xval, %e : f64
+ %sel = arith.select %cmp, %d, %xval : f64
+ omp.yield(%sel : f64)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_fcmp_predicate(%x: memref<f32>, %e: f32, %d: f32) {
+ // expected-error @below {{unsupported comparison operator 'uno' in atomic compare region, supported operators are: oeq, ogt, olt}}
+ omp.atomic.compare %x : memref<f32> {
+ ^bb0(%xval: f32):
+ %cmp = llvm.fcmp "uno" %xval, %e : f32
+ %sel = llvm.select %cmp, %d, %xval : i1, f32
+ omp.yield(%sel : f32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpi_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'ne' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = arith.cmpi ne, %xval, %e : i32
+ %sel = arith.select %cmp, %d, %xval : i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_icmp_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'ne' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = llvm.icmp "ne" %xval, %e : i32
+ %sel = llvm.select %cmp, %d, %xval : i1, i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpi_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'sle' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = arith.cmpi sle, %xval, %e : i32
+ %sel = arith.select %cmp, %d, %xval : i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_icmp_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'sle' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = llvm.icmp "sle" %xval, %e : i32
+ %sel = llvm.select %cmp, %d, %xval : i1, i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpi_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'sge' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = arith.cmpi sge, %xval, %e : i32
+ %sel = arith.select %cmp, %d, %xval : i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_icmp_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'sge' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = llvm.icmp "sge" %xval, %e : i32
+ %sel = llvm.select %cmp, %d, %xval : i1, i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpi_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'ult' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = arith.cmpi ult, %xval, %e : i32
+ %sel = arith.select %cmp, %d, %xval : i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_icmp_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'ult' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = llvm.icmp "ult" %xval, %e : i32
+ %sel = llvm.select %cmp, %d, %xval : i1, i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpi_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'ule' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = arith.cmpi ule, %xval, %e : i32
+ %sel = arith.select %cmp, %d, %xval : i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_icmp_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'ule' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = llvm.icmp "ule" %xval, %e : i32
+ %sel = llvm.select %cmp, %d, %xval : i1, i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpi_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'ugt' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = arith.cmpi ugt, %xval, %e : i32
+ %sel = arith.select %cmp, %d, %xval : i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_icmp_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'ugt' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = llvm.icmp "ugt" %xval, %e : i32
+ %sel = llvm.select %cmp, %d, %xval : i1, i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_cmpi_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'uge' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = arith.cmpi uge, %xval, %e : i32
+ %sel = arith.select %cmp, %d, %xval : i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
+func.func @omp_atomic_compare_invalid_icmp_predicate(%x: memref<i32>, %e: i32, %d: i32) {
+ // expected-error @below {{unsupported comparison operator 'uge' in atomic compare region, supported operators are: eq, slt, sgt}}
+ omp.atomic.compare %x : memref<i32> {
+ ^bb0(%xval: i32):
+ %cmp = llvm.icmp "uge" %xval, %e : i32
+ %sel = llvm.select %cmp, %d, %xval : i1, i32
+ omp.yield(%sel : i32)
+ }
+ return
+}
+
+// -----
+
func.func @omp_teams_parent() {
omp.parallel {
// expected-error @below {{expected to be nested inside of omp.target or not nested in any OpenMP dialect operations}}
>From 3e324996754a530043993e701233a4351b140287 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe31.hpc.amslabs.hpecorp.net>
Date: Tue, 17 Mar 2026 11:39:54 -0500
Subject: [PATCH 4/8] 1) Moving load, in Atomic.cpp, before atomic op. 2) Since
namespace Fortran is used Fortran:: qualification is removed.
---
flang/lib/Lower/ConvertType.cpp | 48 +++++++++----------
flang/lib/Lower/OpenMP/Atomic.cpp | 42 ++++++++--------
flang/test/Lower/OpenMP/atomic-compare.f90 | 12 ++---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 11 +++--
4 files changed, 57 insertions(+), 56 deletions(-)
diff --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp
index a3c978c00769b..fe82b14b3390d 100644
--- a/flang/lib/Lower/ConvertType.cpp
+++ b/flang/lib/Lower/ConvertType.cpp
@@ -701,39 +701,39 @@ using namespace Fortran::common;
FOR_EACH_SPECIFIC_TYPE(template class Fortran::lower::TypeBuilder, )
/// Convert parser's INTEGER relational operators to MLIR.
-mlir::arith::CmpIPredicate Fortran::lower::translateSignedRelational(
- Fortran::common::RelationalOperator rop) {
+mlir::arith::CmpIPredicate
+Fortran::lower::translateSignedRelational(RelationalOperator rop) {
switch (rop) {
- case Fortran::common::RelationalOperator::LT:
+ case RelationalOperator::LT:
return mlir::arith::CmpIPredicate::slt;
- case Fortran::common::RelationalOperator::LE:
+ case RelationalOperator::LE:
return mlir::arith::CmpIPredicate::sle;
- case Fortran::common::RelationalOperator::EQ:
+ case RelationalOperator::EQ:
return mlir::arith::CmpIPredicate::eq;
- case Fortran::common::RelationalOperator::NE:
+ case RelationalOperator::NE:
return mlir::arith::CmpIPredicate::ne;
- case Fortran::common::RelationalOperator::GT:
+ case RelationalOperator::GT:
return mlir::arith::CmpIPredicate::sgt;
- case Fortran::common::RelationalOperator::GE:
+ case RelationalOperator::GE:
return mlir::arith::CmpIPredicate::sge;
}
llvm_unreachable("unhandled INTEGER relational operator");
}
-mlir::arith::CmpIPredicate Fortran::lower::translateUnsignedRelational(
- Fortran::common::RelationalOperator rop) {
+mlir::arith::CmpIPredicate
+Fortran::lower::translateUnsignedRelational(RelationalOperator rop) {
switch (rop) {
- case Fortran::common::RelationalOperator::LT:
+ case RelationalOperator::LT:
return mlir::arith::CmpIPredicate::ult;
- case Fortran::common::RelationalOperator::LE:
+ case RelationalOperator::LE:
return mlir::arith::CmpIPredicate::ule;
- case Fortran::common::RelationalOperator::EQ:
+ case RelationalOperator::EQ:
return mlir::arith::CmpIPredicate::eq;
- case Fortran::common::RelationalOperator::NE:
+ case RelationalOperator::NE:
return mlir::arith::CmpIPredicate::ne;
- case Fortran::common::RelationalOperator::GT:
+ case RelationalOperator::GT:
return mlir::arith::CmpIPredicate::ugt;
- case Fortran::common::RelationalOperator::GE:
+ case RelationalOperator::GE:
return mlir::arith::CmpIPredicate::uge;
}
llvm_unreachable("unhandled UNSIGNED relational operator");
@@ -747,20 +747,20 @@ mlir::arith::CmpIPredicate Fortran::lower::translateUnsignedRelational(
/// FIXME: The signaling/quiet aspect of the table 17.1 requirement is not
/// fully enforced. FIR and LLVM `fcmp` instructions do not give any guarantee
/// whether the comparison will signal or not in case of quiet NaN argument.
-mlir::arith::CmpFPredicate Fortran::lower::translateFloatRelational(
- Fortran::common::RelationalOperator rop) {
+mlir::arith::CmpFPredicate
+Fortran::lower::translateFloatRelational(RelationalOperator rop) {
switch (rop) {
- case Fortran::common::RelationalOperator::LT:
+ case RelationalOperator::LT:
return mlir::arith::CmpFPredicate::OLT;
- case Fortran::common::RelationalOperator::LE:
+ case RelationalOperator::LE:
return mlir::arith::CmpFPredicate::OLE;
- case Fortran::common::RelationalOperator::EQ:
+ case RelationalOperator::EQ:
return mlir::arith::CmpFPredicate::OEQ;
- case Fortran::common::RelationalOperator::NE:
+ case RelationalOperator::NE:
return mlir::arith::CmpFPredicate::UNE;
- case Fortran::common::RelationalOperator::GT:
+ case RelationalOperator::GT:
return mlir::arith::CmpFPredicate::OGT;
- case Fortran::common::RelationalOperator::GE:
+ case RelationalOperator::GE:
return mlir::arith::CmpFPredicate::OGE;
}
llvm_unreachable("unhandled REAL relational operator");
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index c5195ee088ad2..87168cb57a51f 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -489,9 +489,8 @@ genAtomicOperation(lower::AbstractConverter &converter,
/// Reverse a relational operator as if the operands were swapped.
/// e.g. LT becomes GT, LE becomes GE. Symmetric operators (EQ, NE)
/// are returned unchanged.
-static Fortran::common::RelationalOperator
-reverseRelOp(Fortran::common::RelationalOperator op) {
- using RO = Fortran::common::RelationalOperator;
+static common::RelationalOperator reverseRelOp(common::RelationalOperator op) {
+ using RO = common::RelationalOperator;
switch (op) {
case RO::LT:
return RO::GT;
@@ -556,22 +555,21 @@ void Fortran::lower::omp::lowerAtomic(
TODO(loc, "Compound clauses of OpenMP ATOMIC COMPARE");
}
- Fortran::common::RelationalOperator relOpr =
- Fortran::common::RelationalOperator::EQ;
+ common::RelationalOperator relOpr = common::RelationalOperator::EQ;
std::optional<semantics::SomeExpr> expectedExprStorage;
- if (const auto *rel = Fortran::evaluate::UnwrapExpr<
- Fortran::evaluate::Relational<Fortran::evaluate::SomeType>>(
- *cond)) {
+ if (const auto *rel =
+ evaluate::UnwrapExpr<evaluate::Relational<evaluate::SomeType>>(
+ *cond)) {
std::visit(
[&](const auto &relImpl) {
relOpr = relImpl.opr;
using Operand = typename std::decay_t<decltype(relImpl)>::Operand;
- auto leftExpr = Fortran::evaluate::AsGenericExpr(
- Fortran::evaluate::Expr<Operand>{relImpl.left()});
- auto rightExpr = Fortran::evaluate::AsGenericExpr(
- Fortran::evaluate::Expr<Operand>{relImpl.right()});
- if (Fortran::evaluate::IsSameOrConvertOf(rightExpr, atom)) {
+ auto leftExpr = evaluate::AsGenericExpr(
+ evaluate::Expr<Operand>{relImpl.left()});
+ auto rightExpr = evaluate::AsGenericExpr(
+ evaluate::Expr<Operand>{relImpl.right()});
+ if (evaluate::IsSameOrConvertOf(rightExpr, atom)) {
// e.g. e == x (atom is on the right)
// left operand is expected value (e)
// reverse the operator so that the comparison becomes
@@ -591,29 +589,29 @@ void Fortran::lower::omp::lowerAtomic(
return;
}
+ mlir::Type elemTypeOfX = fir::unwrapRefType(atomAddr.getType());
+ mlir::Value expectedVal = fir::getBase(
+ converter.genExprValue(*expectedExprStorage, stmtCtx, &loc));
+ if (expectedVal.getType() != elemTypeOfX) {
+ expectedVal = builder.createConvert(loc, elemTypeOfX, expectedVal);
+ }
+
mlir::UnitAttr weakAttr = nullptr;
mlir::Operation *atomicOp = mlir::omp::AtomicCompareOp::create(
builder, loc, atomAddr, weakAttr, hint,
makeMemOrderAttr(converter, memOrder));
- mlir::Type elemTypeOfX = fir::unwrapRefType(atomAddr.getType());
mlir::Block *block = builder.createBlock(&atomicOp->getRegion(0));
mlir::Value blockArg = block->addArgument(elemTypeOfX, loc);
builder.setInsertionPointToEnd(block);
- mlir::Value expectedVal = fir::getBase(
- converter.genExprValue(*expectedExprStorage, stmtCtx, &loc));
- if (expectedVal.getType() != elemTypeOfX) {
- expectedVal = builder.createConvert(loc, elemTypeOfX, expectedVal);
- }
-
// Generate comparison: e.g. x == e
mlir::Value cmpResult;
if (mlir::isa<mlir::IntegerType>(elemTypeOfX)) {
- auto pred = Fortran::lower::translateSignedRelational(relOpr);
+ auto pred = lower::translateSignedRelational(relOpr);
cmpResult = mlir::arith::CmpIOp::create(builder, loc, pred, blockArg,
expectedVal);
} else if (mlir::isa<mlir::FloatType>(elemTypeOfX)) {
- auto pred = Fortran::lower::translateFloatRelational(relOpr);
+ auto pred = lower::translateFloatRelational(relOpr);
cmpResult = mlir::arith::CmpFOp::create(builder, loc, pred, blockArg,
expectedVal);
} else {
diff --git a/flang/test/Lower/OpenMP/atomic-compare.f90 b/flang/test/Lower/OpenMP/atomic-compare.f90
index e69c0b55d5351..a376935ce2f52 100644
--- a/flang/test/Lower/OpenMP/atomic-compare.f90
+++ b/flang/test/Lower/OpenMP/atomic-compare.f90
@@ -11,9 +11,9 @@
! CHECK: %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<i32> {
! CHECK: ^bb0(%[[XVAL:.*]]: i32):
-! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
! CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[XVAL]], %[[EVAL]] : i32
! CHECK: %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
@@ -32,9 +32,9 @@ subroutine atomic_compare_int_eq(x, e, d)
! CHECK: %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<f32> {
! CHECK: ^bb0(%[[XVAL:.*]]: f32):
-! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
! CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[XVAL]], %[[EVAL]] fastmath<contract> : f32
! CHECK: %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<f32>
! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : f32
@@ -51,9 +51,9 @@ subroutine atomic_compare_float_eq(x, e, d)
! CHECK-SAME: %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"})
! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<i32> {
! CHECK: ^bb0(%[[XVAL:.*]]: i32):
-! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
! CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[XVAL]], %[[EVAL]] : i32
! CHECK: %[[EVAL2:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[EVAL2]], %[[XVAL]] : i32
@@ -70,9 +70,9 @@ subroutine atomic_compare_int_lt(x, e)
! CHECK-SAME: %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"})
! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<i32> {
! CHECK: ^bb0(%[[XVAL:.*]]: i32):
-! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[XVAL]], %[[EVAL]] : i32
! CHECK: %[[EVAL2:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[EVAL2]], %[[XVAL]] : i32
@@ -89,9 +89,9 @@ subroutine atomic_compare_int_gt(x, e)
! CHECK-SAME: %[[E:.*]]: !fir.ref<f32> {fir.bindc_name = "e"})
! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<f32> {
! CHECK: ^bb0(%[[XVAL:.*]]: f32):
-! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
! CHECK: %[[CMP:.*]] = arith.cmpf olt, %[[XVAL]], %[[EVAL]] fastmath<contract> : f32
! CHECK: %[[EVAL2:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[EVAL2]], %[[XVAL]] : f32
@@ -108,9 +108,9 @@ subroutine atomic_compare_float_lt(x, e)
! CHECK-SAME: %[[E:.*]]: !fir.ref<f32> {fir.bindc_name = "e"})
! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<f32> {
! CHECK: ^bb0(%[[XVAL:.*]]: f32):
-! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
! CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[XVAL]], %[[EVAL]] fastmath<contract> : f32
! CHECK: %[[EVAL2:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<f32>
! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[EVAL2]], %[[XVAL]] : f32
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index fa04a2b3d2e25..9d3799cf87e9b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4228,10 +4228,13 @@ convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
// Trace back through load operations and generate load instructions
auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
- llvm::Value *loadAddr = moduleTranslation.lookupValue(loadOp.getAddr());
- llvm::Type *loadType =
- moduleTranslation.convertType(loadOp.getResult().getType());
- return builder.CreateLoad(loadType, loadAddr);
+ if (loadOp->getParentRegion() == ®ion) {
+ llvm::Value *loadAddr =
+ moduleTranslation.lookupValue(loadOp.getAddr());
+ llvm::Type *loadType =
+ moduleTranslation.convertType(loadOp.getResult().getType());
+ return builder.CreateLoad(loadType, loadAddr);
+ }
}
return moduleTranslation.lookupValue(val);
};
>From 7731b4e6a1b412bf696cec66baaa0baa5505709b Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe31.hpc.amslabs.hpecorp.net>
Date: Tue, 17 Mar 2026 12:12:49 -0500
Subject: [PATCH 5/8] Fixing git-clang-format error.
---
.../Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9d3799cf87e9b..f11babc24db6f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4229,8 +4229,7 @@ convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
if (loadOp->getParentRegion() == ®ion) {
- llvm::Value *loadAddr =
- moduleTranslation.lookupValue(loadOp.getAddr());
+ llvm::Value *loadAddr = moduleTranslation.lookupValue(loadOp.getAddr());
llvm::Type *loadType =
moduleTranslation.convertType(loadOp.getResult().getType());
return builder.CreateLoad(loadType, loadAddr);
>From 761b7c035870a86a3d20b34187894fc0b7c725a9 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe31.hpc.amslabs.hpecorp.net>
Date: Mon, 6 Apr 2026 12:26:42 -0500
Subject: [PATCH 6/8] Handling complex variables in omp atomic compare.
---
flang/lib/Lower/OpenMP/Atomic.cpp | 7 +-
flang/test/Lower/OpenMP/atomic-compare.f90 | 23 ++-
.../Interfaces/AtomicInterfaces.td | 17 +-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 156 +++++++++++++++---
4 files changed, 173 insertions(+), 30 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index 87168cb57a51f..7edeed26bc273 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -614,8 +614,13 @@ void Fortran::lower::omp::lowerAtomic(
auto pred = lower::translateFloatRelational(relOpr);
cmpResult = mlir::arith::CmpFOp::create(builder, loc, pred, blockArg,
expectedVal);
+ } else if (fir::isa_complex(elemTypeOfX)) {
+ auto pred = lower::translateFloatRelational(relOpr);
+ cmpResult =
+ fir::CmpcOp::create(builder, loc, pred, blockArg, expectedVal);
} else {
- llvm_unreachable("unsupported type for atomic compare");
+ mlir::emitError(loc, "unsupported type for atomic compare");
+ return;
}
// Check for presence of Assignment (x = d) and wether it is being invoked
diff --git a/flang/test/Lower/OpenMP/atomic-compare.f90 b/flang/test/Lower/OpenMP/atomic-compare.f90
index a376935ce2f52..62388d11f491c 100644
--- a/flang/test/Lower/OpenMP/atomic-compare.f90
+++ b/flang/test/Lower/OpenMP/atomic-compare.f90
@@ -1,5 +1,3 @@
-! REQUIRES: openmp_runtime
-
! This test checks lowering of atomic compare constructs.
! RUN: bbc %openmp_flags -fopenmp-version=51 -emit-hlfir %s -o - | FileCheck %s
! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=51 %s -o - | FileCheck %s
@@ -46,6 +44,27 @@ subroutine atomic_compare_float_eq(x, e, d)
if (x .eq. e) x = d
end
+! CHECK-LABEL: func.func @_QPatomic_compare_complex_eq(
+! CHECK-SAME: %[[X:.*]]: !fir.ref<complex<f32>> {fir.bindc_name = "x"},
+! CHECK-SAME: %[[E:.*]]: !fir.ref<complex<f32>> {fir.bindc_name = "e"},
+! CHECK-SAME: %[[D:.*]]: !fir.ref<complex<f32>> {fir.bindc_name = "d"})
+! CHECK: %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK: %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK: %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<complex<f32>>
+! CHECK: omp.atomic.compare memory_order(relaxed) %[[X_DECL]]#0 : !fir.ref<complex<f32>> {
+! CHECK: ^bb0(%[[XVAL:.*]]: complex<f32>):
+! CHECK: %[[CMP:.*]] = fir.cmpc "oeq", %[[XVAL]], %[[EVAL]] {fastmath = #arith.fastmath<contract>} : complex<f32>
+! CHECK: %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<complex<f32>>
+! CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : complex<f32>
+! CHECK: omp.yield(%[[SEL]] : complex<f32>)
+! CHECK: }
+subroutine atomic_compare_complex_eq(x, e, d)
+ complex :: x, e, d
+ !$omp atomic compare
+ if (x .eq. e) x = d
+end
+
! CHECK-LABEL: func.func @_QPatomic_compare_int_lt(
! CHECK-SAME: %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
! CHECK-SAME: %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"})
diff --git a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
index 63f280606126b..abb21705b3c1c 100644
--- a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
@@ -459,13 +459,28 @@ def AtomicCompareOpInterface : OpInterface<"AtomicCompareOpInterface"> {
}
}
break;
+ } else if (opName == "fir.cmpc") {
+ foundComparison = true;
+ auto predAttr = op.getAttrOfType<mlir::IntegerAttr>("predicate");
+ if (predAttr) {
+ auto predName = mlir::arith::stringifyCmpFPredicate(
+ static_cast<mlir::arith::CmpFPredicate>(predAttr.getInt()));
+ if (predName != "oeq") {
+ return $_op.emitError(
+ "unsupported comparison operator '")
+ << predName
+ << "' in atomic compare region for complex type, "
+ "only 'oeq' is supported";
+ }
+ }
+ break;
}
}
if (!foundComparison)
return $_op.emitError(
"atomic compare region must contain a comparison operation "
- "(arith.cmpi or llvm.icmp)");
+ "(arith.cmpi, arith.cmpf, llvm.icmp, llvm.fcmp, or fir.cmpc)");
return mlir::success();
}]
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f11babc24db6f..d42d74844a2dc 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4245,35 +4245,139 @@ convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
llvm::Value *dVal = nullptr;
bool isXBinopExpr = false;
+ auto traceToAggregate = [](mlir::Value v) -> mlir::Value {
+ if (auto extractOp = v.getDefiningOp<LLVM::ExtractValueOp>())
+ return extractOp.getContainer();
+ return nullptr;
+ };
+
+ // Check for a decomposed complex comparison pattern:
+ // %re_x = llvm.extractvalue %xval[0]
+ // %re_e = llvm.extractvalue %eStruct[0]
+ // %cmp_re = llvm.fcmp "oeq" %re_x, %re_e
+ // %im_x = llvm.extractvalue %xval[1]
+ // %im_e = llvm.extractvalue %eStruct[1]
+ // %cmp_im = llvm.fcmp "oeq" %im_x, %im_e
+ // %cmp = llvm.and %cmp_re, %cmp_im (for EQ)
+ // Detect this by looking for AndOp/OrOp whose operands are both FCmpOps
+ // operating on ExtractValueOps from the block argument.
+ bool isComplexPattern = false;
for (Operation &op : block.getOperations()) {
- if (auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
- auto maybeOp =
- convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate());
- if (!maybeOp)
- return atomicCompareOp.emitError(
- "unsupported comparison predicate in atomic compare");
- compareOp = *maybeOp;
-
- // Identify which operand is the block argument (x) and which is e.
- isXBinopExpr = (icmpOp.getOperand(0) == block.getArgument(0));
- mlir::Value eOperand =
- isXBinopExpr ? icmpOp.getOperand(1) : icmpOp.getOperand(0);
- eVal = materializeValue(eOperand);
- } else if (auto fcmpOp = dyn_cast<LLVM::FCmpOp>(op)) {
- auto maybeOp =
- convertFCmpPredicateToAtomicCompareOp(fcmpOp.getPredicate());
- if (!maybeOp)
+ if (!isa<LLVM::AndOp, LLVM::OrOp>(op))
+ continue;
+
+ // Using : %cmp = llvm.and %cmp_re, %cmp_im
+ auto lhsFcmp = op.getOperand(0).getDefiningOp<LLVM::FCmpOp>();
+ auto rhsFcmp = op.getOperand(1).getDefiningOp<LLVM::FCmpOp>();
+ if (!lhsFcmp || !rhsFcmp)
+ continue;
+
+ // Using : %cmp_re = llvm.fcmp "oeq" %re_x, %re_e
+ // Check presence of x (block argument) and get e.
+ mlir::Value lhsAgg0 = traceToAggregate(lhsFcmp.getOperand(0));
+ mlir::Value lhsAgg1 = traceToAggregate(lhsFcmp.getOperand(1));
+ bool lhsXIsOp0 = (lhsAgg0 == block.getArgument(0));
+ bool lhsXIsOp1 = (lhsAgg1 == block.getArgument(0));
+ if (!lhsXIsOp0 && !lhsXIsOp1)
+ continue;
+ mlir::Value eAggregate = lhsXIsOp0 ? lhsAgg1 : lhsAgg0;
+ if (!eAggregate)
+ continue;
+
+ if (isa<LLVM::AndOp>(op))
+ compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+ else
+ // OrOp corresponds to NE, which is not a valid atomic compare op.
+ return atomicCompareOp.emitError(
+ "unsupported comparison predicate (NE) for complex atomic compare");
+
+ isXBinopExpr = lhsXIsOp0;
+ eVal = materializeValue(eAggregate);
+ isComplexPattern = true;
+ break;
+ }
+
+ if (isComplexPattern) {
+ // dVal from SelectOp or YieldOp.
+ for (Operation &op : block.getOperations()) {
+ if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
+ dVal = materializeValue(selectOp.getTrueValue());
+ break;
+ }
+ }
+ if (!dVal) {
+ auto yieldOp = cast<omp::YieldOp>(block.getTerminator());
+ if (yieldOp.getResults().empty())
return atomicCompareOp.emitError(
- "unsupported comparison predicate in atomic compare");
- compareOp = *maybeOp;
-
- isXBinopExpr = (fcmpOp.getOperand(0) == block.getArgument(0));
- mlir::Value eOperand =
- isXBinopExpr ? fcmpOp.getOperand(1) : fcmpOp.getOperand(0);
- eVal = materializeValue(eOperand);
- } else if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
- if (!dVal)
+ "failed to extract desired value (d) from atomic compare region");
+ dVal = materializeValue(yieldOp.getResults()[0]);
+ }
+
+ const llvm::DataLayout &DL =
+ builder.GetInsertBlock()->getModule()->getDataLayout();
+ unsigned totalBits =
+ DL.getTypeStoreSizeInBits(llvmXElementType).getFixedValue();
+ llvm::IntegerType *intTy =
+ llvm::IntegerType::get(builder.getContext(), totalBits);
+
+ llvm::Value *eAlloca =
+ builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.e");
+ llvm::Value *dAlloca =
+ builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.d");
+
+ builder.CreateStore(eVal, eAlloca);
+ llvm::Value *eInt = builder.CreateLoad(intTy, eAlloca, "cmplx.e.int");
+ builder.CreateStore(dVal, dAlloca);
+ llvm::Value *dInt = builder.CreateLoad(intTy, dAlloca, "cmplx.d.int");
+
+ llvm::AtomicOrdering failOrdering =
+ llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(atomicOrdering);
+ builder.CreateAtomicCmpXchg(llvmX, eInt, dInt, llvm::MaybeAlign(),
+ atomicOrdering, failOrdering);
+
+ return success();
+ } else {
+
+ for (Operation &op : block.getOperations()) {
+ if (auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
+ auto maybeOp =
+ convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate());
+ if (!maybeOp)
+ return atomicCompareOp.emitError(
+ "unsupported comparison predicate in atomic compare");
+ compareOp = *maybeOp;
+
+ // Identify which operand is the block argument (x) and which is e.
+ isXBinopExpr = (icmpOp.getOperand(0) == block.getArgument(0));
+ mlir::Value eOperand =
+ isXBinopExpr ? icmpOp.getOperand(1) : icmpOp.getOperand(0);
+ eVal = materializeValue(eOperand);
+ } else if (auto fcmpOp = dyn_cast<LLVM::FCmpOp>(op)) {
+ auto maybeOp =
+ convertFCmpPredicateToAtomicCompareOp(fcmpOp.getPredicate());
+ if (!maybeOp)
+ return atomicCompareOp.emitError(
+ "unsupported comparison predicate in atomic compare");
+ compareOp = *maybeOp;
+
+ isXBinopExpr = (fcmpOp.getOperand(0) == block.getArgument(0));
+ mlir::Value eOperand =
+ isXBinopExpr ? fcmpOp.getOperand(1) : fcmpOp.getOperand(0);
+ eVal = materializeValue(eOperand);
+ } else if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
+ if (!dVal)
+ dVal = materializeValue(selectOp.getTrueValue());
+ }
+ }
+ }
+
+ // For non-complex patterns, also extract dVal from SelectOp.
+ if (!dVal) {
+ for (Operation &op : block.getOperations()) {
+ if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
dVal = materializeValue(selectOp.getTrueValue());
+ break;
+ }
}
}
>From 55e6b3f7853bb26a5ba59b9b7f9d88e7a2a0948b Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe31.hpc.amslabs.hpecorp.net>
Date: Thu, 9 Apr 2026 04:16:11 -0500
Subject: [PATCH 7/8] Adding a check for complex variable as part of "!$omp
atomic compare" in mlir/test/Target/LLVMIR/openmp-llvm.mlir.
---
mlir/test/Target/LLVMIR/openmp-llvm.mlir | 24 ++++++++++++++++++++++--
1 file changed, 22 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 675f060a85ba9..b14f40a690d58 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2453,10 +2453,11 @@ llvm.func @omp_atomic_capture_misc(
// -----
// CHECK-LABEL: @omp_atomic_compare
-// CHECK-SAME: (ptr %[[X:.*]], i32 %[[E:.*]], i32 %[[D:.*]], ptr %[[XF:.*]], float %[[EF:.*]], float %[[DF:.*]])
+// CHECK-SAME: (ptr %[[X:.*]], i32 %[[E:.*]], i32 %[[D:.*]], ptr %[[XF:.*]], float %[[EF:.*]], float %[[DF:.*]], ptr %[[XC:.*]], { float, float } %[[EC:.*]], { float, float } %[[DC:.*]])
llvm.func @omp_atomic_compare(
%x : !llvm.ptr, %e : i32, %d : i32,
- %xf : !llvm.ptr, %ef : f32, %df : f32) {
+ %xf : !llvm.ptr, %ef : f32, %df : f32,
+ %xc : !llvm.ptr, %ec : !llvm.struct<(f32, f32)>, %dc : !llvm.struct<(f32, f32)>) {
// Integer equality → cmpxchg
// CHECK: cmpxchg ptr %[[X]], i32 %[[E]], i32 %[[D]] monotonic monotonic
@@ -2478,6 +2479,25 @@ llvm.func @omp_atomic_compare(
omp.yield(%sel1 : f32)
}
+ // Complex equality → bitcasted integer cmpxchg
+ // CHECK: store { float, float } %[[EC]], ptr %{{.*}}
+ // CHECK: %[[EINT:.*]] = load i64, ptr %{{.*}}
+ // CHECK: store { float, float } %[[DC]], ptr %{{.*}}
+ // CHECK: %[[DINT:.*]] = load i64, ptr %{{.*}}
+ // CHECK: cmpxchg ptr %[[XC]], i64 %[[EINT]], i64 %[[DINT]] monotonic monotonic
+ omp.atomic.compare %xc : !llvm.ptr {
+ ^bb0(%xval : !llvm.struct<(f32, f32)>):
+ %re_x = llvm.extractvalue %xval[0] : !llvm.struct<(f32, f32)>
+ %re_e = llvm.extractvalue %ec[0] : !llvm.struct<(f32, f32)>
+ %cmp_re = llvm.fcmp "oeq" %re_x, %re_e : f32
+ %im_x = llvm.extractvalue %xval[1] : !llvm.struct<(f32, f32)>
+ %im_e = llvm.extractvalue %ec[1] : !llvm.struct<(f32, f32)>
+ %cmp_im = llvm.fcmp "oeq" %im_x, %im_e : f32
+ %cmp = llvm.and %cmp_re, %cmp_im : i1
+ %sel = llvm.select %cmp, %dc, %xval : i1, !llvm.struct<(f32, f32)>
+ omp.yield(%sel : !llvm.struct<(f32, f32)>)
+ }
+
// Integer x < e → atomicrmw umax (reversed, unsigned)
// CHECK: atomicrmw umax ptr %[[X]], i32 %[[E]] monotonic
omp.atomic.compare %x : !llvm.ptr {
>From 9748365a4d7f64b3ca642131844dde82529051d1 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe31.hpc.amslabs.hpecorp.net>
Date: Fri, 17 Apr 2026 12:50:42 -0500
Subject: [PATCH 8/8] Incorporating feedbacks: 1) Fortran integers are signed.
2) flang support of volatile. 3) alignment requirements. 4) Error message for
complex types wider than 128 bits, because cmpxchg is not supported for
wider than 128 bits.
---
.../Integration/OpenMP/atomic-compare.f90 | 48 +++++++++++++++----
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 41 ++++++++++++----
mlir/test/Target/LLVMIR/openmp-llvm.mlir | 22 +++++----
3 files changed, 84 insertions(+), 27 deletions(-)
diff --git a/flang/test/Integration/OpenMP/atomic-compare.f90 b/flang/test/Integration/OpenMP/atomic-compare.f90
index c5be037b7533f..1a26a31191efd 100644
--- a/flang/test/Integration/OpenMP/atomic-compare.f90
+++ b/flang/test/Integration/OpenMP/atomic-compare.f90
@@ -70,22 +70,22 @@ subroutine atomic_compare_relaxed(x, e, d)
if (x == e) x = d
end
-! Less-than comparison → atomicrmw umax
+! Less-than comparison → atomicrmw max (signed)
!CHECK-LABEL: define void @atomic_compare_lt_(
!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]])
!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
-!CHECK: atomicrmw umax ptr %[[X]], i32 %[[EVAL]] monotonic
+!CHECK: atomicrmw max ptr %[[X]], i32 %[[EVAL]] monotonic
subroutine atomic_compare_lt(x, e)
integer :: x, e
!$omp atomic compare
if (x < e) x = e
end
-! Less-than with seq_cst → atomicrmw umax seq_cst + flush
+! Less-than with seq_cst → atomicrmw max seq_cst + flush (signed)
!CHECK-LABEL: define void @atomic_compare_lt_seq_cst_(
!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]])
!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
-!CHECK: atomicrmw umax ptr %[[X]], i32 %[[EVAL]] seq_cst
+!CHECK: atomicrmw max ptr %[[X]], i32 %[[EVAL]] seq_cst
!CHECK: call void @__kmpc_flush(
subroutine atomic_compare_lt_seq_cst(x, e)
integer :: x, e
@@ -93,24 +93,56 @@ subroutine atomic_compare_lt_seq_cst(x, e)
if (x < e) x = e
end
-! Less-than with acquire → atomicrmw umax acquire
+! Less-than with acquire → atomicrmw max acquire (signed)
!CHECK-LABEL: define void @atomic_compare_lt_acquire_(
!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]])
!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
-!CHECK: atomicrmw umax ptr %[[X]], i32 %[[EVAL]] acquire
+!CHECK: atomicrmw max ptr %[[X]], i32 %[[EVAL]] acquire
subroutine atomic_compare_lt_acquire(x, e)
integer :: x, e
!$omp atomic compare acquire
if (x < e) x = e
end
-! Greater-than comparison → atomicrmw umin
+! Greater-than comparison → atomicrmw min (signed)
!CHECK-LABEL: define void @atomic_compare_gt_(
!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]])
!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]], align 4
-!CHECK: atomicrmw umin ptr %[[X]], i32 %[[EVAL]] monotonic
+!CHECK: atomicrmw min ptr %[[X]], i32 %[[EVAL]] monotonic
subroutine atomic_compare_gt(x, e)
integer :: x, e
!$omp atomic compare
if (x > e) x = e
end
+
+! Complex(4) equality → type-punned i64 cmpxchg with consistent alignment
+!CHECK-LABEL: define void @atomic_compare_complex4_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]])
+!CHECK: %[[EALLOCA:.*]] = alloca { float, float }, align [[ALIGN:[0-9]+]]
+!CHECK: %[[DALLOCA:.*]] = alloca { float, float }, align [[ALIGN]]
+!CHECK: store { float, float } %{{.*}}, ptr %[[EALLOCA]], align [[ALIGN]]
+!CHECK: %[[EINT:.*]] = load i64, ptr %[[EALLOCA]], align [[ALIGN]]
+!CHECK: store { float, float } %{{.*}}, ptr %[[DALLOCA]], align [[ALIGN]]
+!CHECK: %[[DINT:.*]] = load i64, ptr %[[DALLOCA]], align [[ALIGN]]
+!CHECK: cmpxchg ptr %[[X]], i64 %[[EINT]], i64 %[[DINT]] monotonic monotonic, align [[ALIGN]]
+subroutine atomic_compare_complex4(x, e, d)
+ complex :: x, e, d
+ !$omp atomic compare
+ if (x == e) x = d
+end
+
+! Complex(8) equality → type-punned i128 cmpxchg with consistent alignment
+!CHECK-LABEL: define void @atomic_compare_complex8_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]])
+!CHECK: %[[EALLOCA:.*]] = alloca { double, double }, align [[ALIGN:[0-9]+]]
+!CHECK: %[[DALLOCA:.*]] = alloca { double, double }, align [[ALIGN]]
+!CHECK: store { double, double } %{{.*}}, ptr %[[EALLOCA]], align [[ALIGN]]
+!CHECK: %[[EINT:.*]] = load i128, ptr %[[EALLOCA]], align [[ALIGN]]
+!CHECK: store { double, double } %{{.*}}, ptr %[[DALLOCA]], align [[ALIGN]]
+!CHECK: %[[DINT:.*]] = load i128, ptr %[[DALLOCA]], align [[ALIGN]]
+!CHECK: cmpxchg ptr %[[X]], i128 %[[EINT]], i128 %[[DINT]] monotonic monotonic, align [[ALIGN]]
+subroutine atomic_compare_complex8(x, e, d)
+ complex(8) :: x, e, d
+ !$omp atomic compare
+ if (x == e) x = d
+end
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9c4e333d9839d..6c329ce1c386d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4532,8 +4532,13 @@ convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
"unable to determine element type for atomic compare");
llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
+
+ // Fortran integers are signed, and the OpenMPIRBuilder may use signedness
+ // for GT/LT atomic compare operations. Set IsSigned=true to produce correct
+ // code for all valid inputs.
llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
- false, false};
+ /*IsSigned=*/true,
+ /*IsVolatile=*/false};
llvm::AtomicOrdering atomicOrdering =
convertAtomicOrdering(atomicCompareOp.getMemoryOrder());
@@ -4630,23 +4635,41 @@ convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
builder.GetInsertBlock()->getModule()->getDataLayout();
unsigned totalBits =
DL.getTypeStoreSizeInBits(llvmXElementType).getFixedValue();
+
+ // Reject complex types wider than 128 bits (e.g. COMPLEX(16) would be
+ // i256). Major LLVM backends do not natively support cmpxchg for such
+ // wide integer types, resulting in code having libcall within a lock/mutex.
+ if (totalBits > 128)
+ return atomicCompareOp.emitError(
+ "atomic compare for complex types wider than 128 bits is not "
+ "supported (requires i" +
+ llvm::Twine(totalBits) + " cmpxchg)");
+
llvm::IntegerType *intTy =
llvm::IntegerType::get(builder.getContext(), totalBits);
- llvm::Value *eAlloca =
+ llvm::Align complexAlign = DL.getABITypeAlign(llvmXElementType);
+ llvm::Align intAlign = DL.getABITypeAlign(intTy);
+ llvm::Align maxAlign = std::max(complexAlign, intAlign);
+
+ llvm::AllocaInst *eAlloca =
builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.e");
- llvm::Value *dAlloca =
+ eAlloca->setAlignment(maxAlign);
+ llvm::AllocaInst *dAlloca =
builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.d");
+ dAlloca->setAlignment(maxAlign);
- builder.CreateStore(eVal, eAlloca);
- llvm::Value *eInt = builder.CreateLoad(intTy, eAlloca, "cmplx.e.int");
- builder.CreateStore(dVal, dAlloca);
- llvm::Value *dInt = builder.CreateLoad(intTy, dAlloca, "cmplx.d.int");
+ builder.CreateAlignedStore(eVal, eAlloca, maxAlign);
+ llvm::Value *eInt =
+ builder.CreateAlignedLoad(intTy, eAlloca, maxAlign, "cmplx.e.int");
+ builder.CreateAlignedStore(dVal, dAlloca, maxAlign);
+ llvm::Value *dInt =
+ builder.CreateAlignedLoad(intTy, dAlloca, maxAlign, "cmplx.d.int");
llvm::AtomicOrdering failOrdering =
llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(atomicOrdering);
- builder.CreateAtomicCmpXchg(llvmX, eInt, dInt, llvm::MaybeAlign(),
- atomicOrdering, failOrdering);
+ builder.CreateAtomicCmpXchg(llvmX, eInt, dInt, maxAlign, atomicOrdering,
+ failOrdering);
return success();
} else {
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index b14f40a690d58..5cd1bda0caa7d 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2479,12 +2479,14 @@ llvm.func @omp_atomic_compare(
omp.yield(%sel1 : f32)
}
- // Complex equality → bitcasted integer cmpxchg
- // CHECK: store { float, float } %[[EC]], ptr %{{.*}}
- // CHECK: %[[EINT:.*]] = load i64, ptr %{{.*}}
- // CHECK: store { float, float } %[[DC]], ptr %{{.*}}
- // CHECK: %[[DINT:.*]] = load i64, ptr %{{.*}}
- // CHECK: cmpxchg ptr %[[XC]], i64 %[[EINT]], i64 %[[DINT]] monotonic monotonic
+ // Complex equality → bitcasted integer cmpxchg with consistent alignment
+ // CHECK: %[[EALLOCA:.*]] = alloca { float, float }, align [[ALIGN:[0-9]+]]
+ // CHECK: %[[DALLOCA:.*]] = alloca { float, float }, align [[ALIGN]]
+ // CHECK: store { float, float } %[[EC]], ptr %[[EALLOCA]], align [[ALIGN]]
+ // CHECK: %[[EINT:.*]] = load i64, ptr %[[EALLOCA]], align [[ALIGN]]
+ // CHECK: store { float, float } %[[DC]], ptr %[[DALLOCA]], align [[ALIGN]]
+ // CHECK: %[[DINT:.*]] = load i64, ptr %[[DALLOCA]], align [[ALIGN]]
+ // CHECK: cmpxchg ptr %[[XC]], i64 %[[EINT]], i64 %[[DINT]] monotonic monotonic, align [[ALIGN]]
omp.atomic.compare %xc : !llvm.ptr {
^bb0(%xval : !llvm.struct<(f32, f32)>):
%re_x = llvm.extractvalue %xval[0] : !llvm.struct<(f32, f32)>
@@ -2498,8 +2500,8 @@ llvm.func @omp_atomic_compare(
omp.yield(%sel : !llvm.struct<(f32, f32)>)
}
- // Integer x < e → atomicrmw umax (reversed, unsigned)
- // CHECK: atomicrmw umax ptr %[[X]], i32 %[[E]] monotonic
+ // Integer x < e → atomicrmw max (signed)
+ // CHECK: atomicrmw max ptr %[[X]], i32 %[[E]] monotonic
omp.atomic.compare %x : !llvm.ptr {
^bb0(%xval : i32):
%cmp2 = llvm.icmp "slt" %xval, %e : i32
@@ -2507,8 +2509,8 @@ llvm.func @omp_atomic_compare(
omp.yield(%sel2 : i32)
}
- // Integer x > e → atomicrmw umin (reversed, unsigned)
- // CHECK: atomicrmw umin ptr %[[X]], i32 %[[E]] monotonic
+ // Integer x > e → atomicrmw min (signed)
+ // CHECK: atomicrmw min ptr %[[X]], i32 %[[E]] monotonic
omp.atomic.compare %x : !llvm.ptr {
^bb0(%xval : i32):
%cmp3 = llvm.icmp "sgt" %xval, %e : i32
More information about the Mlir-commits
mailing list