[llvm] [mlir] [mlir] Add __atomic_store to AtomicInfo (PR #121055)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 24 07:14:27 PST 2024


https://github.com/NimishMishra created https://github.com/llvm/llvm-project/pull/121055

This PR adds functionality for `__atomic_store` libcall in AtomicInfo. This allows for supporting complex types in `atomic write`.

Fixes https://github.com/llvm/llvm-project/issues/113479

>From b400ef53c69b82d772502be1f0b9f14002d6ef04 Mon Sep 17 00:00:00 2001
From: Nimish Mishra <neelam.nimish at gmail.com>
Date: Tue, 24 Dec 2024 20:41:07 +0530
Subject: [PATCH] [mlir] Add __atomic_store to AtomicInfo

---
 llvm/include/llvm/Frontend/Atomic/Atomic.h |  2 ++
 llvm/lib/Frontend/Atomic/Atomic.cpp        | 33 ++++++++++++++++++++++
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp  | 12 +++++++-
 mlir/test/Target/LLVMIR/openmp-llvm.mlir   | 22 +++++++++++++++
 4 files changed, 68 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Frontend/Atomic/Atomic.h b/llvm/include/llvm/Frontend/Atomic/Atomic.h
index 9f46fde6292a90..d9819db17cf645 100644
--- a/llvm/include/llvm/Frontend/Atomic/Atomic.h
+++ b/llvm/include/llvm/Frontend/Atomic/Atomic.h
@@ -96,6 +96,8 @@ class AtomicInfo {
                             bool IsVolatile, bool IsWeak);
 
   std::pair<LoadInst *, AllocaInst *> EmitAtomicLoadLibcall(AtomicOrdering AO);
+
+  void EmitAtomicStoreLibcall(AtomicOrdering AO, Value *Source);
 };
 } // end namespace llvm
 
diff --git a/llvm/lib/Frontend/Atomic/Atomic.cpp b/llvm/lib/Frontend/Atomic/Atomic.cpp
index c9f9a9dcfb702a..31a8794b2b83a9 100644
--- a/llvm/lib/Frontend/Atomic/Atomic.cpp
+++ b/llvm/lib/Frontend/Atomic/Atomic.cpp
@@ -141,6 +141,39 @@ AtomicInfo::EmitAtomicLoadLibcall(AtomicOrdering AO) {
       AllocaResult);
 }
 
+void AtomicInfo::EmitAtomicStoreLibcall(AtomicOrdering AO, Value *Source) {
+  LLVMContext &Ctx = getLLVMContext();
+  SmallVector<Value *, 6> Args;
+  AttributeList Attr;
+  Module *M = Builder->GetInsertBlock()->getModule();
+  const DataLayout &DL = M->getDataLayout();
+  Args.push_back(
+      ConstantInt::get(DL.getIntPtrType(Ctx), this->getAtomicSizeInBits() / 8));
+
+  Value *PtrVal = getAtomicPointer();
+  PtrVal = Builder->CreateAddrSpaceCast(PtrVal, PointerType::getUnqual(Ctx));
+  Args.push_back(PtrVal);
+
+  Value *SourceAlloca = Builder->CreateAlloca(Source->getType());
+  Builder->CreateStore(Source, SourceAlloca);
+  SourceAlloca = Builder->CreatePointerBitCastOrAddrSpaceCast(
+      SourceAlloca, Builder->getPtrTy());
+  Args.push_back(SourceAlloca);
+
+  Constant *OrderingVal =
+      ConstantInt::get(Type::getInt32Ty(Ctx), (int)toCABI(AO));
+  Args.push_back(OrderingVal);
+
+  SmallVector<Type *, 6> ArgTys;
+  for (Value *Arg : Args)
+    ArgTys.push_back(Arg->getType());
+  FunctionType *FnType = FunctionType::get(Type::getVoidTy(Ctx), ArgTys, false);
+  FunctionCallee LibcallFn =
+      M->getOrInsertFunction("__atomic_store", FnType, Attr);
+  CallInst *Call = Builder->CreateCall(LibcallFn, Args);
+  Call->setAttributes(Attr);
+}
+
 std::pair<Value *, Value *> AtomicInfo::EmitAtomicCompareExchange(
     Value *ExpectedVal, Value *DesiredVal, AtomicOrdering Success,
     AtomicOrdering Failure, bool IsVolatile, bool IsWeak) {
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 0d8dbbe3a8a718..d47759a7bc9bd9 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -8400,12 +8400,22 @@ OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
          "OMP Atomic expects a pointer to target memory");
   Type *XElemTy = X.ElemTy;
   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
-          XElemTy->isPointerTy()) &&
+          XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
          "OMP atomic write expected a scalar type");
 
   if (XElemTy->isIntegerTy()) {
     StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
     XSt->setAtomic(AO);
+  } else if (XElemTy->isStructTy()) {
+    LoadInst *OldVal = Builder.CreateLoad(XElemTy, X.Var, "omp.atomic.read");
+    const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
+    unsigned LoadSize =
+        LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
+    OpenMPIRBuilder::AtomicInfo atomicInfo(
+        &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
+        OldVal->getAlign(), true /* UseLibcall */, X.Var);
+    atomicInfo.EmitAtomicStoreLibcall(AO, Expr);
+    OldVal->eraseFromParent();
   } else {
     // We need to bitcast and perform atomic op as integers
     IntegerType *IntCastTy =
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 5f8bdf8afdf783..1ae6218ec11004 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -1409,6 +1409,28 @@ llvm.func @omp_atomic_update(%x:!llvm.ptr, %expr: i32, %xbool: !llvm.ptr, %exprb
   llvm.return
 }
 
+// ----
+
+//CHECK-LABEL: @atomic_complex_write
+//CHECK: %[[V:.*]] = alloca { float, float }, i64 1, align 8
+//CHECK: %[[X:.*]] = alloca { float, float }, i64 1, align 8
+//CHECK: %[[LOAD:.*]] = load { float, float }, ptr %1, align 4
+//CHECK: %[[ALLOCA:.*]] = alloca { float, float }, align 8
+//CHECK: store { float, float } %[[LOAD]], ptr %[[ALLOCA]], align 4
+//CHECK: call void @__atomic_store(i64 8, ptr %[[X]], ptr %[[ALLOCA]], i32 0)
+
+llvm.func @atomic_complex_write() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x !llvm.struct<(f32, f32)> {bindc_name = "r"} : (i64) -> !llvm.ptr
+  %2 = llvm.mlir.constant(1 : i64) : i64
+  %3 = llvm.alloca %2 x !llvm.struct<(f32, f32)> {bindc_name = "l"} : (i64) -> !llvm.ptr
+  %4 = llvm.mlir.constant(1 : i64) : i64
+  %5 = llvm.mlir.constant(1 : i64) : i64
+  %6 = llvm.load %1 : !llvm.ptr -> !llvm.struct<(f32, f32)>
+  omp.atomic.write %3 = %6 : !llvm.ptr, !llvm.struct<(f32, f32)>
+  llvm.return
+}
+
 // -----
 
 //CHECK: %[[X_NEW_VAL:.*]] = alloca { float, float }, align 8



More information about the llvm-commits mailing list