[Mlir-commits] [mlir] generalize pass gpu-kernel-outlining for symbol op (PR #72074)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 12 21:28:17 PST 2023

https://github.com/fengxie updated https://github.com/llvm/llvm-project/pull/72074

>From 33c60ccee1075f0602f6916baff3d5cbbe97abde Mon Sep 17 00:00:00 2001
From: Fung Xie <ftse at nvidia.com>
Date: Mon, 13 Nov 2023 09:46:49 +0800
Subject: [PATCH 1/2] generalize pass gpu-kernel-outlining for symbol op

 mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index b1e2f914db4cb9b..7432a58f18b4422 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -349,13 +349,13 @@ class GpuKernelOutliningPass
   void runOnOperation() override {
     SymbolTable symbolTable(getOperation());
     bool modified = false;
-    for (auto func : getOperation().getOps<func::FuncOp>()) {
+    for (auto func : getOperation().getOps<SymbolOpInterface>()) {
       // Insert just after the function.
       Block::iterator insertPt(func->getNextNode());
       auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
         SetVector<Value> operands;
         std::string kernelFnName =
-            Twine(op->getParentOfType<func::FuncOp>().getName(), "_kernel")
+            Twine(op->getParentOfType<SymbolOpInterface>().getName(), "_kernel")
         gpu::GPUFuncOp outlinedFunc =

>From 4f22f60741d9dd9af4277172a5bd131d6ce3877b Mon Sep 17 00:00:00 2001
From: Fung Xie <ftse at nvidia.com>
Date: Mon, 13 Nov 2023 13:27:59 +0800
Subject: [PATCH 2/2] add unit test to verify gpu.launch call from llvm.func

 mlir/test/Dialect/GPU/outlining.mlir | 37 +++++++++++++++++++++++++++-
 1 file changed, 36 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir
index 28c121a550100c2..8020f6dfa65b745 100644
--- a/mlir/test/Dialect/GPU/outlining.mlir
+++ b/mlir/test/Dialect/GPU/outlining.mlir
@@ -37,7 +37,6 @@ func.func @launch() {
 // CHECK-DL-LABEL: gpu.module @launch_kernel attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>>}
 // CHECK-LABEL: gpu.module @launch_kernel
 // CHECK-NEXT: gpu.func @launch_kernel
 // CHECK-SAME: (%[[KERNEL_ARG0:.*]]: f32, %[[KERNEL_ARG1:.*]]: memref<?xf32, 1>)
@@ -63,6 +62,42 @@ func.func @launch() {
 // -----
+// This test checks gpu-out-lining can handle gpu.launch kernel from an llvm.func
+// CHECK-LABEL: @launch_from_llvm_func
+llvm.func @launch_from_llvm_func() {
+  // CHECK: %[[ARG0:.*]] = "op"() : () -> f32
+  %0 = "op"() : () -> (f32)
+  // CHECK: %[[ARG1:.*]] = "op"() : () -> memref<?xf32, 1>
+  %1 = "op"() : () -> (memref<?xf32, 1>)
+  // CHECK: %[[DIM:.*]] = arith.constant 1
+  %dim = arith.constant 1 : index
+  // CHECK: gpu.launch_func @launch_from_llvm_func_kernel::@launch_from_llvm_func_kernel
+  // CHECK-SAME: (%[[DIM]], %[[DIM]], %[[DIM]])
+  // CHECK-SAME: (%[[DIM]], %[[DIM]], %[[DIM]]) args(%[[ARG0]] : f32, %[[ARG1]] : memref<?xf32, 1>)
+  // CHECK-NEXT: llvm.return
+  // CHECK: gpu.func {{.*}} kernel attributes
+  // CHECK-SAME: gpu.known_block_size = array<i32: 1, 1, 1>
+  // CHECK-SAME: gpu.known_grid_size = array<i32: 1, 1, 1>
+  // CHECK: gpu.return
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %dim, %grid_y = %dim,
+                                       %grid_z = %dim)
+             threads(%tx, %ty, %tz) in (%block_x = %dim, %block_y = %dim,
+                                        %block_z = %dim) {
+    "use"(%0): (f32) -> ()
+    "some_op"(%bx, %block_x) : (index, index) -> ()
+    %2 = memref.load %1[%tx] : memref<?xf32, 1>
+    gpu.terminator
+  }
+  llvm.return
+// CHECK-DL-LABLE: gpu.module @launch_from_llvm_func_kernel attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>>}
+// -----
 // CHECK: module attributes {gpu.container_module}
 // CHECK-LABEL: @multiple_launches
 func.func @multiple_launches() {

More information about the Mlir-commits mailing list