[Mlir-commits] [mlir] [mlir][linalg] Fix #93973 - linalg reduce verifier crash (PR #119871)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 13 04:40:45 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Clément Fournier (oowekyala)
<details>
<summary>Changes</summary>
This PR fixes #<!-- -->93973.
It is a new opening of #<!-- -->107005. AFAICT it is not settled whethe linalg.reduce should be allowed to have variadic operands or be unary. The op is currently designed to have variadic operands, although its [behavior](https://github.com/llvm/llvm-project/pull/107005#issuecomment-2541277983) is confusing. AFAIU the planned [restructuring of the dialect](https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586), led by @<!-- -->rengolin, could benefit from having a unary `linalg.reduce`, as `linalg.contract` would be usable for some of the use cases (the case with two inputs and one output).
Anyway this PR just preserves the status quo and fixes the crash in the verifier.
---
Full diff: https://github.com/llvm/llvm-project/pull/119871.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+4)
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+19)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d9840e3923c4f7..d31d3ef4bd7ef9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1749,6 +1749,10 @@ void ReduceOp::print(OpAsmPrinter &p) {
LogicalResult ReduceOp::verify() {
ArrayRef<int64_t> dimensionsRef = getDimensions();
+ if (getNumDpsInits() != getNumDpsInputs()) {
+ return emitOpError() << "requires same number of input and init operands";
+ }
+
for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index e3b6958cfa8816..bf502eab798780 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -728,6 +728,25 @@ func.func @reduce_reduced_input_init_rank_mismatch(%input: tensor<16x32x64xf32>,
func.return %reduce : tensor<16x64xf32>
}
+
+// -----
+
+func.func @reduce_mismatched_inputs_outputs(
+ %input1: tensor<16x32x64xf32>,
+ %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>) -> (tensor<16x64xf32>) {
+ // expected-error @+1{{'linalg.reduce' op requires same number of input and init operands}}
+ %reduce = linalg.reduce
+ ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
+ outs(%init1 : tensor<16x64xf32>)
+ dimensions = [1]
+ (%in: f32, %in2: f32, %out: f32) {
+ %0 = arith.mulf %in, %in2: f32
+ %1 = arith.addf %in, %out: f32
+ linalg.yield %1: f32
+ }
+ func.return %reduce : tensor<16x64xf32>
+}
+
// -----
func.func @reduce_wrong_number_of_block_arguments(
``````````
</details>
https://github.com/llvm/llvm-project/pull/119871
More information about the Mlir-commits
mailing list