[Mlir-commits] [mlir] [mlir] Fix semantics of linalg::ReduceOp with several inputs (PR #107005)
Clément Fournier
llvmlistbot at llvm.org
Thu Nov 28 02:26:53 PST 2024
https://github.com/oowekyala updated https://github.com/llvm/llvm-project/pull/107005
>From 9e1383f2c3a69d5df5beaef8fff522af0bd389a0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Sat, 1 Jun 2024 19:23:45 +0200
Subject: [PATCH] [mlir] Fix #93973 - linalg::ReduceOp verifier crash
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 2 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 16 ++++---
mlir/test/Dialect/Linalg/roundtrip.mlir | 42 +++++++++++++++++++
3 files changed, 54 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 37eec6e07963b1..461e2ed091fa4d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -331,7 +331,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
def ReduceOp : LinalgStructuredBase_Op<"reduce", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
- SameVariadicOperandSize,
+ AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "Reduce operator";
let description = [{
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d9840e3923c4f7..2fa1405ff86186 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1339,11 +1339,12 @@ LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
static ParseResult parseDstStyleOp(
OpAsmParser &parser, OperationState &result,
function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
- nullptr) {
+ nullptr,
+ bool addOperandSegmentSizes = false) {
// Parse `ins` and `outs`.
SmallVector<Type, 4> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
- /*addOperandSegmentSizes=*/false))
+ addOperandSegmentSizes))
return failure();
// Add result types.
@@ -1694,9 +1695,12 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
}
if (parseDstStyleOp(
- parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
+ parser, result,
+ [&](OpAsmParser &parser, NamedAttrList &attributes) {
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
- }))
+ },
+ /*addOperandSegmentSizes=*/true))
+
return failure();
if (payloadOpName.has_value()) {
@@ -1731,7 +1735,9 @@ void ReduceOp::print(OpAsmPrinter &p) {
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
- p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
+ p.printOptionalAttrDict(
+ (*this)->getAttrs(),
+ {getDimensionsAttrName(), getOperandSegmentSizesAttrName()});
if (!payloadOp) {
// Print region if the payload op was not detected.
p.increaseIndent();
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1b8969bd115595..d4ad7584d00d86 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -485,6 +485,48 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
// -----
+func.func @reduce_asymmetric(%input: tensor<16x32x64xi32>, %input2: tensor<16x32x64xi32>,
+ %init: tensor<16x64xi32>) -> tensor<16x64xi32> {
+ %reduce = linalg.reduce
+ ins(%input, %input2:tensor<16x32x64xi32>, tensor<16x32x64xi32>)
+ outs(%init:tensor<16x64xi32>)
+ dimensions = [1]
+ (%in: i32, %in2: i32, %out: i32) {
+ %0 = arith.muli %in, %in2: i32
+ %1 = arith.addi %out, %0: i32
+ linalg.yield %1: i32
+ }
+ func.return %reduce : tensor<16x64xi32>
+}
+// CHECK-LABEL: func @reduce_asymmetric
+// CHECK: linalg.reduce ins(%{{.*}}, %{{.*}}: tensor<16x32x64xi32>, tensor<16x32x64xi32>)
+// CHECK-NOT: operandSegmentSize
+// CHECK-SAME: outs(%{{.*}}: tensor<16x64xi32>)
+// CHECK-SAME: dimensions = [1]
+
+// -----
+
+func.func @reduce_asymmetric_memref(%input: memref<16x32x64xi32>, %input2: memref<16x32x64xi32>,
+ %init: memref<16x64xi32>) {
+ linalg.reduce
+ ins(%input, %input2:memref<16x32x64xi32>, memref<16x32x64xi32>)
+ outs(%init:memref<16x64xi32>)
+ dimensions = [1]
+ (%in: i32, %in2: i32, %out: i32) {
+ %0 = arith.muli %in, %in2: i32
+ %1 = arith.addi %out, %0: i32
+ linalg.yield %1: i32
+ }
+ func.return
+}
+// CHECK-LABEL: func @reduce_asymmetric_memref
+// CHECK: linalg.reduce ins(%{{.*}}, %{{.*}}: memref<16x32x64xi32>, memref<16x32x64xi32>)
+// CHECK-NOT: operandSegmentSize
+// CHECK-SAME: outs(%{{.*}}: memref<16x64xi32>)
+// CHECK-SAME: dimensions = [1]
+
+// -----
+
func.func @transpose(%input: tensor<16x32x64xf32>,
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
%transpose = linalg.transpose
More information about the Mlir-commits
mailing list