[Mlir-commits] [flang] [mlir] [Flang] [OpenMP] atomic compare (PR #184761)
Tom Eccles
llvmlistbot at llvm.org
Thu Apr 23 03:13:41 PDT 2026
================
@@ -4671,6 +4685,293 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
return success();
}
+/// Helper to extract the OMPAtomicCompareOp from an integer comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) {
+ switch (predicate) {
+ case LLVM::ICmpPredicate::eq:
+ return llvm::omp::OMPAtomicCompareOp::EQ;
+ case LLVM::ICmpPredicate::slt:
+ case LLVM::ICmpPredicate::ult:
+ return llvm::omp::OMPAtomicCompareOp::MIN;
+ case LLVM::ICmpPredicate::sgt:
+ case LLVM::ICmpPredicate::ugt:
+ return llvm::omp::OMPAtomicCompareOp::MAX;
+ default:
+ return std::nullopt;
+ }
+}
+
+/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
+ switch (predicate) {
+ case LLVM::FCmpPredicate::oeq:
+ case LLVM::FCmpPredicate::ueq:
+ return llvm::omp::OMPAtomicCompareOp::EQ;
+ case LLVM::FCmpPredicate::olt:
+ case LLVM::FCmpPredicate::ult:
+ return llvm::omp::OMPAtomicCompareOp::MIN;
+ case LLVM::FCmpPredicate::ogt:
+ case LLVM::FCmpPredicate::ugt:
+ return llvm::omp::OMPAtomicCompareOp::MAX;
+ default:
+ return std::nullopt;
+ }
+}
+
+/// Converts an omp.atomic.compare operation to LLVM IR.
+///
+/// if (x == e) x = d
+/// The region contains a comparison + select pattern:
+/// ^bb0(%xval: T):
+/// %cmp = llvm.icmp/fcmp <pred> %xval, %e : T
+/// %sel = llvm.select %cmp, %d, %xval : i1, T
+/// omp.yield(%sel : T)
+///
+/// From MLIR extract:
+/// 1) comparison operator
+/// 2) expected value (e)
+/// 3) desired value (d)
+/// These are passed to OpenMPIRBuilder::createAtomicCompare which generates
+/// the actual cmpxchg / atomicrmw instruction.
+///
+static LogicalResult
+convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
+ llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ if (failed(checkImplementationStatus(*atomicCompareOp)))
+ return failure();
+
+ Region ®ion = atomicCompareOp.getRegion();
+ Block &block = region.front();
+
+ // Determine element type from the region block argument
+ llvm::Type *llvmXElementType =
+ moduleTranslation.convertType(block.getArgument(0).getType());
+ if (!llvmXElementType)
+ return atomicCompareOp.emitError(
+ "unable to determine element type for atomic compare");
+
+ llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
+
+ // IsSigned is determined from the comparison predicate in the region.
+ // Signed ICmp predicates (slt/sgt) set this to true; unsigned (ult/ugt)
+ // leave it false. For EQ and float comparisons, signedness is irrelevant.
+ bool isSigned = false;
+ llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
+ isSigned,
+ /*IsVolatile=*/false};
+
+ llvm::AtomicOrdering atomicOrdering =
+ convertAtomicOrdering(atomicCompareOp.getMemoryOrder());
+
+ // Trace back through load operations and generate load instructions
+ auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
+ if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
+ if (loadOp->getParentRegion() == ®ion) {
+ llvm::Value *loadAddr = moduleTranslation.lookupValue(loadOp.getAddr());
+ llvm::Type *loadType =
+ moduleTranslation.convertType(loadOp.getResult().getType());
+ return builder.CreateLoad(loadType, loadAddr);
+ }
+ }
+ return moduleTranslation.lookupValue(val);
+ };
+
+ // Walk the region to extract comparison predicate, eVal, and dVal.
+ // if (x == eVal) x = dVal
+ llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+ llvm::Value *eVal = nullptr;
+ llvm::Value *dVal = nullptr;
+ bool isXBinopExpr = false;
+
+ auto traceToAggregate = [](mlir::Value v) -> mlir::Value {
+ if (auto extractOp = v.getDefiningOp<LLVM::ExtractValueOp>())
+ return extractOp.getContainer();
+ return nullptr;
+ };
+
+ // Check for a decomposed complex comparison pattern:
+ // %re_x = llvm.extractvalue %xval[0]
+ // %re_e = llvm.extractvalue %eStruct[0]
+ // %cmp_re = llvm.fcmp "oeq" %re_x, %re_e
+ // %im_x = llvm.extractvalue %xval[1]
+ // %im_e = llvm.extractvalue %eStruct[1]
+ // %cmp_im = llvm.fcmp "oeq" %im_x, %im_e
+ // %cmp = llvm.and %cmp_re, %cmp_im (for EQ)
+ // Detect this by looking for AndOp/OrOp whose operands are both FCmpOps
+ // operating on ExtractValueOps from the block argument.
+ bool isComplexPattern = false;
+ for (Operation &op : block.getOperations()) {
+ if (!isa<LLVM::AndOp, LLVM::OrOp>(op))
+ continue;
+
+ // Using : %cmp = llvm.and %cmp_re, %cmp_im
+ auto lhsFcmp = op.getOperand(0).getDefiningOp<LLVM::FCmpOp>();
+ auto rhsFcmp = op.getOperand(1).getDefiningOp<LLVM::FCmpOp>();
+ if (!lhsFcmp || !rhsFcmp)
+ continue;
+
+ // Using : %cmp_re = llvm.fcmp "oeq" %re_x, %re_e
+ // Check presence of x (block argument) and get e.
+ mlir::Value lhsAgg0 = traceToAggregate(lhsFcmp.getOperand(0));
+ mlir::Value lhsAgg1 = traceToAggregate(lhsFcmp.getOperand(1));
+ bool lhsXIsOp0 = (lhsAgg0 == block.getArgument(0));
+ bool lhsXIsOp1 = (lhsAgg1 == block.getArgument(0));
+ if (!lhsXIsOp0 && !lhsXIsOp1)
+ continue;
+ mlir::Value eAggregate = lhsXIsOp0 ? lhsAgg1 : lhsAgg0;
+ if (!eAggregate)
+ continue;
+
+ if (isa<LLVM::AndOp>(op))
+ compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+ else
+ // OrOp corresponds to NE, which is not a valid atomic compare op.
+ return atomicCompareOp.emitError(
+ "unsupported comparison predicate (NE) for complex atomic compare");
+
+ isXBinopExpr = lhsXIsOp0;
+ eVal = materializeValue(eAggregate);
+ isComplexPattern = true;
+ break;
+ }
+
+ if (isComplexPattern) {
+ // dVal from SelectOp or YieldOp.
+ for (Operation &op : block.getOperations()) {
+ if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
+ dVal = materializeValue(selectOp.getTrueValue());
+ break;
+ }
+ }
+ if (!dVal) {
+ auto yieldOp = cast<omp::YieldOp>(block.getTerminator());
+ if (yieldOp.getResults().empty())
+ return atomicCompareOp.emitError(
+ "failed to extract desired value (d) from atomic compare region");
+ dVal = materializeValue(yieldOp.getResults()[0]);
+ }
+
+ const llvm::DataLayout &DL =
+ builder.GetInsertBlock()->getModule()->getDataLayout();
+ unsigned totalBits =
+ DL.getTypeStoreSizeInBits(llvmXElementType).getFixedValue();
+
+ llvm::IntegerType *intTy =
+ llvm::IntegerType::get(builder.getContext(), totalBits);
+
+ llvm::Align complexAlign = DL.getABITypeAlign(llvmXElementType);
+ llvm::Align intAlign = DL.getABITypeAlign(intTy);
+ llvm::Align maxAlign = std::max(complexAlign, intAlign);
+
+ llvm::AllocaInst *eAlloca =
+ builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.e");
+ eAlloca->setAlignment(maxAlign);
+ llvm::AllocaInst *dAlloca =
+ builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.d");
+ dAlloca->setAlignment(maxAlign);
+
+ builder.CreateAlignedStore(eVal, eAlloca, maxAlign);
+ llvm::Value *eInt =
+ builder.CreateAlignedLoad(intTy, eAlloca, maxAlign, "cmplx.e.int");
+ builder.CreateAlignedStore(dVal, dAlloca, maxAlign);
+ llvm::Value *dInt =
+ builder.CreateAlignedLoad(intTy, dAlloca, maxAlign, "cmplx.d.int");
+
+ llvm::AtomicOrdering failOrdering =
+ llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(atomicOrdering);
+ builder.CreateAtomicCmpXchg(llvmX, eInt, dInt, maxAlign, atomicOrdering,
+ failOrdering);
----------------
tblah wrote:
Without a call to `OpenMPIRBuilder::createAtomicCompare`, we never end up running `OpenMPIRBuilder::checkAndEmitFlushAfterAtomic` and so I think no flush call is generated.
https://github.com/llvm/llvm-project/pull/184761
More information about the Mlir-commits
mailing list