[Mlir-commits] [mlir] 67e47fb - [mlir][gpu] Add SymbolUserOpInterface to launch_func op (#173277)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 17 04:14:50 PDT 2026
Author: lonely eagle
Date: 2026-03-17T19:14:45+08:00
New Revision: 67e47fb5317cabd82317379baf8e31d61984d43d
URL: https://github.com/llvm/llvm-project/commit/67e47fb5317cabd82317379baf8e31d61984d43d
DIFF: https://github.com/llvm/llvm-project/commit/67e47fb5317cabd82317379baf8e31d61984d43d.diff
LOG: [mlir][gpu] Add SymbolUserOpInterface to launch_func op (#173277)
The gpu.launch_func is an operation that performs symbol references.
Currently, its symbol validation logic is implemented within
GPUDialect::verifyOperationAttribute. To improve the clarity and
structure of the validation logic, this PR makes LaunchFuncOp implement
the SymbolUserOpInterface. In addition, implementing this interface
allows the operation to benefit from various symbol-usage analysis
passes.
Added:
Modified:
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/test/Dialect/GPU/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 6b0fd1ed9080e..b5a9e3413ddfd 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -618,6 +618,7 @@ def LaunchIndx : AnyTypeOf<[Index, I32, I64]>;
def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
GPU_AsyncOpInterface, AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
AllTypesMatch<["gridSizeX", "gridSizeY", "gridSizeZ", "blockSizeX",
"blockSizeY", "blockSizeZ"]>]>,
Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 2680c7311924f..5d409f71847c6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -409,83 +409,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
return op->emitError("expected '")
<< getContainerModuleAttrName() << "' attribute to be attached to '"
<< ModuleOp::getOperationName() << '\'';
-
- auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
- // Ignore launches that are nested more or less deep than functions in the
- // module we are currently checking.
- if (!launchOp->getParentOp() ||
- launchOp->getParentOp()->getParentOp() != module)
- return success();
-
- // Ignore launch ops with missing attributes here. The errors will be
- // reported by the verifiers of those ops.
- if (!launchOp->getAttrOfType<SymbolRefAttr>(
- LaunchFuncOp::getKernelAttrName(launchOp->getName())))
- return success();
-
- // Check that `launch_func` refers to a well-formed GPU kernel container.
- StringAttr kernelContainerName = launchOp.getKernelModuleName();
- Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
- if (!kernelContainer)
- return launchOp.emitOpError()
- << "kernel container '" << kernelContainerName.getValue()
- << "' is undefined";
-
- // If the container is a GPU binary op return success.
- if (isa<BinaryOp>(kernelContainer))
- return success();
-
- auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
- if (!kernelModule)
- return launchOp.emitOpError()
- << "kernel module '" << kernelContainerName.getValue()
- << "' is undefined";
-
- // Check that `launch_func` refers to a well-formed kernel function.
- Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
- if (!kernelFunc)
- return launchOp.emitOpError("kernel function '")
- << launchOp.getKernel() << "' is undefined";
- auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
- if (!kernelConvertedFunction) {
- InFlightDiagnostic diag = launchOp.emitOpError()
- << "referenced kernel '" << launchOp.getKernel()
- << "' is not a function";
- diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
- return diag;
- }
-
- if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
- GPUDialect::getKernelFuncAttrName()))
- return launchOp.emitOpError("kernel function is missing the '")
- << GPUDialect::getKernelFuncAttrName() << "' attribute";
-
- // TODO: If the kernel isn't a GPU function (which happens during separate
- // compilation), do not check type correspondence as it would require the
- // verifier to be aware of the type conversion.
- auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
- if (!kernelGPUFunction)
- return success();
-
- unsigned actualNumArguments = launchOp.getNumKernelOperands();
- unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
- if (expectedNumArguments != actualNumArguments)
- return launchOp.emitOpError("got ")
- << actualNumArguments << " kernel operands but expected "
- << expectedNumArguments;
-
- auto functionType = kernelGPUFunction.getFunctionType();
- for (unsigned i = 0; i < expectedNumArguments; ++i) {
- if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
- return launchOp.emitOpError("type of function argument ")
- << i << " does not match";
- }
- }
-
- return success();
- });
-
- return walkResult.wasInterrupted() ? failure() : success();
+ return success();
}
/// Parses an optional list of async operands with an optional leading keyword.
@@ -1397,6 +1321,90 @@ LogicalResult LaunchFuncOp::verify() {
return success();
}
+LogicalResult
+LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ LaunchFuncOp launchOp = *this;
+ Operation *table = SymbolTable::getNearestSymbolTable(launchOp);
+ // GPU modules cannot be nested within each other, escape to resolve the name.
+ if (isa<GPUModuleOp>(table))
+ table = SymbolTable::getNearestSymbolTable(table->getParentOp());
+
+ // Ignore launches that are nested more or less deep than functions in the
+ // module we are currently checking.
+ if (!launchOp->getParentOp() ||
+ launchOp->getParentOp()->getParentOp() != table)
+ return success();
+
+ // Ignore launch ops with missing attributes here. The errors will be
+ // reported by the verifiers of those ops.
+ if (!launchOp->getAttrOfType<SymbolRefAttr>(
+ LaunchFuncOp::getKernelAttrName(launchOp->getName())))
+ return success();
+
+ // Check that `launch_func` refers to a well-formed GPU kernel container.
+ StringAttr kernelContainerName = launchOp.getKernelModuleName();
+ Operation *kernelContainer =
+ symbolTable.lookupNearestSymbolFrom(table, kernelContainerName);
+ if (!kernelContainer)
+ return launchOp.emitOpError()
+ << "kernel container '" << kernelContainerName.getValue()
+ << "' is undefined";
+
+ // If the container is a GPU binary op return success.
+ if (isa<BinaryOp>(kernelContainer))
+ return success();
+
+ auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
+ if (!kernelModule)
+ return launchOp.emitOpError()
+ << "kernel module '" << kernelContainerName.getValue()
+ << "' is undefined";
+
+ // Check that `launch_func` refers to a well-formed kernel function.
+ Operation *kernelFunc = symbolTable.lookupNearestSymbolFrom(
+ kernelModule, launchOp.getKernelName());
+ if (!kernelFunc)
+ return launchOp.emitOpError("kernel function '")
+ << launchOp.getKernel() << "' is undefined";
+ auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
+ if (!kernelConvertedFunction) {
+ InFlightDiagnostic diag = launchOp.emitOpError()
+ << "referenced kernel '" << launchOp.getKernel()
+ << "' is not a function";
+ diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
+ return diag;
+ }
+
+ if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
+ GPUDialect::getKernelFuncAttrName()))
+ return launchOp.emitOpError("kernel function is missing the '")
+ << GPUDialect::getKernelFuncAttrName() << "' attribute";
+
+ // TODO: If the kernel isn't a GPU function (which happens during separate
+ // compilation), do not check type correspondence as it would require the
+ // verifier to be aware of the type conversion.
+ auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
+ if (!kernelGPUFunction)
+ return success();
+
+ unsigned actualNumArguments = launchOp.getNumKernelOperands();
+ unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
+ if (expectedNumArguments != actualNumArguments)
+ return launchOp.emitOpError("got ")
+ << actualNumArguments << " kernel operands but expected "
+ << expectedNumArguments;
+
+ FunctionType functionType = kernelGPUFunction.getFunctionType();
+ for (unsigned i = 0; i < expectedNumArguments; ++i) {
+ if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
+ return launchOp.emitOpError("type of function argument ")
+ << i << " does not match";
+ }
+ }
+
+ return success();
+}
+
static ParseResult
parseLaunchDimType(OpAsmParser &parser, Type &dimTy,
std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index f8e75cec9b7cb..bf862b2c5ae3c 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -190,15 +190,31 @@ module attributes {gpu.container_module} {
// -----
+module attributes {gpu.container_module} {
+ gpu.module @kernels_container {
+ gpu.func @kernel_1(%arg1 : !llvm.ptr) kernel {
+ gpu.return
+ }
+ }
+
+ func.func @launch_func_missing_kernel_attr(%sz : index, %arg : !llvm.ptr) {
+ // expected-error at +1 {{kernel container 'kernels' is undefined}}
+ gpu.launch_func @kernels::@kernel_1 blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) args(%arg : !llvm.ptr)
+ return
+ }
+}
+
+// -----
+
module attributes {gpu.container_module} {
module @kernels {
+ // expected-error at +1 {{'gpu.func' op expects parent op 'gpu.module'}}
gpu.func @kernel_1(%arg1 : !llvm.ptr) kernel {
gpu.return
}
}
func.func @launch_func_missing_kernel_attr(%sz : index, %arg : !llvm.ptr) {
- // expected-error at +1 {{kernel module 'kernels' is undefined}}
gpu.launch_func @kernels::@kernel_1 blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) args(%arg : !llvm.ptr)
return
}
More information about the Mlir-commits
mailing list