[flang-commits] [flang] [mlir] [flang][OpenMP] Support for "atomic compare capture" (PR #202315)

via flang-commits flang-commits at lists.llvm.org
Mon Jun 8 03:30:16 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-openacc

Author: SunilKuravinakop

<details>
<summary>Changes</summary>

Adding support for "!$omp atomic compare capture".

```
subroutine compare_capture_01(var1, num1, num2, num3)
integer :: var1, num1, num2, num3
!$omp atomic compare capture
num3 = var1
if (var1 == num1) var1 = num2
!$omp end atomic
end subroutine
```

This also Fixes [ #<!-- -->202311](https://github.com/llvm/llvm-project/issues/202311)

---

Patch is 28.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/202315.diff


10 Files Affected:

- (modified) flang/lib/Lower/OpenMP/Atomic.cpp (+51-39) 
- (modified) flang/test/Integration/OpenMP/atomic-compare.f90 (+32) 
- (modified) flang/test/Lower/OpenMP/atomic-compare.f90 (+56) 
- (modified) flang/test/Parser/OpenMP/atomic-unparse.f90 (+29) 
- (modified) mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td (+8-1) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+10) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+6) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+177-40) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+48) 
- (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+41) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index 4ce0c8f878c48..c0f97aada637d 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -560,16 +560,45 @@ void Fortran::lower::omp::lowerAtomic(
   int action1 = analysis.op1.what & analysis.Action;
   memOrder = makeValidForAction(memOrder, action0, action1, version);
 
+  // --- Shared capture scaffolding ---
+  mlir::Operation *captureOp = nullptr;
+  fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
+  fir::FirOpBuilder::InsertPoint atomicAt, postAt;
+
+  if (construct.IsCapture()) {
+    assert(action0 != analysis.None && action1 != analysis.None &&
+           "Expexcing two actions");
+    (void)action0;
+    (void)action1;
+    captureOp = mlir::omp::AtomicCaptureOp::create(
+        builder, loc, hint, makeMemOrderAttr(converter, memOrder));
+    // Set the non-atomic insertion point to before the atomic.capture.
+    preAt = getInsertionPointBefore(captureOp);
+
+    mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
+    builder.setInsertionPointToEnd(block);
+    // Set the atomic insertion point to before the terminator inside
+    // atomic.capture.
+    mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
+    atomicAt = getInsertionPointBefore(term);
+    postAt = getInsertionPointAfter(captureOp);
+    hint = nullptr;
+    memOrder = std::nullopt;
+  }
+
   if (auto *cond = get(analysis.cond)) {
     // atomic compare: if (x == e) x = d
     // e : expecteVal
     // d : desiredVal
 
-    // Check for compound clauses (fail, capture) that are not yet
+    // Restore insertion point so pre-processing code (e.g. computing
+    // expectedVal) is emitted before the capture op, not after the terminator.
+    builder.restoreInsertionPoint(preAt);
+
+    // Check for compound clause (fail) that is not yet
     // supported with atomic compare.
     if (llvm::any_of(clauses, [](const omp::Clause &clause) {
-          return clause.id == llvm::omp::Clause::OMPC_fail ||
-                 clause.id == llvm::omp::Clause::OMPC_capture;
+          return clause.id == llvm::omp::Clause::OMPC_fail;
         })) {
       TODO(loc, "Compound clauses of OpenMP ATOMIC COMPARE");
     }
@@ -617,6 +646,17 @@ void Fortran::lower::omp::lowerAtomic(
       expectedVal = builder.createConvert(loc, elemTypeOfX, expectedVal);
     }
 
+    // If this is a compare+capture, generate the read op first.
+    if (construct.IsCapture()) {
+      assert(get(analysis.op0.assign) && (analysis.op0.what & analysis.Read) &&
+             "Expected a read operation for compare capture");
+      mlir::Operation *readOp = genAtomicRead(
+          converter, semaCtx, loc, stmtCtx, atomAddr, atom,
+          *get(analysis.op0.assign), hint, memOrder, preAt, atomicAt, postAt);
+      assert(readOp && "Should have created an atomic read operation");
+      builder.setInsertionPointAfter(readOp);
+    }
+
     mlir::UnitAttr weakAttr = nullptr;
     if (llvm::any_of(clauses, [](const omp::Clause &clause) {
           return clause.id == llvm::omp::Clause::OMPC_weak;
@@ -685,34 +725,9 @@ void Fortran::lower::omp::lowerAtomic(
     // Generate omp.yield
     mlir::omp::YieldOp::create(builder, loc, newVal);
     builder.setInsertionPointAfter(atomicOp);
-
     // END omp atomic compare
   } else {
-    mlir::Operation *captureOp = nullptr;
-    fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
-    fir::FirOpBuilder::InsertPoint atomicAt, postAt;
-
-    if (construct.IsCapture()) {
-      // Capturing operation.
-      assert(action0 != analysis.None && action1 != analysis.None &&
-             "Expexcing two actions");
-      (void)action0;
-      (void)action1;
-      captureOp = mlir::omp::AtomicCaptureOp::create(
-          builder, loc, hint, makeMemOrderAttr(converter, memOrder));
-      // Set the non-atomic insertion point to before the atomic.capture.
-      preAt = getInsertionPointBefore(captureOp);
-
-      mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
-      builder.setInsertionPointToEnd(block);
-      // Set the atomic insertion point to before the terminator inside
-      // atomic.capture.
-      mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
-      atomicAt = getInsertionPointBefore(term);
-      postAt = getInsertionPointAfter(captureOp);
-      hint = nullptr;
-      memOrder = std::nullopt;
-    } else {
+    if (!construct.IsCapture()) {
       // Non-capturing operation.
       assert(action0 != analysis.None && action1 == analysis.None &&
              "Expexcing single action");
@@ -735,16 +750,13 @@ void Fortran::lower::omp::lowerAtomic(
           *get(analysis.op1.assign), hint, memOrder, preAt, atomicAt, postAt);
     }
 
-    if (construct.IsCapture()) {
-      // If this is a capture operation, the first/second ops will be inside
-      // of it. Set the insertion point to past the capture op itself.
-      builder.restoreInsertionPoint(postAt);
-    } else {
-      if (secondOp) {
-        builder.setInsertionPointAfter(secondOp);
-      } else {
-        builder.setInsertionPointAfter(firstOp);
-      }
+    if (!construct.IsCapture()) {
+      builder.setInsertionPointAfter(secondOp ? secondOp : firstOp);
     }
   }
+
+  // Shared capture cleanup.
+  if (construct.IsCapture()) {
+    builder.restoreInsertionPoint(postAt);
+  }
 }
diff --git a/flang/test/Integration/OpenMP/atomic-compare.f90 b/flang/test/Integration/OpenMP/atomic-compare.f90
index 249fb0dd8fa64..650f64b80af12 100644
--- a/flang/test/Integration/OpenMP/atomic-compare.f90
+++ b/flang/test/Integration/OpenMP/atomic-compare.f90
@@ -260,3 +260,35 @@ subroutine atomic_compare_weak(x, e, d)
   if (x == e) x = d
 end 
 
+! Integer equality compare+capture: cmpxchg + store old value
+!CHECK-LABEL: define void @atomic_compare_capture_int_eq_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
+!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
+!CHECK: store i32 %[[OLD]], ptr %[[V]]
+subroutine atomic_compare_capture_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture
+    v = x
+    if (x == e) x = d
+  !$omp end atomic
+end
+
+! Compare+capture with clause order reversed: capture compare
+!CHECK-LABEL: define void @atomic_capture_compare_int_eq_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
+!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
+!CHECK: store i32 %[[OLD]], ptr %[[V]]
+subroutine atomic_capture_compare_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic capture compare
+    v = x
+    if (x == e) x = d
+  !$omp end atomic
+end
+
diff --git a/flang/test/Lower/OpenMP/atomic-compare.f90 b/flang/test/Lower/OpenMP/atomic-compare.f90
index ac70edbed4e60..752a221aa538d 100644
--- a/flang/test/Lower/OpenMP/atomic-compare.f90
+++ b/flang/test/Lower/OpenMP/atomic-compare.f90
@@ -161,3 +161,59 @@ subroutine atomic_compare_int_eq_weak(x, e, d)
   !$omp atomic compare weak
   if (x .eq. e) x = d
 end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_eq(
+! CHECK-SAME:    %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME:    %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME:    %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
+! CHECK-SAME:    %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
+! CHECK:         %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK:         %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK:         %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
+! CHECK:         %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK:         %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK:         omp.atomic.capture memory_order(relaxed) {
+! CHECK:           omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[XVAL:.*]]: i32):
+! CHECK:             %[[CMP:.*]] = arith.cmpi eq, %[[XVAL]], %[[EVAL]] : i32
+! CHECK:             %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK:             %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK:             omp.yield(%[[SEL]] : i32)
+! CHECK:           }
+! CHECK:         }
+subroutine atomic_compare_capture_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture
+    v = x
+    if (x .eq. e) x = d
+  !$omp end atomic
+end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_gt(
+! CHECK-SAME:    %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME:    %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME:    %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
+! CHECK-SAME:    %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
+! CHECK:         %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK:         %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK:         %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
+! CHECK:         %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK:         %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK:         omp.atomic.capture memory_order(relaxed) {
+! CHECK:           omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[XVAL:.*]]: i32):
+! CHECK:             %[[CMP:.*]] = arith.cmpi sgt, %[[XVAL]], %[[EVAL]] : i32
+! CHECK:             %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK:             %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK:             omp.yield(%[[SEL]] : i32)
+! CHECK:           } {{.*}}weak{{.*}}
+! CHECK:         }
+subroutine atomic_compare_capture_int_gt(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture weak
+    v = x
+    if (x > e) x = d
+  !$omp end atomic
+end
diff --git a/flang/test/Parser/OpenMP/atomic-unparse.f90 b/flang/test/Parser/OpenMP/atomic-unparse.f90
index 4f3cf0eac0338..dc0cc1a62f6c2 100644
--- a/flang/test/Parser/OpenMP/atomic-unparse.f90
+++ b/flang/test/Parser/OpenMP/atomic-unparse.f90
@@ -192,6 +192,26 @@ program main
       i = j
    end if
 
+!COMPARE CAPTURE
+!$omp atomic compare capture
+   k = i
+   if (i .eq. j) then
+      i = k
+   end if
+!$omp end atomic
+!$omp atomic capture compare
+   k = i
+   if (i .eq. j) then
+      i = k
+   end if
+!$omp end atomic
+!$omp atomic capture compare weak
+   k = i
+   if (i < j) then
+      i = k
+   end if
+!$omp end atomic
+
 !ATOMIC
 !$omp atomic
    i = j
@@ -296,6 +316,15 @@ end program main
 !CHECK: !$OMP ATOMIC WEAK COMPARE
 !CHECK: !$OMP ATOMIC COMPARE SEQ_CST WEAK
 
+!COMPARE CAPTURE
+
+!CHECK: !$OMP ATOMIC COMPARE CAPTURE
+!CHECK: !$OMP END ATOMIC
+!CHECK: !$OMP ATOMIC CAPTURE COMPARE
+!CHECK: !$OMP END ATOMIC
+!CHECK: !$OMP ATOMIC CAPTURE COMPARE WEAK
+!CHECK: !$OMP END ATOMIC
+
 !ATOMIC
 !CHECK: !$OMP ATOMIC
 !CHECK: !$OMP ATOMIC SEQ_CST
diff --git a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
index abb21705b3c1c..8c9015f05bb72 100644
--- a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
@@ -289,10 +289,12 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
         auto secondReadStmt = dyn_cast<AtomicReadOpInterface>(secondOp);
         auto secondUpdateStmt = dyn_cast<AtomicUpdateOpInterface>(secondOp);
         auto secondWriteStmt = dyn_cast<AtomicWriteOpInterface>(secondOp);
+        auto secondCompareStmt = dyn_cast<AtomicCompareOpInterface>(secondOp);
 
         if (!((firstUpdateStmt && secondReadStmt) ||
               (firstReadStmt && secondUpdateStmt) ||
-              (firstReadStmt && secondWriteStmt)))
+              (firstReadStmt && secondWriteStmt) ||
+              (firstReadStmt && secondCompareStmt)))
           return ops.front().emitError()
                 << "invalid sequence of operations in the capture region";
         if (firstUpdateStmt && secondReadStmt &&
@@ -310,6 +312,11 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
           return firstReadStmt.emitError()
                 << "captured variable in atomic.read must be updated in "
                     "second operation";
+        if (firstReadStmt && secondCompareStmt &&
+            firstReadStmt.getX() != secondCompareStmt.getX())
+          return firstReadStmt.emitError()
+                << "captured variable in atomic.read must be updated in "
+                    "second operation";
 
         return mlir::success();
       }]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 0962b330e2f23..1241abc10298f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1928,6 +1928,12 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
         omp.atomic.write ...
         omp.terminator
       }
+
+      omp.atomic.capture {
+        omp.atomic.read ...
+        omp.atomic.compare ...
+        omp.terminator
+      }
     ```
   }] # clausesDescription;
 
@@ -1946,6 +1952,10 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
     /// Returns the `atomic.update` operation inside the region, if any.
     /// Otherwise, it returns nullptr.
     AtomicUpdateOp getAtomicUpdateOp();
+
+    /// Returns the `atomic.compare` operation inside the region, if any.
+    /// Otherwise, it returns nullptr.
+    AtomicCompareOp getAtomicCompareOp();
   }] # clausesExtraClassDeclaration;
 
   let hasRegionVerifier = 1;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index db5fd8f2e3230..0eafd0a267b97 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4615,6 +4615,12 @@ AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
   return dyn_cast<AtomicUpdateOp>(getSecondOp());
 }
 
+AtomicCompareOp AtomicCaptureOp::getAtomicCompareOp() {
+  if (auto op = dyn_cast<AtomicCompareOp>(getFirstOp()))
+    return op;
+  return dyn_cast<AtomicCompareOp>(getSecondOp());
+}
+
 LogicalResult AtomicCaptureOp::verify() {
   return verifySynchronizationHint(*this, getHint());
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 6f93ad231cfac..c5a07a7dc6cb2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4795,6 +4795,43 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
   return success();
 }
 
+/// Helper to extract the OMPAtomicCompareOp from an integer comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::ICmpPredicate::eq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::ICmpPredicate::slt:
+  case LLVM::ICmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::ICmpPredicate::sgt:
+  case LLVM::ICmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
+/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::FCmpPredicate::oeq:
+  case LLVM::FCmpPredicate::ueq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::FCmpPredicate::olt:
+  case LLVM::FCmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::FCmpPredicate::ogt:
+  case LLVM::FCmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
 static LogicalResult
 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
                         llvm::IRBuilderBase &builder,
@@ -4803,13 +4840,150 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
   if (failed(checkImplementationStatus(*atomicCaptureOp)))
     return failure();
 
+  omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
+  omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
+  omp::AtomicCompareOp atomicCompareOp = atomicCaptureOp.getAtomicCompareOp();
+
+  // If the capture contains an atomic.compare, delegate to
+  // createAtomicCompare with the capture variable (V) set.
+  if (atomicCompareOp) {
+    omp::AtomicReadOp atomicReadOp = atomicCaptureOp.getAtomicReadOp();
+    assert(atomicReadOp && "expected atomic.read in capture+compare");
+
+    Region &region = atomicCompareOp.getRegion();
+    Block &block = region.front();
+
+    llvm::Type *llvmXElementType =
+        moduleTranslation.convertType(block.getArgument(0).getType());
+    llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
+    llvm::Value *llvmV = moduleTranslation.lookupValue(atomicReadOp.getV());
+
+    bool isSigned = false;
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {
+        llvmX, llvmXElementType, isSigned, /*IsVolatile=*/false};
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {
+        llvmV, llvmXElementType, /*isSigned=*/false, /*IsVolatile=*/false};
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicR = {nullptr, nullptr, false,
+                                                        false};
+
+    llvm::AtomicOrdering atomicOrdering =
+        convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
+
+    // Pre-translate non-pattern operations inside the compare region.
+    auto isAtomicComparePatternOp = [](Operation &op) {
+      return llvm::isa<LLVM::ICmpOp, LLVM::FCmpOp, LLVM::SelectOp, LLVM::AndOp,
+                       LLVM::OrOp>(op);
+    };
+    for (Operation &op : block.without_terminator()) {
+      if (isAtomicComparePatternOp(op))
+        continue;
+      bool allOperandsMapped =
+          llvm::all_of(op.getOperands(), [&](mlir::Value v) {
+            return moduleTranslation.lookupValue(v) != nullptr;
+          });
+      if (!allOperandsMapped)
+        continue;
+      if (failed(moduleTranslation.convertOperation(op, builder)))
+        return atomicCompareOp.emitError(
+            "failed to translate operation inside atomic compare region");
+    }
+
+    auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
+      if (llvm::Value *existing = moduleTranslation.lookupValue(val))
+        return existing;
+      if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
+        if (loadOp->getParentRegion() == &region) {
+          llvm::Value *loadAddr =
+              moduleTranslation.lookupValue(loadOp.getAddr());
+          if (!loadAddr)
+            return nullptr;
+          llvm::Type *loadType =
+              moduleTranslation.convertType(loadOp.getResult().getType());
+          return builder.CreateLoad(loadType, loadAddr);
+        }
+      }
+      return nullptr;
+    };
+
+    // Extract comparison predicate, eVal, and dVal from the region.
+    llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+    llvm::Value *eVal = nullptr;
+    llvm::Value *dVal = nullptr;
+    bool isXBinopExpr = false;
+
+    for (Operation &op : block.getOperations()) {
+      if (auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
+        auto maybeOp =
+            convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate());
+        if (!maybeOp)
+     ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/202315


More information about the flang-commits mailing list