[Mlir-commits] [mlir] [mlir][gpu] Validate argument count in gpu.launch parser (PR #180388)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 7 20:15:17 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR adds validation in the `gpu.launch` parser to ensure the launch configuration provides exactly 3 arguments. Emit a parser error when the argument count is not 3. Fixes #<!-- -->176426.

---
Full diff: https://github.com/llvm/llvm-project/pull/180388.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+17-12) 
- (modified) mlir/test/Dialect/GPU/invalid.mlir (+34) 


``````````diff
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 8bd508b364e37..5953553ae0df0 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1000,13 +1000,19 @@ static ParseResult
 parseSizeAssignment(OpAsmParser &parser,
                     MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes,
                     MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes,
-                    MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) {
+                    MutableArrayRef<OpAsmParser::UnresolvedOperand> indices,
+                    StringRef keyword) {
   assert(indices.size() == 3 && "space for three indices expected");
   SmallVector<OpAsmParser::UnresolvedOperand, 3> args;
   if (parser.parseOperandList(args, OpAsmParser::Delimiter::Paren,
                               /*allowResultNumber=*/false) ||
       parser.parseKeyword("in") || parser.parseLParen())
     return failure();
+
+  if (args.size() != 3) {
+    return parser.emitError(parser.getNameLoc())
+           << keyword << " expects 3 arguments, but got " << args.size();
+  }
   std::move(args.begin(), args.end(), indices.begin());
 
   for (int i = 0; i < 3; ++i) {
@@ -1057,8 +1063,7 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
   }
 
   bool hasCluster = false;
-  if (succeeded(
-          parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
+  if (succeeded(parser.parseOptionalKeyword(LaunchOp::getClustersKeyword()))) {
     hasCluster = true;
     sizes.resize(9);
     regionArgs.resize(18);
@@ -1069,9 +1074,9 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
   // Last three segment assigns the cluster size. In the region argument
   // list, this is last 6 arguments.
   if (hasCluster) {
-    if (parseSizeAssignment(parser, sizesRef.drop_front(6),
-                            regionArgsRef.slice(15, 3),
-                            regionArgsRef.slice(12, 3)))
+    if (parseSizeAssignment(
+            parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
+            regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
       return failure();
   }
   // Parse the size assignment segments: the first segment assigns grid sizes
@@ -1079,14 +1084,14 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
   // sizes and defines values for thread identifiers.  In the region argument
   // list, identifiers precede sizes, and block-related values precede
   // thread-related values.
-  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
+  if (parser.parseKeyword(LaunchOp::getBlocksKeyword()) ||
       parseSizeAssignment(parser, sizesRef.take_front(3),
-                          regionArgsRef.slice(6, 3),
-                          regionArgsRef.slice(0, 3)) ||
-      parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
+                          regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
+                          LaunchOp::getBlocksKeyword()) ||
+      parser.parseKeyword(LaunchOp::getThreadsKeyword()) ||
       parseSizeAssignment(parser, sizesRef.drop_front(3),
-                          regionArgsRef.slice(9, 3),
-                          regionArgsRef.slice(3, 3)) ||
+                          regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
+                          LaunchOp::getThreadsKeyword()) ||
       parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
                              result.operands))
     return failure();
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 6e67d682703ec..7c678e4f34d3d 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -46,6 +46,40 @@ func.func @launch_result_no_async() {
 
 // -----
 
+func.func @launch_wrong_clusters(%sz : index) {
+  // expected-error at +1 {{'gpu.launch' clusters expects 3 arguments, but got 1}}
+  gpu.launch clusters(%cx) in (%scx = %sz, %scy = %sz, %scz = %sz)
+             blocks(%bx, %by, %bz) in (%sbx = %sz, %sby = %sz, %sbz = %sz)
+             threads(%tx, %ty, %tz) in (%stx = %sz, %sty = %sz, %stz = %sz) {
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @launch_wrong_blocks(%sz : index) {
+  // expected-error at +1 {{'gpu.launch' blocks expects 3 arguments, but got 1}}
+  gpu.launch blocks(%bx) in (%sbx = %sz, %sby = %sz, %sbz = %sz)
+             threads(%tx, %ty, %tz) in (%stx = %sz, %sty = %sz, %stz = %sz) {
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @launch_wrong_threads(%sz : index) {
+  // expected-error at +1 {{'gpu.launch' threads expects 3 arguments, but got 1}}
+  gpu.launch blocks(%bx, %by, %bz) in (%sbx = %sz, %sby = %sz, %sbz = %sz)
+             threads(%tx) in (%stx = %sz, %sty = %sz, %stz = %sz) {
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
 func.func @launch_func_too_few_operands(%sz : index) {
   // expected-error at +1 {{expected 6 or more operands}}
   "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz)

``````````

</details>


https://github.com/llvm/llvm-project/pull/180388


More information about the Mlir-commits mailing list