[flang-commits] [flang] [flang][cuda] Add support for f16 atomicadd (PR #166229)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Mon Nov 3 12:19:15 PST 2025
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/166229
The values needs to be loaded and repacked into a vector<f16> to be passed to the atomic add operation.
>From 94622dc503a5cd5e0eead8314522bad402ecfca6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 3 Nov 2025 12:01:22 -0800
Subject: [PATCH] [flang][cuda] Add support for f16 atomicadd
---
.../flang/Optimizer/Builder/IntrinsicCall.h | 2 +
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 44 ++++++++++++++++++-
flang/module/cudadevice.f90 | 5 +++
flang/test/Lower/CUDA/cuda-device-proc.cuf | 4 ++
4 files changed, 54 insertions(+), 1 deletion(-)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 3407dd01dd504..9f15ce68eb3d5 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -188,6 +188,8 @@ struct IntrinsicLibrary {
fir::ExtendedValue genAny(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genAtanpi(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genAtomicAdd(mlir::Type, llvm::ArrayRef<mlir::Value>);
+ fir::ExtendedValue genAtomicAddR2(mlir::Type,
+ llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genAtomicAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genAtomicCas(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 15ea84565dd75..79996b82df5be 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -294,6 +294,10 @@ static constexpr IntrinsicHandler handlers[]{
{"atomicaddf", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
{"atomicaddi", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
{"atomicaddl", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
+ {"atomicaddr2",
+ &I::genAtomicAddR2,
+ {{{"a", asAddr}, {"v", asAddr}}},
+ false},
{"atomicandi", &I::genAtomicAnd, {{{"a", asAddr}, {"v", asValue}}}, false},
{"atomiccasd",
&I::genAtomicCas,
@@ -3119,7 +3123,6 @@ static mlir::Value genAtomBinOp(fir::FirOpBuilder &builder, mlir::Location &loc,
mlir::Value IntrinsicLibrary::genAtomicAdd(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 2);
-
mlir::LLVM::AtomicBinOp binOp =
mlir::isa<mlir::IntegerType>(args[1].getType())
? mlir::LLVM::AtomicBinOp::add
@@ -3127,6 +3130,45 @@ mlir::Value IntrinsicLibrary::genAtomicAdd(mlir::Type resultType,
return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
}
+fir::ExtendedValue
+IntrinsicLibrary::genAtomicAddR2(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+
+ mlir::Value a = fir::getBase(args[0]);
+
+ if (mlir::isa<fir::BaseBoxType>(a.getType())) {
+ a = fir::BoxAddrOp::create(builder, loc, a);
+ }
+
+ auto loc = builder.getUnknownLoc();
+ auto f16Ty = builder.getF16Type();
+ auto i32Ty = builder.getI32Type();
+ auto vecF16Ty = mlir::VectorType::get({2}, f16Ty);
+ mlir::Type idxTy = builder.getIndexType();
+ auto f16RefTy = fir::ReferenceType::get(f16Ty);
+ auto zero = builder.createIntegerConstant(loc, idxTy, 0);
+ auto one = builder.createIntegerConstant(loc, idxTy, 1);
+ auto v1Coord = fir::CoordinateOp::create(builder, loc, f16RefTy,
+ fir::getBase(args[1]), zero);
+ auto v2Coord = fir::CoordinateOp::create(builder, loc, f16RefTy,
+ fir::getBase(args[1]), one);
+ auto v1 = fir::LoadOp::create(builder, loc, v1Coord);
+ auto v2 = fir::LoadOp::create(builder, loc, v2Coord);
+ mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecF16Ty);
+ mlir::Value vec1 = mlir::LLVM::InsertElementOp::create(
+ builder, loc, undef, v1, builder.createIntegerConstant(loc, i32Ty, 0));
+ mlir::Value vec2 = mlir::LLVM::InsertElementOp::create(
+ builder, loc, vec1, v2, builder.createIntegerConstant(loc, i32Ty, 1));
+ auto res = genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::fadd,
+ a, vec2);
+ auto i32VecTy = mlir::VectorType::get({1}, i32Ty);
+ mlir::Value vecI32 =
+ mlir::vector::BitCastOp::create(builder, loc, i32VecTy, res);
+ return mlir::vector::ExtractOp::create(builder, loc, vecI32,
+ mlir::ArrayRef<int64_t>{0});
+}
+
mlir::Value IntrinsicLibrary::genAtomicSub(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 2);
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 59af58ddcd32e..7a764b589dc56 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -1171,6 +1171,11 @@ attributes(device) pure integer(8) function atomicaddl(address, val)
integer(8), intent(inout) :: address
integer(8), value :: val
end function
+ attributes(device) pure integer(4) function atomicaddr2(address, val)
+ !dir$ ignore_tkr (rd) address, (d) val
+ real(2), dimension(2), intent(inout) :: address
+ real(2), dimension(2), intent(in) :: val
+ end function
end interface
interface atomicsub
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 09b4302446ee7..674548b7489e8 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -14,6 +14,8 @@ attributes(global) subroutine devsub()
integer :: smalltime
integer(4) :: res, offset
integer(8) :: resl
+ real(2) :: r2a(2)
+ real(2) :: tmp2(2)
integer :: tid
tid = threadIdx%x
@@ -34,6 +36,7 @@ attributes(global) subroutine devsub()
al = atomicadd(al, 1_8)
af = atomicadd(af, 1.0_4)
ad = atomicadd(ad, 1.0_8)
+ ai = atomicadd(r2a, tmp2)
ai = atomicsub(ai, 1_4)
al = atomicsub(al, 1_8)
@@ -128,6 +131,7 @@ end
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f32
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f64
+! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, vector<2xf16>
! CHECK: %{{.*}} = llvm.atomicrmw sub %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
! CHECK: %{{.*}} = llvm.atomicrmw sub %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64
More information about the flang-commits
mailing list