[Mlir-commits] [mlir] [mlir] Fix #93973 - linalg::ReduceOp verifier crash (PR #107005)

Clément Fournier llvmlistbot at llvm.org
Mon Sep 2 08:50:23 PDT 2024


https://github.com/oowekyala created https://github.com/llvm/llvm-project/pull/107005

Fix #93973. This allows using `linalg.reduce` to eg reduce several tensors into one. The current implementation is limited to have the same number of inputs and outputs.


>From 37b3934ec01bb2a2aa3eac153755d878ac87e10e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Sat, 1 Jun 2024 19:23:45 +0200
Subject: [PATCH] [mlir] Fix #93973 - linalg::ReduceOp verifier crash

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  2 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 16 ++++---
 mlir/test/Dialect/Linalg/roundtrip.mlir       | 42 +++++++++++++++++++
 3 files changed, 54 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ac61117c3d6e36..f20f036d6fe480 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -311,7 +311,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
 def ReduceOp : LinalgStructuredBase_Op<"reduce", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
-    SameVariadicOperandSize,
+    AttrSizedOperandSegments,
     SingleBlockImplicitTerminator<"YieldOp">]> {
   let summary = "Reduce operator";
   let description = [{
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76df3ecf2d2bd4..9c6c36075b55bd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1301,11 +1301,12 @@ LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
 static ParseResult parseDstStyleOp(
     OpAsmParser &parser, OperationState &result,
     function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
-        nullptr) {
+        nullptr,
+    bool addOperandSegmentSizes = false) {
   // Parse `ins` and `outs`.
   SmallVector<Type, 4> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
-                                   /*addOperandSegmentSizes=*/false))
+                                   addOperandSegmentSizes))
     return failure();
 
   // Add result types.
@@ -1646,9 +1647,12 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
   }
 
   if (parseDstStyleOp(
-          parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
+          parser, result,
+          [&](OpAsmParser &parser, NamedAttrList &attributes) {
             return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
-          }))
+          },
+          /*addOperandSegmentSizes=*/true))
+
     return failure();
 
   if (payloadOpName.has_value()) {
@@ -1683,7 +1687,9 @@ void ReduceOp::print(OpAsmPrinter &p) {
 
   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
-  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
+  p.printOptionalAttrDict(
+      (*this)->getAttrs(),
+      {getDimensionsAttrName(), getOperandSegmentSizesAttrName()});
   if (!payloadOp) {
     // Print region if the payload op was not detected.
     p.increaseIndent();
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 146e9780b8ebbe..802de7c335d9b1 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -485,6 +485,48 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
 
 // -----
 
+func.func @reduce_asymmetric(%input: tensor<16x32x64xi32>, %input2: tensor<16x32x64xi32>,
+                  %init: tensor<16x64xi32>) -> tensor<16x64xi32> {
+  %reduce = linalg.reduce
+      ins(%input, %input2:tensor<16x32x64xi32>, tensor<16x32x64xi32>)
+      outs(%init:tensor<16x64xi32>)
+      dimensions = [1]
+      (%in: i32, %in2: i32, %out: i32) {
+        %0 = arith.muli %in, %in2: i32
+        %1 = arith.addi %out, %0: i32
+        linalg.yield %1: i32
+      }
+  func.return %reduce : tensor<16x64xi32>
+}
+// CHECK-LABEL: func @reduce_asymmetric
+//       CHECK:   linalg.reduce ins(%{{.*}}, %{{.*}}: tensor<16x32x64xi32>, tensor<16x32x64xi32>)
+//  CHECK-NOT:    operandSegmentSize
+//  CHECK-SAME:   outs(%{{.*}}: tensor<16x64xi32>)
+//  CHECK-SAME:   dimensions = [1]
+
+// -----
+
+func.func @reduce_asymmetric_memref(%input: memref<16x32x64xi32>, %input2: memref<16x32x64xi32>,
+                  %init: memref<16x64xi32>) {
+  linalg.reduce
+      ins(%input, %input2:memref<16x32x64xi32>, memref<16x32x64xi32>)
+      outs(%init:memref<16x64xi32>)
+      dimensions = [1]
+      (%in: i32, %in2: i32, %out: i32) {
+        %0 = arith.muli %in, %in2: i32
+        %1 = arith.addi %out, %0: i32
+        linalg.yield %1: i32
+      }
+  func.return
+}
+// CHECK-LABEL: func @reduce_asymmetric_memref
+//       CHECK:   linalg.reduce ins(%{{.*}}, %{{.*}}: memref<16x32x64xi32>, memref<16x32x64xi32>)
+//  CHECK-NOT:    operandSegmentSize
+//  CHECK-SAME:   outs(%{{.*}}: memref<16x64xi32>)
+//  CHECK-SAME:   dimensions = [1]
+
+// -----
+
 func.func @transpose(%input: tensor<16x32x64xf32>,
                      %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
   %transpose = linalg.transpose



More information about the Mlir-commits mailing list