[flang-commits] [flang] [flang][cuda] Adapt TargetRewrite to support gpu.launch_func (PR #119933)
via flang-commits
flang-commits at lists.llvm.org
Fri Dec 13 14:50:48 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
The gpu.func are already supported in the TargetRewrite pass. Update the pass to also support rewriting the gpu.launch_func operation.
---
Full diff: https://github.com/llvm/llvm-project/pull/119933.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+31-8)
- (modified) flang/test/Fir/CUDA/cuda-target-rewrite.mlir (+24)
``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 5a042b34a58c0a..b0b9499557e2b7 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -134,10 +134,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
mod.walk([&](mlir::Operation *op) {
if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
if (!hasPortableSignature(call.getFunctionType(), op))
- convertCallOp(call);
+ convertCallOp(call, call.getFunctionType());
} else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
if (!hasPortableSignature(dispatch.getFunctionType(), op))
- convertCallOp(dispatch);
+ convertCallOp(dispatch, dispatch.getFunctionType());
+ } else if (auto gpuLaunchFunc =
+ mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) {
+ llvm::SmallVector<mlir::Type> operandsTypes;
+ for (auto arg : gpuLaunchFunc.getKernelOperands())
+ operandsTypes.push_back(arg.getType());
+ auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {});
+ if (!hasPortableSignature(fctTy, op))
+ convertCallOp(gpuLaunchFunc, fctTy);
} else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
if (mlir::isa<mlir::FunctionType>(addr.getType()) &&
!hasPortableSignature(addr.getType(), op))
@@ -357,8 +365,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// Convert fir.call and fir.dispatch Ops.
template <typename A>
- void convertCallOp(A callOp) {
- auto fnTy = callOp.getFunctionType();
+ void convertCallOp(A callOp, mlir::FunctionType fnTy) {
auto loc = callOp.getLoc();
rewriter->setInsertionPoint(callOp);
llvm::SmallVector<mlir::Type> newResTys;
@@ -376,7 +383,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newOpers.push_back(callOp.getOperand(0));
dropFront = 1;
}
- } else {
+ } else if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
dropFront = 1; // First operand is the polymorphic object.
}
@@ -402,10 +409,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
llvm::SmallVector<mlir::Type> trailingInTys;
llvm::SmallVector<mlir::Value> trailingOpers;
+ llvm::SmallVector<mlir::Value> operands;
unsigned passArgShift = 0;
+ if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>)
+ operands = callOp.getKernelOperands();
+ else
+ operands = callOp.getOperands().drop_front(dropFront);
for (auto e : llvm::enumerate(
- llvm::zip(fnTy.getInputs().drop_front(dropFront),
- callOp.getOperands().drop_front(dropFront)))) {
+ llvm::zip(fnTy.getInputs().drop_front(dropFront), operands))) {
mlir::Type ty = std::get<0>(e.value());
mlir::Value oper = std::get<1>(e.value());
unsigned index = e.index();
@@ -507,7 +518,19 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
llvm::SmallVector<mlir::Value, 1> newCallResults;
- if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
+ if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
+ auto newCall = rewriter->create<A>(
+ loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
+ callOp.getBlockSizeOperandValues(),
+ callOp.getDynamicSharedMemorySize(), newOpers);
+ if (callOp.getClusterSizeX())
+ newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX());
+ if (callOp.getClusterSizeY())
+ newCall.getClusterSizeYMutable().assign(callOp.getClusterSizeY());
+ if (callOp.getClusterSizeZ())
+ newCall.getClusterSizeZMutable().assign(callOp.getClusterSizeZ());
+ newCallResults.append(newCall.result_begin(), newCall.result_end());
+ } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
fir::CallOp newCall;
if (callOp.getCallee()) {
newCall =
diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index d88b6776795a0b..0e7534e06c89c9 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -27,3 +27,27 @@ gpu.module @testmod {
// CHECK-LABEL: gpu.func @_QPtest
// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
// CHECK: gpu.return %{{.*}} : tuple<f64, f64>
+
+
+// -----
+module attributes {gpu.container_module} {
+
+gpu.module @testmod {
+ gpu.func @_QPtest(%arg0: complex<f64>) -> () kernel {
+ gpu.return
+ }
+}
+
+func.func @main(%arg0: complex<f64>) {
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %1 = llvm.mlir.constant(0 : i32) : i32
+ gpu.launch_func @testmod::@_QPtest blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %1 args(%arg0 : complex<f64>)
+ return
+}
+
+}
+
+// CHECK-LABEL: gpu.func @_QPtest
+// CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
+// CHECK: gpu.return
+// CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64)
``````````
</details>
https://github.com/llvm/llvm-project/pull/119933
More information about the flang-commits
mailing list