[flang-commits] [flang] [flang][cuda] Lower dim3 grid z correctly on calls (PR #85346)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Mar 14 18:26:01 PDT 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/85346

fir.cuda_kernel_laucnh now accept z dimension for grid parameters. Update lowering to pass the correct value. 

>From a27b7d177680e15c2d59a2d561d9614649e27209 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 14 Mar 2024 18:24:51 -0700
Subject: [PATCH] [flang][cuda] Lower dim3 grid z correctly on calls

---
 flang/lib/Lower/ConvertCall.cpp             | 6 ++++--
 flang/test/Lower/CUDA/cuda-kernel-calls.cuf | 6 ++++--
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 990912195d1445..95569337a06e90 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -416,7 +416,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
     mlir::Type i32Ty = builder.getI32Type();
     mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
 
-    mlir::Value grid_x, grid_y;
+    mlir::Value grid_x, grid_y, grid_z;
     if (caller.getCallDescription().chevrons()[0].GetType()->category() ==
         Fortran::common::TypeCategory::Integer) {
       // If grid is an integer, it is converted to dim3(grid,1,1). Since z is
@@ -426,11 +426,13 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
           fir::getBase(converter.genExprValue(
               caller.getCallDescription().chevrons()[0], stmtCtx)));
       grid_y = one;
+      grid_z = one;
     } else {
       auto dim3Addr = converter.genExprAddr(
           caller.getCallDescription().chevrons()[0], stmtCtx);
       grid_x = readDim3Value(builder, loc, fir::getBase(dim3Addr), "x");
       grid_y = readDim3Value(builder, loc, fir::getBase(dim3Addr), "y");
+      grid_z = readDim3Value(builder, loc, fir::getBase(dim3Addr), "z");
     }
 
     mlir::Value block_x, block_y, block_z;
@@ -466,7 +468,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
               caller.getCallDescription().chevrons()[3], stmtCtx)));
 
     builder.create<fir::CUDAKernelLaunch>(
-        loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, one,
+        loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, grid_z,
         block_x, block_y, block_z, bytes, stream, operands);
     callNumResults = 0;
   } else if (caller.requireDispatchCall()) {
diff --git a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
index d5dabaa1df962b..55b5246e59a006 100644
--- a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
+++ b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
@@ -20,13 +20,15 @@ contains
     call dev_kernel0<<<10, 20>>>()
 ! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}>>>()
 
-    call dev_kernel0<<< __builtin_dim3(1,1), __builtin_dim3(32,1,1) >>>
+    call dev_kernel0<<< __builtin_dim3(1,1,4), __builtin_dim3(32,1,1) >>>
 ! CHECK: %[[ADDR_DIM3_GRID:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>
 ! CHECK: %[[DIM3_GRID:.*]]:2 = hlfir.declare %[[ADDR_DIM3_GRID]] {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQro._QM__fortran_builtinsT__builtin_dim3.0"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>)
 ! CHECK: %[[GRID_X:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"x"}   : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
 ! CHECK: %[[GRID_X_LOAD:.*]] = fir.load %[[GRID_X]] : !fir.ref<i32>
 ! CHECK: %[[GRID_Y:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"y"}   : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
 ! CHECK: %[[GRID_Y_LOAD:.*]] = fir.load %[[GRID_Y]] : !fir.ref<i32>
+! CHECK: %[[GRID_Z:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"z"}   : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
+! CHECK: %[[GRID_Z_LOAD:.*]] = fir.load %[[GRID_Z]] : !fir.ref<i32>
 ! CHECK: %[[ADDR_DIM3_BLOCK:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>
 ! CHECK: %[[DIM3_BLOCK:.*]]:2 = hlfir.declare %[[ADDR_DIM3_BLOCK]] {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQro._QM__fortran_builtinsT__builtin_dim3.1"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>)
 ! CHECK: %[[BLOCK_X:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"x"}   : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
@@ -35,7 +37,7 @@ contains
 ! CHECK: %[[BLOCK_Y_LOAD:.*]] = fir.load %[[BLOCK_Y]] : !fir.ref<i32>
 ! CHECK: %[[BLOCK_Z:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"z"}   : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
 ! CHECK: %[[BLOCK_Z_LOAD:.*]] = fir.load %[[BLOCK_Z]] : !fir.ref<i32>
-! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %c1{{.*}}, %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>()
+! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %[[GRID_Z_LOAD]], %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>()
 
     call dev_kernel0<<<10, 20, 2>>>()
 ! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>() 



More information about the flang-commits mailing list