[flang-commits] [flang] [llvm] [mlir] [OpenMPIRBuilder] Emit __atomic_load and __atomic_compare_exchange libcalls for complex types in atomic update (PR #92364)

via flang-commits flang-commits at lists.llvm.org
Wed Oct 2 22:11:58 PDT 2024


https://github.com/NimishMishra updated https://github.com/llvm/llvm-project/pull/92364

>From ed2d7ead81b224875abeee3a635b753c122be4b0 Mon Sep 17 00:00:00 2001
From: Nimish Mishra <neelam.nimish at gmail.com>
Date: Thu, 3 Oct 2024 10:40:05 +0530
Subject: [PATCH] [flang][mlir][llvm][OpenMP] Emit __atomic_load and
 __atomic_compare_exchange libcalls for complex types in atomic update

---
 flang/lib/Lower/DirectivesCommon.h            |   3 +-
 .../OpenMP/atomic-capture-complex.f90         |  47 ++++
 .../OpenMP/atomic-update-complex.f90          |  42 ++++
 .../test/Lower/OpenMP/Todo/atomic-complex.f90 |   8 -
 llvm/include/llvm/Frontend/Atomic/Atomic.h    | 232 ++++++++++++++++++
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  31 +++
 llvm/lib/Frontend/Atomic/Atomic.cpp           |  19 ++
 llvm/lib/Frontend/Atomic/CMakeLists.txt       |  15 ++
 llvm/lib/Frontend/CMakeLists.txt              |   1 +
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  48 ++++
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  31 +--
 mlir/test/Target/LLVMIR/openmp-llvm.mlir      | 114 +++++++++
 12 files changed, 563 insertions(+), 28 deletions(-)
 create mode 100644 flang/test/Integration/OpenMP/atomic-capture-complex.f90
 create mode 100644 flang/test/Integration/OpenMP/atomic-update-complex.f90
 delete mode 100644 flang/test/Lower/OpenMP/Todo/atomic-complex.f90
 create mode 100644 llvm/include/llvm/Frontend/Atomic/Atomic.h
 create mode 100644 llvm/lib/Frontend/Atomic/Atomic.cpp
 create mode 100644 llvm/lib/Frontend/Atomic/CMakeLists.txt

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index a32f0b287e049a..da192ded4aa971 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -129,8 +129,7 @@ static void processOmpAtomicTODO(mlir::Type elementType,
     // Based on assertion for supported element types in OMPIRBuilder.cpp
     // createAtomicRead
     mlir::Type unwrappedEleTy = fir::unwrapRefType(elementType);
-    bool supportedAtomicType =
-        (!fir::isa_complex(unwrappedEleTy) && fir::isa_trivial(unwrappedEleTy));
+    bool supportedAtomicType = fir::isa_trivial(unwrappedEleTy);
     if (!supportedAtomicType)
       TODO(loc, "Unsupported atomic type");
   }
diff --git a/flang/test/Integration/OpenMP/atomic-capture-complex.f90 b/flang/test/Integration/OpenMP/atomic-capture-complex.f90
new file mode 100644
index 00000000000000..72329f0b2eb10d
--- /dev/null
+++ b/flang/test/Integration/OpenMP/atomic-capture-complex.f90
@@ -0,0 +1,47 @@
+!===----------------------------------------------------------------------===!
+! This directory can be used to add Integration tests involving multiple
+! stages of the compiler (for eg. from Fortran to LLVM IR). It should not
+! contain executable tests. We should only add tests here sparingly and only
+! if there is no other way to test. Repeat this message in each test that is
+! added to this directory and sub-directories.
+!===----------------------------------------------------------------------===!
+
+!RUN: %flang_fc1 -emit-llvm -fopenmp %s -o - | FileCheck %s
+
+!CHECK: %[[X_NEW_VAL:.*]] = alloca { float, float }, align 8
+!CHECK: %[[VAL_1:.*]] = alloca { float, float }, i64 1, align 8
+!CHECK: %[[ORIG_VAL:.*]] = alloca { float, float }, i64 1, align 8
+!CHECK: store { float, float } { float 2.000000e+00, float 2.000000e+00 }, ptr %[[ORIG_VAL]], align 4
+!CHECK: br label %entry
+
+!CHECK: entry:
+!CHECK: %[[ATOMIC_TEMP_LOAD:.*]] = alloca { float, float }, align 8
+!CHECK: call void @__atomic_load(i64 8, ptr %[[ORIG_VAL]], ptr %[[ATOMIC_TEMP_LOAD]], i32 0)
+!CHECK: %[[PHI_NODE_ENTRY_1:.*]] = load { float, float }, ptr %[[ATOMIC_TEMP_LOAD]], align 8
+!CHECK: br label %.atomic.cont
+
+!CHECK: .atomic.cont
+!CHECK: %[[VAL_4:.*]] = phi { float, float } [ %[[PHI_NODE_ENTRY_1]], %entry ], [ %{{.*}}, %.atomic.cont ]
+!CHECK: %[[VAL_5:.*]] = extractvalue { float, float } %[[VAL_4]], 0
+!CHECK: %[[VAL_6:.*]] = extractvalue { float, float } %[[VAL_4]], 1
+!CHECK: %[[VAL_7:.*]] = fadd contract float %[[VAL_5]], 1.000000e+00
+!CHECK: %[[VAL_8:.*]] = fadd contract float %[[VAL_6]], 1.000000e+00
+!CHECK: %[[VAL_9:.*]] = insertvalue { float, float } undef, float %[[VAL_7]], 0
+!CHECK: %[[VAL_10:.*]] = insertvalue { float, float } %[[VAL_9]], float %[[VAL_8]], 1
+!CHECK: store { float, float } %[[VAL_10]], ptr %[[X_NEW_VAL]], align 4
+!CHECK: %[[VAL_11:.*]] = call i1 @__atomic_compare_exchange(i64 8, ptr %[[ORIG_VAL]], ptr %[[ATOMIC_TEMP_LOAD]], ptr %[[X_NEW_VAL]],
+!i32 2, i32 2)
+!CHECK: %[[VAL_12:.*]] = load { float, float }, ptr %[[ATOMIC_TEMP_LOAD]], align 4
+!CHECK: br i1 %[[VAL_11]], label %.atomic.exit, label %.atomic.cont
+
+!CHECK: .atomic.exit
+!CHECK: store { float, float } %[[VAL_10]], ptr %[[VAL_1]], align 4
+
+program main
+      complex*8 ia, ib
+      ia = (2, 2)
+      !$omp atomic capture
+        ia = ia + (1, 1)
+        ib = ia
+      !$omp end atomic
+end program
diff --git a/flang/test/Integration/OpenMP/atomic-update-complex.f90 b/flang/test/Integration/OpenMP/atomic-update-complex.f90
new file mode 100644
index 00000000000000..827e84a011f53b
--- /dev/null
+++ b/flang/test/Integration/OpenMP/atomic-update-complex.f90
@@ -0,0 +1,42 @@
+!===----------------------------------------------------------------------===!
+! This directory can be used to add Integration tests involving multiple
+! stages of the compiler (for eg. from Fortran to LLVM IR). It should not
+! contain executable tests. We should only add tests here sparingly and only
+! if there is no other way to test. Repeat this message in each test that is
+! added to this directory and sub-directories.
+!===----------------------------------------------------------------------===!
+
+!RUN: %flang_fc1 -emit-llvm -fopenmp %s -o - | FileCheck %s
+
+!CHECK: define void @_QQmain() {
+!CHECK: %[[X_NEW_VAL:.*]] = alloca { float, float }, align 8
+!CHECK: {{.*}} = alloca { float, float }, i64 1, align 8
+!CHECK: %[[ORIG_VAL:.*]] = alloca { float, float }, i64 1, align 8
+!CHECK: store { float, float } { float 2.000000e+00, float 2.000000e+00 }, ptr %[[ORIG_VAL]], align 4
+!CHECK: br label %entry
+
+!CHECK: entry:
+!CHECK: %[[ATOMIC_TEMP_LOAD:.*]] = alloca { float, float }, align 8
+!CHECK: call void @__atomic_load(i64 8, ptr %[[ORIG_VAL]], ptr %[[ATOMIC_TEMP_LOAD]], i32 0)
+!CHECK: %[[PHI_NODE_ENTRY_1:.*]] = load { float, float }, ptr %[[ATOMIC_TEMP_LOAD]], align 8
+!CHECK: br label %.atomic.cont
+
+!CHECK: .atomic.cont
+!CHECK: %[[VAL_4:.*]] = phi { float, float } [ %[[PHI_NODE_ENTRY_1]], %entry ], [ %{{.*}}, %.atomic.cont ]
+!CHECK: %[[VAL_5:.*]] = extractvalue { float, float } %[[VAL_4]], 0
+!CHECK: %[[VAL_6:.*]] = extractvalue { float, float } %[[VAL_4]], 1
+!CHECK: %[[VAL_7:.*]] = fadd contract float %[[VAL_5]], 1.000000e+00
+!CHECK: %[[VAL_8:.*]] = fadd contract float %[[VAL_6]], 1.000000e+00
+!CHECK: %[[VAL_9:.*]] = insertvalue { float, float } undef, float %[[VAL_7]], 0
+!CHECK: %[[VAL_10:.*]] = insertvalue { float, float } %[[VAL_9]], float %[[VAL_8]], 1
+!CHECK: store { float, float } %[[VAL_10]], ptr %[[X_NEW_VAL]], align 4
+!CHECK: %[[VAL_11:.*]] = call i1 @__atomic_compare_exchange(i64 8, ptr %[[ORIG_VAL]], ptr %[[ATOMIC_TEMP_LOAD]], ptr %[[X_NEW_VAL]], i32 2, i32 2)
+!CHECK: %[[VAL_12:.*]] = load { float, float }, ptr %[[ATOMIC_TEMP_LOAD]], align 4
+!CHECK: br i1 %[[VAL_11]], label %.atomic.exit, label %.atomic.cont
+program main
+      complex*8 ia, ib
+      ia = (2, 2)
+      !$omp atomic update
+        ia = ia + (1, 1)
+      !$omp end atomic
+end program
diff --git a/flang/test/Lower/OpenMP/Todo/atomic-complex.f90 b/flang/test/Lower/OpenMP/Todo/atomic-complex.f90
deleted file mode 100644
index 6d6e4399ee192e..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/atomic-complex.f90
+++ /dev/null
@@ -1,8 +0,0 @@
-! RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Unsupported atomic type
-subroutine complex_atomic
-  complex :: l, r
-  !$omp atomic read
-    l = r
-end subroutine
diff --git a/llvm/include/llvm/Frontend/Atomic/Atomic.h b/llvm/include/llvm/Frontend/Atomic/Atomic.h
new file mode 100644
index 00000000000000..3942d06144ce17
--- /dev/null
+++ b/llvm/include/llvm/Frontend/Atomic/Atomic.h
@@ -0,0 +1,232 @@
+//===--- Atomic.h - Codegen of atomic operations
+//---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_FRONTEND_ATOMIC_ATOMIC_H
+#define LLVM_FRONTEND_ATOMIC_ATOMIC_H
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/RuntimeLibcalls.h"
+
+namespace llvm {
+
+template <typename IRBuilderTy> struct AtomicInfo {
+
+  IRBuilderTy *Builder;
+  Type *Ty;
+  uint64_t AtomicSizeInBits;
+  uint64_t ValueSizeInBits;
+  llvm::Align AtomicAlign;
+  llvm::Align ValueAlign;
+  bool UseLibcall;
+
+public:
+  AtomicInfo(IRBuilderTy *Builder, Type *Ty, uint64_t AtomicSizeInBits,
+             uint64_t ValueSizeInBits, llvm::Align AtomicAlign,
+             llvm::Align ValueAlign, bool UseLibcall)
+      : Builder(Builder), Ty(Ty), AtomicSizeInBits(AtomicSizeInBits),
+        ValueSizeInBits(ValueSizeInBits), AtomicAlign(AtomicAlign),
+        ValueAlign(ValueAlign), UseLibcall(UseLibcall) {}
+
+  virtual ~AtomicInfo() = default;
+
+  llvm::Align getAtomicAlignment() const { return AtomicAlign; }
+  uint64_t getAtomicSizeInBits() const { return AtomicSizeInBits; }
+  uint64_t getValueSizeInBits() const { return ValueSizeInBits; }
+  bool shouldUseLibcall() const { return UseLibcall; }
+  llvm::Type *getAtomicTy() const { return Ty; }
+
+  virtual llvm::Value *getAtomicPointer() const = 0;
+  virtual void decorateWithTBAA(Instruction *I) = 0;
+  virtual llvm::AllocaInst *CreateAlloca(llvm::Type *Ty,
+                                         const llvm::Twine &Name) const = 0;
+
+  /*
+   * Is the atomic size larger than the underlying value type?
+   * Note that the absence of padding does not mean that atomic
+   * objects are completely interchangeable with non-atomic
+   * objects: we might have promoted the alignment of a type
+   * without making it bigger.
+   */
+  bool hasPadding() const { return (ValueSizeInBits != AtomicSizeInBits); }
+
+  LLVMContext &getLLVMContext() const { return Builder->getContext(); }
+
+  static bool shouldCastToInt(llvm::Type *ValTy, bool CmpXchg) {
+    if (ValTy->isFloatingPointTy())
+      return ValTy->isX86_FP80Ty() || CmpXchg;
+    return !ValTy->isIntegerTy() && !ValTy->isPointerTy();
+  }
+
+  llvm::Value *EmitAtomicLoadOp(llvm::AtomicOrdering AO, bool IsVolatile,
+                                bool CmpXchg = false) {
+    Value *Ptr = getAtomicPointer();
+    Type *AtomicTy = Ty;
+    if (shouldCastToInt(Ty, CmpXchg))
+      AtomicTy = llvm::IntegerType::get(getLLVMContext(), AtomicSizeInBits);
+    LoadInst *Load =
+        Builder->CreateAlignedLoad(AtomicTy, Ptr, AtomicAlign, "atomic-load");
+    Load->setAtomic(AO);
+    if (IsVolatile)
+      Load->setVolatile(true);
+    decorateWithTBAA(Load);
+    return Load;
+  }
+
+  static CallInst *EmitAtomicLibcall(IRBuilderTy *Builder, StringRef fnName,
+                                     Type *ResultType, ArrayRef<Value *> Args) {
+    LLVMContext &ctx = Builder->getContext();
+    SmallVector<Type *, 6> ArgTys;
+    for (Value *Arg : Args)
+      ArgTys.push_back(Arg->getType());
+    FunctionType *FnType = FunctionType::get(ResultType, ArgTys, false);
+    Module *M = Builder->GetInsertBlock()->getModule();
+
+    // TODO: Use llvm::TargetLowering for Libcall ABI
+    llvm::AttrBuilder fnAttrBuilder(ctx);
+    fnAttrBuilder.addAttribute(llvm::Attribute::NoUnwind);
+    fnAttrBuilder.addAttribute(llvm::Attribute::WillReturn);
+    llvm::AttributeList fnAttrs = llvm::AttributeList::get(
+        ctx, llvm::AttributeList::FunctionIndex, fnAttrBuilder);
+    FunctionCallee LibcallFn = M->getOrInsertFunction(fnName, FnType, fnAttrs);
+    CallInst *Call = Builder->CreateCall(LibcallFn, Args);
+    return Call;
+  }
+
+  llvm::Value *getAtomicSizeValue() const {
+    LLVMContext &ctx = getLLVMContext();
+
+    // TODO: Get from llvm::TargetMachine / clang::TargetInfo
+    // 	if clang shares this codegen in future
+    constexpr uint16_t SizeTBits = 64;
+    constexpr uint16_t BitsPerByte = 8;
+    return llvm::ConstantInt::get(llvm::IntegerType::get(ctx, SizeTBits),
+                                  AtomicSizeInBits / BitsPerByte);
+  }
+
+  std::pair<llvm::Value *, llvm::Value *> EmitAtomicCompareExchangeLibcall(
+      llvm::Value *ExpectedVal, llvm::Value *DesiredVal,
+      llvm::AtomicOrdering Success, llvm::AtomicOrdering Failure) {
+    LLVMContext &ctx = getLLVMContext();
+
+    // __atomic_compare_exchange's expected and desired are passed by pointers
+    // FIXME: types
+
+    // TODO: Get from llvm::TargetMachine / clang::TargetInfo
+    // 	if clang shares this codegen in future
+    constexpr uint64_t IntBits = 32;
+
+    // bool __atomic_compare_exchange(size_t size, void *obj, void *expected,
+    // 	void *desired, int success, int failure);
+    llvm::Value *Args[6] = {
+        getAtomicSizeValue(),
+        getAtomicPointer(),
+        ExpectedVal,
+        DesiredVal,
+        llvm::Constant::getIntegerValue(
+            llvm::IntegerType::get(ctx, IntBits),
+            llvm::APInt(IntBits, static_cast<uint64_t>(Success),
+                        /*signed=*/true)),
+        llvm::Constant::getIntegerValue(
+            llvm::IntegerType::get(ctx, IntBits),
+            llvm::APInt(IntBits, static_cast<uint64_t>(Failure),
+                        /*signed=*/true)),
+    };
+    auto Result = EmitAtomicLibcall(Builder, "__atomic_compare_exchange",
+                                    llvm::IntegerType::getInt1Ty(ctx), Args);
+    return std::make_pair(ExpectedVal, Result);
+  }
+
+  Value *castToAtomicIntPointer(Value *addr) const {
+    return addr; // opaque pointer
+  }
+
+  Value *getAtomicAddressAsAtomicIntPointer() const {
+    return castToAtomicIntPointer(getAtomicPointer());
+  }
+
+  std::pair<llvm::Value *, llvm::Value *>
+  EmitAtomicCompareExchangeOp(llvm::Value *ExpectedVal, llvm::Value *DesiredVal,
+                              llvm::AtomicOrdering Success,
+                              llvm::AtomicOrdering Failure,
+                              bool IsVolatile = false, bool IsWeak = false) {
+    // Do the atomic store.
+    Value *Addr = getAtomicAddressAsAtomicIntPointer();
+    auto *Inst = Builder->CreateAtomicCmpXchg(Addr, ExpectedVal, DesiredVal,
+                                              getAtomicAlignment(), Success,
+                                              Failure, llvm::SyncScope::System);
+    // Other decoration.
+    Inst->setVolatile(IsVolatile);
+    Inst->setWeak(IsWeak);
+
+    auto *PreviousVal = Builder->CreateExtractValue(Inst, /*Idxs=*/0);
+    auto *SuccessFailureVal = Builder->CreateExtractValue(Inst, /*Idxs=*/1);
+    return std::make_pair(PreviousVal, SuccessFailureVal);
+  }
+
+  std::pair<llvm::Value *, llvm::Value *>
+  EmitAtomicCompareExchange(llvm::Value *ExpectedVal, llvm::Value *DesiredVal,
+                            llvm::AtomicOrdering Success,
+                            llvm::AtomicOrdering Failure, bool IsVolatile,
+                            bool IsWeak) {
+    if (shouldUseLibcall())
+      return EmitAtomicCompareExchangeLibcall(ExpectedVal, DesiredVal, Success,
+                                              Failure);
+
+    auto Res = EmitAtomicCompareExchangeOp(ExpectedVal, DesiredVal, Success,
+                                           Failure, IsVolatile, IsWeak);
+    return Res;
+  }
+
+  // void __atomic_load(size_t size, void *mem, void *return, int order);
+  std::pair<llvm::LoadInst *, llvm::AllocaInst *>
+  EmitAtomicLoadLibcall(llvm::AtomicOrdering AO) {
+    LLVMContext &Ctx = getLLVMContext();
+    Type *SizedIntTy = Type::getIntNTy(Ctx, getAtomicSizeInBits());
+    Type *ResultTy;
+    SmallVector<Value *, 6> Args;
+    AttributeList Attr;
+    Module *M = Builder->GetInsertBlock()->getModule();
+    const DataLayout &DL = M->getDataLayout();
+    Args.push_back(ConstantInt::get(DL.getIntPtrType(Ctx),
+                                    this->getAtomicSizeInBits() / 8));
+
+    Value *PtrVal = getAtomicPointer();
+    PtrVal = Builder->CreateAddrSpaceCast(PtrVal, PointerType::getUnqual(Ctx));
+    Args.push_back(PtrVal);
+    AllocaInst *AllocaResult =
+        CreateAlloca(Ty, getAtomicPointer()->getName() + "atomic.temp.load");
+    const Align AllocaAlignment = DL.getPrefTypeAlign(SizedIntTy);
+    AllocaResult->setAlignment(AllocaAlignment);
+    Args.push_back(AllocaResult);
+    Constant *OrderingVal =
+        ConstantInt::get(Type::getInt32Ty(Ctx), (int)toCABI(AO));
+    Args.push_back(OrderingVal);
+
+    ResultTy = Type::getVoidTy(Ctx);
+    SmallVector<Type *, 6> ArgTys;
+    for (Value *Arg : Args)
+      ArgTys.push_back(Arg->getType());
+    FunctionType *FnType = FunctionType::get(ResultTy, ArgTys, false);
+    FunctionCallee LibcallFn =
+        M->getOrInsertFunction("__atomic_load", FnType, Attr);
+    CallInst *Call = Builder->CreateCall(LibcallFn, Args);
+    Call->setAttributes(Attr);
+    return std::make_pair(
+        Builder->CreateAlignedLoad(Ty, AllocaResult, AllocaAlignment),
+        AllocaResult);
+  }
+};
+} // end namespace llvm
+
+#endif /* LLVM_FRONTEND_ATOMIC_ATOMIC_H */
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4be0159fb1dd9f..1b8a6e47b3baf8 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -15,6 +15,7 @@
 #define LLVM_FRONTEND_OPENMP_OMPIRBUILDER_H
 
 #include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Frontend/Atomic/Atomic.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
 #include "llvm/IR/DebugLoc.h"
@@ -479,6 +480,27 @@ class OpenMPIRBuilder {
         T(Triple(M.getTargetTriple())) {}
   ~OpenMPIRBuilder();
 
+  class AtomicInfo : public llvm::AtomicInfo<IRBuilder<>> {
+    llvm::Value *AtomicVar;
+
+  public:
+    AtomicInfo(IRBuilder<> *Builder, llvm::Type *Ty, uint64_t AtomicSizeInBits,
+               uint64_t ValueSizeInBits, llvm::Align AtomicAlign,
+               llvm::Align ValueAlign, bool UseLibcall, llvm::Value *AtomicVar)
+        : llvm::AtomicInfo<IRBuilder<>>(Builder, Ty, AtomicSizeInBits,
+                                        ValueSizeInBits, AtomicAlign,
+                                        ValueAlign, UseLibcall),
+          AtomicVar(AtomicVar) {}
+
+    llvm::Value *getAtomicPointer() const override { return AtomicVar; }
+    void decorateWithTBAA(llvm::Instruction *I) override {}
+    llvm::AllocaInst *CreateAlloca(llvm::Type *Ty,
+                                   const llvm::Twine &Name) const override {
+      llvm::AllocaInst *allocaInst = Builder->CreateAlloca(Ty);
+      allocaInst->setName(Name);
+      return allocaInst;
+    }
+  };
   /// Initialize the internal state, this will put structures types and
   /// potentially other helpers into the underlying module. Must be called
   /// before any other method and only once! This internal state includes types
@@ -3039,6 +3061,15 @@ class OpenMPIRBuilder {
                    AtomicUpdateCallbackTy &UpdateOp, bool VolatileX,
                    bool IsXBinopExpr);
 
+  std::pair<llvm::LoadInst *, llvm::AllocaInst *>
+  EmitAtomicLoadLibcall(Value *X, Type *XElemTy, llvm::AtomicOrdering AO,
+                        uint64_t AtomicSizeInBits);
+
+  std::pair<llvm::Value *, llvm::Value *> EmitAtomicCompareExchangeLibcall(
+      Value *X, Type *XElemTy, uint64_t AtomicSizeInBits,
+      llvm::Value *ExpectedVal, llvm::Value *DesiredVal,
+      llvm::AtomicOrdering Success, llvm::AtomicOrdering Failure);
+
   /// Emit the binary op. described by \p RMWOp, using \p Src1 and \p Src2 .
   ///
   /// \Return The instruction
diff --git a/llvm/lib/Frontend/Atomic/Atomic.cpp b/llvm/lib/Frontend/Atomic/Atomic.cpp
new file mode 100644
index 00000000000000..03b476d113bd48
--- /dev/null
+++ b/llvm/lib/Frontend/Atomic/Atomic.cpp
@@ -0,0 +1,19 @@
+//===--- Atomic.h - Codegen of atomic operations
+//---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Frontend/Atomic/Atomic.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Frontend/Atomic/Atomic.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Operator.h"
+
+namespace {} // namespace
+
+namespace llvm {} // end namespace llvm
diff --git a/llvm/lib/Frontend/Atomic/CMakeLists.txt b/llvm/lib/Frontend/Atomic/CMakeLists.txt
new file mode 100644
index 00000000000000..0d0d3d445b726d
--- /dev/null
+++ b/llvm/lib/Frontend/Atomic/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_component_library(LLVMFrontendAtomic
+	Atomic.cpp
+
+	ADDITIONAL_HEADER_DIRS
+	${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/Atomic
+
+	DEPENDS
+	LLVMAnalysis
+	LLVMTargetParser
+
+	LINK_COMPONENTS
+	Core
+	Support
+	Analysis
+)
diff --git a/llvm/lib/Frontend/CMakeLists.txt b/llvm/lib/Frontend/CMakeLists.txt
index 62dd0da1e6c2de..b305ce7d771ce7 100644
--- a/llvm/lib/Frontend/CMakeLists.txt
+++ b/llvm/lib/Frontend/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Atomic)
 add_subdirectory(Driver)
 add_subdirectory(HLSL)
 add_subdirectory(OpenACC)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 922c65d7fc3f5c..47cc6ff7655caf 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7977,6 +7977,54 @@ std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
       Res.second = Res.first;
     else
       Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
+  } else if (RMWOp == llvm::AtomicRMWInst::BinOp::BAD_BINOP &&
+             XElemTy->isStructTy()) {
+    LoadInst *OldVal =
+        Builder.CreateLoad(XElemTy, X, X->getName() + ".atomic.load");
+    OldVal->setAtomic(AO);
+    const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
+    unsigned LoadSize =
+        LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
+
+    OpenMPIRBuilder::AtomicInfo atomicInfo(
+        &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
+        OldVal->getAlign(), true /* UseLibcall */, X);
+    auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
+    BasicBlock *CurBB = Builder.GetInsertBlock();
+    Instruction *CurBBTI = CurBB->getTerminator();
+    CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
+    BasicBlock *ExitBB =
+        CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
+    BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
+                                                X->getName() + ".atomic.cont");
+    ContBB->getTerminator()->eraseFromParent();
+    Builder.restoreIP(AllocaIP);
+    AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
+    NewAtomicAddr->setName(X->getName() + "x.new.val");
+    Builder.SetInsertPoint(ContBB);
+    llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
+    PHI->addIncoming(AtomicLoadRes.first, CurBB);
+    Value *OldExprVal = PHI;
+    Value *Upd = UpdateOp(OldExprVal, Builder);
+    Builder.CreateStore(Upd, NewAtomicAddr);
+    AtomicOrdering Failure =
+        llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
+    auto Result = atomicInfo.EmitAtomicCompareExchangeLibcall(
+        AtomicLoadRes.second, NewAtomicAddr, AO, Failure);
+    LoadInst *PHILoad = Builder.CreateLoad(XElemTy, Result.first);
+    PHI->addIncoming(PHILoad, Builder.GetInsertBlock());
+    Builder.CreateCondBr(Result.second, ExitBB, ContBB);
+    OldVal->eraseFromParent();
+    Res.first = OldExprVal;
+    Res.second = Upd;
+
+    if (UnreachableInst *ExitTI =
+            dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
+      CurBBTI->eraseFromParent();
+      Builder.SetInsertPoint(ExitBB);
+    } else {
+      Builder.SetInsertPoint(ExitTI);
+    }
   } else {
     IntegerType *IntCastTy =
         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9e5f800dca60bd..19d80fbbd699b0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2055,28 +2055,23 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
     isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
                       atomicCaptureOp.getAtomicUpdateOp().getOperation();
     auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
-    bool isRegionArgUsed{false};
     // Find the binary update operation that uses the region argument
     // and get the expression to update
-    for (Operation &innerOp : innerOpList) {
-      if (innerOp.getNumOperands() == 2) {
-        binop = convertBinOpToAtomic(innerOp);
-        if (!llvm::is_contained(innerOp.getOperands(),
-                                atomicUpdateOp.getRegion().getArgument(0)))
-          continue;
-        isRegionArgUsed = true;
-        isXBinopExpr =
-            innerOp.getNumOperands() > 0 &&
-            innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
-        mlirExpr =
-            (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
-        break;
+    if (innerOpList.size() == 2) {
+      mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
+      if (!llvm::is_contained(innerOp.getOperands(),
+                              atomicUpdateOp.getRegion().getArgument(0))) {
+        return atomicUpdateOp.emitError(
+            "no atomic update operation with region argument"
+            " as operand found inside atomic.update region");
       }
+      binop = convertBinOpToAtomic(innerOp);
+      isXBinopExpr =
+          innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
+      mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
+    } else {
+      binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
     }
-    if (!isRegionArgUsed)
-      return atomicUpdateOp.emitError(
-          "no atomic update operation with region argument"
-          " as operand found inside atomic.update region");
   }
 
   llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 95e12e5bc4e742..5d76e87472dfe4 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -1450,6 +1450,120 @@ llvm.func @omp_atomic_update(%x:!llvm.ptr, %expr: i32, %xbool: !llvm.ptr, %exprb
 
 // -----
 
+//CHECK: %[[X_NEW_VAL:.*]] = alloca { float, float }, align 8
+//CHECK: {{.*}} = alloca { float, float }, i64 1, align 8
+//CHECK: %[[ORIG_VAL:.*]] = alloca { float, float }, i64 1, align 8
+
+//CHECK: br label %entry
+
+//CHECK: entry:
+//CHECK: %[[ATOMIC_TEMP_LOAD:.*]] = alloca { float, float }, align 8
+//CHECK: call void @__atomic_load(i64 8, ptr %[[ORIG_VAL]], ptr %[[ATOMIC_TEMP_LOAD]], i32 0)
+//CHECK: %[[PHI_NODE_ENTRY_1:.*]] = load { float, float }, ptr %[[ATOMIC_TEMP_LOAD]], align 8
+//CHECK: br label %.atomic.cont
+
+//CHECK: .atomic.cont
+//CHECK: %[[VAL_4:.*]] = phi { float, float } [ %[[PHI_NODE_ENTRY_1]], %entry ], [ %{{.*}}, %.atomic.cont ]
+//CHECK: %[[VAL_5:.*]] = extractvalue { float, float } %[[VAL_4]], 0
+//CHECK: %[[VAL_6:.*]] = extractvalue { float, float } %[[VAL_4]], 1
+//CHECK: %[[VAL_7:.*]] = fadd contract float %[[VAL_5]], 1.000000e+00
+//CHECK: %[[VAL_8:.*]] = fadd contract float %[[VAL_6]], 1.000000e+00
+//CHECK: %[[VAL_9:.*]] = insertvalue { float, float } undef, float %[[VAL_7]], 0
+//CHECK: %[[VAL_10:.*]] = insertvalue { float, float } %[[VAL_9]], float %[[VAL_8]], 1
+//CHECK: store { float, float } %[[VAL_10]], ptr %[[X_NEW_VAL]], align 4
+//CHECK: %[[VAL_11:.*]] = call i1 @__atomic_compare_exchange(i64 8, ptr %[[ORIG_VAL]], ptr %[[ATOMIC_TEMP_LOAD]], ptr %[[X_NEW_VAL]], i32 2, i32 2)
+//CHECK: %[[VAL_12:.*]] = load { float, float }, ptr %[[ATOMIC_TEMP_LOAD]], align 4
+//CHECK: br i1 %[[VAL_11]], label %.atomic.exit, label %.atomic.cont
+
+llvm.func @_QPomp_atomic_update_complex() {
+    %0 = llvm.mlir.constant(1 : i64) : i64
+    %1 = llvm.alloca %0 x !llvm.struct<(f32, f32)> {bindc_name = "ib"} : (i64) -> !llvm.ptr
+    %2 = llvm.mlir.constant(1 : i64) : i64
+    %3 = llvm.alloca %2 x !llvm.struct<(f32, f32)> {bindc_name = "ia"} : (i64) -> !llvm.ptr
+    %4 = llvm.mlir.constant(1.000000e+00 : f32) : f32
+    %5 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
+    %6 = llvm.insertvalue %4, %5[0] : !llvm.struct<(f32, f32)>
+    %7 = llvm.insertvalue %4, %6[1] : !llvm.struct<(f32, f32)>
+    omp.atomic.update %3 : !llvm.ptr {
+    ^bb0(%arg0: !llvm.struct<(f32, f32)>):
+      %8 = llvm.extractvalue %arg0[0] : !llvm.struct<(f32, f32)>
+      %9 = llvm.extractvalue %arg0[1] : !llvm.struct<(f32, f32)>
+      %10 = llvm.extractvalue %7[0] : !llvm.struct<(f32, f32)>
+      %11 = llvm.extractvalue %7[1] : !llvm.struct<(f32, f32)>
+      %12 = llvm.fadd %8, %10  {fastmathFlags = #llvm.fastmath<contract>} : f32
+      %13 = llvm.fadd %9, %11  {fastmathFlags = #llvm.fastmath<contract>} : f32
+      %14 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
+      %15 = llvm.insertvalue %12, %14[0] : !llvm.struct<(f32, f32)>
+      %16 = llvm.insertvalue %13, %15[1] : !llvm.struct<(f32, f32)>
+      omp.yield(%16 : !llvm.struct<(f32, f32)>)
+    }
+   llvm.return
+}
+
+// -----
+
+//CHECK: %[[X_NEW_VAL:.*]] = alloca { float, float }, align 8
+//CHECK: %[[VAL_1:.*]] = alloca { float, float }, i64 1, align 8
+//CHECK: %[[ORIG_VAL:.*]] = alloca { float, float }, i64 1, align 8
+//CHECK: store { float, float } { float 2.000000e+00, float 2.000000e+00 }, ptr %[[ORIG_VAL]], align 4
+//CHECK: br label %entry
+
+//CHECK: entry:							; preds = %0
+//CHECK: %[[ATOMIC_TEMP_LOAD:.*]] = alloca { float, float }, align 8
+//CHECK: call void @__atomic_load(i64 8, ptr %[[ORIG_VAL]], ptr %[[ATOMIC_TEMP_LOAD]], i32 0)
+//CHECK: %[[PHI_NODE_ENTRY_1:.*]] = load { float, float }, ptr %[[ATOMIC_TEMP_LOAD]], align 8
+//CHECK: br label %.atomic.cont
+
+//CHECK: .atomic.cont
+//CHECK: %[[VAL_4:.*]] = phi { float, float } [ %[[PHI_NODE_ENTRY_1]], %entry ], [ %{{.*}}, %.atomic.cont ]
+//CHECK: %[[VAL_5:.*]] = extractvalue { float, float } %[[VAL_4]], 0
+//CHECK: %[[VAL_6:.*]] = extractvalue { float, float } %[[VAL_4]], 1
+//CHECK: %[[VAL_7:.*]] = fadd contract float %[[VAL_5]], 1.000000e+00
+//CHECK: %[[VAL_8:.*]] = fadd contract float %[[VAL_6]], 1.000000e+00
+//CHECK: %[[VAL_9:.*]] = insertvalue { float, float } undef, float %[[VAL_7]], 0
+//CHECK: %[[VAL_10:.*]] = insertvalue { float, float } %[[VAL_9]], float %[[VAL_8]], 1
+//CHECK: store { float, float } %[[VAL_10]], ptr %[[X_NEW_VAL]], align 4 
+//CHECK: %[[VAL_11:.*]] = call i1 @__atomic_compare_exchange(i64 8, ptr %[[ORIG_VAL]], ptr %[[ATOMIC_TEMP_LOAD]], ptr %[[X_NEW_VAL]], i32 2, i32 2)
+//CHECK: %[[VAL_12:.*]] = load { float, float }, ptr %[[ATOMIC_TEMP_LOAD]], align 4
+//CHECK: br i1 %[[VAL_11]], label %.atomic.exit, label %.atomic.cont
+//CHECK: .atomic.exit
+//CHECK: store { float, float } %[[VAL_10]], ptr %[[VAL_1]], align 4
+
+llvm.func @_QPomp_atomic_capture_complex() {
+    %0 = llvm.mlir.constant(1 : i64) : i64
+    %1 = llvm.alloca %0 x !llvm.struct<(f32, f32)> {bindc_name = "ib"} : (i64) -> !llvm.ptr
+    %2 = llvm.mlir.constant(1 : i64) : i64
+    %3 = llvm.alloca %2 x !llvm.struct<(f32, f32)> {bindc_name = "ia"} : (i64) -> !llvm.ptr
+    %4 = llvm.mlir.constant(1.000000e+00 : f32) : f32
+    %5 = llvm.mlir.constant(2.000000e+00 : f32) : f32
+    %6 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
+    %7 = llvm.insertvalue %5, %6[0] : !llvm.struct<(f32, f32)>
+    %8 = llvm.insertvalue %5, %7[1] : !llvm.struct<(f32, f32)>
+    llvm.store %8, %3 : !llvm.struct<(f32, f32)>, !llvm.ptr
+    %9 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
+    %10 = llvm.insertvalue %4, %9[0] : !llvm.struct<(f32, f32)>
+    %11 = llvm.insertvalue %4, %10[1] : !llvm.struct<(f32, f32)>
+    omp.atomic.capture {
+      omp.atomic.update %3 : !llvm.ptr {
+      ^bb0(%arg0: !llvm.struct<(f32, f32)>):
+        %12 = llvm.extractvalue %arg0[0] : !llvm.struct<(f32, f32)>
+        %13 = llvm.extractvalue %arg0[1] : !llvm.struct<(f32, f32)>
+        %14 = llvm.extractvalue %11[0] : !llvm.struct<(f32, f32)>
+        %15 = llvm.extractvalue %11[1] : !llvm.struct<(f32, f32)>
+        %16 = llvm.fadd %12, %14  {fastmathFlags = #llvm.fastmath<contract>} : f32
+        %17 = llvm.fadd %13, %15  {fastmathFlags = #llvm.fastmath<contract>} : f32
+        %18 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
+        %19 = llvm.insertvalue %16, %18[0] : !llvm.struct<(f32, f32)>
+        %20 = llvm.insertvalue %17, %19[1] : !llvm.struct<(f32, f32)>
+        omp.yield(%20 : !llvm.struct<(f32, f32)>)
+      }
+      omp.atomic.read %1 = %3 : !llvm.ptr, !llvm.struct<(f32, f32)>
+    }
+    llvm.return
+}
+
+// -----
+
 // Checking an order-dependent operation when the order is `expr binop x`
 // CHECK-LABEL: @omp_atomic_update_ordering
 // CHECK-SAME: (ptr %[[x:.*]], i32 %[[expr:.*]])



More information about the flang-commits mailing list