[flang-commits] [flang] 69ccb13 - [flang][cuda] Make argument passed by value for sync functions (#125909)

via flang-commits flang-commits at lists.llvm.org
Wed Feb 5 13:47:56 PST 2025


Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-02-05T13:47:53-08:00
New Revision: 69ccb1357fa6cf72063c737d06d6b29ffc465bee

URL: https://github.com/llvm/llvm-project/commit/69ccb1357fa6cf72063c737d06d6b29ffc465bee
DIFF: https://github.com/llvm/llvm-project/commit/69ccb1357fa6cf72063c737d06d6b29ffc465bee.diff

LOG: [flang][cuda] Make argument passed by value for sync functions (#125909)

`syncthreads_and`, `syncthreads_count`, `syncthreads_or`, `synwrap` must
take their argument by value. This patch updates the interfaces and
makes sure these functions can be called inside a cuff kernel as well.

Added: 
    

Modified: 
    flang/module/cudadevice.f90
    flang/test/Lower/CUDA/cuda-device-proc.cuf

Removed: 
    


################################################################################
diff  --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 00e8b3db73ad87..1fe99b30b1db08 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -29,28 +29,28 @@ attributes(device) subroutine syncthreads()
 
   interface
     attributes(device) integer function syncthreads_and(value)
-      integer :: value
+      integer, value :: value
     end function
   end interface
   public :: syncthreads_and
 
   interface
     attributes(device) integer function syncthreads_count(value)
-      integer :: value
+      integer, value :: value
     end function
   end interface
   public :: syncthreads_count
 
   interface
     attributes(device) integer function syncthreads_or(value)
-      integer :: value
+      integer, value :: value
     end function
   end interface
   public :: syncthreads_or
 
   interface
     attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
-      integer :: mask
+      integer, value :: mask
     end subroutine
   end interface
   public :: syncwarp

diff  --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 5805dd5010a842..ec825263474c1e 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -47,7 +47,7 @@ end
 
 ! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
 ! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
-! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
+! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
@@ -79,17 +79,9 @@ end
 ! CHECK: %{{.*}} = llvm.atomicrmw uinc_wrap  %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
 ! CHECK: %{{.*}} = llvm.atomicrmw udec_wrap  %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
 
-! CHECK: func.func private @llvm.nvvm.barrier0()
-! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
-! CHECK: func.func private @llvm.nvvm.membar.gl()
-! CHECK: func.func private @llvm.nvvm.membar.cta()
-! CHECK: func.func private @llvm.nvvm.membar.sys()
-! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
-
 subroutine host1()
   integer, device :: a(32)
+  integer, device :: ret
   integer :: i, j
 
 block; use cudadevice
@@ -98,6 +90,28 @@ block; use cudadevice
     a(i) = a(i) * 2.0
     call syncthreads()
     a(i) = a(i) + a(j) - 34.0
+
+    call syncwarp(1)
+    ret = syncthreads_and(1)
+    ret = syncthreads_count(1)
+    ret = syncthreads_or(1)
   end do
 end block
 end 
+
+! CHECK-LABEL: func.func @_QPhost1()
+! CHECK: cuf.kernel
+! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
+! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
+! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+
+! CHECK: func.func private @llvm.nvvm.barrier0()
+! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
+! CHECK: func.func private @llvm.nvvm.membar.gl()
+! CHECK: func.func private @llvm.nvvm.membar.cta()
+! CHECK: func.func private @llvm.nvvm.membar.sys()
+! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
+! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
+! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32


        


More information about the flang-commits mailing list