[flang-commits] [flang] [llvm] [mlir] [OpenMPIRBuilder] Emit __atomic_load and __atomic_compare_exchange libcalls for complex types in atomic update (PR #92364)

via flang-commits flang-commits at lists.llvm.org
Thu Sep 19 05:24:55 PDT 2024


https://github.com/NimishMishra updated https://github.com/llvm/llvm-project/pull/92364

>From d6fa629104a116d78ae03fa3249c5c61b3b514a0 Mon Sep 17 00:00:00 2001
From: Nimish Mishra <neelam.nimish at gmail.com>
Date: Thu, 19 Sep 2024 17:53:51 +0530
Subject: [PATCH] [flang][mlir][llvm] Emit __atomic_load and
 __atomic_compare_exchange libcalls for complex types in atomic update

---
 flang/lib/Lower/DirectivesCommon.h            |   3 +-
 .../OpenMP/atomic-update-complex.f90          |  42 ++++++
 .../test/Lower/OpenMP/Todo/atomic-complex.f90 |   8 --
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |   9 ++
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 122 ++++++++++++++++++
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  31 ++---
 6 files changed, 187 insertions(+), 28 deletions(-)
 create mode 100644 flang/test/Integration/OpenMP/atomic-update-complex.f90
 delete mode 100644 flang/test/Lower/OpenMP/Todo/atomic-complex.f90

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index d2060e77ce5305..c2bc96e78c6057 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -157,8 +157,7 @@ static void processOmpAtomicTODO(mlir::Type elementType,
     // Based on assertion for supported element types in OMPIRBuilder.cpp
     // createAtomicRead
     mlir::Type unwrappedEleTy = fir::unwrapRefType(elementType);
-    bool supportedAtomicType =
-        (!fir::isa_complex(unwrappedEleTy) && fir::isa_trivial(unwrappedEleTy));
+    bool supportedAtomicType = fir::isa_trivial(unwrappedEleTy);
     if (!supportedAtomicType)
       TODO(loc, "Unsupported atomic type");
   }
diff --git a/flang/test/Integration/OpenMP/atomic-update-complex.f90 b/flang/test/Integration/OpenMP/atomic-update-complex.f90
new file mode 100644
index 00000000000000..6e0b419d11f952
--- /dev/null
+++ b/flang/test/Integration/OpenMP/atomic-update-complex.f90
@@ -0,0 +1,42 @@
+!===----------------------------------------------------------------------===!
+! This directory can be used to add Integration tests involving multiple
+! stages of the compiler (for eg. from Fortran to LLVM IR). It should not
+! contain executable tests. We should only add tests here sparingly and only
+! if there is no other way to test. Repeat this message in each test that is
+! added to this directory and sub-directories.
+!===----------------------------------------------------------------------===!
+
+!RUN: %flang_fc1 -emit-llvm -fopenmp %s -o - | FileCheck %s
+
+!CHECK: define void @_QQmain() {
+!CHECK: %[[X_NEW_VAL:.*]] = alloca { float, float }, align 8
+!CHECK: {{.*}} = alloca { float, float }, i64 1, align 8
+!CHECK: %[[ORIG_VAL:.*]] = alloca { float, float }, i64 1, align 8
+!CHECK: store { float, float } { float 2.000000e+00, float 2.000000e+00 }, ptr %[[ORIG_VAL]], align 4
+!CHECK: br label %entry
+
+!CHECK: entry:
+!CHECK: %[[ATOMIC_TEMP_LOAD:.*]] = alloca { float, float }, align 16
+!CHECK: call void @__atomic_load(i64 8, ptr %[[ORIG_VAL]], ptr %[[ATOMIC_TEMP_LOAD]], i32 0)
+!CHECK: %[[PHI_NODE_ENTRY_1:.*]] = load { float, float }, ptr %[[ATOMIC_TEMP_LOAD]], align 16
+!CHECK: br label %.atomic.cont
+
+!CHECK: .atomic.cont
+!CHECK: %[[VAL_4:.*]] = phi { float, float } [ %[[PHI_NODE_ENTRY_1]], %entry ], [ %{{.*}}, %.atomic.cont ]
+!CHECK: %[[VAL_5:.*]] = extractvalue { float, float } %[[VAL_4]], 0
+!CHECK: %[[VAL_6:.*]] = extractvalue { float, float } %[[VAL_4]], 1
+!CHECK: %[[VAL_7:.*]] = fadd contract float %[[VAL_5]], 1.000000e+00
+!CHECK: %[[VAL_8:.*]] = fadd contract float %[[VAL_6]], 1.000000e+00
+!CHECK: %[[VAL_9:.*]] = insertvalue { float, float } undef, float %[[VAL_7]], 0
+!CHECK: %[[VAL_10:.*]] = insertvalue { float, float } %[[VAL_9]], float %[[VAL_8]], 1
+!CHECK: store { float, float } %[[VAL_10]], ptr %[[X_NEW_VAL]], align 4
+!CHECK: %[[VAL_11:.*]] = call i1 @__atomic_compare_exchange(i64 8, ptr %[[ORIG_VAL]], ptr %[[ATOMIC_TEMP_LOAD]], ptr %[[X_NEW_VAL]], i32 2, i32 2)
+!CHECK: %[[VAL_12:.*]] = load { float, float }, ptr %[[ATOMIC_TEMP_LOAD]], align 4
+!CHECK: br i1 %[[VAL_11]], label %.atomic.exit, label %.atomic.cont
+program main
+      complex*8 ia, ib
+      ia = (2, 2)
+      !$omp atomic update
+        ia = ia + (1, 1)
+      !$omp end atomic
+end program
diff --git a/flang/test/Lower/OpenMP/Todo/atomic-complex.f90 b/flang/test/Lower/OpenMP/Todo/atomic-complex.f90
deleted file mode 100644
index 6d6e4399ee192e..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/atomic-complex.f90
+++ /dev/null
@@ -1,8 +0,0 @@
-! RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Unsupported atomic type
-subroutine complex_atomic
-  complex :: l, r
-  !$omp atomic read
-    l = r
-end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4be0159fb1dd9f..7e9fca9feca37b 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3039,6 +3039,15 @@ class OpenMPIRBuilder {
                    AtomicUpdateCallbackTy &UpdateOp, bool VolatileX,
                    bool IsXBinopExpr);
 
+  std::pair<llvm::LoadInst *, llvm::AllocaInst *>
+  EmitAtomicLoadLibcall(Value *X, Type *XElemTy, llvm::AtomicOrdering AO,
+                        uint64_t AtomicSizeInBits);
+
+  std::pair<llvm::Value *, llvm::Value *> EmitAtomicCompareExchangeLibcall(
+      Value *X, Type *XElemTy, uint64_t AtomicSizeInBits,
+      llvm::Value *ExpectedVal, llvm::Value *DesiredVal,
+      llvm::AtomicOrdering Success, llvm::AtomicOrdering Failure);
+
   /// Emit the binary op. described by \p RMWOp, using \p Src1 and \p Src2 .
   ///
   /// \Return The instruction
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 515b74cbb75883..025ce179464759 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7943,6 +7943,83 @@ Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
   llvm_unreachable("Unsupported atomic update operation");
 }
 
+std::pair<llvm::LoadInst *, llvm::AllocaInst *>
+OpenMPIRBuilder::EmitAtomicLoadLibcall(Value *X, Type *XElemTy,
+                                       llvm::AtomicOrdering AO,
+                                       uint64_t AtomicSizeInBits) {
+  LLVMContext &Ctx = Builder.getContext();
+  Type *SizedIntTy = Type::getIntNTy(Ctx, AtomicSizeInBits * 8);
+  Type *ResultTy;
+  SmallVector<Value *, 6> Args;
+  AttributeList Attr;
+  Module *M = Builder.GetInsertBlock()->getModule();
+  const DataLayout &DL = M->getDataLayout();
+  Args.push_back(ConstantInt::get(DL.getIntPtrType(Ctx), AtomicSizeInBits / 8));
+  Value *PtrVal = X;
+  PtrVal = Builder.CreateAddrSpaceCast(PtrVal, PointerType::getUnqual(Ctx));
+  Args.push_back(PtrVal);
+  llvm::AllocaInst *allocaInst = Builder.CreateAlloca(XElemTy);
+  allocaInst->setName(X->getName() + "atomic.temp.load");
+  const Align AllocaAlignment = DL.getPrefTypeAlign(SizedIntTy);
+  allocaInst->setAlignment(AllocaAlignment);
+  Args.push_back(allocaInst);
+  Constant *OrderingVal =
+      ConstantInt::get(Type::getInt32Ty(Ctx), (int)toCABI(AO));
+  Args.push_back(OrderingVal);
+  ResultTy = Type::getVoidTy(Ctx);
+  SmallVector<Type *, 6> ArgTys;
+  for (Value *Arg : Args)
+    ArgTys.push_back(Arg->getType());
+  FunctionType *FnType = FunctionType::get(ResultTy, ArgTys, false);
+  FunctionCallee LibcallFn =
+      M->getOrInsertFunction("__atomic_load", FnType, Attr);
+  CallInst *Call = Builder.CreateCall(LibcallFn, Args);
+  Call->setAttributes(Attr);
+  return std::make_pair(
+      Builder.CreateAlignedLoad(XElemTy, allocaInst, AllocaAlignment),
+      allocaInst);
+}
+
+std::pair<llvm::Value *, llvm::Value *>
+OpenMPIRBuilder::EmitAtomicCompareExchangeLibcall(
+    Value *X, Type *XElemTy, uint64_t AtomicSizeInBits,
+    llvm::Value *ExpectedVal, llvm::Value *DesiredVal,
+    llvm::AtomicOrdering Success, llvm::AtomicOrdering Failure) {
+  LLVMContext &Ctx = Builder.getContext();
+  uint64_t IntBits = 32;
+  uint16_t SizeTBits = 64;
+  uint16_t BitsPerByte = 8;
+  llvm::Value *AtomicSizeValue = llvm::ConstantInt::get(
+      llvm::IntegerType::get(Ctx, SizeTBits), AtomicSizeInBits / BitsPerByte);
+  llvm::Value *Args[6] = {
+      AtomicSizeValue,
+      X,
+      ExpectedVal,
+      DesiredVal,
+      llvm::Constant::getIntegerValue(
+          llvm::IntegerType::get(Ctx, IntBits),
+          llvm::APInt(IntBits, (uint64_t)Success, /*signed=*/true)),
+      llvm::Constant::getIntegerValue(
+          llvm::IntegerType::get(Ctx, IntBits),
+          llvm::APInt(IntBits, (uint64_t)Failure, /*signed=*/true)),
+  };
+  SmallVector<Type *, 6> ArgTys;
+  Type *ResultType = llvm::IntegerType::getInt1Ty(Ctx);
+  llvm::AttrBuilder Attr(Ctx);
+  Attr.addAttribute(llvm::Attribute::NoUnwind);
+  Attr.addAttribute(llvm::Attribute::WillReturn);
+  llvm::AttributeList fnAttrs =
+      llvm::AttributeList::get(Ctx, llvm::AttributeList::FunctionIndex, Attr);
+  Module *M = Builder.GetInsertBlock()->getModule();
+  for (Value *Arg : Args)
+    ArgTys.push_back(Arg->getType());
+  FunctionType *FnType = FunctionType::get(ResultType, ArgTys, false);
+  FunctionCallee LibcallFn =
+      M->getOrInsertFunction("__atomic_compare_exchange", FnType, fnAttrs);
+  CallInst *Call = Builder.CreateCall(LibcallFn, Args);
+  return std::make_pair(ExpectedVal, Call);
+}
+
 std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
     InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
     AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
@@ -7977,6 +8054,51 @@ std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
       Res.second = Res.first;
     else
       Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
+  } else if (RMWOp == llvm::AtomicRMWInst::BinOp::BAD_BINOP &&
+             XElemTy->isStructTy()) {
+    LoadInst *OldVal =
+        Builder.CreateLoad(XElemTy, X, X->getName() + ".atomic.load");
+    OldVal->setAtomic(AO);
+    const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
+    unsigned LoadSize =
+        LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
+    auto AtomicLoadRes = EmitAtomicLoadLibcall(X, XElemTy, AO, LoadSize * 8);
+    BasicBlock *CurBB = Builder.GetInsertBlock();
+    Instruction *CurBBTI = CurBB->getTerminator();
+    CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
+    BasicBlock *ExitBB =
+        CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
+    BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
+                                                X->getName() + ".atomic.cont");
+    ContBB->getTerminator()->eraseFromParent();
+    Builder.restoreIP(AllocaIP);
+    AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
+    NewAtomicAddr->setName(X->getName() + "x.new.val");
+    Builder.SetInsertPoint(ContBB);
+    llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
+    PHI->addIncoming(AtomicLoadRes.first, CurBB);
+    Value *OldExprVal = PHI;
+    Value *Upd = UpdateOp(OldExprVal, Builder);
+    Builder.CreateStore(Upd, NewAtomicAddr);
+    AtomicOrdering Failure =
+        llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
+    auto Result = EmitAtomicCompareExchangeLibcall(X, XElemTy, LoadSize * 8,
+                                                   AtomicLoadRes.second,
+                                                   NewAtomicAddr, AO, Failure);
+    LoadInst *PHILoad = Builder.CreateLoad(XElemTy, Result.first);
+    PHI->addIncoming(PHILoad, Builder.GetInsertBlock());
+    Builder.CreateCondBr(Result.second, ExitBB, ContBB);
+    OldVal->eraseFromParent();
+    Res.first = OldExprVal;
+    Res.second = Upd;
+
+    if (UnreachableInst *ExitTI =
+            dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
+      CurBBTI->eraseFromParent();
+      Builder.SetInsertPoint(ExitBB);
+    } else {
+      Builder.SetInsertPoint(ExitTI);
+    }
   } else {
     IntegerType *IntCastTy =
         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0cba8d80681f13..a0016485979226 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2050,28 +2050,23 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
     isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
                       atomicCaptureOp.getAtomicUpdateOp().getOperation();
     auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
-    bool isRegionArgUsed{false};
     // Find the binary update operation that uses the region argument
     // and get the expression to update
-    for (Operation &innerOp : innerOpList) {
-      if (innerOp.getNumOperands() == 2) {
-        binop = convertBinOpToAtomic(innerOp);
-        if (!llvm::is_contained(innerOp.getOperands(),
-                                atomicUpdateOp.getRegion().getArgument(0)))
-          continue;
-        isRegionArgUsed = true;
-        isXBinopExpr =
-            innerOp.getNumOperands() > 0 &&
-            innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
-        mlirExpr =
-            (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
-        break;
+    if (innerOpList.size() == 2) {
+      mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
+      if (!llvm::is_contained(innerOp.getOperands(),
+                              atomicUpdateOp.getRegion().getArgument(0))) {
+        return atomicUpdateOp.emitError(
+            "no atomic update operation with region argument"
+            " as operand found inside atomic.update region");
       }
+      binop = convertBinOpToAtomic(innerOp);
+      isXBinopExpr =
+          innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
+      mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
+    } else {
+      binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
     }
-    if (!isRegionArgUsed)
-      return atomicUpdateOp.emitError(
-          "no atomic update operation with region argument"
-          " as operand found inside atomic.update region");
   }
 
   llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);



More information about the flang-commits mailing list