[Mlir-commits] [mlir] d02233f - [mlir][Linalg] Add ReduceOp to Linalg structured ops.
Adrian Kuegel
llvmlistbot at llvm.org
Thu Sep 29 07:23:27 PDT 2022
Author: Adrian Kuegel
Date: 2022-09-29T16:23:12+02:00
New Revision: d02233f0da17c73f2070b5d59c80547102fa12a3
URL: https://github.com/llvm/llvm-project/commit/d02233f0da17c73f2070b5d59c80547102fa12a3
DIFF: https://github.com/llvm/llvm-project/commit/d02233f0da17c73f2070b5d59c80547102fa12a3.diff
LOG: [mlir][Linalg] Add ReduceOp to Linalg structured ops.
This will allow to model (variadic) reductions with this special op instead of
using GenericOp.
RFC: https://discourse.llvm.org/t/rfc-primitive-ops-add-mapop-reductionop-transposeop-broadcastop-to-linalg/64184
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 02df0511ca30e..3fc0968ff6102 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -221,6 +221,67 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [AttrSizedOperandSegments]> {
}
+//===----------------------------------------------------------------------===//
+// Reduce op.
+//===----------------------------------------------------------------------===//
+
+def TensorOrMemref :
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
+
+def ReduceOp : LinalgStructuredBase_Op<"reduce", [
+ SameVariadicOperandSize, SingleBlockImplicitTerminator<"YieldOp">
+ ]> {
+ let summary = "Reduce operator";
+ let description = [{
+ Executes `combiner` on the `dimensions` of `inputs` and returns the
+ reduced result. The `dimensions` attribute needs to list the reduction
+ dimensions in increasing order.
+
+ Example:
+ ```
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [1]
+ (%in: f32, %out: f32) {
+ %0 = arith.addf %in, %out: f32
+ linalg.yield %0: f32
+ }
+ ```
+ }];
+
+ let arguments = (ins
+ // Input arg
+ Variadic<TensorOrMemref>:$inputs,
+ // Output arg
+ Variadic<TensorOrMemref>:$inits,
+
+ DenseI64ArrayAttr:$dimensions
+ );
+ let results = (outs Variadic<TensorOrMemref>);
+ let regions = (region SizedRegion<1>:$combiner);
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Declare functions necessary for LinalgStructuredInterface.
+ ArrayAttr getIteratorTypes();
+ ArrayAttr getIndexingMaps();
+
+ // Implement functions necessary for DestinationStyleOpInterface.
+ mlir::ValueRange getOutputs() { return getInits(); }
+ unsigned getNumInputs() { return getInputs().size(); };
+ unsigned getNumOutputs() { return getInits().size(); };
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder() {
+ return nullptr;
+ }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 952fbde9b487f..ddf8d7e4e76d1 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -243,12 +243,21 @@ class AffineMap {
/// Returns a new AffineMap with the same number of dims and symbols and one
/// less result at `pos`, dropped.
- AffineMap dropResult(unsigned pos) {
+ AffineMap dropResult(int64_t pos) {
auto exprs = llvm::to_vector<4>(getResults());
exprs.erase(exprs.begin() + pos);
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
}
+ // Returns a new AffineMap with the same number of dims and symbols, but all
+ // positions in `positions` dropped from results.
+ AffineMap dropResults(ArrayRef<int64_t> positions) {
+ AffineMap resultMap = *this;
+ for (int64_t pos : positions)
+ resultMap = resultMap.dropResult(pos);
+ return resultMap;
+ }
+
/// Returns a new AffineMap with the same number of dims and symbols and an
/// extra result inserted at `pos`.
AffineMap insertResult(AffineExpr expr, unsigned pos) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c088310eddc2d..6fe0593b5bfd4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -124,7 +124,8 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
- SmallVectorImpl<Type> &outputTypes) {
+ SmallVectorImpl<Type> &outputTypes,
+ bool addOperandSegmentSizes = true) {
SMLoc inputsOperandsLoc, outputsOperandsLoc;
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
outputsOperands;
@@ -155,10 +156,12 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
result.operands))
return failure();
- result.addAttribute("operand_segment_sizes",
- parser.getBuilder().getDenseI32ArrayAttr(
- {static_cast<int32_t>(inputsOperands.size()),
- static_cast<int32_t>(outputsOperands.size())}));
+ if (addOperandSegmentSizes) {
+ result.addAttribute("operand_segment_sizes",
+ parser.getBuilder().getDenseI32ArrayAttr(
+ {static_cast<int32_t>(inputsOperands.size()),
+ static_cast<int32_t>(outputsOperands.size())}));
+ }
return success();
}
@@ -1180,6 +1183,209 @@ LogicalResult GenericOp::fold(ArrayRef<Attribute>,
return foldMemRefCast(*this);
}
+//===----------------------------------------------------------------------===//
+// ReduceOp
+//===----------------------------------------------------------------------===//
+
+ArrayAttr ReduceOp::getIteratorTypes() {
+ int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
+ SmallVector<StringRef> iteratorTypes(inputRank,
+ getParallelIteratorTypeName());
+ for (int64_t reductionDim : getDimensions())
+ iteratorTypes[reductionDim] = getReductionIteratorTypeName();
+ return Builder(getContext()).getStrArrayAttr(iteratorTypes);
+}
+
+ArrayAttr ReduceOp::getIndexingMaps() {
+ int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
+ SmallVector<AffineMap> affineMaps(
+ getNumInputs(),
+ AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
+ AffineMap resultMap =
+ AffineMap::getMultiDimIdentityMap(inputRank, getContext())
+ .dropResults(getDimensions());
+ for (int64_t i = 0, e = getNumOutputs(); i < e; ++i)
+ affineMaps.push_back(resultMap);
+ return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
+}
+
+void ReduceOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ SmallVector<Value> inputBuffers = getInputBufferOperands();
+ SmallVector<Value> outputBuffers = getOutputBufferOperands();
+ getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
+ outputBuffers);
+}
+
+static ParseResult parseDstStyleOp(
+ OpAsmParser &parser, OperationState &result,
+ function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
+ nullptr) {
+ // Parse `ins` and `outs`.
+ SmallVector<Type, 4> inputTypes, outputTypes;
+ if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
+ /*addOperandSegmentSizes=*/false))
+ return failure();
+
+ // Add result types.
+ for (Type outputType : outputTypes) {
+ if (!outputType.isa<RankedTensorType>())
+ return failure();
+ result.addTypes(outputType);
+ }
+
+ // Parse required attributes.
+ if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
+ return failure();
+
+ // Parse optional attributes.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ return success();
+}
+
+static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
+ NamedAttrList &attributes,
+ StringRef attributeName) {
+ if (parser.parseKeyword(attributeName) || parser.parseEqual())
+ return failure();
+
+ attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
+ return success();
+}
+
+ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
+ if (parseDstStyleOp(
+ parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
+ return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
+ }))
+ return failure();
+
+ SmallVector<OpAsmParser::Argument> regionArgs;
+ if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+ /*allowType=*/true, /*allowAttrs=*/true)) {
+ return failure();
+ }
+
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+
+ return success();
+}
+
+static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
+ ArrayRef<int64_t> attributeValue) {
+ p << " " << attributeName << " = [" << attributeValue << "] ";
+}
+
+void ReduceOp::print(OpAsmPrinter &p) {
+ printCommonStructuredOpParts(p, getInputs(), getOutputs());
+ printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
+ p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
+
+ p << "(";
+ llvm::interleaveComma(getCombiner().getArguments(), p,
+ [&](auto arg) { p.printRegionArgument(arg); });
+ p << ") ";
+
+ p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
+}
+
+LogicalResult ReduceOp::verify() {
+ ArrayRef<int64_t> dimensionsRef = getDimensions();
+
+ for (int64_t i = 1; i < getNumInputs(); ++i) {
+ if (getInputs()[i].getType().cast<ShapedType>().getShape() !=
+ getInputs()[0].getType().cast<ShapedType>().getShape()) {
+ return emitOpError() << "expects all inputs to have the same shapes. "
+ "Shape at input-index "
+ << i
+ << " is not equal to the shape at input-index 0.";
+ }
+ }
+ for (int64_t i = 1; i < getNumOutputs(); ++i) {
+ if (getInits()[i].getType().cast<ShapedType>().getShape() !=
+ getInits()[0].getType().cast<ShapedType>().getShape()) {
+ return emitOpError() << "expects all outputs to have the same shapes. "
+ "Shape at output-index "
+ << i
+ << " is not equal to the shape at output-index 0.";
+ }
+ }
+ auto inputType = getInputs()[0].getType().cast<ShapedType>();
+ auto initType = getInits()[0].getType().cast<ShapedType>();
+
+ DenseSet<int64_t> dimensionsToReduce;
+ int64_t lastDimension = -1;
+ for (int64_t dimension : dimensionsRef) {
+ if (dimension < 0 || dimension >= inputType.getRank()) {
+ return emitOpError()
+ << "dimensions for reduction should be in the range [0, "
+ << inputType.getRank() - 1 << "].";
+ }
+ if (dimension <= lastDimension) {
+ return emitOpError()
+ << "reduction dimensions are not in increasing order: "
+ << dimensionsRef;
+ }
+
+ lastDimension = dimension;
+ dimensionsToReduce.insert(dimension);
+ }
+
+ auto inputDims = inputType.getShape();
+ auto initDims = initType.getShape();
+
+ // Input dimensions that will be left after the reduction.
+ SmallVector<int64_t> reducedInputDims;
+ for (const auto &en : llvm::enumerate(inputDims)) {
+ if (!dimensionsToReduce.count(en.index()))
+ reducedInputDims.push_back(en.value());
+ }
+
+ if (reducedInputDims.size() != initType.getRank()) {
+ return emitOpError() << "number of dimensions after reduction "
+ << reducedInputDims.size()
+ << " doesn't match the init rank "
+ << initType.getRank();
+ }
+
+ if (reducedInputDims != initDims)
+ return emitOpError() << "init dimensions [" << initDims
+ << "] doesn't match input dimensions after reduction ["
+ << reducedInputDims << "]";
+
+ Block *block = getBody();
+ if (block->getNumArguments() != this->getNumOperands())
+ return emitOpError()
+ << "mismatching number of operands and block arguments";
+
+ // Check that the first block arguments match the element type of the inputs.
+ for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
+ Type inputElementType = input.getType().cast<ShapedType>().getElementType();
+ if (inputElementType != bbArg.getType())
+ return emitOpError()
+ << "input element type " << inputElementType
+ << " does not match corresponding block argument type "
+ << bbArg.getType();
+ }
+
+ // Check that the last block arguments match the element type of the outputs.
+ for (auto [output, bbArg] : llvm::zip(
+ getOutputs(), block->getArguments().take_back(getNumOutputs()))) {
+ auto outputElementType =
+ output.getType().cast<ShapedType>().getElementType();
+ if (outputElementType != bbArg.getType())
+ return emitOpError()
+ << "output element type " << outputElementType
+ << " does not match corresponding block argument type "
+ << bbArg.getType();
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// InitTensorOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a7fe2f09e533f..d69a79869b1df 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -415,3 +415,175 @@ func.func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) {
}
return
}
+
+// -----
+
+func.func @reduce_input_vs_init_dimension_mismatch(
+ %input: tensor<16x32x64xf32>,
+ %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+ // expected-error @+1 {{'linalg.reduce' op init dimensions [16, 64] doesn't match input dimensions after reduction [16, 32]}}
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [2]
+ (%in: f32, %out: f32) {
+ %0 = arith.addf %in, %out: f32
+ linalg.yield %0: f32
+ }
+ func.return %reduce : tensor<16x64xf32>
+}
+
+// -----
+
+func.func @reduce_dimensions_out_of_range(%input: tensor<16x32x64xf32>,
+ %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+ // expected-error @+1 {{'linalg.reduce' op dimensions for reduction should be in the range [0, 2].}}
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [3]
+ (%in: f32, %out: f32) {
+ %0 = arith.addf %in, %out: f32
+ linalg.yield %0: f32
+ }
+ func.return %reduce : tensor<16x64xf32>
+}
+
+// -----
+
+func.func @reduce_duplicate_dimensions(%input: tensor<16x32x64xf32>,
+ %init: tensor<16xf32>) -> tensor<16xf32> {
+ // expected-error @+1 {{'linalg.reduce' op reduction dimensions are not in increasing order: 1, 1}}
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16xf32>)
+ dimensions = [1, 1]
+ (%in: f32, %out: f32) {
+ %0 = arith.addf %in, %out: f32
+ linalg.yield %0: f32
+ }
+ func.return %reduce : tensor<16xf32>
+}
+
+// -----
+
+func.func @reduce_non_increasing_dimensions(%input: tensor<16x32x64xf32>,
+ %init: tensor<16xf32>) -> tensor<16xf32> {
+ // expected-error @+1 {{'linalg.reduce' op reduction dimensions are not in increasing order: 2, 1}}
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16xf32>)
+ dimensions = [2, 1]
+ (%in: f32, %out: f32) {
+ %0 = arith.addf %in, %out: f32
+ linalg.yield %0: f32
+ }
+ func.return %reduce : tensor<16xf32>
+}
+
+// -----
+
+func.func @reduce_reduced_input_init_rank_mismatch(%input: tensor<16x32x64xf32>,
+ %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+ // expected-error @+1 {{'linalg.reduce' op number of dimensions after reduction 1 doesn't match the init rank 2}}
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [1, 2]
+ (%in: f32, %out: f32) {
+ %0 = arith.addf %in, %out: f32
+ linalg.yield %0: f32
+ }
+ func.return %reduce : tensor<16x64xf32>
+}
+
+// -----
+
+func.func @reduce_wrong_number_of_block_arguments(
+ %input1: tensor<16x32x64xf32>,
+ %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>,
+ %init2: tensor<16x64xf32>) -> (tensor<16x64xf32>, tensor<16x64xf32>) {
+ // expected-error @+1{{'linalg.reduce' op mismatching number of operands and block arguments}}
+ %reduce, %reduce2 = linalg.reduce
+ ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
+ outs(%init1, %init2 : tensor<16x64xf32>, tensor<16x64xf32>)
+ dimensions = [1]
+ (%in: f32, %out: f32) {
+ %0 = arith.addf %in, %out: f32
+ linalg.yield %0: f32
+ }
+ func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<16x64xf32>
+}
+
+// -----
+
+func.func @reduce_wrong_block_argument_input_type(
+ %input1: tensor<16x32x64xf32>,
+ %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>,
+ %init2: tensor<16x64xf32>) -> (tensor<16x64xf32>, tensor<16x64xf32>) {
+ // expected-error @+1{{'linalg.reduce' op input element type 'f32' does not match corresponding block argument type 'f64'}}
+ %reduce, %reduce2 = linalg.reduce
+ ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
+ outs(%init1, %init2 : tensor<16x64xf32>, tensor<16x64xf32>)
+ dimensions = [1]
+ (%in1: f32, %in2: f64, %out1: f32, %out2: f64) {
+ %0 = arith.addf %in1, %out1: f32
+ %1 = arith.addf %in2, %out2: f64
+ linalg.yield %0, %1: f32, f64
+ }
+ func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<16x64xf32>
+}
+
+// -----
+
+func.func @reduce_wrong_block_argument_output_type(
+ %input1: tensor<16x32x64xf32>,
+ %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>,
+ %init2: tensor<16x64xf64>) -> (tensor<16x64xf32>, tensor<16x64xf32>) {
+ // expected-error @+1{{'linalg.reduce' op output element type 'f64' does not match corresponding block argument type 'f32'}}
+ %reduce, %reduce2 = linalg.reduce
+ ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
+ outs(%init1, %init2 : tensor<16x64xf32>, tensor<16x64xf64>)
+ dimensions = [1]
+ (%in1: f32, %in2: f32, %out1: f32, %out2: f32) {
+ %0 = arith.addf %in1, %out1: f32
+ linalg.yield %0, %out2: f32, f32
+ }
+ func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<16x64xf64>
+}
+
+// -----
+
+func.func @reduce_
diff erent_input_shapes(%input1: tensor<16x32x64xf32>,
+ %init1: tensor<16x64xf32>, %input2: tensor<17x32x64xf32>,
+ %init2: tensor<17x64xf32>) -> (tensor<16x64xf32>, tensor<17x64xf32>) {
+ // expected-error @+1{{'linalg.reduce' op expects all inputs to have the same shapes. Shape at input-index 1 is not equal to the shape at input-index 0.}}
+ %reduce, %reduce2 = linalg.reduce
+ ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<17x32x64xf32>)
+ outs(%init1, %init2 : tensor<16x64xf32>, tensor<17x64xf32>)
+ dimensions = [1]
+ (%in1: f32, %in2: f32, %out1: f32, %out2: f32) {
+ %0 = arith.addf %in1, %out1: f32
+ %1 = arith.addf %in2, %out2: f32
+ linalg.yield %0, %1: f32, f32
+ }
+ func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<17x64xf32>
+}
+
+// -----
+
+func.func @reduce_
diff erent_output_shapes(%input1: tensor<16x32x64xf32>,
+ %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>,
+ %init2: tensor<17x64xf32>) -> (tensor<16x64xf32>, tensor<17x64xf32>) {
+ // expected-error @+1{{'linalg.reduce' op expects all outputs to have the same shapes. Shape at output-index 1 is not equal to the shape at output-index 0.}}
+ %reduce, %reduce2 = linalg.reduce
+ ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
+ outs(%init1, %init2 : tensor<16x64xf32>, tensor<17x64xf32>)
+ dimensions = [1]
+ (%in1: f32, %in2: f32, %out1: f32, %out2: f32) {
+ %0 = arith.addf %in1, %out1: f32
+ %1 = arith.addf %in2, %out2: f32
+ linalg.yield %0, %1: f32, f32
+ }
+ func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<17x64xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index ae68fb3ddc4a4..3d2e9bfd228eb 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -366,3 +366,39 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
}
// CHECK-LABEL: func @mixed_parallel_reduced_results
// CHECK: linalg.generic
+
+// -----
+
+func.func @reduce(%input: tensor<16x32x64xf32>,
+ %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [1]
+ (%in: f32, %out: f32) {
+ %0 = arith.addf %in, %out: f32
+ linalg.yield %0: f32
+ }
+ func.return %reduce : tensor<16x64xf32>
+}
+// CHECK-LABEL: func @reduce
+// CHECK: linalg.reduce
+
+// -----
+
+func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
+ %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xi64>,
+ %init2: tensor<16x64xi64>) -> (tensor<16x64xf32>, tensor<16x64xi64>) {
+ %reduce, %reduce2 = linalg.reduce
+ ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xi64>)
+ outs(%init1, %init2 : tensor<16x64xf32>, tensor<16x64xi64>)
+ dimensions = [1]
+ (%in1: f32, %in2: i64, %out1: f32, %out2: i64) {
+ %0 = arith.addf %in1, %out1: f32
+ %1 = arith.addi %in2, %out2: i64
+ linalg.yield %0, %1: f32, i64
+ }
+ func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<16x64xi64>
+}
+// CHECK-LABEL: func @variadic_reduce
+// CHECK: linalg.reduce
More information about the Mlir-commits
mailing list