[Mlir-commits] [mlir] [mlir][gpu] Support outlining nested `gpu.launch` (PR #152696)
Longsheng Mou
llvmlistbot at llvm.org
Fri Aug 8 05:03:18 PDT 2025
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/152696
This PR fixes a crash in `GpuKernelOutliningPass` that occurred when encountering a symbol that was not a `FlatSymbolRefAttr`, enabling outlining of nested `gpu.launch` operations. Fixes #149318.
>From 95d8bc20775ad4e16da7fa664461d3f7d2cb8122 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 8 Aug 2025 16:51:51 +0800
Subject: [PATCH] [mlir][gpu] Support outlining nested `gpu.launch`
This PR fixes a crash in `GpuKernelOutliningPass` that occurred when encountering a symbol that was not a `FlatSymbolRefAttr`, enabling outlining of nested `gpu.launch` operations.
---
.../GPU/Transforms/KernelOutlining.cpp | 3 +--
mlir/test/Dialect/GPU/outlining.mlir | 26 +++++++++++++++++++
2 files changed, 27 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index d4978ca768747..97adad64d78c4 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -431,8 +431,7 @@ class GpuKernelOutliningPass
if (std::optional<SymbolTable::UseRange> symbolUses =
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
- StringRef symbolName =
- cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue();
+ StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference();
if (symbolTable.lookup(symbolName))
continue;
diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir
index 04901182a80f5..e14521ab9bb5c 100644
--- a/mlir/test/Dialect/GPU/outlining.mlir
+++ b/mlir/test/Dialect/GPU/outlining.mlir
@@ -634,3 +634,29 @@ func.func @testNoAttributes() {
}
return
}
+
+// -----
+
+// This test tests nested `gpu.launch`.
+
+// CHECK-LABEL: func.func @nested_launch(
+// CHECK-SAME: %[[ARG0:.*]]: index) {
+// CHECK: gpu.launch_func @nested_launch_kernel_0::@nested_launch_kernel blocks in (%[[ARG0]], %[[ARG0]], %[[ARG0]]) threads in (%[[ARG0]], %[[ARG0]], %[[ARG0]]) args(%[[ARG0]] : index)
+// CHECK: gpu.module @nested_launch_kernel
+// CHECK: gpu.func @nested_launch_kernel() kernel
+// CHECK: "some_op"
+// CHECK: gpu.module @nested_launch_kernel_0
+// CHECK: gpu.func @nested_launch_kernel(%[[VAL_0:.*]]: index) kernel
+// CHECK: gpu.launch_func @nested_launch_kernel::@nested_launch_kernel blocks in (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]]) threads in (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]])
+func.func @nested_launch(%sz : index) {
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+ threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
+ gpu.launch blocks(%bx1, %by1, %bz1) in (%grid_x1 = %sz, %grid_y1 = %sz, %grid_z1 = %sz)
+ threads(%tx1, %ty1, %tz1) in (%block_x1 = %sz, %block_y1 = %sz, %block_z1 = %sz) {
+ "some_op"(%bx1, %tx1) : (index, index) -> ()
+ gpu.terminator
+ }
+ gpu.terminator
+ }
+ return
+}
More information about the Mlir-commits
mailing list