[flang-commits] [flang] [flang][cuda] Make argument passed by value for sync functions (PR #125909)
via flang-commits
flang-commits at lists.llvm.org
Wed Feb 5 10:49:05 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
`syncthreads_and`, `syncthreads_count`, `syncthreads_or`, `synwrap` must take their argument by value. This patch updates the interfaces and makes sure these functions can be called inside a cuff kernel as well.
---
Full diff: https://github.com/llvm/llvm-project/pull/125909.diff
2 Files Affected:
- (modified) flang/module/cudadevice.f90 (+4-4)
- (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+24-10)
``````````diff
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 00e8b3db73ad87..1fe99b30b1db08 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -29,28 +29,28 @@ attributes(device) subroutine syncthreads()
interface
attributes(device) integer function syncthreads_and(value)
- integer :: value
+ integer, value :: value
end function
end interface
public :: syncthreads_and
interface
attributes(device) integer function syncthreads_count(value)
- integer :: value
+ integer, value :: value
end function
end interface
public :: syncthreads_count
interface
attributes(device) integer function syncthreads_or(value)
- integer :: value
+ integer, value :: value
end function
end interface
public :: syncthreads_or
interface
attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
- integer :: mask
+ integer, value :: mask
end subroutine
end interface
public :: syncwarp
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 5805dd5010a842..ec825263474c1e 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -47,7 +47,7 @@ end
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
-! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
+! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
@@ -79,17 +79,9 @@ end
! CHECK: %{{.*}} = llvm.atomicrmw uinc_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
! CHECK: %{{.*}} = llvm.atomicrmw udec_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
-! CHECK: func.func private @llvm.nvvm.barrier0()
-! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
-! CHECK: func.func private @llvm.nvvm.membar.gl()
-! CHECK: func.func private @llvm.nvvm.membar.cta()
-! CHECK: func.func private @llvm.nvvm.membar.sys()
-! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
-
subroutine host1()
integer, device :: a(32)
+ integer, device :: ret
integer :: i, j
block; use cudadevice
@@ -98,6 +90,28 @@ block; use cudadevice
a(i) = a(i) * 2.0
call syncthreads()
a(i) = a(i) + a(j) - 34.0
+
+ call syncwarp(1)
+ ret = syncthreads_and(1)
+ ret = syncthreads_count(1)
+ ret = syncthreads_or(1)
end do
end block
end
+
+! CHECK-LABEL: func.func @_QPhost1()
+! CHECK: cuf.kernel
+! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
+! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
+! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+
+! CHECK: func.func private @llvm.nvvm.barrier0()
+! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
+! CHECK: func.func private @llvm.nvvm.membar.gl()
+! CHECK: func.func private @llvm.nvvm.membar.cta()
+! CHECK: func.func private @llvm.nvvm.membar.sys()
+! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
+! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
+! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
``````````
</details>
https://github.com/llvm/llvm-project/pull/125909
More information about the flang-commits
mailing list