[Mlir-commits] [mlir] [mlir][gpu] Guard negative workgroup_attributions in GPU ops (PR #174409)

Prathamesh Tagore llvmlistbot at llvm.org
Mon Jan 5 06:26:10 PST 2026


https://github.com/meshtag created https://github.com/llvm/llvm-project/pull/174409

Prevent crashes from negative workgroup_attributions by validating the attribute in gpu.launch and gpu.func verification. We emit an error with a helpful message instead of an outright crash in such cases now.

Fixes https://github.com/llvm/llvm-project/issues/159674

>From 1c9abef7d77293a621872aef9dd8988c5abe91f9 Mon Sep 17 00:00:00 2001
From: Prathamesh Tagore <prathameshtagore at gmail.com>
Date: Mon, 5 Jan 2026 19:43:21 +0530
Subject: [PATCH] [mlir][gpu] Guard negative workgroup_attributions in GPU ops

Prevent crashes from negative workgroup_attributions by validating the
attribute in gpu.launch and gpu.func verification. We emit an error with a
helpful message instead of an outright crash in such cases now.
---
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 19 +++++++++++++++++--
 mlir/test/Dialect/GPU/invalid.mlir     | 24 ++++++++++++++++++++++++
 2 files changed, 41 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 21c0d369b8d1c..7e174f1d21adb 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -886,12 +886,20 @@ LogicalResult LaunchOp::verify() {
 }
 
 LogicalResult LaunchOp::verifyRegions() {
+  int64_t numWorkgroupAttributions = 0;
+  if (auto attr = (*this)->getAttrOfType<IntegerAttr>(
+          getNumWorkgroupAttributionsAttrName())) {
+    numWorkgroupAttributions = attr.getInt();
+    if (numWorkgroupAttributions < 0)
+      return emitOpError() << "expected non-negative workgroup_attributions";
+  }
+
   // Kernel launch takes kNumConfigOperands leading operands for grid/block
   // sizes and transforms them into kNumConfigRegionAttributes region arguments
   // for block/thread identifiers and grid/block sizes.
   if (!getBody().empty()) {
     if (getBody().getNumArguments() <
-        kNumConfigRegionAttributes + getNumWorkgroupAttributions())
+        kNumConfigRegionAttributes + numWorkgroupAttributions)
       return emitOpError("unexpected number of region arguments");
   }
 
@@ -1814,7 +1822,14 @@ LogicalResult GPUFuncOp::verifyBody() {
   if (empty())
     return emitOpError() << "expected body with at least one block";
   unsigned numFuncArguments = getNumArguments();
-  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
+  int64_t numWorkgroupAttributions = 0;
+  if (auto attr = (*this)->getAttrOfType<IntegerAttr>(
+          getNumWorkgroupAttributionsAttrName())) {
+    numWorkgroupAttributions = attr.getInt();
+    if (numWorkgroupAttributions < 0)
+      return emitOpError() << "expected non-negative workgroup_attributions";
+  }
+
   unsigned numBlockArguments = front().getNumArguments();
   if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
     return emitOpError() << "expected at least "
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 26bcf948bc85d..bc06e78186b19 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -1021,3 +1021,27 @@ func.func @warp_mismatch_rank(%laneid: index) {
   }
   return
 }
+
+// -----
+
+func.func @launch_negative_workgroup_attributions(%sz : index) {
+  // expected-error at +1 {{expected non-negative workgroup_attributions}}
+  "gpu.launch"(%sz, %sz, %sz, %sz, %sz, %sz) ({})
+      {operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0>,
+       workgroup_attributions = -1 : i64}
+      : (index, index, index, index, index, index) -> ()
+  return
+}
+
+// -----
+
+module {
+  gpu.module @gpu_funcs {
+    // expected-error at +1 {{expected non-negative workgroup_attributions}}
+    "gpu.func"() ({
+    ^bb0:
+      gpu.return
+    }) {function_type = () -> (), sym_name = "gpu_func_negative_workgroup_attributions",
+         workgroup_attributions = -1 : i64} : () -> ()
+  }
+}



More information about the Mlir-commits mailing list