[Mlir-commits] [mlir] d02233f - [mlir][Linalg] Add ReduceOp to Linalg structured ops.

Adrian Kuegel llvmlistbot at llvm.org
Thu Sep 29 07:23:27 PDT 2022


Author: Adrian Kuegel
Date: 2022-09-29T16:23:12+02:00
New Revision: d02233f0da17c73f2070b5d59c80547102fa12a3

URL: https://github.com/llvm/llvm-project/commit/d02233f0da17c73f2070b5d59c80547102fa12a3
DIFF: https://github.com/llvm/llvm-project/commit/d02233f0da17c73f2070b5d59c80547102fa12a3.diff

LOG: [mlir][Linalg] Add ReduceOp to Linalg structured ops.

This will allow to model (variadic) reductions with this special op instead of
using GenericOp.

RFC: https://discourse.llvm.org/t/rfc-primitive-ops-add-mapop-reductionop-transposeop-broadcastop-to-linalg/64184

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 02df0511ca30e..3fc0968ff6102 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -221,6 +221,67 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [AttrSizedOperandSegments]> {
 }
 
 
+//===----------------------------------------------------------------------===//
+// Reduce op.
+//===----------------------------------------------------------------------===//
+
+def TensorOrMemref :
+  AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
+
+def ReduceOp : LinalgStructuredBase_Op<"reduce", [
+      SameVariadicOperandSize, SingleBlockImplicitTerminator<"YieldOp">
+    ]> {
+  let summary = "Reduce operator";
+  let description = [{
+    Executes `combiner` on the `dimensions` of `inputs` and returns the
+    reduced result. The `dimensions` attribute needs to list the reduction
+    dimensions in increasing order.
+
+    Example:
+    ```
+      %reduce = linalg.reduce
+          ins(%input:tensor<16x32x64xf32>)
+          outs(%init:tensor<16x64xf32>)
+          dimensions = [1]
+          (%in: f32, %out: f32) {
+            %0 = arith.addf %in, %out: f32
+            linalg.yield %0: f32
+          }
+    ```
+  }];
+
+  let arguments = (ins
+    // Input arg
+    Variadic<TensorOrMemref>:$inputs,
+    // Output arg
+    Variadic<TensorOrMemref>:$inits,
+
+    DenseI64ArrayAttr:$dimensions
+  );
+  let results = (outs Variadic<TensorOrMemref>);
+  let regions = (region SizedRegion<1>:$combiner);
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    // Declare functions necessary for LinalgStructuredInterface.
+    ArrayAttr getIteratorTypes();
+    ArrayAttr getIndexingMaps();
+
+    // Implement functions necessary for DestinationStyleOpInterface.
+    mlir::ValueRange getOutputs() { return getInits(); }
+    unsigned getNumInputs() { return getInputs().size(); };
+    unsigned getNumOutputs() { return getInits().size(); };
+    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                              mlir::ArrayRef<mlir::NamedAttribute>)>
+    getRegionBuilder() {
+      return nullptr;
+    }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 952fbde9b487f..ddf8d7e4e76d1 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -243,12 +243,21 @@ class AffineMap {
 
   /// Returns a new AffineMap with the same number of dims and symbols and one
   /// less result at `pos`, dropped.
-  AffineMap dropResult(unsigned pos) {
+  AffineMap dropResult(int64_t pos) {
     auto exprs = llvm::to_vector<4>(getResults());
     exprs.erase(exprs.begin() + pos);
     return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
   }
 
+  // Returns a new AffineMap with the same number of dims and symbols, but all
+  // positions in `positions` dropped from results.
+  AffineMap dropResults(ArrayRef<int64_t> positions) {
+    AffineMap resultMap = *this;
+    for (int64_t pos : positions)
+      resultMap = resultMap.dropResult(pos);
+    return resultMap;
+  }
+
   /// Returns a new AffineMap with the same number of dims and symbols and an
   /// extra result inserted at `pos`.
   AffineMap insertResult(AffineExpr expr, unsigned pos) {

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c088310eddc2d..6fe0593b5bfd4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -124,7 +124,8 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
 static ParseResult
 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
                              SmallVectorImpl<Type> &inputTypes,
-                             SmallVectorImpl<Type> &outputTypes) {
+                             SmallVectorImpl<Type> &outputTypes,
+                             bool addOperandSegmentSizes = true) {
   SMLoc inputsOperandsLoc, outputsOperandsLoc;
   SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
       outputsOperands;
@@ -155,10 +156,12 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
                              result.operands))
     return failure();
 
-  result.addAttribute("operand_segment_sizes",
-                      parser.getBuilder().getDenseI32ArrayAttr(
-                          {static_cast<int32_t>(inputsOperands.size()),
-                           static_cast<int32_t>(outputsOperands.size())}));
+  if (addOperandSegmentSizes) {
+    result.addAttribute("operand_segment_sizes",
+                        parser.getBuilder().getDenseI32ArrayAttr(
+                            {static_cast<int32_t>(inputsOperands.size()),
+                             static_cast<int32_t>(outputsOperands.size())}));
+  }
   return success();
 }
 
@@ -1180,6 +1183,209 @@ LogicalResult GenericOp::fold(ArrayRef<Attribute>,
   return foldMemRefCast(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// ReduceOp
+//===----------------------------------------------------------------------===//
+
+ArrayAttr ReduceOp::getIteratorTypes() {
+  int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
+  SmallVector<StringRef> iteratorTypes(inputRank,
+                                       getParallelIteratorTypeName());
+  for (int64_t reductionDim : getDimensions())
+    iteratorTypes[reductionDim] = getReductionIteratorTypeName();
+  return Builder(getContext()).getStrArrayAttr(iteratorTypes);
+}
+
+ArrayAttr ReduceOp::getIndexingMaps() {
+  int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
+  SmallVector<AffineMap> affineMaps(
+      getNumInputs(),
+      AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
+  AffineMap resultMap =
+      AffineMap::getMultiDimIdentityMap(inputRank, getContext())
+          .dropResults(getDimensions());
+  for (int64_t i = 0, e = getNumOutputs(); i < e; ++i)
+    affineMaps.push_back(resultMap);
+  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
+}
+
+void ReduceOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  SmallVector<Value> inputBuffers = getInputBufferOperands();
+  SmallVector<Value> outputBuffers = getOutputBufferOperands();
+  getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
+                        outputBuffers);
+}
+
+static ParseResult parseDstStyleOp(
+    OpAsmParser &parser, OperationState &result,
+    function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
+        nullptr) {
+  // Parse `ins` and `outs`.
+  SmallVector<Type, 4> inputTypes, outputTypes;
+  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
+                                   /*addOperandSegmentSizes=*/false))
+    return failure();
+
+  // Add result types.
+  for (Type outputType : outputTypes) {
+    if (!outputType.isa<RankedTensorType>())
+      return failure();
+    result.addTypes(outputType);
+  }
+
+  // Parse required attributes.
+  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
+    return failure();
+
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+  return success();
+}
+
+static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
+                                          NamedAttrList &attributes,
+                                          StringRef attributeName) {
+  if (parser.parseKeyword(attributeName) || parser.parseEqual())
+    return failure();
+
+  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
+  return success();
+}
+
+ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
+  if (parseDstStyleOp(
+          parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
+            return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
+          }))
+    return failure();
+
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+                               /*allowType=*/true, /*allowAttrs=*/true)) {
+    return failure();
+  }
+
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+
+  return success();
+}
+
+static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
+                                   ArrayRef<int64_t> attributeValue) {
+  p << " " << attributeName << " = [" << attributeValue << "] ";
+}
+
+void ReduceOp::print(OpAsmPrinter &p) {
+  printCommonStructuredOpParts(p, getInputs(), getOutputs());
+  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
+  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
+
+  p << "(";
+  llvm::interleaveComma(getCombiner().getArguments(), p,
+                        [&](auto arg) { p.printRegionArgument(arg); });
+  p << ") ";
+
+  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
+}
+
+LogicalResult ReduceOp::verify() {
+  ArrayRef<int64_t> dimensionsRef = getDimensions();
+
+  for (int64_t i = 1; i < getNumInputs(); ++i) {
+    if (getInputs()[i].getType().cast<ShapedType>().getShape() !=
+        getInputs()[0].getType().cast<ShapedType>().getShape()) {
+      return emitOpError() << "expects all inputs to have the same shapes. "
+                              "Shape at input-index "
+                           << i
+                           << " is not equal to the shape at input-index 0.";
+    }
+  }
+  for (int64_t i = 1; i < getNumOutputs(); ++i) {
+    if (getInits()[i].getType().cast<ShapedType>().getShape() !=
+        getInits()[0].getType().cast<ShapedType>().getShape()) {
+      return emitOpError() << "expects all outputs to have the same shapes. "
+                              "Shape at output-index "
+                           << i
+                           << " is not equal to the shape at output-index 0.";
+    }
+  }
+  auto inputType = getInputs()[0].getType().cast<ShapedType>();
+  auto initType = getInits()[0].getType().cast<ShapedType>();
+
+  DenseSet<int64_t> dimensionsToReduce;
+  int64_t lastDimension = -1;
+  for (int64_t dimension : dimensionsRef) {
+    if (dimension < 0 || dimension >= inputType.getRank()) {
+      return emitOpError()
+             << "dimensions for reduction should be in the range [0, "
+             << inputType.getRank() - 1 << "].";
+    }
+    if (dimension <= lastDimension) {
+      return emitOpError()
+             << "reduction dimensions are not in increasing order: "
+             << dimensionsRef;
+    }
+
+    lastDimension = dimension;
+    dimensionsToReduce.insert(dimension);
+  }
+
+  auto inputDims = inputType.getShape();
+  auto initDims = initType.getShape();
+
+  // Input dimensions that will be left after the reduction.
+  SmallVector<int64_t> reducedInputDims;
+  for (const auto &en : llvm::enumerate(inputDims)) {
+    if (!dimensionsToReduce.count(en.index()))
+      reducedInputDims.push_back(en.value());
+  }
+
+  if (reducedInputDims.size() != initType.getRank()) {
+    return emitOpError() << "number of dimensions after reduction "
+                         << reducedInputDims.size()
+                         << " doesn't match the init rank "
+                         << initType.getRank();
+  }
+
+  if (reducedInputDims != initDims)
+    return emitOpError() << "init dimensions [" << initDims
+                         << "] doesn't match input dimensions after reduction ["
+                         << reducedInputDims << "]";
+
+  Block *block = getBody();
+  if (block->getNumArguments() != this->getNumOperands())
+    return emitOpError()
+           << "mismatching number of operands and block arguments";
+
+  // Check that the first block arguments match the element type of the inputs.
+  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
+    Type inputElementType = input.getType().cast<ShapedType>().getElementType();
+    if (inputElementType != bbArg.getType())
+      return emitOpError()
+             << "input element type " << inputElementType
+             << " does not match corresponding block argument type "
+             << bbArg.getType();
+  }
+
+  // Check that the last block arguments match the element type of the outputs.
+  for (auto [output, bbArg] : llvm::zip(
+           getOutputs(), block->getArguments().take_back(getNumOutputs()))) {
+    auto outputElementType =
+        output.getType().cast<ShapedType>().getElementType();
+    if (outputElementType != bbArg.getType())
+      return emitOpError()
+             << "output element type " << outputElementType
+             << " does not match corresponding block argument type "
+             << bbArg.getType();
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // InitTensorOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a7fe2f09e533f..d69a79869b1df 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -415,3 +415,175 @@ func.func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) {
         }
         return
 }
+
+// -----
+
+func.func @reduce_input_vs_init_dimension_mismatch(
+    %input: tensor<16x32x64xf32>,
+    %init: tensor<16x64xf32>)  -> tensor<16x64xf32> {
+  // expected-error @+1 {{'linalg.reduce' op init dimensions [16, 64] doesn't match input dimensions after reduction [16, 32]}}
+  %reduce = linalg.reduce
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16x64xf32>)
+      dimensions = [2]
+      (%in: f32, %out: f32) {
+        %0 = arith.addf %in, %out: f32
+        linalg.yield %0: f32
+      }
+  func.return %reduce : tensor<16x64xf32>
+}
+
+// -----
+
+func.func @reduce_dimensions_out_of_range(%input: tensor<16x32x64xf32>,
+    %init: tensor<16x64xf32>)  -> tensor<16x64xf32> {
+  // expected-error @+1 {{'linalg.reduce' op dimensions for reduction should be in the range [0, 2].}}
+  %reduce = linalg.reduce
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16x64xf32>)
+      dimensions = [3]
+      (%in: f32, %out: f32) {
+        %0 = arith.addf %in, %out: f32
+        linalg.yield %0: f32
+      }
+  func.return %reduce : tensor<16x64xf32>
+}
+
+// -----
+
+func.func @reduce_duplicate_dimensions(%input: tensor<16x32x64xf32>,
+    %init: tensor<16xf32>)  -> tensor<16xf32> {
+  // expected-error @+1 {{'linalg.reduce' op reduction dimensions are not in increasing order: 1, 1}}
+  %reduce = linalg.reduce
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16xf32>)
+      dimensions = [1, 1]
+      (%in: f32, %out: f32) {
+        %0 = arith.addf %in, %out: f32
+        linalg.yield %0: f32
+      }
+  func.return %reduce : tensor<16xf32>
+}
+
+// -----
+
+func.func @reduce_non_increasing_dimensions(%input: tensor<16x32x64xf32>,
+    %init: tensor<16xf32>)  -> tensor<16xf32> {
+  // expected-error @+1 {{'linalg.reduce' op reduction dimensions are not in increasing order: 2, 1}}
+  %reduce = linalg.reduce
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16xf32>)
+      dimensions = [2, 1]
+      (%in: f32, %out: f32) {
+        %0 = arith.addf %in, %out: f32
+        linalg.yield %0: f32
+      }
+  func.return %reduce : tensor<16xf32>
+}
+
+// -----
+
+func.func @reduce_reduced_input_init_rank_mismatch(%input: tensor<16x32x64xf32>,
+    %init: tensor<16x64xf32>)  -> tensor<16x64xf32> {
+  // expected-error @+1 {{'linalg.reduce' op number of dimensions after reduction 1 doesn't match the init rank 2}}
+  %reduce = linalg.reduce
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16x64xf32>)
+      dimensions = [1, 2]
+      (%in: f32, %out: f32) {
+        %0 = arith.addf %in, %out: f32
+        linalg.yield %0: f32
+      }
+  func.return %reduce : tensor<16x64xf32>
+}
+
+// -----
+
+func.func @reduce_wrong_number_of_block_arguments(
+    %input1: tensor<16x32x64xf32>,
+    %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>,
+    %init2: tensor<16x64xf32>)  -> (tensor<16x64xf32>, tensor<16x64xf32>) {
+  // expected-error @+1{{'linalg.reduce' op mismatching number of operands and block arguments}}
+  %reduce, %reduce2 = linalg.reduce
+      ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
+      outs(%init1, %init2 : tensor<16x64xf32>, tensor<16x64xf32>)
+      dimensions = [1]
+      (%in: f32, %out: f32) {
+        %0 = arith.addf %in, %out: f32
+        linalg.yield %0: f32
+      }
+  func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<16x64xf32>
+}
+
+// -----
+
+func.func @reduce_wrong_block_argument_input_type(
+    %input1: tensor<16x32x64xf32>,
+    %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>,
+    %init2: tensor<16x64xf32>)  -> (tensor<16x64xf32>, tensor<16x64xf32>) {
+  // expected-error @+1{{'linalg.reduce' op input element type 'f32' does not match corresponding block argument type 'f64'}}
+  %reduce, %reduce2 = linalg.reduce
+      ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
+      outs(%init1, %init2 : tensor<16x64xf32>, tensor<16x64xf32>)
+      dimensions = [1]
+      (%in1: f32, %in2: f64, %out1: f32, %out2: f64) {
+        %0 = arith.addf %in1, %out1: f32
+        %1 = arith.addf %in2, %out2: f64
+        linalg.yield %0, %1: f32, f64
+      }
+  func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<16x64xf32>
+}
+
+// -----
+
+func.func @reduce_wrong_block_argument_output_type(
+    %input1: tensor<16x32x64xf32>,
+    %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>,
+    %init2: tensor<16x64xf64>)  -> (tensor<16x64xf32>, tensor<16x64xf32>) {
+  // expected-error @+1{{'linalg.reduce' op output element type 'f64' does not match corresponding block argument type 'f32'}}
+  %reduce, %reduce2 = linalg.reduce
+      ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
+      outs(%init1, %init2 : tensor<16x64xf32>, tensor<16x64xf64>)
+      dimensions = [1]
+      (%in1: f32, %in2: f32, %out1: f32, %out2: f32) {
+        %0 = arith.addf %in1, %out1: f32
+        linalg.yield %0, %out2: f32, f32
+      }
+  func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<16x64xf64>
+}
+
+// -----
+
+func.func @reduce_
diff erent_input_shapes(%input1: tensor<16x32x64xf32>,
+    %init1: tensor<16x64xf32>, %input2: tensor<17x32x64xf32>,
+    %init2: tensor<17x64xf32>)  -> (tensor<16x64xf32>, tensor<17x64xf32>) {
+  // expected-error @+1{{'linalg.reduce' op expects all inputs to have the same shapes. Shape at input-index 1 is not equal to the shape at input-index 0.}}
+  %reduce, %reduce2 = linalg.reduce
+      ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<17x32x64xf32>)
+      outs(%init1, %init2 : tensor<16x64xf32>, tensor<17x64xf32>)
+      dimensions = [1]
+      (%in1: f32, %in2: f32, %out1: f32, %out2: f32) {
+        %0 = arith.addf %in1, %out1: f32
+        %1 = arith.addf %in2, %out2: f32
+        linalg.yield %0, %1: f32, f32
+      }
+  func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<17x64xf32>
+}
+
+// -----
+
+func.func @reduce_
diff erent_output_shapes(%input1: tensor<16x32x64xf32>,
+    %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>,
+    %init2: tensor<17x64xf32>)  -> (tensor<16x64xf32>, tensor<17x64xf32>) {
+  // expected-error @+1{{'linalg.reduce' op expects all outputs to have the same shapes. Shape at output-index 1 is not equal to the shape at output-index 0.}}
+  %reduce, %reduce2 = linalg.reduce
+      ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
+      outs(%init1, %init2 : tensor<16x64xf32>, tensor<17x64xf32>)
+      dimensions = [1]
+      (%in1: f32, %in2: f32, %out1: f32, %out2: f32) {
+        %0 = arith.addf %in1, %out1: f32
+        %1 = arith.addf %in2, %out2: f32
+        linalg.yield %0, %1: f32, f32
+      }
+  func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<17x64xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index ae68fb3ddc4a4..3d2e9bfd228eb 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -366,3 +366,39 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
 }
 // CHECK-LABEL: func @mixed_parallel_reduced_results
 //       CHECK:     linalg.generic
+
+// -----
+
+func.func @reduce(%input: tensor<16x32x64xf32>,
+                     %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+  %reduce = linalg.reduce
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16x64xf32>)
+      dimensions = [1]
+      (%in: f32, %out: f32) {
+        %0 = arith.addf %in, %out: f32
+        linalg.yield %0: f32
+      }
+  func.return %reduce : tensor<16x64xf32>
+}
+// CHECK-LABEL: func @reduce
+//       CHECK:     linalg.reduce
+
+// -----
+
+func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
+    %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xi64>,
+    %init2: tensor<16x64xi64>)  -> (tensor<16x64xf32>, tensor<16x64xi64>) {
+  %reduce, %reduce2 = linalg.reduce
+      ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xi64>)
+      outs(%init1, %init2 : tensor<16x64xf32>, tensor<16x64xi64>)
+      dimensions = [1]
+      (%in1: f32, %in2: i64, %out1: f32, %out2: i64) {
+        %0 = arith.addf %in1, %out1: f32
+        %1 = arith.addi %in2, %out2: i64
+        linalg.yield %0, %1: f32, i64
+      }
+  func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<16x64xi64>
+}
+// CHECK-LABEL: func @variadic_reduce
+//       CHECK:     linalg.reduce


        


More information about the Mlir-commits mailing list