[flang-commits] [flang] [flang][cuda] Add interfaces and lowering for barrier_try_wait(_sleep) (PR #165316)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Mon Oct 27 13:58:21 PDT 2025
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/165316
As described in the programming guide: https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/#load-and-store-functions-using-bulk-tma-operations
>From 9a22f5551243055f33b8fd8dc15dc0a93dd482b2 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 23 Oct 2025 16:17:01 -0700
Subject: [PATCH] [flang][cuda] Add interfaces and lowering for
barrier_try_wait(_sleep)
---
.../flang/Optimizer/Builder/IntrinsicCall.h | 2 +
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 60 +++++++++++++++++++
flang/module/cudadevice.f90 | 27 +++++++--
flang/test/Lower/CUDA/cuda-device-proc.cuf | 22 +++++++
4 files changed, 105 insertions(+), 6 deletions(-)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index c3cd119b96174..ed0cbd3bdf16b 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -211,6 +211,8 @@ struct IntrinsicLibrary {
mlir::Value genBarrierArrive(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genBarrierArriveCnt(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
+ mlir::Value genBarrierTryWait(mlir::Type, llvm::ArrayRef<mlir::Value>);
+ mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genBesselJn(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genBesselYn(mlir::Type,
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 6b02fefb92196..645540ad22e51 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -50,6 +50,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -358,6 +359,14 @@ static constexpr IntrinsicHandler handlers[]{
&I::genBarrierInit,
{{{"barrier", asAddr}, {"count", asValue}}},
/*isElemental=*/false},
+ {"barrier_try_wait",
+ &I::genBarrierTryWait,
+ {{{"barrier", asAddr}, {"token", asValue}}},
+ /*isElemental=*/false},
+ {"barrier_try_wait_sleep",
+ &I::genBarrierTryWaitSleep,
+ {{{"barrier", asAddr}, {"token", asValue}, {"ns", asValue}}},
+ /*isElemental=*/false},
{"bessel_jn",
&I::genBesselJn,
{{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
@@ -3280,6 +3289,57 @@ void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space);
}
+// BARRIER_TRY_WAIT (CUDA)
+mlir::Value
+IntrinsicLibrary::genBarrierTryWait(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
+ mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
+ fir::StoreOp::create(builder, loc, zero, res);
+ mlir::Value ns =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 1000000);
+ mlir::Value load = fir::LoadOp::create(builder, loc, res);
+ auto whileOp = mlir::scf::WhileOp::create(
+ builder, loc, mlir::TypeRange{resultType}, mlir::ValueRange{load});
+ mlir::Block *beforeBlock = builder.createBlock(&whileOp.getBefore());
+ mlir::Value beforeArg = beforeBlock->addArgument(resultType, loc);
+ builder.setInsertionPointToStart(beforeBlock);
+ mlir::Value condition = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::ne, beforeArg, zero);
+ mlir::scf::ConditionOp::create(builder, loc, condition, beforeArg);
+ mlir::Block *afterBlock = builder.createBlock(&whileOp.getAfter());
+ mlir::Value afterArg = afterBlock->addArgument(resultType, loc);
+ builder.setInsertionPointToStart(afterBlock);
+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+ auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
+ mlir::Value ret =
+ mlir::NVVM::InlinePtxOp::create(
+ builder, loc, {resultType}, {barrier, args[1], ns}, {},
+ ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; "
+ "selp.b32 %0, 1, 0, p;",
+ {})
+ .getResult(0);
+ mlir::scf::YieldOp::create(builder, loc, ret);
+ builder.setInsertionPointAfter(whileOp);
+ return whileOp.getResult(0);
+}
+
+// BARRIER_TRY_WAIT_SLEEP (CUDA)
+mlir::Value
+IntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 3);
+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+ auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
+ return mlir::NVVM::InlinePtxOp::create(
+ builder, loc, {resultType}, {barrier, args[1], args[2]}, {},
+ ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; "
+ "selp.b32 %0, 1, 0, p;",
+ {})
+ .getResult(0);
+}
+
// BESSEL_JN
fir::ExtendedValue
IntrinsicLibrary::genBesselJn(mlir::Type resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 5182950cbffea..ea54c974c9e7c 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -1998,6 +1998,18 @@ attributes(device,host) logical function on_device() bind(c)
! TMA Operations
+ interface barrier_arrive
+ attributes(device) function barrier_arrive(barrier) result(token)
+ integer(8), shared :: barrier
+ integer(8) :: token
+ end function
+ attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
+ integer(8), shared :: barrier
+ integer(4), value :: count
+ integer(8) :: token
+ end function
+ end interface
+
interface
attributes(device) subroutine barrier_init(barrier, count)
integer(8), shared :: barrier
@@ -2005,15 +2017,18 @@ attributes(device) subroutine barrier_init(barrier, count)
end subroutine
end interface
- interface barrier_arrive
- attributes(device) function barrier_arrive(barrier) result(token)
+ interface
+ attributes(device) integer function barrier_try_wait(barrier, token)
integer(8), shared :: barrier
- integer(8) :: token
+ integer(8), value :: token
end function
- attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
+ end interface
+
+ interface
+ attributes(device) integer function barrier_try_wait_sleep(barrier, token, ns)
integer(8), shared :: barrier
- integer(4), value :: count
- integer(8) :: token
+ integer(8), value :: token
+ integer(4), value :: ns
end function
end interface
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 7d6caf58d71b3..27bca167be2ed 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -479,3 +479,25 @@ end subroutine
! CHECK-LABEL: func.func @_QPtest_bulk_s2g
! CHECL: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+
+attributes(global) subroutine test_barrier_try_wait()
+ integer :: istat
+ integer(8), shared :: barrier1
+ integer(8) :: token
+ istat = barrier_try_wait(barrier1, token)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_barrier_try_wait()
+! CHECK: scf.while
+! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %{{.*}}, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %c1000000{{.*}} : !llvm.ptr, i64, i32) -> i32
+
+attributes(global) subroutine test_barrier_try_wait_sleep()
+ integer :: istat
+ integer(8), shared :: barrier1
+ integer(8) :: token
+ integer(4) :: sleep_time
+ istat = barrier_try_wait_sleep(barrier1, token, sleep_time)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
+! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32
More information about the flang-commits
mailing list