[flang-commits] [flang] [flang][cuda] Use -1 for grid values when * is used (PR #115534)
via flang-commits
flang-commits at lists.llvm.org
Fri Nov 8 11:03:12 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Chevron syntax has been update to allow `*` to be used for the grid value. Make sure we set the three grid values to -1 in lowering.
---
Full diff: https://github.com/llvm/llvm-project/pull/115534.diff
2 Files Affected:
- (modified) flang/lib/Lower/ConvertCall.cpp (+9-2)
- (modified) flang/test/Lower/CUDA/cuda-kernel-calls.cuf (+4-1)
``````````diff
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 9f5b58590fb79e..eaf5a25e4390ef 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -541,8 +541,15 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
loc, i32Ty,
fir::getBase(converter.genExprValue(
caller.getCallDescription().chevrons()[0], stmtCtx)));
- grid_y = one;
- grid_z = one;
+ auto gridXValue = fir::getIntIfConstant(grid_x);
+ if (gridXValue && *gridXValue < 0) {
+ // Call using * for grid size.
+ grid_y = grid_x;
+ grid_z = grid_x;
+ } else {
+ grid_y = one;
+ grid_z = one;
+ }
} else {
auto dim3Addr = converter.genExprAddr(
caller.getCallDescription().chevrons()[0], stmtCtx);
diff --git a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
index 82d1a61f8e157c..08ec2f433f5838 100644
--- a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
+++ b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
@@ -47,7 +47,10 @@ contains
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>()
call dev_kernel1<<<1, 32>>>(a)
-! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%1#1) : (!fir.ref<f32>)
+! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}}) : (!fir.ref<f32>)
+
+ call dev_kernel1<<<*, 32>>>(a)
+! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c-1{{.*}}, %c-1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}})
end
end
``````````
</details>
https://github.com/llvm/llvm-project/pull/115534
More information about the flang-commits
mailing list