[Mlir-commits] [mlir] [mlir][linalg] Fix crash in linalg.reduce verifier when inputs \!= inits count (PR #186278)

Mehdi Amini llvmlistbot at llvm.org
Mon Mar 23 06:35:45 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/186278

>From 00ff2387d92ad9d6a799f9ac403143b8670fad96 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 12 Mar 2026 15:30:25 -0700
Subject: [PATCH] [mlir][linalg] Fix crash in linalg.reduce verifier when
 inputs \!= inits count

`ReduceOp` uses the `SameVariadicOperandSize` ODS trait, which computes
each variadic group's size as `total_operands / num_groups` (floordiv).
When the number of `ins` operands and `outs` operands differ (e.g. 2 ins
and 1 out), the division is inexact and the generated `getInputs()` and
`getInits()` accessors return slices of incorrect size.

The verifier then uses `getNumDpsInputs()` (which counts correctly via
the `DestinationStyleOpInterface`) as the loop bound but dereferences
`getInputs()[i]` from the incorrectly-sized range, causing an
out-of-bounds assertion.

Add an early check in `ReduceOp::verify()` that compares the operand
count from the ODS accessor with `getNumDpsInputs()`.  A mismatch means
the `SameVariadicOperandSize` invariant is violated and the verifier
emits a clear diagnostic instead of crashing.

Fixes #93973

Assisted-by: Claude Code
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 17 +++++++++++++++++
 mlir/test/Dialect/Linalg/invalid.mlir    | 22 ++++++++++++++++++++++
 2 files changed, 39 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ad2909f656eea..11435b4524a2f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1883,6 +1883,23 @@ void ReduceOp::print(OpAsmPrinter &p) {
 LogicalResult ReduceOp::verify() {
   ArrayRef<int64_t> dimensionsRef = getDimensions();
 
+  // The ReduceOp uses `SameVariadicOperandSize`, which requires equal numbers
+  // of inputs and inits. Detect a mismatch early: when they differ, the
+  // ODS-generated getInputs()/getInits() accessors compute each group's size
+  // via floordiv of the total operand count, producing incorrect slices that
+  // would cause out-of-bounds accesses below.
+  if (getInputs().size() != static_cast<size_t>(getNumDpsInputs()))
+    return emitOpError()
+           << "expected equal number of inputs and outputs (required by "
+              "SameVariadicOperandSize), got "
+           << getNumDpsInputs() << " input(s) and " << getNumDpsInits()
+           << " output(s)";
+
+  if (getInputs().empty())
+    return emitOpError() << "expected at least one input";
+  if (getInits().empty())
+    return emitOpError() << "expected at least one output";
+
   for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
     if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
         llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 9500d00a5e647..d2c0934e27fe7 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -2172,3 +2172,25 @@ func.func @matmul_invalid_mixed_types(%t: tensor<?xf16>, %f: vector<4xf16>)
                                 outs(%f : vector<4xf16>) -> tensor<?xf16>
   func.return %0, %f : tensor<?xf16>, vector<4xf16>
 }
+
+// -----
+
+// Regression test for https://github.com/llvm/llvm-project/issues/93973.
+// Having more inputs than inits must not crash but produce a clear error.
+
+func.func @reduce_unequal_input_output_count(
+    %arg0: tensor<32xi32>, %arg1: tensor<32xi32>) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %init = tensor.from_elements %c0 : tensor<i32>
+  // expected-error @+1 {{'linalg.reduce' op expected equal number of inputs and outputs}}
+  %reduced = linalg.reduce
+      ins(%arg0, %arg1 : tensor<32xi32>, tensor<32xi32>)
+      outs(%init : tensor<i32>)
+      dimensions = [0]
+      (%in0: i32, %in1: i32, %acc: i32) {
+        %v = arith.addi %in0, %acc : i32
+        linalg.yield %v : i32
+      }
+  %ext = tensor.extract %reduced[] : tensor<i32>
+  return %ext : i32
+}



More information about the Mlir-commits mailing list