[flang-commits] [flang] [flang][cuda] Use -1 for grid values when * is used (PR #115534)

via flang-commits flang-commits at lists.llvm.org
Fri Nov 8 11:03:12 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Chevron syntax has been update to allow `*` to be used for the grid value. Make sure we set the three grid values to -1 in lowering. 

---
Full diff: https://github.com/llvm/llvm-project/pull/115534.diff


2 Files Affected:

- (modified) flang/lib/Lower/ConvertCall.cpp (+9-2) 
- (modified) flang/test/Lower/CUDA/cuda-kernel-calls.cuf (+4-1) 


``````````diff
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 9f5b58590fb79e..eaf5a25e4390ef 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -541,8 +541,15 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
           loc, i32Ty,
           fir::getBase(converter.genExprValue(
               caller.getCallDescription().chevrons()[0], stmtCtx)));
-      grid_y = one;
-      grid_z = one;
+      auto gridXValue = fir::getIntIfConstant(grid_x);
+      if (gridXValue && *gridXValue < 0) {
+        // Call using * for grid size.
+        grid_y = grid_x;
+        grid_z = grid_x;
+      } else {
+        grid_y = one;
+        grid_z = one;
+      }
     } else {
       auto dim3Addr = converter.genExprAddr(
           caller.getCallDescription().chevrons()[0], stmtCtx);
diff --git a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
index 82d1a61f8e157c..08ec2f433f5838 100644
--- a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
+++ b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
@@ -47,7 +47,10 @@ contains
 ! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>()
 
     call dev_kernel1<<<1, 32>>>(a)
-! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%1#1) : (!fir.ref<f32>)
+! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}}) : (!fir.ref<f32>)
+
+    call dev_kernel1<<<*, 32>>>(a)
+! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c-1{{.*}}, %c-1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}})
   end
 
 end

``````````

</details>


https://github.com/llvm/llvm-project/pull/115534


More information about the flang-commits mailing list