[Mlir-commits] [mlir] [mlir][gpu] Validate argument count in gpu.launch parser (PR #180388)

Longsheng Mou llvmlistbot at llvm.org
Sat Feb 7 20:14:45 PST 2026


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/180388

This PR adds validation in the `gpu.launch` parser to ensure the launch configuration provides exactly 3 arguments. Emit a parser error when the argument count is not 3. Fixes #176426.

>From 6c66b89b808d55553bd54065b37fb7326a220342 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sun, 8 Feb 2026 12:09:57 +0800
Subject: [PATCH] [mlir][gpu] Validate argument count in gpu.launch parser

This PR adds validation in the `gpu.launch` parser to ensure the launch
configuration provides exactly 3 arguments. Emit a parser error when the argument
count is not 3.
---
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 29 +++++++++++++---------
 mlir/test/Dialect/GPU/invalid.mlir     | 34 ++++++++++++++++++++++++++
 2 files changed, 51 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 8bd508b364e37..5953553ae0df0 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1000,13 +1000,19 @@ static ParseResult
 parseSizeAssignment(OpAsmParser &parser,
                     MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes,
                     MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes,
-                    MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) {
+                    MutableArrayRef<OpAsmParser::UnresolvedOperand> indices,
+                    StringRef keyword) {
   assert(indices.size() == 3 && "space for three indices expected");
   SmallVector<OpAsmParser::UnresolvedOperand, 3> args;
   if (parser.parseOperandList(args, OpAsmParser::Delimiter::Paren,
                               /*allowResultNumber=*/false) ||
       parser.parseKeyword("in") || parser.parseLParen())
     return failure();
+
+  if (args.size() != 3) {
+    return parser.emitError(parser.getNameLoc())
+           << keyword << " expects 3 arguments, but got " << args.size();
+  }
   std::move(args.begin(), args.end(), indices.begin());
 
   for (int i = 0; i < 3; ++i) {
@@ -1057,8 +1063,7 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
   }
 
   bool hasCluster = false;
-  if (succeeded(
-          parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
+  if (succeeded(parser.parseOptionalKeyword(LaunchOp::getClustersKeyword()))) {
     hasCluster = true;
     sizes.resize(9);
     regionArgs.resize(18);
@@ -1069,9 +1074,9 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
   // Last three segment assigns the cluster size. In the region argument
   // list, this is last 6 arguments.
   if (hasCluster) {
-    if (parseSizeAssignment(parser, sizesRef.drop_front(6),
-                            regionArgsRef.slice(15, 3),
-                            regionArgsRef.slice(12, 3)))
+    if (parseSizeAssignment(
+            parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
+            regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
       return failure();
   }
   // Parse the size assignment segments: the first segment assigns grid sizes
@@ -1079,14 +1084,14 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
   // sizes and defines values for thread identifiers.  In the region argument
   // list, identifiers precede sizes, and block-related values precede
   // thread-related values.
-  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
+  if (parser.parseKeyword(LaunchOp::getBlocksKeyword()) ||
       parseSizeAssignment(parser, sizesRef.take_front(3),
-                          regionArgsRef.slice(6, 3),
-                          regionArgsRef.slice(0, 3)) ||
-      parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
+                          regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
+                          LaunchOp::getBlocksKeyword()) ||
+      parser.parseKeyword(LaunchOp::getThreadsKeyword()) ||
       parseSizeAssignment(parser, sizesRef.drop_front(3),
-                          regionArgsRef.slice(9, 3),
-                          regionArgsRef.slice(3, 3)) ||
+                          regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
+                          LaunchOp::getThreadsKeyword()) ||
       parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
                              result.operands))
     return failure();
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 6e67d682703ec..7c678e4f34d3d 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -46,6 +46,40 @@ func.func @launch_result_no_async() {
 
 // -----
 
+func.func @launch_wrong_clusters(%sz : index) {
+  // expected-error at +1 {{'gpu.launch' clusters expects 3 arguments, but got 1}}
+  gpu.launch clusters(%cx) in (%scx = %sz, %scy = %sz, %scz = %sz)
+             blocks(%bx, %by, %bz) in (%sbx = %sz, %sby = %sz, %sbz = %sz)
+             threads(%tx, %ty, %tz) in (%stx = %sz, %sty = %sz, %stz = %sz) {
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @launch_wrong_blocks(%sz : index) {
+  // expected-error at +1 {{'gpu.launch' blocks expects 3 arguments, but got 1}}
+  gpu.launch blocks(%bx) in (%sbx = %sz, %sby = %sz, %sbz = %sz)
+             threads(%tx, %ty, %tz) in (%stx = %sz, %sty = %sz, %stz = %sz) {
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @launch_wrong_threads(%sz : index) {
+  // expected-error at +1 {{'gpu.launch' threads expects 3 arguments, but got 1}}
+  gpu.launch blocks(%bx, %by, %bz) in (%sbx = %sz, %sby = %sz, %sbz = %sz)
+             threads(%tx) in (%stx = %sz, %sty = %sz, %stz = %sz) {
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
 func.func @launch_func_too_few_operands(%sz : index) {
   // expected-error at +1 {{expected 6 or more operands}}
   "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz)



More information about the Mlir-commits mailing list