[Mlir-commits] [mlir] [mlir][gpu] Make launch_func op use SymbolUserOpInterface (PR #173277)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 22 08:10:53 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu
Author: lonely eagle (linuxlonelyeagle)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/173277.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+1)
- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+81-77)
- (modified) mlir/test/Dialect/GPU/invalid.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 5c7df25c58cde..0215adf7b8105 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -606,6 +606,7 @@ def LaunchIndx : AnyTypeOf<[Index, I32, I64]>;
def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
GPU_AsyncOpInterface, AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
AllTypesMatch<["gridSizeX", "gridSizeY", "gridSizeZ", "blockSizeX",
"blockSizeY", "blockSizeZ"]>]>,
Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 61a630aa88960..46e6c8f9386cb 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -405,83 +405,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
return op->emitError("expected '")
<< getContainerModuleAttrName() << "' attribute to be attached to '"
<< ModuleOp::getOperationName() << '\'';
-
- auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
- // Ignore launches that are nested more or less deep than functions in the
- // module we are currently checking.
- if (!launchOp->getParentOp() ||
- launchOp->getParentOp()->getParentOp() != module)
- return success();
-
- // Ignore launch ops with missing attributes here. The errors will be
- // reported by the verifiers of those ops.
- if (!launchOp->getAttrOfType<SymbolRefAttr>(
- LaunchFuncOp::getKernelAttrName(launchOp->getName())))
- return success();
-
- // Check that `launch_func` refers to a well-formed GPU kernel container.
- StringAttr kernelContainerName = launchOp.getKernelModuleName();
- Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
- if (!kernelContainer)
- return launchOp.emitOpError()
- << "kernel container '" << kernelContainerName.getValue()
- << "' is undefined";
-
- // If the container is a GPU binary op return success.
- if (isa<BinaryOp>(kernelContainer))
- return success();
-
- auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
- if (!kernelModule)
- return launchOp.emitOpError()
- << "kernel module '" << kernelContainerName.getValue()
- << "' is undefined";
-
- // Check that `launch_func` refers to a well-formed kernel function.
- Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
- if (!kernelFunc)
- return launchOp.emitOpError("kernel function '")
- << launchOp.getKernel() << "' is undefined";
- auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
- if (!kernelConvertedFunction) {
- InFlightDiagnostic diag = launchOp.emitOpError()
- << "referenced kernel '" << launchOp.getKernel()
- << "' is not a function";
- diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
- return diag;
- }
-
- if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
- GPUDialect::getKernelFuncAttrName()))
- return launchOp.emitOpError("kernel function is missing the '")
- << GPUDialect::getKernelFuncAttrName() << "' attribute";
-
- // TODO: If the kernel isn't a GPU function (which happens during separate
- // compilation), do not check type correspondence as it would require the
- // verifier to be aware of the type conversion.
- auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
- if (!kernelGPUFunction)
- return success();
-
- unsigned actualNumArguments = launchOp.getNumKernelOperands();
- unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
- if (expectedNumArguments != actualNumArguments)
- return launchOp.emitOpError("got ")
- << actualNumArguments << " kernel operands but expected "
- << expectedNumArguments;
-
- auto functionType = kernelGPUFunction.getFunctionType();
- for (unsigned i = 0; i < expectedNumArguments; ++i) {
- if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
- return launchOp.emitOpError("type of function argument ")
- << i << " does not match";
- }
- }
-
- return success();
- });
-
- return walkResult.wasInterrupted() ? failure() : success();
+ return success();
}
/// Parses an optional list of async operands with an optional leading keyword.
@@ -1381,6 +1305,86 @@ LogicalResult LaunchFuncOp::verify() {
return success();
}
+LogicalResult
+LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ LaunchFuncOp launchOp = *this;
+ ModuleOp module = (*this)->getParentOfType<ModuleOp>();
+ // Ignore launches that are nested more or less deep than functions in the
+ // module we are currently checking.
+ if (!launchOp->getParentOp() ||
+ launchOp->getParentOp()->getParentOp() != module)
+ return success();
+
+ // Ignore launch ops with missing attributes here. The errors will be
+ // reported by the verifiers of those ops.
+ if (!launchOp->getAttrOfType<SymbolRefAttr>(
+ LaunchFuncOp::getKernelAttrName(launchOp->getName())))
+ return success();
+
+ // Check that `launch_func` refers to a well-formed GPU kernel container.
+ StringAttr kernelContainerName = launchOp.getKernelModuleName();
+ Operation *kernelContainer =
+ symbolTable.lookupNearestSymbolFrom(module, kernelContainerName);
+ if (!kernelContainer)
+ return launchOp.emitOpError()
+ << "kernel container '" << kernelContainerName.getValue()
+ << "' is undefined";
+
+ // If the container is a GPU binary op return success.
+ if (isa<BinaryOp>(kernelContainer))
+ return success();
+
+ auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
+ if (!kernelModule)
+ return launchOp.emitOpError()
+ << "kernel module '" << kernelContainerName.getValue()
+ << "' is undefined";
+
+ // Check that `launch_func` refers to a well-formed kernel function.
+ Operation *kernelFunc =
+ symbolTable.lookupNearestSymbolFrom(module, launchOp.getKernelAttr());
+ if (!kernelFunc)
+ return launchOp.emitOpError("kernel function '")
+ << launchOp.getKernel() << "' is undefined";
+ auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
+ if (!kernelConvertedFunction) {
+ InFlightDiagnostic diag = launchOp.emitOpError()
+ << "referenced kernel '" << launchOp.getKernel()
+ << "' is not a function";
+ diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
+ return diag;
+ }
+
+ if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
+ GPUDialect::getKernelFuncAttrName()))
+ return launchOp.emitOpError("kernel function is missing the '")
+ << GPUDialect::getKernelFuncAttrName() << "' attribute";
+
+ // TODO: If the kernel isn't a GPU function (which happens during separate
+ // compilation), do not check type correspondence as it would require the
+ // verifier to be aware of the type conversion.
+ auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
+ if (!kernelGPUFunction)
+ return success();
+
+ unsigned actualNumArguments = launchOp.getNumKernelOperands();
+ unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
+ if (expectedNumArguments != actualNumArguments)
+ return launchOp.emitOpError("got ")
+ << actualNumArguments << " kernel operands but expected "
+ << expectedNumArguments;
+
+ auto functionType = kernelGPUFunction.getFunctionType();
+ for (unsigned i = 0; i < expectedNumArguments; ++i) {
+ if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
+ return launchOp.emitOpError("type of function argument ")
+ << i << " does not match";
+ }
+ }
+
+ return success();
+}
+
static ParseResult
parseLaunchDimType(OpAsmParser &parser, Type &dimTy,
std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 26bcf948bc85d..eaff7d8a0d5db 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -137,14 +137,14 @@ module attributes {gpu.container_module} {
// -----
module attributes {gpu.container_module} {
- module @kernels {
+ gpu.module @kernels_container {
gpu.func @kernel_1(%arg1 : !llvm.ptr) kernel {
gpu.return
}
}
func.func @launch_func_missing_kernel_attr(%sz : index, %arg : !llvm.ptr) {
- // expected-error at +1 {{kernel module 'kernels' is undefined}}
+ // expected-error at +1 {{kernel container 'kernels' is undefined}}
gpu.launch_func @kernels::@kernel_1 blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) args(%arg : !llvm.ptr)
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/173277
More information about the Mlir-commits
mailing list