[Mlir-commits] [mlir] [mlir][analysis] Fix a crash in TestMatchReductionPass (PR #149803)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 21 05:28:21 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

Previously, the `TestMatchReductionPass` assumed that all tested region operations had exactly one input by default. If this was not the case, it could lead to a crash. This PR updates the pass to automatically retrieve the number of inputs for operations implementing the `DestinationStyleOpInterface`, thereby preventing the crash.
Fixes #<!-- -->131437.

---
Full diff: https://github.com/llvm/llvm-project/pull/149803.diff


3 Files Affected:

- (modified) mlir/lib/Analysis/SliceAnalysis.cpp (+2) 
- (modified) mlir/test/Analysis/test-match-reduction.mlir (+14) 
- (modified) mlir/test/lib/Analysis/TestMatchReduction.cpp (+5-1) 


``````````diff
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 991c71e3f689a..510aa2fcdcd8c 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -327,6 +327,8 @@ Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
   // Check that the yielded value is in the same position as in
   // `iterCarriedArgs`.
   Operation *terminatorOp = combinerOp;
+  assert(redPos < terminatorOp->getNumOperands() &&
+         "'redPos' is out of bounds");
   if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
     return nullptr;
 
diff --git a/mlir/test/Analysis/test-match-reduction.mlir b/mlir/test/Analysis/test-match-reduction.mlir
index b5902db77e899..a92a4b5ee77e0 100644
--- a/mlir/test/Analysis/test-match-reduction.mlir
+++ b/mlir/test/Analysis/test-match-reduction.mlir
@@ -81,6 +81,20 @@ func.func @linalg_fused_red_add(%in0t: tensor<4x4xf32>, %out0t: tensor<4xf32>) {
 
 // -----
 
+// expected-remark at below {{Testing function}}
+func.func @linalg_multiple_inputs(%arg0: tensor<1x8229x40x8xf32>, %arg1: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  %0 = tensor.empty() : tensor<1x1xf32>
+  // expected-remark at below {{Reduction found in output #0!}}
+  // expected-remark at below {{Reduced Value: <block argument> of type 'f32' at index: 0}}
+  // expected-remark at below {{Combiner Op: %2 = arith.addf }}
+  %1 = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+                               ins(%arg0, %0 : tensor<1x8229x40x8xf32>, tensor<1x1xf32>)
+                               outs(%arg1 : tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %1 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
 // expected-remark at below {{Testing function}}
 func.func @affine_no_red_rec(%in: memref<512xf32>) {
  %cst = arith.constant 0.000000e+00 : f32
diff --git a/mlir/test/lib/Analysis/TestMatchReduction.cpp b/mlir/test/lib/Analysis/TestMatchReduction.cpp
index 54aea8410388f..12aa1f9ded3d4 100644
--- a/mlir/test/lib/Analysis/TestMatchReduction.cpp
+++ b/mlir/test/lib/Analysis/TestMatchReduction.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 
@@ -69,7 +70,10 @@ struct TestMatchReductionPass
       if (args.size() < 2)
         return;
 
-      auto outputs = args.drop_front();
+      unsigned inputsNum = 1;
+      if (auto destOp = dyn_cast<DestinationStyleOpInterface>(op))
+        inputsNum = destOp.getNumDpsInputs();
+      auto outputs = args.drop_front(inputsNum);
       for (int i = 0, size = outputs.size(); i < size; ++i) {
         SmallVector<Operation *, 4> combinerOps;
         Value reducedValue = matchReduction(outputs, i, combinerOps);

``````````

</details>


https://github.com/llvm/llvm-project/pull/149803


More information about the Mlir-commits mailing list