[flang-commits] [flang] [flang][cuda] Adapt TargetRewrite to support gpu.launch_func (PR #119933)

via flang-commits flang-commits at lists.llvm.org
Fri Dec 13 14:50:48 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

The gpu.func are already supported in the TargetRewrite pass. Update the pass to also support rewriting the gpu.launch_func operation. 

---
Full diff: https://github.com/llvm/llvm-project/pull/119933.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+31-8) 
- (modified) flang/test/Fir/CUDA/cuda-target-rewrite.mlir (+24) 


``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 5a042b34a58c0a..b0b9499557e2b7 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -134,10 +134,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     mod.walk([&](mlir::Operation *op) {
       if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
         if (!hasPortableSignature(call.getFunctionType(), op))
-          convertCallOp(call);
+          convertCallOp(call, call.getFunctionType());
       } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
         if (!hasPortableSignature(dispatch.getFunctionType(), op))
-          convertCallOp(dispatch);
+          convertCallOp(dispatch, dispatch.getFunctionType());
+      } else if (auto gpuLaunchFunc =
+                     mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) {
+        llvm::SmallVector<mlir::Type> operandsTypes;
+        for (auto arg : gpuLaunchFunc.getKernelOperands())
+          operandsTypes.push_back(arg.getType());
+        auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {});
+        if (!hasPortableSignature(fctTy, op))
+          convertCallOp(gpuLaunchFunc, fctTy);
       } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
         if (mlir::isa<mlir::FunctionType>(addr.getType()) &&
             !hasPortableSignature(addr.getType(), op))
@@ -357,8 +365,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 
   // Convert fir.call and fir.dispatch Ops.
   template <typename A>
-  void convertCallOp(A callOp) {
-    auto fnTy = callOp.getFunctionType();
+  void convertCallOp(A callOp, mlir::FunctionType fnTy) {
     auto loc = callOp.getLoc();
     rewriter->setInsertionPoint(callOp);
     llvm::SmallVector<mlir::Type> newResTys;
@@ -376,7 +383,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
         newOpers.push_back(callOp.getOperand(0));
         dropFront = 1;
       }
-    } else {
+    } else if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
       dropFront = 1; // First operand is the polymorphic object.
     }
 
@@ -402,10 +409,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 
     llvm::SmallVector<mlir::Type> trailingInTys;
     llvm::SmallVector<mlir::Value> trailingOpers;
+    llvm::SmallVector<mlir::Value> operands;
     unsigned passArgShift = 0;
+    if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>)
+      operands = callOp.getKernelOperands();
+    else
+      operands = callOp.getOperands().drop_front(dropFront);
     for (auto e : llvm::enumerate(
-             llvm::zip(fnTy.getInputs().drop_front(dropFront),
-                       callOp.getOperands().drop_front(dropFront)))) {
+             llvm::zip(fnTy.getInputs().drop_front(dropFront), operands))) {
       mlir::Type ty = std::get<0>(e.value());
       mlir::Value oper = std::get<1>(e.value());
       unsigned index = e.index();
@@ -507,7 +518,19 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
 
     llvm::SmallVector<mlir::Value, 1> newCallResults;
-    if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
+    if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
+      auto newCall = rewriter->create<A>(
+          loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
+          callOp.getBlockSizeOperandValues(),
+          callOp.getDynamicSharedMemorySize(), newOpers);
+      if (callOp.getClusterSizeX())
+        newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX());
+      if (callOp.getClusterSizeY())
+        newCall.getClusterSizeYMutable().assign(callOp.getClusterSizeY());
+      if (callOp.getClusterSizeZ())
+        newCall.getClusterSizeZMutable().assign(callOp.getClusterSizeZ());
+      newCallResults.append(newCall.result_begin(), newCall.result_end());
+    } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
       fir::CallOp newCall;
       if (callOp.getCallee()) {
         newCall =
diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index d88b6776795a0b..0e7534e06c89c9 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -27,3 +27,27 @@ gpu.module @testmod {
 // CHECK-LABEL: gpu.func @_QPtest
 // CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
 // CHECK: gpu.return %{{.*}} : tuple<f64, f64>
+
+
+// -----
+module attributes {gpu.container_module} {
+
+gpu.module @testmod {
+  gpu.func @_QPtest(%arg0: complex<f64>) -> () kernel {
+    gpu.return
+  }
+}
+
+func.func @main(%arg0: complex<f64>) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(0 : i32) : i32
+  gpu.launch_func  @testmod::@_QPtest blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %1 args(%arg0 : complex<f64>)
+  return
+}
+
+}
+
+// CHECK-LABEL: gpu.func @_QPtest
+// CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
+// CHECK: gpu.return
+// CHECK: gpu.launch_func  @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64)

``````````

</details>


https://github.com/llvm/llvm-project/pull/119933


More information about the flang-commits mailing list