[flang-commits] [flang] b4d3c2c - [flang][cuda] Update FIROps.td to add $grid_z to CudaKernelLaunch (#85318)

via flang-commits flang-commits at lists.llvm.org
Thu Mar 14 18:11:41 PDT 2024


Author: Iman Hosseini
Date: 2024-03-14T18:11:38-07:00
New Revision: b4d3c2cac2426070258cdb32d6932bf05e938c7d

URL: https://github.com/llvm/llvm-project/commit/b4d3c2cac2426070258cdb32d6932bf05e938c7d
DIFF: https://github.com/llvm/llvm-project/commit/b4d3c2cac2426070258cdb32d6932bf05e938c7d.diff

LOG: [flang][cuda] Update FIROps.td to add $grid_z to CudaKernelLaunch (#85318)

grid can be 3 dimensional. (@clementval)

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Lower/ConvertCall.cpp
    flang/test/Lower/CUDA/cuda-kernel-calls.cuf

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 65a86d25333b5d..f4792637f481c0 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2454,6 +2454,7 @@ def fir_CUDAKernelLaunch : fir_Op<"cuda_kernel_launch", [CallOpInterface,
     SymbolRefAttr:$callee,
     I32:$grid_x,
     I32:$grid_y,
+    I32:$grid_z,
     I32:$block_x,
     I32:$block_y,
     I32:$block_z,
@@ -2463,8 +2464,8 @@ def fir_CUDAKernelLaunch : fir_Op<"cuda_kernel_launch", [CallOpInterface,
   );
 
   let assemblyFormat = [{
-    $callee `<` `<` `<` $grid_x `,` $grid_y `,` $block_x `,` $block_y `,`
-        $block_z ( `,` $bytes^ ( `,` $stream^ )? )? `>` `>` `>`
+    $callee `<` `<` `<` $grid_x `,` $grid_y `,` $grid_z `,`$block_x `,`
+        $block_y `,` $block_z ( `,` $bytes^ ( `,` $stream^ )? )? `>` `>` `>`
         `` `(` ( $args^ `:` type($args) )? `)` attr-dict
   }];
 

diff  --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 6e3ce101ef1af9..990912195d1445 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -466,8 +466,8 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
               caller.getCallDescription().chevrons()[3], stmtCtx)));
 
     builder.create<fir::CUDAKernelLaunch>(
-        loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, block_x,
-        block_y, block_z, bytes, stream, operands);
+        loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, one,
+        block_x, block_y, block_z, bytes, stream, operands);
     callNumResults = 0;
   } else if (caller.requireDispatchCall()) {
     // Procedure call requiring a dynamic dispatch. Call is created with

diff  --git a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
index c1e89d1978e4c2..d5dabaa1df962b 100644
--- a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
+++ b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
@@ -18,7 +18,7 @@ contains
 ! CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QMtest_callFhostEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 
     call dev_kernel0<<<10, 20>>>()
-! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}>>>()
+! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}>>>()
 
     call dev_kernel0<<< __builtin_dim3(1,1), __builtin_dim3(32,1,1) >>>
 ! CHECK: %[[ADDR_DIM3_GRID:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>
@@ -35,16 +35,16 @@ contains
 ! CHECK: %[[BLOCK_Y_LOAD:.*]] = fir.load %[[BLOCK_Y]] : !fir.ref<i32>
 ! CHECK: %[[BLOCK_Z:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"z"}   : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
 ! CHECK: %[[BLOCK_Z_LOAD:.*]] = fir.load %[[BLOCK_Z]] : !fir.ref<i32>
-! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>()
+! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %c1{{.*}}, %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>()
 
     call dev_kernel0<<<10, 20, 2>>>()
-! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>() 
+! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>() 
 
     call dev_kernel0<<<10, 20, 2, 0>>>()
-! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>()
+! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>()
 
     call dev_kernel1<<<1, 32>>>(a)
-! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%1#1 : !fir.ref<f32>)
+! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%1#1 : !fir.ref<f32>)
   end
 
 end


        


More information about the flang-commits mailing list