[Mlir-commits] [mlir] 67e47fb - [mlir][gpu] Add SymbolUserOpInterface to launch_func op (#173277)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 17 04:14:50 PDT 2026


Author: lonely eagle
Date: 2026-03-17T19:14:45+08:00
New Revision: 67e47fb5317cabd82317379baf8e31d61984d43d

URL: https://github.com/llvm/llvm-project/commit/67e47fb5317cabd82317379baf8e31d61984d43d
DIFF: https://github.com/llvm/llvm-project/commit/67e47fb5317cabd82317379baf8e31d61984d43d.diff

LOG: [mlir][gpu] Add SymbolUserOpInterface to launch_func op (#173277)

The gpu.launch_func is an operation that performs symbol references.
Currently, its symbol validation logic is implemented within
GPUDialect::verifyOperationAttribute. To improve the clarity and
structure of the validation logic, this PR makes LaunchFuncOp implement
the SymbolUserOpInterface. In addition, implementing this interface
allows the operation to benefit from various symbol-usage analysis
passes.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/test/Dialect/GPU/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 6b0fd1ed9080e..b5a9e3413ddfd 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -618,6 +618,7 @@ def LaunchIndx : AnyTypeOf<[Index, I32, I64]>;
 
 def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
       GPU_AsyncOpInterface, AttrSizedOperandSegments,
+      DeclareOpInterfaceMethods<SymbolUserOpInterface>,
       AllTypesMatch<["gridSizeX", "gridSizeY", "gridSizeZ", "blockSizeX",
                      "blockSizeY", "blockSizeZ"]>]>,
     Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 2680c7311924f..5d409f71847c6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -409,83 +409,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
     return op->emitError("expected '")
            << getContainerModuleAttrName() << "' attribute to be attached to '"
            << ModuleOp::getOperationName() << '\'';
-
-  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
-    // Ignore launches that are nested more or less deep than functions in the
-    // module we are currently checking.
-    if (!launchOp->getParentOp() ||
-        launchOp->getParentOp()->getParentOp() != module)
-      return success();
-
-    // Ignore launch ops with missing attributes here. The errors will be
-    // reported by the verifiers of those ops.
-    if (!launchOp->getAttrOfType<SymbolRefAttr>(
-            LaunchFuncOp::getKernelAttrName(launchOp->getName())))
-      return success();
-
-    // Check that `launch_func` refers to a well-formed GPU kernel container.
-    StringAttr kernelContainerName = launchOp.getKernelModuleName();
-    Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
-    if (!kernelContainer)
-      return launchOp.emitOpError()
-             << "kernel container '" << kernelContainerName.getValue()
-             << "' is undefined";
-
-    // If the container is a GPU binary op return success.
-    if (isa<BinaryOp>(kernelContainer))
-      return success();
-
-    auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
-    if (!kernelModule)
-      return launchOp.emitOpError()
-             << "kernel module '" << kernelContainerName.getValue()
-             << "' is undefined";
-
-    // Check that `launch_func` refers to a well-formed kernel function.
-    Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
-    if (!kernelFunc)
-      return launchOp.emitOpError("kernel function '")
-             << launchOp.getKernel() << "' is undefined";
-    auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
-    if (!kernelConvertedFunction) {
-      InFlightDiagnostic diag = launchOp.emitOpError()
-                                << "referenced kernel '" << launchOp.getKernel()
-                                << "' is not a function";
-      diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
-      return diag;
-    }
-
-    if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
-            GPUDialect::getKernelFuncAttrName()))
-      return launchOp.emitOpError("kernel function is missing the '")
-             << GPUDialect::getKernelFuncAttrName() << "' attribute";
-
-    // TODO: If the kernel isn't a GPU function (which happens during separate
-    // compilation), do not check type correspondence as it would require the
-    // verifier to be aware of the type conversion.
-    auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
-    if (!kernelGPUFunction)
-      return success();
-
-    unsigned actualNumArguments = launchOp.getNumKernelOperands();
-    unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
-    if (expectedNumArguments != actualNumArguments)
-      return launchOp.emitOpError("got ")
-             << actualNumArguments << " kernel operands but expected "
-             << expectedNumArguments;
-
-    auto functionType = kernelGPUFunction.getFunctionType();
-    for (unsigned i = 0; i < expectedNumArguments; ++i) {
-      if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
-        return launchOp.emitOpError("type of function argument ")
-               << i << " does not match";
-      }
-    }
-
-    return success();
-  });
-
-  return walkResult.wasInterrupted() ? failure() : success();
+  return success();
 }
 
 /// Parses an optional list of async operands with an optional leading keyword.
@@ -1397,6 +1321,90 @@ LogicalResult LaunchFuncOp::verify() {
   return success();
 }
 
+LogicalResult
+LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  LaunchFuncOp launchOp = *this;
+  Operation *table = SymbolTable::getNearestSymbolTable(launchOp);
+  // GPU modules cannot be nested within each other, escape to resolve the name.
+  if (isa<GPUModuleOp>(table))
+    table = SymbolTable::getNearestSymbolTable(table->getParentOp());
+
+  // Ignore launches that are nested more or less deep than functions in the
+  // module we are currently checking.
+  if (!launchOp->getParentOp() ||
+      launchOp->getParentOp()->getParentOp() != table)
+    return success();
+
+  // Ignore launch ops with missing attributes here. The errors will be
+  // reported by the verifiers of those ops.
+  if (!launchOp->getAttrOfType<SymbolRefAttr>(
+          LaunchFuncOp::getKernelAttrName(launchOp->getName())))
+    return success();
+
+  // Check that `launch_func` refers to a well-formed GPU kernel container.
+  StringAttr kernelContainerName = launchOp.getKernelModuleName();
+  Operation *kernelContainer =
+      symbolTable.lookupNearestSymbolFrom(table, kernelContainerName);
+  if (!kernelContainer)
+    return launchOp.emitOpError()
+           << "kernel container '" << kernelContainerName.getValue()
+           << "' is undefined";
+
+  // If the container is a GPU binary op return success.
+  if (isa<BinaryOp>(kernelContainer))
+    return success();
+
+  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
+  if (!kernelModule)
+    return launchOp.emitOpError()
+           << "kernel module '" << kernelContainerName.getValue()
+           << "' is undefined";
+
+  // Check that `launch_func` refers to a well-formed kernel function.
+  Operation *kernelFunc = symbolTable.lookupNearestSymbolFrom(
+      kernelModule, launchOp.getKernelName());
+  if (!kernelFunc)
+    return launchOp.emitOpError("kernel function '")
+           << launchOp.getKernel() << "' is undefined";
+  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
+  if (!kernelConvertedFunction) {
+    InFlightDiagnostic diag = launchOp.emitOpError()
+                              << "referenced kernel '" << launchOp.getKernel()
+                              << "' is not a function";
+    diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
+    return diag;
+  }
+
+  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
+          GPUDialect::getKernelFuncAttrName()))
+    return launchOp.emitOpError("kernel function is missing the '")
+           << GPUDialect::getKernelFuncAttrName() << "' attribute";
+
+  // TODO: If the kernel isn't a GPU function (which happens during separate
+  // compilation), do not check type correspondence as it would require the
+  // verifier to be aware of the type conversion.
+  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
+  if (!kernelGPUFunction)
+    return success();
+
+  unsigned actualNumArguments = launchOp.getNumKernelOperands();
+  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
+  if (expectedNumArguments != actualNumArguments)
+    return launchOp.emitOpError("got ")
+           << actualNumArguments << " kernel operands but expected "
+           << expectedNumArguments;
+
+  FunctionType functionType = kernelGPUFunction.getFunctionType();
+  for (unsigned i = 0; i < expectedNumArguments; ++i) {
+    if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
+      return launchOp.emitOpError("type of function argument ")
+             << i << " does not match";
+    }
+  }
+
+  return success();
+}
+
 static ParseResult
 parseLaunchDimType(OpAsmParser &parser, Type &dimTy,
                    std::optional<OpAsmParser::UnresolvedOperand> clusterValue,

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index f8e75cec9b7cb..bf862b2c5ae3c 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -190,15 +190,31 @@ module attributes {gpu.container_module} {
 
 // -----
 
+module attributes {gpu.container_module} {
+  gpu.module @kernels_container {
+    gpu.func @kernel_1(%arg1 : !llvm.ptr) kernel {
+      gpu.return
+    }
+  }
+
+  func.func @launch_func_missing_kernel_attr(%sz : index, %arg : !llvm.ptr) {
+    // expected-error at +1 {{kernel container 'kernels' is undefined}}
+    gpu.launch_func @kernels::@kernel_1 blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) args(%arg : !llvm.ptr)
+    return
+  }
+}
+
+// -----
+
 module attributes {gpu.container_module} {
   module @kernels {
+    // expected-error at +1 {{'gpu.func' op expects parent op 'gpu.module'}}
     gpu.func @kernel_1(%arg1 : !llvm.ptr) kernel {
       gpu.return
     }
   }
 
   func.func @launch_func_missing_kernel_attr(%sz : index, %arg : !llvm.ptr) {
-    // expected-error at +1 {{kernel module 'kernels' is undefined}}
     gpu.launch_func @kernels::@kernel_1 blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) args(%arg : !llvm.ptr)
     return
   }


        


More information about the Mlir-commits mailing list