[Mlir-commits] [mlir] [MLIR][SCF] Verify number of operands in scf.parallel reduce terminator (PR #171450)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 11 02:00:35 PST 2025


https://github.com/Men-cotton updated https://github.com/llvm/llvm-project/pull/171450

>From 26649106a111661c8ef12ca4aba3b1a57b1913a2 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Tue, 9 Dec 2025 23:32:38 +0900
Subject: [PATCH 1/2] [MLIR][SCF] Verify number of operands in scf.parallel
 reduce terminator

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp    |  5 +++++
 mlir/test/Dialect/SCF/invalid.mlir | 14 ++++++++++++++
 2 files changed, 19 insertions(+)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c75528a76c999..bb18b2a1e5abc 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3152,6 +3152,11 @@ LogicalResult ParallelOp::verify() {
     return emitOpError() << "expects number of results: " << resultsSize
                          << " to be the same as number of initial values: "
                          << initValsSize;
+  if (reduceOp.getNumOperands() != initValsSize)
+    return emitOpError() << "expects number of operands in the terminator: "
+                         << reduceOp.getNumOperands()
+                         << " to be the same as number of initial values: "
+                         << initValsSize;
 
   // Check that the types of the results and reductions are the same.
   for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 3f481ad5dbba7..e2edeb6805864 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -274,6 +274,20 @@ func.func @parallel_different_types_of_results_and_reduces(
 
 // -----
 
+// The scf.parallel operation requires the number of operands in the terminator
+// (scf.reduce) to match the number of initial values provided to the loop.
+func.func @invalid_parallel_reduce_operand_count() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // expected-error @+1 {{expects number of operands in the terminator: 1 to be the same as number of initial values: 0}}
+  scf.parallel (%arg1) = (%c0) to (%c1) step (%c1) {
+    scf.reduce(%c1 : index)
+  }
+  return
+}
+
+// -----
+
 func.func @top_level_reduce(%arg0 : f32) {
   // expected-error at +1 {{expects parent op 'scf.parallel'}}
   scf.reduce(%arg0 : f32) {

>From 0f31f9da8a4345776232e553ad2c93c280458150 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Thu, 11 Dec 2025 18:58:17 +0900
Subject: [PATCH 2/2] Fix: scf.reduce verification masking and crash

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp    | 11 +++++++----
 mlir/test/Dialect/SCF/invalid.mlir | 21 +++++++++++++++++++--
 2 files changed, 26 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index bb18b2a1e5abc..edb7299410370 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3153,10 +3153,8 @@ LogicalResult ParallelOp::verify() {
                          << " to be the same as number of initial values: "
                          << initValsSize;
   if (reduceOp.getNumOperands() != initValsSize)
-    return emitOpError() << "expects number of operands in the terminator: "
-                         << reduceOp.getNumOperands()
-                         << " to be the same as number of initial values: "
-                         << initValsSize;
+    // Delegate error reporting to ReduceOp
+    return success();
 
   // Check that the types of the results and reductions are the same.
   for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
@@ -3459,6 +3457,11 @@ void ReduceOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult ReduceOp::verifyRegions() {
+  if (getReductions().size() != getOperands().size())
+    return emitOpError() << "expects number of reduction regions: "
+                         << getReductions().size()
+                         << " to be the same as number of reduction operands: "
+                         << getOperands().size();
   // The region of a ReduceOp has two arguments of the same type as its
   // corresponding operand.
   for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index e2edeb6805864..6db43ffd4b81b 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -276,11 +276,11 @@ func.func @parallel_different_types_of_results_and_reduces(
 
 // The scf.parallel operation requires the number of operands in the terminator
 // (scf.reduce) to match the number of initial values provided to the loop.
-func.func @invalid_parallel_reduce_operand_count() {
+func.func @invalid_reduce_too_few_regions() {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // expected-error @+1 {{expects number of operands in the terminator: 1 to be the same as number of initial values: 0}}
   scf.parallel (%arg1) = (%c0) to (%c1) step (%c1) {
+    // expected-error @+1 {{expects number of reduction regions: 0 to be the same as number of reduction operands: 1}}
     scf.reduce(%c1 : index)
   }
   return
@@ -288,6 +288,23 @@ func.func @invalid_parallel_reduce_operand_count() {
 
 // -----
 
+// The scf.parallel operation requires the number of operands in the terminator
+// (scf.reduce) to match the number of initial values provided to the loop.
+func.func @invalid_reduce_too_many_regions() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = scf.parallel (%i0) = (%c0) to (%c1) step (%c1) init (%c0) -> (index) {
+    // expected-error @+1 {{expects number of reduction regions: 1 to be the same as number of reduction operands: 0}}
+    scf.reduce {
+      ^bb0(%lhs : index, %rhs : index):
+        scf.reduce.return %lhs : index
+    }
+  }
+  return
+}
+
+// -----
+
 func.func @top_level_reduce(%arg0 : f32) {
   // expected-error at +1 {{expects parent op 'scf.parallel'}}
   scf.reduce(%arg0 : f32) {



More information about the Mlir-commits mailing list