[Mlir-commits] [mlir] [mlir][gpu] Avoid kernel outlining crash on invalid symbol refs (PR #169843)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 27 09:47:36 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Men-cotton (Men-cotton)

<details>
<summary>Changes</summary>

Fixes: https://github.com/llvm/llvm-project/issues/64639

Handle nested/invalid symbol references during GPU kernel outlining to avoid crashes and emit proper diagnostics.
Add a test to cover the new behavior.

CC: @<!-- -->fabianmcg 

---
Full diff: https://github.com/llvm/llvm-project/pull/169843.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp (+24-7) 
- (added) mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir (+29) 


``````````diff
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 97adad64d78c4..dace8fa38dc6d 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -371,7 +371,11 @@ class GpuKernelOutliningPass
         // Create nested module and insert outlinedFunc. The module will
         // originally get the same name as the function, but may be renamed on
         // insertion into the parent module.
-        auto kernelModule = createKernelModule(op, outlinedFunc, symbolTable);
+        auto kernelModuleOrFailure =
+            createKernelModule(op, outlinedFunc, symbolTable);
+        if (failed(kernelModuleOrFailure))
+          return WalkResult::interrupt();
+        auto kernelModule = *kernelModuleOrFailure;
         symbolTable.insert(kernelModule, insertPt);
 
         // Potentially changes signature, pulling in constants.
@@ -392,9 +396,9 @@ class GpuKernelOutliningPass
 
 private:
   /// Returns a gpu.module containing kernelFunc and all callees (recursive).
-  gpu::GPUModuleOp createKernelModule(gpu::LaunchOp gpuLaunchOp,
-                                      gpu::GPUFuncOp kernelFunc,
-                                      const SymbolTable &parentSymbolTable) {
+  FailureOr<gpu::GPUModuleOp>
+  createKernelModule(gpu::LaunchOp gpuLaunchOp, gpu::GPUFuncOp kernelFunc,
+                     const SymbolTable &parentSymbolTable) {
     // TODO: This code cannot use an OpBuilder because it must be inserted into
     // a SymbolTable by the caller. SymbolTable needs to be refactored to
     // prevent manual building of Ops with symbols in code using SymbolTables
@@ -431,12 +435,25 @@ class GpuKernelOutliningPass
       if (std::optional<SymbolTable::UseRange> symbolUses =
               SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
         for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
-          StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference();
+          SymbolRefAttr symbolRef = symbolUse.getSymbolRef();
+          StringAttr symbolName = symbolRef.getLeafReference();
           if (symbolTable.lookup(symbolName))
             continue;
 
-          Operation *symbolDefClone =
-              parentSymbolTable.lookup(symbolName)->clone();
+          Operation *symbolDef =
+              SymbolTable::lookupSymbolIn(parentSymbolTable.getOp(), symbolRef);
+          if (!symbolDef) {
+            if (isa<FlatSymbolRefAttr>(symbolRef)) {
+              return symbolUse.getUser()->emitOpError(
+                  "failed to outline gpu kernel: symbol '" +
+                  symbolName.getValue() + "' not found");
+            }
+            return symbolUse.getUser()->emitOpError(
+                       "failed to outline gpu kernel: "
+                       "found invalid symbol reference: ")
+                   << symbolRef;
+          }
+          Operation *symbolDefClone = symbolDef->clone();
           symbolDefWorklist.push_back(symbolDefClone);
           symbolTable.insert(symbolDefClone);
         }
diff --git a/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir b/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir
new file mode 100644
index 0000000000000..5cd290fc396c3
--- /dev/null
+++ b/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt -gpu-kernel-outlining -verify-diagnostics -split-input-file %s
+
+module attributes {gpu.container_module} {
+  func.func @kernel_crash() {
+    %c1 = arith.constant 1 : index
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+               threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
+      // expected-error at +1 {{failed to outline gpu kernel: symbol 'unknown_func' not found}}
+      "test.op"() {symbol = @unknown_func} : () -> ()
+      gpu.terminator
+    }
+    return
+  }
+}
+
+// -----
+
+module attributes {gpu.container_module} {
+  func.func @kernel_invalid_ref() {
+    %c1 = arith.constant 1 : index
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+               threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
+      // expected-error at +1 {{failed to outline gpu kernel: found invalid symbol reference: @nested::@ref}}
+      "test.op"() {symbol = @nested::@ref} : () -> ()
+      gpu.terminator
+    }
+    return
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/169843


More information about the Mlir-commits mailing list