[Mlir-commits] [mlir] [mlir][gpu] Support outlining nested `gpu.launch` (PR #152696)

Longsheng Mou llvmlistbot at llvm.org
Fri Aug 8 05:03:18 PDT 2025


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/152696

This PR fixes a crash in `GpuKernelOutliningPass` that occurred when encountering a symbol that was not a `FlatSymbolRefAttr`, enabling outlining of nested `gpu.launch` operations. Fixes #149318.

>From 95d8bc20775ad4e16da7fa664461d3f7d2cb8122 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 8 Aug 2025 16:51:51 +0800
Subject: [PATCH] [mlir][gpu] Support outlining nested `gpu.launch`

This PR fixes a crash in `GpuKernelOutliningPass` that occurred when encountering a symbol that was not a `FlatSymbolRefAttr`, enabling outlining of nested `gpu.launch` operations.
---
 .../GPU/Transforms/KernelOutlining.cpp        |  3 +--
 mlir/test/Dialect/GPU/outlining.mlir          | 26 +++++++++++++++++++
 2 files changed, 27 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index d4978ca768747..97adad64d78c4 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -431,8 +431,7 @@ class GpuKernelOutliningPass
       if (std::optional<SymbolTable::UseRange> symbolUses =
               SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
         for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
-          StringRef symbolName =
-              cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue();
+          StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference();
           if (symbolTable.lookup(symbolName))
             continue;
 
diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir
index 04901182a80f5..e14521ab9bb5c 100644
--- a/mlir/test/Dialect/GPU/outlining.mlir
+++ b/mlir/test/Dialect/GPU/outlining.mlir
@@ -634,3 +634,29 @@ func.func @testNoAttributes() {
   }
   return
 }
+
+// -----
+
+// This test tests nested `gpu.launch`.
+
+// CHECK-LABEL: func.func @nested_launch(
+//  CHECK-SAME:                          %[[ARG0:.*]]: index) {
+//       CHECK:   gpu.launch_func  @nested_launch_kernel_0::@nested_launch_kernel blocks in (%[[ARG0]], %[[ARG0]], %[[ARG0]]) threads in (%[[ARG0]], %[[ARG0]], %[[ARG0]])  args(%[[ARG0]] : index)
+//       CHECK: gpu.module @nested_launch_kernel
+//       CHECK:   gpu.func @nested_launch_kernel() kernel
+//       CHECK:     "some_op"
+//       CHECK: gpu.module @nested_launch_kernel_0
+//       CHECK:   gpu.func @nested_launch_kernel(%[[VAL_0:.*]]: index) kernel
+//       CHECK:     gpu.launch_func  @nested_launch_kernel::@nested_launch_kernel blocks in (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]]) threads in (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]])
+func.func @nested_launch(%sz : index) {
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+             threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
+    gpu.launch blocks(%bx1, %by1, %bz1) in (%grid_x1 = %sz, %grid_y1 = %sz, %grid_z1 = %sz)
+               threads(%tx1, %ty1, %tz1) in (%block_x1 = %sz, %block_y1 = %sz, %block_z1 = %sz) {
+      "some_op"(%bx1, %tx1) : (index, index) -> ()
+      gpu.terminator
+    }
+    gpu.terminator
+  }
+  return
+}



More information about the Mlir-commits mailing list