[Mlir-commits] [mlir] [mlir][gpu] Avoid kernel outlining crash on invalid symbol refs (PR #169843)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 27 09:47:02 PST 2025


https://github.com/Men-cotton created https://github.com/llvm/llvm-project/pull/169843

Fixes: https://github.com/llvm/llvm-project/issues/64639

Handle nested/invalid symbol references during GPU kernel outlining to avoid crashes and emit proper diagnostics.
Add a test to cover the new behavior.

CC: @fabianmcg 

>From 63d403b153bbc23c3451f14ba63b410fc253a559 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Fri, 28 Nov 2025 02:40:59 +0900
Subject: [PATCH] [mlir][gpu] Avoid kernel outlining crash on invalid symbol
 refs

---
 .../GPU/Transforms/KernelOutlining.cpp        | 31 ++++++++++++++-----
 .../Dialect/GPU/outlining-invalid-symbol.mlir | 29 +++++++++++++++++
 2 files changed, 53 insertions(+), 7 deletions(-)
 create mode 100644 mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir

diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 97adad64d78c4..dace8fa38dc6d 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -371,7 +371,11 @@ class GpuKernelOutliningPass
         // Create nested module and insert outlinedFunc. The module will
         // originally get the same name as the function, but may be renamed on
         // insertion into the parent module.
-        auto kernelModule = createKernelModule(op, outlinedFunc, symbolTable);
+        auto kernelModuleOrFailure =
+            createKernelModule(op, outlinedFunc, symbolTable);
+        if (failed(kernelModuleOrFailure))
+          return WalkResult::interrupt();
+        auto kernelModule = *kernelModuleOrFailure;
         symbolTable.insert(kernelModule, insertPt);
 
         // Potentially changes signature, pulling in constants.
@@ -392,9 +396,9 @@ class GpuKernelOutliningPass
 
 private:
   /// Returns a gpu.module containing kernelFunc and all callees (recursive).
-  gpu::GPUModuleOp createKernelModule(gpu::LaunchOp gpuLaunchOp,
-                                      gpu::GPUFuncOp kernelFunc,
-                                      const SymbolTable &parentSymbolTable) {
+  FailureOr<gpu::GPUModuleOp>
+  createKernelModule(gpu::LaunchOp gpuLaunchOp, gpu::GPUFuncOp kernelFunc,
+                     const SymbolTable &parentSymbolTable) {
     // TODO: This code cannot use an OpBuilder because it must be inserted into
     // a SymbolTable by the caller. SymbolTable needs to be refactored to
     // prevent manual building of Ops with symbols in code using SymbolTables
@@ -431,12 +435,25 @@ class GpuKernelOutliningPass
       if (std::optional<SymbolTable::UseRange> symbolUses =
               SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
         for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
-          StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference();
+          SymbolRefAttr symbolRef = symbolUse.getSymbolRef();
+          StringAttr symbolName = symbolRef.getLeafReference();
           if (symbolTable.lookup(symbolName))
             continue;
 
-          Operation *symbolDefClone =
-              parentSymbolTable.lookup(symbolName)->clone();
+          Operation *symbolDef =
+              SymbolTable::lookupSymbolIn(parentSymbolTable.getOp(), symbolRef);
+          if (!symbolDef) {
+            if (isa<FlatSymbolRefAttr>(symbolRef)) {
+              return symbolUse.getUser()->emitOpError(
+                  "failed to outline gpu kernel: symbol '" +
+                  symbolName.getValue() + "' not found");
+            }
+            return symbolUse.getUser()->emitOpError(
+                       "failed to outline gpu kernel: "
+                       "found invalid symbol reference: ")
+                   << symbolRef;
+          }
+          Operation *symbolDefClone = symbolDef->clone();
           symbolDefWorklist.push_back(symbolDefClone);
           symbolTable.insert(symbolDefClone);
         }
diff --git a/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir b/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir
new file mode 100644
index 0000000000000..5cd290fc396c3
--- /dev/null
+++ b/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt -gpu-kernel-outlining -verify-diagnostics -split-input-file %s
+
+module attributes {gpu.container_module} {
+  func.func @kernel_crash() {
+    %c1 = arith.constant 1 : index
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+               threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
+      // expected-error at +1 {{failed to outline gpu kernel: symbol 'unknown_func' not found}}
+      "test.op"() {symbol = @unknown_func} : () -> ()
+      gpu.terminator
+    }
+    return
+  }
+}
+
+// -----
+
+module attributes {gpu.container_module} {
+  func.func @kernel_invalid_ref() {
+    %c1 = arith.constant 1 : index
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+               threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
+      // expected-error at +1 {{failed to outline gpu kernel: found invalid symbol reference: @nested::@ref}}
+      "test.op"() {symbol = @nested::@ref} : () -> ()
+      gpu.terminator
+    }
+    return
+  }
+}



More information about the Mlir-commits mailing list