[Mlir-commits] [mlir] 051b36b - [mlir][vector] Add accumulator operand to MultiDimReduce op
Thomas Raoux
llvmlistbot at llvm.org
Tue Jul 12 07:36:01 PDT 2022
Author: Thomas Raoux
Date: 2022-07-12T14:28:30Z
New Revision: 051b36ba2857f5a532f6fd0eeb574731cb1e0df3
URL: https://github.com/llvm/llvm-project/commit/051b36ba2857f5a532f6fd0eeb574731cb1e0df3
DIFF: https://github.com/llvm/llvm-project/commit/051b36ba2857f5a532f6fd0eeb574731cb1e0df3.diff
LOG: [mlir][vector] Add accumulator operand to MultiDimReduce op
This allows vectorizing linalg reductions without changing the operation
order. Therefore this produce a valid vectorization even if operations
are not associative.
Differential Revision: https://reviews.llvm.org/D129535
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
mlir/test/Dialect/Vector/vector-unroll-options.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 9a29825eda506..212084a75d4ca 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -318,6 +318,7 @@ def Vector_ReductionOp :
def Vector_MultiDimReductionOp :
Vector_Op<"multi_reduction", [NoSideEffect,
+ AllTypesMatch<["dest", "acc"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
@@ -325,6 +326,7 @@ def Vector_MultiDimReductionOp :
["getShapeForUnroll"]>]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
AnyVector:$source,
+ AnyType:$acc,
I64ArrayAttr:$reduction_dims)>,
Results<(outs AnyType:$dest)> {
let summary = "Multi-dimensional reduction operation";
@@ -332,19 +334,20 @@ def Vector_MultiDimReductionOp :
Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n)
using the given operation (add/mul/min/max for int/fp and and/or/xor for
int only).
+ Takes an initial accumulator operand.
Example:
```mlir
- %1 = vector.multi_reduction <add>, %0 [1, 3] :
+ %1 = vector.multi_reduction <add>, %0, %acc0 [1, 3] :
vector<4x8x16x32xf32> into vector<4x16xf32>
- %2 = vector.multi_reduction <add>, %1 [0, 1] :
+ %2 = vector.multi_reduction <add>, %1, %acc1 [0, 1] :
vector<4x16xf32> into f32
```
}];
let builders = [
- OpBuilder<(ins "Value":$source, "ArrayRef<bool>":$reductionMask,
- "CombiningKind":$kind)>
+ OpBuilder<(ins "Value":$source, "Value":$acc,
+ "ArrayRef<bool>":$reductionMask, "CombiningKind":$kind)>
];
let extraClassDeclaration = [{
static StringRef getKindAttrStrName() { return "kind"; }
@@ -378,8 +381,9 @@ def Vector_MultiDimReductionOp :
}
}];
let assemblyFormat =
- "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)";
+ "$kind `,` $source `,` $acc attr-dict $reduction_dims `:` type($source) `to` type($dest)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_BroadcastOp :
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c406dc5de88cf..3422ab7c0a76e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -174,13 +174,13 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
/// assumes that `reductionOp` has two operands and one of them is the reduction
/// initial value.
-static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
- Value valueToReduce,
- const SmallVector<bool> &reductionMask) {
+static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
+ Value valueToReduce, Value acc,
+ const SmallVector<bool> &reductionMask) {
auto maybeKind = getCombinerOpKind(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind");
return b.create<vector::MultiDimReductionOp>(
- reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
+ reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind);
}
static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
@@ -315,10 +315,7 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
(outputType && reduceType.getShape() == outputType.getShape()))
return nullptr;
SmallVector<bool> reductionMask = getReductionMask(linalgOp);
- Value reduce = buildMultiDimReduce(b, op, reduceVec, reductionMask);
- return b.create(op->getLoc(), op->getName().getIdentifier(),
- /*operands=*/{reduce, outputVec}, reduce.getType(),
- op->getAttrs());
+ return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask);
}
/// Generic vectorization for a single operation `op`, given already vectorized
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 00db4650f1206..f803868c2150d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -334,34 +334,14 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
void vector::MultiDimReductionOp::build(OpBuilder &builder,
OperationState &result, Value source,
- ArrayRef<bool> reductionMask,
+ Value acc, ArrayRef<bool> reductionMask,
CombiningKind kind) {
SmallVector<int64_t> reductionDims;
for (const auto &en : llvm::enumerate(reductionMask))
if (en.value())
reductionDims.push_back(en.index());
- build(builder, result, kind, source, builder.getI64ArrayAttr(reductionDims));
-}
-
-LogicalResult MultiDimReductionOp::inferReturnTypes(
- MLIRContext *, Optional<Location>, ValueRange operands,
- DictionaryAttr attributes, RegionRange,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- MultiDimReductionOp::Adaptor op(operands, attributes);
- auto vectorType = op.getSource().getType().cast<VectorType>();
- SmallVector<int64_t> targetShape;
- for (auto it : llvm::enumerate(vectorType.getShape()))
- if (!llvm::any_of(op.getReductionDims().getValue(), [&](Attribute attr) {
- return attr.cast<IntegerAttr>().getValue() == it.index();
- }))
- targetShape.push_back(it.value());
- // TODO: update to also allow 0-d vectors when available.
- if (targetShape.empty())
- inferredReturnTypes.push_back(vectorType.getElementType());
- else
- inferredReturnTypes.push_back(
- VectorType::get(targetShape, vectorType.getElementType()));
- return success();
+ build(builder, result, kind, source, acc,
+ builder.getI64ArrayAttr(reductionDims));
}
OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
@@ -375,6 +355,28 @@ Optional<SmallVector<int64_t, 4>> MultiDimReductionOp::getShapeForUnroll() {
return llvm::to_vector<4>(getSourceVectorType().getShape());
}
+LogicalResult MultiDimReductionOp::verify() {
+ SmallVector<int64_t> targetShape;
+ Type inferredReturnType;
+ for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
+ if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
+ return attr.cast<IntegerAttr>().getValue() == it.index();
+ }))
+ targetShape.push_back(it.value());
+ // TODO: update to also allow 0-d vectors when available.
+ if (targetShape.empty())
+ inferredReturnType = getSourceVectorType().getElementType();
+ else
+ inferredReturnType =
+ VectorType::get(targetShape, getSourceVectorType().getElementType());
+ if (getType() != inferredReturnType)
+ return emitOpError() << "destination type " << getType()
+ << " is incompatible with source type "
+ << getSourceVectorType();
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index 0e023ca448322..2582781aaab08 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -87,8 +87,8 @@ class InnerOuterDimReductionConversion
reductionMask[i] = true;
}
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
- multiReductionOp, transposeOp.getResult(), reductionMask,
- multiReductionOp.getKind());
+ multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
+ reductionMask, multiReductionOp.getKind());
return success();
}
@@ -188,11 +188,17 @@ class ReduceMultiDimReductionRank
vectorShape, multiReductionOp.getSourceVectorType().getElementType());
Value cast = rewriter.create<vector::ShapeCastOp>(
loc, castedType, multiReductionOp.getSource());
-
+ Value acc = multiReductionOp.getAcc();
+ if (flattenedParallelDim) {
+ auto accType = VectorType::get(
+ {flattenedParallelDim},
+ multiReductionOp.getSourceVectorType().getElementType());
+ acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
+ }
// 5. Creates the flattened form of vector.multi_reduction with inner/outer
// most dim as reduction.
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
- loc, cast, mask, multiReductionOp.getKind());
+ loc, cast, acc, mask, multiReductionOp.getKind());
// 6. If there are no parallel shapes, the result is a scalar.
// TODO: support 0-d vectors when available.
@@ -238,10 +244,8 @@ struct TwoDimMultiReductionToElementWise
if (!elementType.isIntOrIndexOrFloat())
return failure();
- Value result =
- rewriter.create<vector::ExtractOp>(loc, multiReductionOp.getSource(), 0)
- .getResult();
- for (int64_t i = 1; i < srcShape[0]; i++) {
+ Value result = multiReductionOp.getAcc();
+ for (int64_t i = 0; i < srcShape[0]; i++) {
auto operand = rewriter.create<vector::ExtractOp>(
loc, multiReductionOp.getSource(), i);
result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
@@ -277,8 +281,10 @@ struct TwoDimMultiReductionToReduction
for (int i = 0; i < outerDim; ++i) {
auto v = rewriter.create<vector::ExtractOp>(
loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
+ auto acc = rewriter.create<vector::ExtractOp>(
+ loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
auto reducedValue = rewriter.create<vector::ReductionOp>(
- loc, multiReductionOp.getKind(), v);
+ loc, multiReductionOp.getKind(), v, acc);
result = rewriter.create<vector::InsertElementOp>(
loc, reducedValue, result,
rewriter.create<arith::ConstantIndexOp>(loc, i));
@@ -309,6 +315,8 @@ struct OneDimMultiReductionToTwoDim
auto srcShape = srcVectorType.getShape();
auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
srcVectorType.getElementType());
+ auto accType =
+ VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
assert(!multiReductionOp.getDestType().isa<VectorType>() &&
"multi_reduction with a single dimension expects a scalar result");
@@ -319,8 +327,10 @@ struct OneDimMultiReductionToTwoDim
/// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
Value cast = rewriter.create<vector::ShapeCastOp>(
loc, castedType, multiReductionOp.getSource());
+ Value castAcc = rewriter.create<vector::BroadcastOp>(
+ loc, accType, multiReductionOp.getAcc());
Value reduced = rewriter.create<vector::MultiDimReductionOp>(
- loc, cast, mask, multiReductionOp.getKind());
+ loc, cast, castAcc, mask, multiReductionOp.getKind());
rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
ArrayRef<int64_t>{0});
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bb8cc2bfae396..76151fc358ade 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -997,11 +997,8 @@ struct MultiReduceToContract
}
auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
/*symCount=*/0, exprs, reduceOp.getContext());
- Value zero = rewriter.create<arith::ConstantOp>(
- reduceOp.getLoc(), reduceOp.getDestType(),
- rewriter.getZeroAttr(reduceOp.getDestType()));
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
- reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
+ reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
rewriter.getStrArrayAttr(iteratorTypes));
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index d75d1098d53f6..15f43dc0536c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -431,10 +431,11 @@ struct UnrollMultiReductionPattern
SmallVector<int64_t, 4> offsets =
getVectorOffset(originalSize, *targetShape, i);
+ SmallVector<Value> operands;
SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides);
-
+ loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
+ operands.push_back(slicedOperand);
SmallVector<int64_t> dstShape;
SmallVector<int64_t> destOffset;
for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
@@ -443,17 +444,22 @@ struct UnrollMultiReductionPattern
dstShape.push_back((*targetShape)[i]);
}
}
+ Value acc;
+ SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
+ // If a version of the accumulator has already been computed, use it
+ // otherwise extract the first version from the original operand.
+ auto accIt = accCache.find(destOffset);
+ if (accIt != accCache.end())
+ acc = accIt->second;
+ else
+ acc = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
+ operands.push_back(acc);
auto targetType = VectorType::get(
dstShape, reductionOp.getSourceVectorType().getElementType());
Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
- slicedOperand, targetType);
+ operands, targetType);
Value result = newOp->getResult(0);
- // Save the accumulated value until all the loops are unrolled since
- // reduction loop keeps updating the accumulator.
- auto accIt = accCache.find(destOffset);
- if (accIt != accCache.end())
- result = makeArithReduction(rewriter, loc, reductionOp.getKind(),
- result, accIt->second);
accCache[destOffset] = result;
}
// Assemble back the accumulator into a single vector.
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index b4dc8e4325673..b08d7d1ff1c10 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -10,9 +10,8 @@ func.func @vectorize_matmul(%arg0: tensor<24x12xf32>,
// CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
// CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
// CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
- // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]]
- // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]]
- // CHECK: vector.transfer_write %[[vS]], %[[C]]
+ // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+ // CHECK: vector.transfer_write %[[vR]], %[[C]]
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}
@@ -67,9 +66,8 @@ func.func @vectorize_keep_pad(
// CHECK: %[[vA:.+]] = vector.transfer_read %[[pA]]
// CHECK: %[[vB:.+]] = vector.transfer_read %[[pB]]
// CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
- // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]]
- // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]]
- // CHECK: vector.transfer_write %[[vS]], %[[C]]
+ // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+ // CHECK: vector.transfer_write %[[vR]], %[[C]]
%8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
%9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
return %9 : tensor<24x25xf32>
@@ -127,9 +125,8 @@ func.func @vectorize_pad(
tensor.yield %cst : f32
} : tensor<?x5xf32> to tensor<7x5xf32>
// CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
- // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]]
- // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]]
- // CHECK: vector.transfer_write %[[vS]], %[[C]]
+ // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+ // CHECK: vector.transfer_write %[[vR]], %[[C]]
%8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
%9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
return %9 : tensor<24x25xf32>
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index dbd09576cb76b..bbc36b12556ed 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -6,8 +6,7 @@
func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
-// CHECK: vector.multi_reduction <add>, %{{.*}} [0] : vector<1584xf32> to f32
-// CHECK: arith.addf %{{.*}}, %{{.*}} : f32
+// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [0] : vector<1584xf32> to f32
linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
outs(%C: memref<f32>)
return
@@ -19,8 +18,7 @@ func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memre
func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
-// CHECK: vector.multi_reduction <add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
-// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
outs(%C: memref<1584xf32>)
return
@@ -31,8 +29,7 @@ func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %
// CHECK-LABEL: contraction_matmul
func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
-// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
-// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
outs(%C: memref<1584x1584xf32>)
return
@@ -43,8 +40,7 @@ func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf3
// CHECK-LABEL: contraction_batch_matmul
func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
-// CHECK: vector.multi_reduction <add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
-// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
linalg.batch_matmul
ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
outs(%C: memref<1584x1584x1584xf32>)
@@ -69,10 +65,9 @@ func.func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
- // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
+ // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
- // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
- // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
+ // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
linalg.generic #matmul_trait
ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
@@ -103,10 +98,9 @@ func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<32x8xf32>) {
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
- // CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
+ // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
- // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
- // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
+ // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
linalg.generic #matmul_transpose_out_trait
ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
@@ -157,11 +151,9 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32
%C: memref<8x32xi32>) {
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x32x16xi32>
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32>
- // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
+ // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
// CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
- // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
- // CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32>
-
+ // CHECK: vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
linalg.generic #matmul_trait
ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
@@ -180,8 +172,7 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32
func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
- // CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
- // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32>
+ // CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
linalg.matmul
ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
outs(%C: memref<8x32xf32>)
@@ -560,9 +551,8 @@ func.func @matmul_tensors(
// linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
// convert it to a 2D contract.
// CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
- // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
- // CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32>
- // CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
+ // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[V2]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
+ // CHECK: %[[W:.*]] = vector.transfer_write %[[R]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
outs(%arg2: tensor<8x12xf32>)
-> tensor<8x12xf32>
@@ -801,8 +791,7 @@ func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
// CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
// CHECK: math.exp {{.*}} : vector<4x16x8xf32>
- // CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
- // CHECK: addf {{.*}} : vector<4x16xf32>
+ // CHECK: vector.multi_reduction <add>, %{{.*}}, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
// CHECK: return {{.*}} : tensor<4x16xf32>
%0 = linalg.generic {
@@ -836,8 +825,7 @@ func.func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: addf {{.*}} : vector<2x3x4x5xf32>
- // CHECK: vector.multi_reduction <add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
- // CHECK: addf {{.*}} : vector<2x5xf32>
+ // CHECK: vector.multi_reduction <add>, {{.*}}, %{{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
// CHECK: return {{.*}} : tensor<5x2xf32>
%0 = linalg.generic {
@@ -865,8 +853,7 @@ func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
- // CHECK: %[[R:.+]] = vector.multi_reduction <maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
- // CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32>
+ // CHECK: vector.multi_reduction <maxf>, {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%ident = arith.constant -3.40282e+38 : f32
%init = linalg.init_tensor [4] : tensor<4xf32>
@@ -890,8 +877,7 @@ func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
- // CHECK: %[[R:.+]] = vector.multi_reduction <minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
- // CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32>
+ // CHECK: vector.multi_reduction <minf>, {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%maxf32 = arith.constant 3.40282e+38 : f32
%init = linalg.init_tensor [4] : tensor<4xf32>
@@ -914,7 +900,7 @@ func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
- // CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
+ // CHECK: vector.multi_reduction <mul>, {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%ident = arith.constant 1.0 : f32
%init = linalg.init_tensor [4] : tensor<4xf32>
@@ -937,7 +923,7 @@ func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
- // CHECK: vector.multi_reduction <or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+ // CHECK: vector.multi_reduction <or>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
%ident = arith.constant false
%init = linalg.init_tensor [4] : tensor<4xi1>
@@ -960,7 +946,7 @@ func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
- // CHECK: vector.multi_reduction <and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+ // CHECK: vector.multi_reduction <and>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
%ident = arith.constant true
%init = linalg.init_tensor [4] : tensor<4xi1>
@@ -983,7 +969,7 @@ func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
- // CHECK: vector.multi_reduction <xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+ // CHECK: vector.multi_reduction <xor>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
%ident = arith.constant false
%init = linalg.init_tensor [4] : tensor<4xi1>
@@ -1035,8 +1021,7 @@ func.func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32>
// CHECK: subf {{.*}} : vector<4x4xf32>
// CHECK: math.exp {{.*}} : vector<4x4xf32>
- // CHECK: vector.multi_reduction <add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
- // CHECK: addf {{.*}} : vector<4xf32>
+ // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} : vector<4x4xf32> to vector<4xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
%c0 = arith.constant 0.0 : f32
%init = linalg.init_tensor [4] : tensor<4xf32>
@@ -1075,10 +1060,9 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
// CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
- // CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]] [0]
+ // CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[f0]] [0]
// CHECK-SAME: : vector<32xf32> to f32
- // CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32
- // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<f32>
+ // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
// CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
// CHECK-SAME: : vector<f32>, tensor<f32>
%2 = linalg.generic {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 84b5a45f19e65..702670095c8d5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1281,9 +1281,9 @@ func.func @do_not_swap_extract_slice_transfer_write(%arg0 : vector<8xf32>,
// -----
// CHECK-LABEL: func @vector_multi_reduction_single_parallel(
-// CHECK-SAME: %[[v:.*]]: vector<2xf32>
-func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
+// CHECK-SAME: %[[v:.*]]: vector<2xf32>,
+func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [] : vector<2xf32> to vector<2xf32>
// CHECK: return %[[v]] : vector<2xf32>
return %0 : vector<2xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 87e5f9443807c..d50315970d744 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1138,9 +1138,9 @@ func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
// -----
-func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>) -> f32 {
- // expected-error at +1 {{'vector.multi_reduction' op inferred type(s) 'vector<4xf32>' are incompatible with return type(s) of operation 'vector<16xf32>'}}
- %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<4x16xf32> to vector<16xf32>
+func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>, %acc: vector<16xf32>) -> f32 {
+ // expected-error at +1 {{'vector.multi_reduction' op destination type 'vector<16xf32>' is incompatible with source type 'vector<4x16xf32>'}}
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<4x16xf32> to vector<16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index e42c94f252c46..dc69bb0a78a6f 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -705,10 +705,13 @@ func.func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
}
// CHECK-LABEL: @multi_reduction
-func.func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
- %1 = vector.multi_reduction <add>, %0 [1, 3] :
+func.func @multi_reduction(%0: vector<4x8x16x32xf32>, %acc0: vector<4x16xf32>,
+ %acc1: f32) -> f32 {
+ // CHECK: vector.multi_reduction <add>, %{{.*}}, %{{.*}} [1, 3] : vector<4x8x16x32xf32> to vector<4x16xf32>
+ %1 = vector.multi_reduction <add>, %0, %acc0 [1, 3] :
vector<4x8x16x32xf32> to vector<4x16xf32>
- %2 = vector.multi_reduction <add>, %1 [0, 1] :
+ // CHECK: vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x16xf32> to f32
+ %2 = vector.multi_reduction <add>, %1, %acc1 [0, 1] :
vector<4x16xf32> to f32
return %2 : f32
}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index a39ac990f8281..6b372c3ef1c3e 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -1,40 +1,42 @@
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
-func.func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: func @vector_multi_reduction
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
// CHECK: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0]
-// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]] : vector<4xf32> into f32
+// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1]
-// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]] : vector<4xf32> into f32
+// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]]
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
- %0 = vector.multi_reduction <mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
return %0 : f32
}
// CHECK-LABEL: func @vector_multi_reduction_to_scalar
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
-// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]] : vector<8xf32> into f32
+// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
// CHECK: %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
// CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32>
// CHECK: return %[[RES]]
-func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
- %0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
+ %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
return %0 : vector<2x3xi32>
}
// CHECK-LABEL: func @vector_reduction_inner
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
// CHECK: %[[FLAT_RESULT_VEC_0:.+]] = arith.constant dense<0> : vector<6xi32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
@@ -44,29 +46,35 @@ func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32>
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
// CHECK: %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32>
-// CHECK: %[[V0R:.+]] = vector.reduction <add>, %[[V0]] : vector<20xi32> into i32
+// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : vector<2x3xi32>
+// CHECK: %[[V0R:.+]] = vector.reduction <add>, %[[V0]], %[[ACC0]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : index] : vector<6xi32>
// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32>
-// CHECK: %[[V1R:.+]] = vector.reduction <add>, %[[V1]] : vector<20xi32> into i32
+// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : vector<2x3xi32>
+// CHECK: %[[V1R:.+]] = vector.reduction <add>, %[[V1]], %[[ACC1]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : index] : vector<6xi32>
// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32>
-// CHECK: %[[V2R:.+]] = vector.reduction <add>, %[[V2]] : vector<20xi32> into i32
+// CHECK: %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : vector<2x3xi32>
+// CHECK: %[[V2R:.+]] = vector.reduction <add>, %[[V2]], %[[ACC2]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : index] : vector<6xi32>
// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32>
-// CHECK: %[[V3R:.+]] = vector.reduction <add>, %[[V3]] : vector<20xi32> into i32
+// CHECK: %[[ACC3:.+]] = vector.extract %[[ACC]][1, 0] : vector<2x3xi32>
+// CHECK: %[[V3R:.+]] = vector.reduction <add>, %[[V3]], %[[ACC3]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : index] : vector<6xi32>
// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32>
-// CHECK: %[[V4R:.+]] = vector.reduction <add>, %[[V4]] : vector<20xi32> into i32
+// CHECK: %[[ACC4:.+]] = vector.extract %[[ACC]][1, 1] : vector<2x3xi32>
+// CHECK: %[[V4R:.+]] = vector.reduction <add>, %[[V4]], %[[ACC4]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : index] : vector<6xi32>
/// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32>
-// CHECK: %[[V5R:.+]] = vector.reduction <add>, %[[V5]] : vector<20xi32> into i32
+// CHECK: %[[ACC5:.+]] = vector.extract %[[ACC]][1, 2] : vector<2x3xi32>
+// CHECK: %[[V5R:.+]] = vector.reduction <add>, %[[V5]], %[[ACC5]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : index] : vector<6xi32>
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
// CHECK: return %[[RESULT]]
-func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
- %0 = vector.multi_reduction <add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
+func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> {
+ %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
return %0 : vector<2x5xf32>
}
@@ -77,12 +85,12 @@ func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vect
// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
// CHECK: return %[[RESULT]]
-func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
- %0 = vector.multi_reduction <mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
+func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32>
return %0 : vector<2x4xf32>
}
// CHECK-LABEL: func @vector_multi_reduction_ordering
-// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32>, %[[ACC:.*]]: vector<2x4xf32>)
// CHECK: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<8xf32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -94,28 +102,36 @@ func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2
// CHECK: %[[C7:.+]] = arith.constant 7 : index
// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0]
-// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]] : vector<3xf32> into f32
+// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : vector<2x4xf32>
+// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<8xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1]
-// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]] : vector<3xf32> into f32
+// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : vector<2x4xf32>
+// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<8xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2]
-// CHECK: %[[RV2:.+]] = vector.reduction <mul>, %[[V2]] : vector<3xf32> into f32
+// CHECK: %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : vector<2x4xf32>
+// CHECK: %[[RV2:.+]] = vector.reduction <mul>, %[[V2]], %[[ACC2]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : index] : vector<8xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3]
-// CHECK: %[[RV3:.+]] = vector.reduction <mul>, %[[V3]] : vector<3xf32> into f32
+// CHECK: %[[ACC3:.+]] = vector.extract %[[ACC]][0, 3] : vector<2x4xf32>
+// CHECK: %[[RV3:.+]] = vector.reduction <mul>, %[[V3]], %[[ACC3]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : index] : vector<8xf32>
// CHECK: %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0]
-// CHECK: %[[RV4:.+]] = vector.reduction <mul>, %[[V4]] : vector<3xf32> into f32
+// CHECK: %[[ACC4:.+]] = vector.extract %[[ACC]][1, 0] : vector<2x4xf32>
+// CHECK: %[[RV4:.+]] = vector.reduction <mul>, %[[V4]], %[[ACC4]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : index] : vector<8xf32>
// CHECK: %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1]
-// CHECK: %[[RV5:.+]] = vector.reduction <mul>, %[[V5]] : vector<3xf32> into f32
+// CHECK: %[[ACC5:.+]] = vector.extract %[[ACC]][1, 1] : vector<2x4xf32>
+// CHECK: %[[RV5:.+]] = vector.reduction <mul>, %[[V5]], %[[ACC5]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : index] : vector<8xf32>
// CHECK: %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2]
-// CHECK: %[[RV6:.+]] = vector.reduction <mul>, %[[V6]] : vector<3xf32> into f32
+// CHECK: %[[ACC6:.+]] = vector.extract %[[ACC]][1, 2] : vector<2x4xf32>
+// CHECK: %[[RV6:.+]] = vector.reduction <mul>, %[[V6]], %[[ACC6]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : index] : vector<8xf32>
// CHECK: %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3]
-// CHECK: %[[RV7:.+]] = vector.reduction <mul>, %[[V7]] : vector<3xf32> into f32
+// CHECK: %[[ACC7:.+]] = vector.extract %[[ACC]][1, 3] : vector<2x4xf32>
+// CHECK: %[[RV7:.+]] = vector.reduction <mul>, %[[V7]], %[[ACC7]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32>
// CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
// CHECK: return %[[RESHAPED_VEC]]
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index bd06e48f48231..8a8bf86bfd38b 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -1,101 +1,107 @@
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s
-func.func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: func @vector_multi_reduction
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
+// CHECK: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[V0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
// CHECK: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
// CHECK: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
-func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <minf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: func @vector_multi_reduction_min
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
+// CHECK: %[[RV0:.+]] = arith.minf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.minf %[[V1]], %[[V0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.minf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
// CHECK: %[[RV012:.+]] = arith.minf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
// CHECK: %[[RESULT_VEC:.+]] = arith.minf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
-func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <maxf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: func @vector_multi_reduction_max
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
+// CHECK: %[[RV0:.+]] = arith.maxf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.maxf %[[V1]], %[[V0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.maxf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
// CHECK: %[[RV012:.+]] = arith.maxf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
// CHECK: %[[RESULT_VEC:.+]] = arith.maxf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
-func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
- %0 = vector.multi_reduction <and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+ %0 = vector.multi_reduction <and>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
return %0 : vector<2xi32>
}
// CHECK-LABEL: func @vector_multi_reduction_and
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
+// CHECK: %[[RV0:.+]] = arith.andi %[[V0]], %[[ACC]] : vector<2xi32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
-// CHECK: %[[RV01:.+]] = arith.andi %[[V1]], %[[V0]] : vector<2xi32>
+// CHECK: %[[RV01:.+]] = arith.andi %[[V1]], %[[RV0]] : vector<2xi32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
// CHECK: %[[RV012:.+]] = arith.andi %[[V2]], %[[RV01]] : vector<2xi32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
// CHECK: %[[RESULT_VEC:.+]] = arith.andi %[[V3]], %[[RV012]] : vector<2xi32>
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
-func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
- %0 = vector.multi_reduction <or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+ %0 = vector.multi_reduction <or>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
return %0 : vector<2xi32>
}
// CHECK-LABEL: func @vector_multi_reduction_or
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
+// CHECK: %[[RV0:.+]] = arith.ori %[[V0]], %[[ACC]] : vector<2xi32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
-// CHECK: %[[RV01:.+]] = arith.ori %[[V1]], %[[V0]] : vector<2xi32>
+// CHECK: %[[RV01:.+]] = arith.ori %[[V1]], %[[RV0]] : vector<2xi32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
// CHECK: %[[RV012:.+]] = arith.ori %[[V2]], %[[RV01]] : vector<2xi32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
// CHECK: %[[RESULT_VEC:.+]] = arith.ori %[[V3]], %[[RV012]] : vector<2xi32>
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
-func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
- %0 = vector.multi_reduction <xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+ %0 = vector.multi_reduction <xor>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
return %0 : vector<2xi32>
}
// CHECK-LABEL: func @vector_multi_reduction_xor
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
+// CHECK: %[[RV0:.+]] = arith.xori %[[V0]], %[[ACC]] : vector<2xi32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
-// CHECK: %[[RV01:.+]] = arith.xori %[[V1]], %[[V0]] : vector<2xi32>
+// CHECK: %[[RV01:.+]] = arith.xori %[[V1]], %[[RV0]] : vector<2xi32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
// CHECK: %[[RV012:.+]] = arith.xori %[[V2]], %[[RV01]] : vector<2xi32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
@@ -103,18 +109,20 @@ func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
-func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
- %0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
+ %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
return %0 : vector<2x3xi32>
}
// CHECK-LABEL: func @vector_reduction_outer
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [2, 3, 0, 1] : vector<2x3x4x5xi32> to vector<4x5x2x3xi32>
// CHECK: %[[RESHAPED:.+]] = vector.shape_cast %[[TRANSPOSED]] : vector<4x5x2x3xi32> to vector<20x6xi32>
+// CHECK: %[[FACC:.+]] = vector.shape_cast %[[ACC]] : vector<2x3xi32> to vector<6xi32>
// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED]][0] : vector<20x6xi32>
+// CHECK: %[[R:.+]] = arith.addi %[[V0]], %[[FACC]] : vector<6xi32>
// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED]][1] : vector<20x6xi32>
-// CHECK: %[[R0:.+]] = arith.addi %[[V1]], %[[V0]] : vector<6xi32>
+// CHECK: %[[R0:.+]] = arith.addi %[[V1]], %[[R]] : vector<6xi32>
// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED]][2] : vector<20x6xi32>
// CHECK: %[[R1:.+]] = arith.addi %[[V2]], %[[R0]] : vector<6xi32>
// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED]][3] : vector<20x6xi32>
@@ -157,15 +165,15 @@ func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32>
// This test is mainly to catch a bug that running
// `InnerOuterDimReductionConversion` on this function results in an
// infinite loop. So just check that some value is returned.
-func.func @vector_reduction_1D(%arg0 : vector<2xf32>) -> f32 {
- %0 = vector.multi_reduction #vector.kind<maxf>, %arg0 [0] : vector<2xf32> to f32
+func.func @vector_reduction_1D(%arg0 : vector<2xf32>, %acc: f32) -> f32 {
+ %0 = vector.multi_reduction #vector.kind<maxf>, %arg0, %acc [0] : vector<2xf32> to f32
return %0 : f32
}
// CHECK-LABEL: func @vector_reduction_1D
// CHECK: return %{{.+}}
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>) -> f32 {
- %0 = vector.multi_reduction <add>, %arg0 [0, 1] : vector<2x3xf32> to f32
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 {
+ %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32
return %0 : f32
}
// CHECK-LABEL: func @vector_multi_reduction_to_scalar
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index ade539e278226..f1587c2e2f3d6 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -4,15 +4,15 @@
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-LABEL: multidimreduction_contract
-// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xf32>, %[[ARG1:.*]]: vector<8x32x16xf32>, %[[ARG2:.*]]: vector<8x16xf32>)
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
-// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
// CHECK-NEXT: return %[[R]] : vector<8x16xf32>
func.func @multidimreduction_contract(
- %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
+ %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>, %acc: vector<8x16xf32>) -> vector<8x16xf32> {
%0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
- %1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
+ %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<8x32x16xf32> to vector<8x16xf32>
return %1 : vector<8x16xf32>
}
@@ -22,15 +22,15 @@ func.func @multidimreduction_contract(
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-LABEL: multidimreduction_contract_int
-// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0> : vector<8x16xi32>
+// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xi32>, %[[ARG1:.*]]: vector<8x32x16xi32>, %[[ARG2:.*]]: vector<8x16xi32>)
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
-// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
// CHECK-NEXT: return %[[R]] : vector<8x16xi32>
func.func @multidimreduction_contract_int(
- %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
+ %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>, %acc: vector<8x16xi32>) -> vector<8x16xi32> {
%0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
- %1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
+ %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<8x32x16xi32> to vector<8x16xi32>
return %1 : vector<8x16xi32>
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index d0d10887d6a2d..db6a40d489d6c 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,30 +188,28 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
-func.func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
- %0 = vector.multi_reduction #vector.kind<add>, %v [1] : vector<4x6xf32> to vector<4xf32>
+func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: func @vector_multi_reduction
// CHECK: %[[V0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK: %[[R0:.*]] = vector.multi_reduction <add>, %[[E0]] [1] : vector<2x2xf32> to vector<2xf32>
+// CHECK: %[[ACC0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: %[[R0:.*]] = vector.multi_reduction <add>, %[[E0]], %[[ACC0]] [1] : vector<2x2xf32> to vector<2xf32>
// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK: %[[R1:.*]] = vector.multi_reduction <add>, %[[E1]] [1] : vector<2x2xf32> to vector<2xf32>
-// CHECK: %[[A0:.*]] = arith.addf %[[R1]], %[[R0]] : vector<2xf32>
+// CHECK: %[[R1:.*]] = vector.multi_reduction <add>, %[[E1]], %[[R0]] [1] : vector<2x2xf32> to vector<2xf32>
// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK: %[[R2:.*]] = vector.multi_reduction <add>, %5 [1] : vector<2x2xf32> to vector<2xf32>
-// CHECK: %[[A1:.*]] = arith.addf %[[R2]], %[[A0]] : vector<2xf32>
+// CHECK: %[[R2:.*]] = vector.multi_reduction <add>, %[[E2]], %[[R1]] [1] : vector<2x2xf32> to vector<2xf32>
// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK: %[[R3:.*]] = vector.multi_reduction <add>, %[[E3]] [1] : vector<2x2xf32> to vector<2xf32>
+// CHECK: %[[ACC1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: %[[R3:.*]] = vector.multi_reduction <add>, %[[E3]], %[[ACC1]] [1] : vector<2x2xf32> to vector<2xf32>
// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK: %[[R4:.*]] = vector.multi_reduction <add>, %[[E4]] [1] : vector<2x2xf32> to vector<2xf32>
-// CHECK: %[[A2:.*]] = arith.addf %[[R4]], %[[R3]] : vector<2xf32>
+// CHECK: %[[R4:.*]] = vector.multi_reduction <add>, %[[E4]], %[[R3]] [1] : vector<2x2xf32> to vector<2xf32>
// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-// CHECK: %[[R5:.*]] = vector.multi_reduction <add>, %[[E5]] [1] : vector<2x2xf32> to vector<2xf32>
-// CHECK: %[[A3:.*]] = arith.addf %[[R5]], %[[A2]] : vector<2xf32>
-// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: %[[R5:.*]] = vector.multi_reduction <add>, %[[E5]], %[[R4]] [1] : vector<2x2xf32> to vector<2xf32>
+// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[R2]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK: return %[[V2]] : vector<4xf32>
More information about the Mlir-commits
mailing list