[flang-commits] [flang] [flang][cuda] Add interfaces and lowering for barrier_arrive (PR #162949)
via flang-commits
flang-commits at lists.llvm.org
Fri Oct 10 17:22:54 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/162949.diff
4 Files Affected:
- (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+2)
- (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+44-7)
- (modified) flang/module/cudadevice.f90 (+13)
- (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+14-1)
``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 695221cbcb42c..3f4250b703a21 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -208,6 +208,8 @@ struct IntrinsicLibrary {
fir::ExtendedValue genAssociated(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genAtand(mlir::Type, llvm::ArrayRef<mlir::Value>);
+ mlir::Value genBarrierArrive(mlir::Type, llvm::ArrayRef<mlir::Value>);
+ void genBarrierArriveCnt(llvm::ArrayRef<fir::ExtendedValue>);
void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genBesselJn(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 2c21868295528..2371a617bf7a6 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -346,6 +346,14 @@ static constexpr IntrinsicHandler handlers[]{
&I::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>,
{{{"mask", asValue}, {"pred", asValue}}},
/*isElemental=*/false},
+ {"barrier_arrive",
+ &I::genBarrierArrive,
+ {{{"barrier", asAddr}}},
+ /*isElemental=*/false},
+ {"barrier_arrive_cnt",
+ &I::genBarrierArriveCnt,
+ {{{"barrier", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
{"barrier_init",
&I::genBarrierInit,
{{{"barrier", asAddr}, {"count", asValue}}},
@@ -3180,19 +3188,48 @@ IntrinsicLibrary::genAssociated(mlir::Type resultType,
return fir::runtime::genAssociated(builder, loc, pointerBox, targetBox);
}
-// BARRIER_INIT (CUDA)
-void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 2);
- auto llvmPtr = fir::ConvertOp::create(
+static mlir::Value convertBarrierToLLVM(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::Value barrier) {
+ mlir::Value llvmPtr = fir::ConvertOp::create(
builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()),
- fir::getBase(args[0]));
- auto addrCast = mlir::LLVM::AddrSpaceCastOp::create(
+ barrier);
+ mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create(
builder, loc,
mlir::LLVM::LLVMPointerType::get(
builder.getContext(),
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared)),
llvmPtr);
- mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, addrCast,
+ return addrCast;
+}
+
+// BARRIER_ARRIVE (CUDA)
+mlir::Value
+IntrinsicLibrary::genBarrierArrive(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 1);
+ mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]);
+ return mlir::NVVM::MBarrierArriveSharedOp::create(builder, loc, resultType,
+ barrier)
+ .getResult();
+}
+
+// BARRIER_ARRIBVE_CNT (CUDA)
+void IntrinsicLibrary::genBarrierArriveCnt(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+ mlir::Value barrier =
+ convertBarrierToLLVM(builder, loc, fir::getBase(args[0]));
+ mlir::NVVM::MBarrierArriveExpectTxOp::create(builder, loc, barrier,
+ fir::getBase(args[1]), {});
+}
+
+// BARRIER_INIT (CUDA)
+void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+ mlir::Value barrier =
+ convertBarrierToLLVM(builder, loc, fir::getBase(args[0]));
+ mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, barrier,
fir::getBase(args[1]), {});
}
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 4f552dcf08372..6d7b44613ecaa 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -1987,6 +1987,8 @@ attributes(device,host) logical function on_device() bind(c)
end function
end interface
+ ! TMA Operations
+
interface
attributes(device) subroutine barrier_init(barrier, count)
integer(8) :: barrier
@@ -1994,6 +1996,17 @@ attributes(device) subroutine barrier_init(barrier, count)
end subroutine
end interface
+ interface barrier_arrive
+ attributes(device) function barrier_arrive(barrier) result(token)
+ integer(8) :: barrier
+ integer(8) :: token
+ end function
+ attributes(device) subroutine barrier_arrive_cnt(barrier, count)
+ integer(8), shared :: barrier
+ integer(4) :: count
+ end subroutine
+ end interface
+
contains
attributes(device) subroutine syncthreads()
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index cdb337b115e47..5d48d188d6ff0 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -394,9 +394,14 @@ end subroutine
attributes(global) subroutine test_barrier()
integer(8), shared :: barrier
+ integer(8) :: token
+ integer :: count
call barrier_init(barrier, 256)
-end subroutine
+ token1 = barrier_arrive(barrier)
+
+ call barrier_arrive(barrier, count)
+end subroutine
! CHECK-LABEL: func.func @_QPtest_barrier()
@@ -406,3 +411,11 @@ end subroutine
! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr
! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3>
! CHECK: nvvm.mbarrier.init.shared %[[SHARED_PTR]], %[[COUNT]] : !llvm.ptr<3>, i32
+
+! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3>
+! CHECK: %{{.*}} = nvvm.mbarrier.arrive.shared %[[SHARED_PTR]] : !llvm.ptr<3> -> i64
+
+! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3>
+! CHECK: nvvm.mbarrier.arrive.expect_tx %[[SHARED_PTR]], %{{.*}} : !llvm.ptr<3>, i32
``````````
</details>
https://github.com/llvm/llvm-project/pull/162949
More information about the flang-commits
mailing list