[Mlir-commits] [mlir] [mlir][linalg] Fix #93973 - linalg reduce verifier crash (PR #119871)

Clément Fournier llvmlistbot at llvm.org
Fri Dec 13 04:40:06 PST 2024


https://github.com/oowekyala created https://github.com/llvm/llvm-project/pull/119871

This PR fixes #93973. 

It is a new opening of #107005. AFAICT it is not settled whethe linalg.reduce should be allowed to have variadic operands or be unary. The op is currently designed to have variadic operands, although its [behavior](https://github.com/llvm/llvm-project/pull/107005#issuecomment-2541277983) is confusing. AFAIU the planned [restructuring of the dialect](https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586), led by @rengolin, could benefit from having a unary `linalg.reduce`, as `linalg.contract` would be usable for some of the use cases (the case with two inputs and one output).

Anyway this PR just preserves the status quo and fixes the crash in the verifier.



>From 384e8cd9970a5ee9c211439faf08d64908a339f2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 13 Dec 2024 13:27:35 +0100
Subject: [PATCH] [mlir][linalg] Fix #93973 - linalg reduce verifier crash

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp |  4 ++++
 mlir/test/Dialect/Linalg/invalid.mlir    | 19 +++++++++++++++++++
 2 files changed, 23 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d9840e3923c4f7..d31d3ef4bd7ef9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1749,6 +1749,10 @@ void ReduceOp::print(OpAsmPrinter &p) {
 LogicalResult ReduceOp::verify() {
   ArrayRef<int64_t> dimensionsRef = getDimensions();
 
+  if (getNumDpsInits() != getNumDpsInputs()) {
+    return emitOpError() << "requires same number of input and init operands";
+  }
+
   for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
     if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
         llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index e3b6958cfa8816..bf502eab798780 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -728,6 +728,25 @@ func.func @reduce_reduced_input_init_rank_mismatch(%input: tensor<16x32x64xf32>,
   func.return %reduce : tensor<16x64xf32>
 }
 
+
+// -----
+
+func.func @reduce_mismatched_inputs_outputs(
+    %input1: tensor<16x32x64xf32>,
+    %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>)  -> (tensor<16x64xf32>) {
+  // expected-error @+1{{'linalg.reduce' op requires same number of input and init operands}}
+  %reduce = linalg.reduce
+      ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
+      outs(%init1 : tensor<16x64xf32>)
+      dimensions = [1]
+      (%in: f32, %in2: f32, %out: f32) {
+        %0 = arith.mulf %in, %in2: f32
+        %1 = arith.addf %in, %out: f32
+        linalg.yield %1: f32
+      }
+  func.return %reduce : tensor<16x64xf32>
+}
+
 // -----
 
 func.func @reduce_wrong_number_of_block_arguments(



More information about the Mlir-commits mailing list