[Mlir-commits] [flang] [llvm] [mlir] [OpenMP][OMPIRBuilder] Support complex types in atomic update/capture (PR #191490)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 10 11:55:29 PDT 2026


https://github.com/chichunchen created https://github.com/llvm/llvm-project/pull/191490

Route struct-typed values through the libcall path in `emitAtomicUpdate`.

Previously, the libcall path was gated on `RMWOp == BAD_BINOP`, so atomic capture swap patterns (`v = x; x = expr`) for complex values lowered as structs fell through to the cmpxchg path. That path called `getScalarSizeInBits()` on a struct type, produced 0, and triggered an assertion in `IntegerType::get()`.

Remove the `BAD_BINOP` restriction so struct types always use the libcall path. This is safe because the libcall path does not use `RMWOp` and already handles arbitrary type sizes correctly.

Also fix `LoadSize` in the libcall path to use `XElemTy` rather than the pointer type, which previously gave the wrong size for larger complex types such as `complex(8)`.

Fixes https://github.com/llvm/llvm-project/issues/191317

Assisted with copilot and GPT-5.4

>From c93431e506d912f72e00fbc191d506e1be5587b8 Mon Sep 17 00:00:00 2001
From: "Chi Chun, Chen" <chichun.chen at hpe.com>
Date: Fri, 10 Apr 2026 02:32:49 -0500
Subject: [PATCH] [OpenMP][OMPIRBuilder] Support complex types in atomic
 update/capture

Route struct-typed values through the libcall path in
`emitAtomicUpdate`.

Previously, the libcall path was gated on `RMWOp == BAD_BINOP`, so
atomic capture swap patterns (`v = x; x = expr`) for complex values
lowered as structs fell through to the cmpxchg path. That path called
`getScalarSizeInBits()` on a struct type, produced 0, and triggered an
assertion in `IntegerType::get()`.

Remove the `BAD_BINOP` restriction so struct types always use the
libcall path. This is safe because the libcall path does not use
`RMWOp` and already handles arbitrary type sizes correctly.

Also fix `LoadSize` in the libcall path to use `XElemTy` rather than
the pointer type, which previously gave the wrong size for larger
complex types such as `complex(8)`.

Fixes https://github.com/llvm/llvm-project/issues/191317

Assisted with copilot and GPT-5.4
---
 .../Lower/OpenMP/atomic-capture-complex.f90   | 34 +++++++++++++++++++
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 13 +++----
 mlir/test/Target/LLVMIR/openmp-llvm.mlir      | 22 ++++++++++++
 3 files changed, 61 insertions(+), 8 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/atomic-capture-complex.f90

diff --git a/flang/test/Lower/OpenMP/atomic-capture-complex.f90 b/flang/test/Lower/OpenMP/atomic-capture-complex.f90
new file mode 100644
index 0000000000000..f5fed98903460
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-capture-complex.f90
@@ -0,0 +1,34 @@
+! RUN: %flang_fc1 -emit-llvm -fopenmp -fopenmp-version=40 -o - %s | FileCheck %s
+
+! CHECK-LABEL: define {{.*}} @test_capture_swap_complex4_
+! CHECK: call void @__atomic_load(i64 8, ptr %{{.*}}, ptr %{{.*}}, i32 {{.*}})
+! CHECK: call i1 @__atomic_compare_exchange(i64 8, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 {{.*}}, i32 {{.*}})
+subroutine test_capture_swap_complex4(x, v, expr)
+  complex(4) :: x, v, expr
+  !$omp atomic capture
+  v = x
+  x = expr
+  !$omp end atomic
+end subroutine
+
+! CHECK-LABEL: define {{.*}} @test_capture_swap_complex8_
+! CHECK: call void @__atomic_load(i64 16, ptr %{{.*}}, ptr %{{.*}}, i32 {{.*}})
+! CHECK: call i1 @__atomic_compare_exchange(i64 16, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 {{.*}}, i32 {{.*}})
+subroutine test_capture_swap_complex8(x, v, expr)
+  complex(8) :: x, v, expr
+  !$omp atomic capture
+  v = x
+  x = expr
+  !$omp end atomic
+end subroutine
+
+! CHECK-LABEL: define {{.*}} @test_capture_swap_seqcst_
+! CHECK: call void @__atomic_load(i64 8, ptr %{{.*}}, ptr %{{.*}}, i32 5)
+! CHECK: call i1 @__atomic_compare_exchange(i64 8, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 {{.*}}, i32 {{.*}})
+subroutine test_capture_swap_seqcst(x, v, expr)
+  complex(4) :: x, v, expr
+  !$omp atomic seq_cst capture
+  v = x
+  x = expr
+  !$omp end atomic
+end subroutine
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d06ebbaca9f08..783178a1b6f79 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -10429,7 +10429,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicUpdate(
            "OMP Atomic expects a pointer to target memory");
     Type *XElemTy = X.ElemTy;
     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
-            XElemTy->isPointerTy()) &&
+            XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
            "OMP atomic update expected a scalar type");
     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
            (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
@@ -10489,8 +10489,7 @@ Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(
     AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
     AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr,
     bool IsIgnoreDenormalMode, bool IsFineGrainedMemory, bool IsRemoteMemory) {
-  // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
-  // or a complex datatype.
+  // TODO: handle the case where XElemTy is not byte-sized or not a power of 2.
   bool emitRMWOp = false;
   switch (RMWOp) {
   case AtomicRMWInst::Add:
@@ -10532,14 +10531,12 @@ Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(
       Res.second = Res.first;
     else
       Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
-  } else if (RMWOp == llvm::AtomicRMWInst::BinOp::BAD_BINOP &&
-             XElemTy->isStructTy()) {
+  } else if (XElemTy->isStructTy()) {
     LoadInst *OldVal =
         Builder.CreateLoad(XElemTy, X, X->getName() + ".atomic.load");
     OldVal->setAtomic(AO);
     const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
-    unsigned LoadSize =
-        LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
+    unsigned LoadSize = LoadDL.getTypeStoreSize(XElemTy);
 
     OpenMPIRBuilder::AtomicInfo atomicInfo(
         &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
@@ -10667,7 +10664,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
            "OMP Atomic expects a pointer to target memory");
     Type *XElemTy = X.ElemTy;
     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
-            XElemTy->isPointerTy()) &&
+            XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
            "OMP atomic capture expected a scalar type");
     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
            "OpenMP atomic does not support LT or GT operations");
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index c5cdecd091770..795297cd94d42 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -1766,6 +1766,28 @@ llvm.func @_QPomp_atomic_capture_complex() {
 
 // -----
 
+// CHECK-LABEL: define void @omp_atomic_capture_complex_swap
+llvm.func @omp_atomic_capture_complex_swap(%x_arg: !llvm.ptr, %v_arg: !llvm.ptr) {
+    // CHECK: %[[ATOMIC_TEMP_LOAD:.*]] = alloca { float, float }, align 8
+    // CHECK: %[[X_NEW_VAL:.*]] = alloca { float, float }, align 8
+    // CHECK: call void @__atomic_load(i64 8, ptr %{{.*}}, ptr %[[ATOMIC_TEMP_LOAD]], i32 0)
+    // CHECK: %[[PHI:.*]] = phi { float, float }
+    // CHECK: store { float, float } { float 1.000000e+00, float 1.000000e+00 }, ptr %[[X_NEW_VAL]], align 4
+    // CHECK: call i1 @__atomic_compare_exchange(i64 8, ptr %{{.*}}, ptr %[[ATOMIC_TEMP_LOAD]], ptr %[[X_NEW_VAL]], i32 2, i32 2)
+    // CHECK: store { float, float } %[[PHI]], ptr %{{.*}}, align 4
+    %0 = llvm.mlir.constant(1.000000e+00 : f32) : f32
+    %1 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
+    %2 = llvm.insertvalue %0, %1[0] : !llvm.struct<(f32, f32)>
+    %3 = llvm.insertvalue %0, %2[1] : !llvm.struct<(f32, f32)>
+    omp.atomic.capture {
+      omp.atomic.read %v_arg = %x_arg : !llvm.ptr, !llvm.ptr, !llvm.struct<(f32, f32)>
+      omp.atomic.write %x_arg = %3 : !llvm.ptr, !llvm.struct<(f32, f32)>
+    }
+    llvm.return
+}
+
+// -----
+
 // CHECK-LABEL: define void @omp_atomic_read_complex() {
 llvm.func @omp_atomic_read_complex(){
 



More information about the Mlir-commits mailing list