[Mlir-commits] [flang] [mlir] [Flang] [OpenMP] atomic compare (PR #184761)

Tom Eccles llvmlistbot at llvm.org
Thu Apr 23 03:13:41 PDT 2026


================
@@ -4671,6 +4685,293 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
   return success();
 }
 
+/// Helper to extract the OMPAtomicCompareOp from an integer comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::ICmpPredicate::eq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::ICmpPredicate::slt:
+  case LLVM::ICmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::ICmpPredicate::sgt:
+  case LLVM::ICmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
+/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::FCmpPredicate::oeq:
+  case LLVM::FCmpPredicate::ueq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::FCmpPredicate::olt:
+  case LLVM::FCmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::FCmpPredicate::ogt:
+  case LLVM::FCmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
+/// Converts an omp.atomic.compare operation to LLVM IR.
+///
+///      if (x == e) x = d
+/// The region contains a comparison + select pattern:
+///   ^bb0(%xval: T):
+///     %cmp = llvm.icmp/fcmp <pred> %xval, %e : T
+///     %sel = llvm.select %cmp, %d, %xval : i1, T
+///     omp.yield(%sel : T)
+///
+/// From MLIR extract:
+///   1) comparison operator
+///   2) expected value (e)
+///   3) desired value (d)
+/// These are passed to OpenMPIRBuilder::createAtomicCompare which generates
+/// the actual cmpxchg / atomicrmw instruction.
+///
+static LogicalResult
+convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
+                        llvm::IRBuilderBase &builder,
+                        LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  if (failed(checkImplementationStatus(*atomicCompareOp)))
+    return failure();
+
+  Region &region = atomicCompareOp.getRegion();
+  Block &block = region.front();
+
+  // Determine element type from the region block argument
+  llvm::Type *llvmXElementType =
+      moduleTranslation.convertType(block.getArgument(0).getType());
+  if (!llvmXElementType)
+    return atomicCompareOp.emitError(
+        "unable to determine element type for atomic compare");
+
+  llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
+
+  // IsSigned is determined from the comparison predicate in the region.
+  // Signed ICmp predicates (slt/sgt) set this to true; unsigned (ult/ugt)
+  // leave it false. For EQ and float comparisons, signedness is irrelevant.
+  bool isSigned = false;
+  llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
+                                                      isSigned,
+                                                      /*IsVolatile=*/false};
+
+  llvm::AtomicOrdering atomicOrdering =
+      convertAtomicOrdering(atomicCompareOp.getMemoryOrder());
+
+  // Trace back through load operations and generate load instructions
+  auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
+    if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
+      if (loadOp->getParentRegion() == &region) {
+        llvm::Value *loadAddr = moduleTranslation.lookupValue(loadOp.getAddr());
+        llvm::Type *loadType =
+            moduleTranslation.convertType(loadOp.getResult().getType());
+        return builder.CreateLoad(loadType, loadAddr);
+      }
+    }
+    return moduleTranslation.lookupValue(val);
+  };
+
+  // Walk the region to extract comparison predicate, eVal, and dVal.
+  // if (x == eVal) x = dVal
+  llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+  llvm::Value *eVal = nullptr;
+  llvm::Value *dVal = nullptr;
+  bool isXBinopExpr = false;
+
+  auto traceToAggregate = [](mlir::Value v) -> mlir::Value {
+    if (auto extractOp = v.getDefiningOp<LLVM::ExtractValueOp>())
+      return extractOp.getContainer();
+    return nullptr;
+  };
+
+  // Check for a decomposed complex comparison pattern:
+  //   %re_x = llvm.extractvalue %xval[0]
+  //   %re_e = llvm.extractvalue %eStruct[0]
+  //   %cmp_re = llvm.fcmp "oeq" %re_x, %re_e
+  //   %im_x = llvm.extractvalue %xval[1]
+  //   %im_e = llvm.extractvalue %eStruct[1]
+  //   %cmp_im = llvm.fcmp "oeq" %im_x, %im_e
+  //   %cmp = llvm.and %cmp_re, %cmp_im   (for EQ)
+  // Detect this by looking for AndOp/OrOp whose operands are both FCmpOps
+  // operating on ExtractValueOps from the block argument.
+  bool isComplexPattern = false;
+  for (Operation &op : block.getOperations()) {
+    if (!isa<LLVM::AndOp, LLVM::OrOp>(op))
+      continue;
+
+    // Using : %cmp = llvm.and %cmp_re, %cmp_im
+    auto lhsFcmp = op.getOperand(0).getDefiningOp<LLVM::FCmpOp>();
+    auto rhsFcmp = op.getOperand(1).getDefiningOp<LLVM::FCmpOp>();
+    if (!lhsFcmp || !rhsFcmp)
+      continue;
+
+    // Using : %cmp_re = llvm.fcmp "oeq" %re_x, %re_e
+    // Check presence of x (block argument) and get e.
+    mlir::Value lhsAgg0 = traceToAggregate(lhsFcmp.getOperand(0));
+    mlir::Value lhsAgg1 = traceToAggregate(lhsFcmp.getOperand(1));
+    bool lhsXIsOp0 = (lhsAgg0 == block.getArgument(0));
+    bool lhsXIsOp1 = (lhsAgg1 == block.getArgument(0));
+    if (!lhsXIsOp0 && !lhsXIsOp1)
+      continue;
+    mlir::Value eAggregate = lhsXIsOp0 ? lhsAgg1 : lhsAgg0;
+    if (!eAggregate)
+      continue;
+
+    if (isa<LLVM::AndOp>(op))
+      compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+    else
+      // OrOp corresponds to NE, which is not a valid atomic compare op.
+      return atomicCompareOp.emitError(
+          "unsupported comparison predicate (NE) for complex atomic compare");
+
+    isXBinopExpr = lhsXIsOp0;
+    eVal = materializeValue(eAggregate);
+    isComplexPattern = true;
+    break;
+  }
+
+  if (isComplexPattern) {
+    // dVal from SelectOp or YieldOp.
+    for (Operation &op : block.getOperations()) {
+      if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
+        dVal = materializeValue(selectOp.getTrueValue());
+        break;
+      }
+    }
+    if (!dVal) {
+      auto yieldOp = cast<omp::YieldOp>(block.getTerminator());
+      if (yieldOp.getResults().empty())
+        return atomicCompareOp.emitError(
+            "failed to extract desired value (d) from atomic compare region");
+      dVal = materializeValue(yieldOp.getResults()[0]);
+    }
+
+    const llvm::DataLayout &DL =
+        builder.GetInsertBlock()->getModule()->getDataLayout();
+    unsigned totalBits =
+        DL.getTypeStoreSizeInBits(llvmXElementType).getFixedValue();
+
+    llvm::IntegerType *intTy =
+        llvm::IntegerType::get(builder.getContext(), totalBits);
+
+    llvm::Align complexAlign = DL.getABITypeAlign(llvmXElementType);
+    llvm::Align intAlign = DL.getABITypeAlign(intTy);
+    llvm::Align maxAlign = std::max(complexAlign, intAlign);
+
+    llvm::AllocaInst *eAlloca =
+        builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.e");
+    eAlloca->setAlignment(maxAlign);
+    llvm::AllocaInst *dAlloca =
+        builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.d");
+    dAlloca->setAlignment(maxAlign);
+
+    builder.CreateAlignedStore(eVal, eAlloca, maxAlign);
+    llvm::Value *eInt =
+        builder.CreateAlignedLoad(intTy, eAlloca, maxAlign, "cmplx.e.int");
+    builder.CreateAlignedStore(dVal, dAlloca, maxAlign);
+    llvm::Value *dInt =
+        builder.CreateAlignedLoad(intTy, dAlloca, maxAlign, "cmplx.d.int");
+
+    llvm::AtomicOrdering failOrdering =
+        llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(atomicOrdering);
+    builder.CreateAtomicCmpXchg(llvmX, eInt, dInt, maxAlign, atomicOrdering,
+                                failOrdering);
----------------
tblah wrote:

Without a call to `OpenMPIRBuilder::createAtomicCompare`, we never end up running `OpenMPIRBuilder::checkAndEmitFlushAfterAtomic` and so I think no flush call is generated.

https://github.com/llvm/llvm-project/pull/184761


More information about the Mlir-commits mailing list