[Mlir-commits] [mlir] [mlir][gpu] Validate argument count in gpu.launch parser (PR #180388)
Longsheng Mou
llvmlistbot at llvm.org
Sat Feb 7 20:14:45 PST 2026
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/180388
This PR adds validation in the `gpu.launch` parser to ensure the launch configuration provides exactly 3 arguments. Emit a parser error when the argument count is not 3. Fixes #176426.
>From 6c66b89b808d55553bd54065b37fb7326a220342 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sun, 8 Feb 2026 12:09:57 +0800
Subject: [PATCH] [mlir][gpu] Validate argument count in gpu.launch parser
This PR adds validation in the `gpu.launch` parser to ensure the launch
configuration provides exactly 3 arguments. Emit a parser error when the argument
count is not 3.
---
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 29 +++++++++++++---------
mlir/test/Dialect/GPU/invalid.mlir | 34 ++++++++++++++++++++++++++
2 files changed, 51 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 8bd508b364e37..5953553ae0df0 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1000,13 +1000,19 @@ static ParseResult
parseSizeAssignment(OpAsmParser &parser,
MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes,
MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes,
- MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) {
+ MutableArrayRef<OpAsmParser::UnresolvedOperand> indices,
+ StringRef keyword) {
assert(indices.size() == 3 && "space for three indices expected");
SmallVector<OpAsmParser::UnresolvedOperand, 3> args;
if (parser.parseOperandList(args, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false) ||
parser.parseKeyword("in") || parser.parseLParen())
return failure();
+
+ if (args.size() != 3) {
+ return parser.emitError(parser.getNameLoc())
+ << keyword << " expects 3 arguments, but got " << args.size();
+ }
std::move(args.begin(), args.end(), indices.begin());
for (int i = 0; i < 3; ++i) {
@@ -1057,8 +1063,7 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
}
bool hasCluster = false;
- if (succeeded(
- parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
+ if (succeeded(parser.parseOptionalKeyword(LaunchOp::getClustersKeyword()))) {
hasCluster = true;
sizes.resize(9);
regionArgs.resize(18);
@@ -1069,9 +1074,9 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
// Last three segment assigns the cluster size. In the region argument
// list, this is last 6 arguments.
if (hasCluster) {
- if (parseSizeAssignment(parser, sizesRef.drop_front(6),
- regionArgsRef.slice(15, 3),
- regionArgsRef.slice(12, 3)))
+ if (parseSizeAssignment(
+ parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
+ regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
return failure();
}
// Parse the size assignment segments: the first segment assigns grid sizes
@@ -1079,14 +1084,14 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
// sizes and defines values for thread identifiers. In the region argument
// list, identifiers precede sizes, and block-related values precede
// thread-related values.
- if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
+ if (parser.parseKeyword(LaunchOp::getBlocksKeyword()) ||
parseSizeAssignment(parser, sizesRef.take_front(3),
- regionArgsRef.slice(6, 3),
- regionArgsRef.slice(0, 3)) ||
- parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
+ regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
+ LaunchOp::getBlocksKeyword()) ||
+ parser.parseKeyword(LaunchOp::getThreadsKeyword()) ||
parseSizeAssignment(parser, sizesRef.drop_front(3),
- regionArgsRef.slice(9, 3),
- regionArgsRef.slice(3, 3)) ||
+ regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
+ LaunchOp::getThreadsKeyword()) ||
parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
result.operands))
return failure();
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 6e67d682703ec..7c678e4f34d3d 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -46,6 +46,40 @@ func.func @launch_result_no_async() {
// -----
+func.func @launch_wrong_clusters(%sz : index) {
+ // expected-error at +1 {{'gpu.launch' clusters expects 3 arguments, but got 1}}
+ gpu.launch clusters(%cx) in (%scx = %sz, %scy = %sz, %scz = %sz)
+ blocks(%bx, %by, %bz) in (%sbx = %sz, %sby = %sz, %sbz = %sz)
+ threads(%tx, %ty, %tz) in (%stx = %sz, %sty = %sz, %stz = %sz) {
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @launch_wrong_blocks(%sz : index) {
+ // expected-error at +1 {{'gpu.launch' blocks expects 3 arguments, but got 1}}
+ gpu.launch blocks(%bx) in (%sbx = %sz, %sby = %sz, %sbz = %sz)
+ threads(%tx, %ty, %tz) in (%stx = %sz, %sty = %sz, %stz = %sz) {
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @launch_wrong_threads(%sz : index) {
+ // expected-error at +1 {{'gpu.launch' threads expects 3 arguments, but got 1}}
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %sz, %sby = %sz, %sbz = %sz)
+ threads(%tx) in (%stx = %sz, %sty = %sz, %stz = %sz) {
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
func.func @launch_func_too_few_operands(%sz : index) {
// expected-error at +1 {{expected 6 or more operands}}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz)
More information about the Mlir-commits
mailing list