[Mlir-commits] [mlir] 051b36b - [mlir][vector] Add accumulator operand to MultiDimReduce op

Thomas Raoux llvmlistbot at llvm.org
Tue Jul 12 07:36:01 PDT 2022


Author: Thomas Raoux
Date: 2022-07-12T14:28:30Z
New Revision: 051b36ba2857f5a532f6fd0eeb574731cb1e0df3

URL: https://github.com/llvm/llvm-project/commit/051b36ba2857f5a532f6fd0eeb574731cb1e0df3
DIFF: https://github.com/llvm/llvm-project/commit/051b36ba2857f5a532f6fd0eeb574731cb1e0df3.diff

LOG: [mlir][vector] Add accumulator operand to MultiDimReduce op

This allows vectorizing linalg reductions without changing the operation
order. Therefore this produce a valid vectorization even if operations
are not associative.

Differential Revision: https://reviews.llvm.org/D129535

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
    mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
    mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
    mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
    mlir/test/Dialect/Vector/vector-unroll-options.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 9a29825eda506..212084a75d4ca 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -318,6 +318,7 @@ def Vector_ReductionOp :
 
 def Vector_MultiDimReductionOp :
   Vector_Op<"multi_reduction", [NoSideEffect,
+     AllTypesMatch<["dest", "acc"]>,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      DeclareOpInterfaceMethods<InferTypeOpInterface>,
@@ -325,6 +326,7 @@ def Vector_MultiDimReductionOp :
                                ["getShapeForUnroll"]>]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
                    AnyVector:$source,
+                   AnyType:$acc,
                    I64ArrayAttr:$reduction_dims)>,
     Results<(outs AnyType:$dest)> {
   let summary = "Multi-dimensional reduction operation";
@@ -332,19 +334,20 @@ def Vector_MultiDimReductionOp :
     Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n)
     using the given operation (add/mul/min/max for int/fp and and/or/xor for
     int only).
+    Takes an initial accumulator operand.
 
     Example:
 
     ```mlir
-    %1 = vector.multi_reduction <add>, %0 [1, 3] :
+    %1 = vector.multi_reduction <add>, %0, %acc0 [1, 3] :
       vector<4x8x16x32xf32> into vector<4x16xf32>
-    %2 = vector.multi_reduction <add>, %1 [0, 1] :
+    %2 = vector.multi_reduction <add>, %1, %acc1 [0, 1] :
       vector<4x16xf32> into f32
     ```
   }];
   let builders = [
-    OpBuilder<(ins "Value":$source, "ArrayRef<bool>":$reductionMask,
-                   "CombiningKind":$kind)>
+    OpBuilder<(ins "Value":$source, "Value":$acc,
+                   "ArrayRef<bool>":$reductionMask, "CombiningKind":$kind)>                   
   ];
   let extraClassDeclaration = [{
     static StringRef getKindAttrStrName() { return "kind"; }
@@ -378,8 +381,9 @@ def Vector_MultiDimReductionOp :
     }
   }];
   let assemblyFormat =
-    "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)";
+    "$kind `,` $source `,` $acc attr-dict $reduction_dims `:` type($source) `to` type($dest)";
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 def Vector_BroadcastOp :

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c406dc5de88cf..3422ab7c0a76e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -174,13 +174,13 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
 /// assumes that `reductionOp` has two operands and one of them is the reduction
 /// initial value.
-static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
-                                 Value valueToReduce,
-                                 const SmallVector<bool> &reductionMask) {
+static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
+                                      Value valueToReduce, Value acc,
+                                      const SmallVector<bool> &reductionMask) {
   auto maybeKind = getCombinerOpKind(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
   return b.create<vector::MultiDimReductionOp>(
-      reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
+      reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind);
 }
 
 static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
@@ -315,10 +315,7 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
       (outputType && reduceType.getShape() == outputType.getShape()))
     return nullptr;
   SmallVector<bool> reductionMask = getReductionMask(linalgOp);
-  Value reduce = buildMultiDimReduce(b, op, reduceVec, reductionMask);
-  return b.create(op->getLoc(), op->getName().getIdentifier(),
-                  /*operands=*/{reduce, outputVec}, reduce.getType(),
-                  op->getAttrs());
+  return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask);
 }
 
 /// Generic vectorization for a single operation `op`, given already vectorized

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 00db4650f1206..f803868c2150d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -334,34 +334,14 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
 
 void vector::MultiDimReductionOp::build(OpBuilder &builder,
                                         OperationState &result, Value source,
-                                        ArrayRef<bool> reductionMask,
+                                        Value acc, ArrayRef<bool> reductionMask,
                                         CombiningKind kind) {
   SmallVector<int64_t> reductionDims;
   for (const auto &en : llvm::enumerate(reductionMask))
     if (en.value())
       reductionDims.push_back(en.index());
-  build(builder, result, kind, source, builder.getI64ArrayAttr(reductionDims));
-}
-
-LogicalResult MultiDimReductionOp::inferReturnTypes(
-    MLIRContext *, Optional<Location>, ValueRange operands,
-    DictionaryAttr attributes, RegionRange,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
-  MultiDimReductionOp::Adaptor op(operands, attributes);
-  auto vectorType = op.getSource().getType().cast<VectorType>();
-  SmallVector<int64_t> targetShape;
-  for (auto it : llvm::enumerate(vectorType.getShape()))
-    if (!llvm::any_of(op.getReductionDims().getValue(), [&](Attribute attr) {
-          return attr.cast<IntegerAttr>().getValue() == it.index();
-        }))
-      targetShape.push_back(it.value());
-  // TODO: update to also allow 0-d vectors when available.
-  if (targetShape.empty())
-    inferredReturnTypes.push_back(vectorType.getElementType());
-  else
-    inferredReturnTypes.push_back(
-        VectorType::get(targetShape, vectorType.getElementType()));
-  return success();
+  build(builder, result, kind, source, acc,
+        builder.getI64ArrayAttr(reductionDims));
 }
 
 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
@@ -375,6 +355,28 @@ Optional<SmallVector<int64_t, 4>> MultiDimReductionOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getSourceVectorType().getShape());
 }
 
+LogicalResult MultiDimReductionOp::verify() {
+  SmallVector<int64_t> targetShape;
+  Type inferredReturnType;
+  for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
+    if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
+          return attr.cast<IntegerAttr>().getValue() == it.index();
+        }))
+      targetShape.push_back(it.value());
+  // TODO: update to also allow 0-d vectors when available.
+  if (targetShape.empty())
+    inferredReturnType = getSourceVectorType().getElementType();
+  else
+    inferredReturnType =
+        VectorType::get(targetShape, getSourceVectorType().getElementType());
+  if (getType() != inferredReturnType)
+    return emitOpError() << "destination type " << getType()
+                         << " is incompatible with source type "
+                         << getSourceVectorType();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ReductionOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index 0e023ca448322..2582781aaab08 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -87,8 +87,8 @@ class InnerOuterDimReductionConversion
         reductionMask[i] = true;
     }
     rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
-        multiReductionOp, transposeOp.getResult(), reductionMask,
-        multiReductionOp.getKind());
+        multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
+        reductionMask, multiReductionOp.getKind());
     return success();
   }
 
@@ -188,11 +188,17 @@ class ReduceMultiDimReductionRank
         vectorShape, multiReductionOp.getSourceVectorType().getElementType());
     Value cast = rewriter.create<vector::ShapeCastOp>(
         loc, castedType, multiReductionOp.getSource());
-
+    Value acc = multiReductionOp.getAcc();
+    if (flattenedParallelDim) {
+      auto accType = VectorType::get(
+          {flattenedParallelDim},
+          multiReductionOp.getSourceVectorType().getElementType());
+      acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
+    }
     // 5. Creates the flattened form of vector.multi_reduction with inner/outer
     // most dim as reduction.
     auto newOp = rewriter.create<vector::MultiDimReductionOp>(
-        loc, cast, mask, multiReductionOp.getKind());
+        loc, cast, acc, mask, multiReductionOp.getKind());
 
     // 6. If there are no parallel shapes, the result is a scalar.
     // TODO: support 0-d vectors when available.
@@ -238,10 +244,8 @@ struct TwoDimMultiReductionToElementWise
     if (!elementType.isIntOrIndexOrFloat())
       return failure();
 
-    Value result =
-        rewriter.create<vector::ExtractOp>(loc, multiReductionOp.getSource(), 0)
-            .getResult();
-    for (int64_t i = 1; i < srcShape[0]; i++) {
+    Value result = multiReductionOp.getAcc();
+    for (int64_t i = 0; i < srcShape[0]; i++) {
       auto operand = rewriter.create<vector::ExtractOp>(
           loc, multiReductionOp.getSource(), i);
       result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
@@ -277,8 +281,10 @@ struct TwoDimMultiReductionToReduction
     for (int i = 0; i < outerDim; ++i) {
       auto v = rewriter.create<vector::ExtractOp>(
           loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
+      auto acc = rewriter.create<vector::ExtractOp>(
+          loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
       auto reducedValue = rewriter.create<vector::ReductionOp>(
-          loc, multiReductionOp.getKind(), v);
+          loc, multiReductionOp.getKind(), v, acc);
       result = rewriter.create<vector::InsertElementOp>(
           loc, reducedValue, result,
           rewriter.create<arith::ConstantIndexOp>(loc, i));
@@ -309,6 +315,8 @@ struct OneDimMultiReductionToTwoDim
     auto srcShape = srcVectorType.getShape();
     auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
                                       srcVectorType.getElementType());
+    auto accType =
+        VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
     assert(!multiReductionOp.getDestType().isa<VectorType>() &&
            "multi_reduction with a single dimension expects a scalar result");
 
@@ -319,8 +327,10 @@ struct OneDimMultiReductionToTwoDim
     /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
     Value cast = rewriter.create<vector::ShapeCastOp>(
         loc, castedType, multiReductionOp.getSource());
+    Value castAcc = rewriter.create<vector::BroadcastOp>(
+        loc, accType, multiReductionOp.getAcc());
     Value reduced = rewriter.create<vector::MultiDimReductionOp>(
-        loc, cast, mask, multiReductionOp.getKind());
+        loc, cast, castAcc, mask, multiReductionOp.getKind());
     rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
                                                    ArrayRef<int64_t>{0});
     return success();

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bb8cc2bfae396..76151fc358ade 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -997,11 +997,8 @@ struct MultiReduceToContract
     }
     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
                                  /*symCount=*/0, exprs, reduceOp.getContext());
-    Value zero = rewriter.create<arith::ConstantOp>(
-        reduceOp.getLoc(), reduceOp.getDestType(),
-        rewriter.getZeroAttr(reduceOp.getDestType()));
     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
-        reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
+        reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
         rewriter.getStrArrayAttr(iteratorTypes));
     return success();

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index d75d1098d53f6..15f43dc0536c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -431,10 +431,11 @@ struct UnrollMultiReductionPattern
       SmallVector<int64_t, 4> offsets =
           getVectorOffset(originalSize, *targetShape, i);
 
+      SmallVector<Value> operands;
       SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides);
-
+          loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
+      operands.push_back(slicedOperand);
       SmallVector<int64_t> dstShape;
       SmallVector<int64_t> destOffset;
       for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
@@ -443,17 +444,22 @@ struct UnrollMultiReductionPattern
           dstShape.push_back((*targetShape)[i]);
         }
       }
+      Value acc;
+      SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
+      // If a version of the accumulator has already been computed, use it
+      // otherwise extract the first version from the original operand.
+      auto accIt = accCache.find(destOffset);
+      if (accIt != accCache.end())
+        acc = accIt->second;
+      else
+        acc = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
+      operands.push_back(acc);
       auto targetType = VectorType::get(
           dstShape, reductionOp.getSourceVectorType().getElementType());
       Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
-                                                     slicedOperand, targetType);
+                                                     operands, targetType);
       Value result = newOp->getResult(0);
-      // Save the accumulated value until all the loops are unrolled since
-      // reduction loop keeps updating the accumulator.
-      auto accIt = accCache.find(destOffset);
-      if (accIt != accCache.end())
-        result = makeArithReduction(rewriter, loc, reductionOp.getKind(),
-                                    result, accIt->second);
       accCache[destOffset] = result;
     }
     // Assemble back the accumulator into a single vector.

diff  --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index b4dc8e4325673..b08d7d1ff1c10 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -10,9 +10,8 @@ func.func @vectorize_matmul(%arg0: tensor<24x12xf32>,
   // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
   // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
   // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
-  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]]
-  // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]]
-  // CHECK: vector.transfer_write %[[vS]], %[[C]]
+  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+  // CHECK: vector.transfer_write %[[vR]], %[[C]]
   %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
   func.return %0 : tensor<24x25xf32>
 }
@@ -67,9 +66,8 @@ func.func @vectorize_keep_pad(
   // CHECK: %[[vA:.+]] = vector.transfer_read %[[pA]]
   // CHECK: %[[vB:.+]] = vector.transfer_read %[[pB]]
   // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
-  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]]
-  // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]]
-  // CHECK: vector.transfer_write %[[vS]], %[[C]]
+  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+  // CHECK: vector.transfer_write %[[vR]], %[[C]]
   %8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
   %9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
   return %9 : tensor<24x25xf32>
@@ -127,9 +125,8 @@ func.func @vectorize_pad(
     tensor.yield %cst : f32
   } : tensor<?x5xf32> to tensor<7x5xf32>
   // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
-  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]]
-  // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]]
-  // CHECK: vector.transfer_write %[[vS]], %[[C]]
+  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+  // CHECK: vector.transfer_write %[[vR]], %[[C]]
   %8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
   %9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
   return %9 : tensor<24x25xf32>

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index dbd09576cb76b..bbc36b12556ed 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -6,8 +6,7 @@
 func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
 
 // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
-// CHECK: vector.multi_reduction <add>, %{{.*}} [0] : vector<1584xf32> to f32
-// CHECK: arith.addf %{{.*}}, %{{.*}} : f32
+// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [0] : vector<1584xf32> to f32
   linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
             outs(%C: memref<f32>)
   return
@@ -19,8 +18,7 @@ func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memre
 func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
 
 // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
-// CHECK: vector.multi_reduction <add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
-// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
   linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
             outs(%C: memref<1584xf32>)
   return
@@ -31,8 +29,7 @@ func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %
 // CHECK-LABEL: contraction_matmul
 func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
 // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
-// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
-// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
   linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
             outs(%C: memref<1584x1584xf32>)
   return
@@ -43,8 +40,7 @@ func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf3
 // CHECK-LABEL: contraction_batch_matmul
 func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
 // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
-// CHECK: vector.multi_reduction <add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
-// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
   linalg.batch_matmul
     ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
    outs(%C: memref<1584x1584x1584xf32>)
@@ -69,10 +65,9 @@ func.func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
-  //       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
+  //       CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
   //       CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
-  //       CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
-  //       CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
+  //       CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
   linalg.generic #matmul_trait
     ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
@@ -103,10 +98,9 @@ func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<32x8xf32>) {
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
-  //       CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
+  //       CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
   //       CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
-  //       CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
-  //       CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
+  //       CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
   linalg.generic #matmul_transpose_out_trait
     ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
@@ -157,11 +151,9 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32
                                  %C: memref<8x32xi32>) {
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x32x16xi32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32>
-  //       CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
+  //       CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
   //       CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
-  //       CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
-  //       CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32>
-
+  //       CHECK: vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
   linalg.generic #matmul_trait
     ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
@@ -180,8 +172,7 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32
 func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
   //       CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
-  //       CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
-  //       CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32>
+  //       CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
   linalg.matmul
     ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
    outs(%C: memref<8x32xf32>)
@@ -560,9 +551,8 @@ func.func @matmul_tensors(
   // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
   // convert it to a 2D contract.
   //       CHECK:   %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
-  //       CHECK:   %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
-  //       CHECK:   %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32>
-  //       CHECK:   %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
+  //       CHECK:   %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[V2]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
+  //       CHECK:   %[[W:.*]] = vector.transfer_write %[[R]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
   %0 = linalg.matmul  ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
                      outs(%arg2: tensor<8x12xf32>)
     -> tensor<8x12xf32>
@@ -801,8 +791,7 @@ func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
   // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
   // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
   // CHECK: math.exp {{.*}} : vector<4x16x8xf32>
-  // CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
-  // CHECK: addf {{.*}} : vector<4x16xf32>
+  // CHECK: vector.multi_reduction <add>, %{{.*}}, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
   // CHECK: return {{.*}} : tensor<4x16xf32>
   %0 = linalg.generic {
@@ -836,8 +825,7 @@ func.func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output
   // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
   // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
   // CHECK: addf {{.*}} : vector<2x3x4x5xf32>
-  // CHECK: vector.multi_reduction <add>, {{.*}}  [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
-  // CHECK: addf {{.*}} : vector<2x5xf32>
+  // CHECK: vector.multi_reduction <add>, {{.*}}, %{{.*}}  [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
   // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
   // CHECK: return {{.*}} : tensor<5x2xf32>
   %0 = linalg.generic {
@@ -865,8 +853,7 @@ func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
   // CHECK: linalg.init_tensor [4] : tensor<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
-  // CHECK: %[[R:.+]] = vector.multi_reduction <maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
-  // CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32>
+  // CHECK: vector.multi_reduction <maxf>, {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %ident = arith.constant -3.40282e+38 : f32
   %init = linalg.init_tensor [4] : tensor<4xf32>
@@ -890,8 +877,7 @@ func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: linalg.init_tensor [4] : tensor<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
-  // CHECK: %[[R:.+]] = vector.multi_reduction <minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
-  // CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32>
+  // CHECK: vector.multi_reduction <minf>, {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %maxf32 = arith.constant 3.40282e+38 : f32
   %init = linalg.init_tensor [4] : tensor<4xf32>
@@ -914,7 +900,7 @@ func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: linalg.init_tensor [4] : tensor<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
-  // CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
+  // CHECK: vector.multi_reduction <mul>, {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %ident = arith.constant 1.0 : f32
   %init = linalg.init_tensor [4] : tensor<4xf32>
@@ -937,7 +923,7 @@ func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: linalg.init_tensor [4] : tensor<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
-  // CHECK: vector.multi_reduction <or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+  // CHECK: vector.multi_reduction <or>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   %ident = arith.constant false
   %init = linalg.init_tensor [4] : tensor<4xi1>
@@ -960,7 +946,7 @@ func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: linalg.init_tensor [4] : tensor<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
-  // CHECK: vector.multi_reduction <and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+  // CHECK: vector.multi_reduction <and>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   %ident = arith.constant true
   %init = linalg.init_tensor [4] : tensor<4xi1>
@@ -983,7 +969,7 @@ func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: linalg.init_tensor [4] : tensor<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
-  // CHECK: vector.multi_reduction <xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+  // CHECK: vector.multi_reduction <xor>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   %ident = arith.constant false
   %init = linalg.init_tensor [4] : tensor<4xi1>
@@ -1035,8 +1021,7 @@ func.func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>
   // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32>
   // CHECK: subf {{.*}} : vector<4x4xf32>
   // CHECK: math.exp {{.*}} : vector<4x4xf32>
-  // CHECK: vector.multi_reduction <add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
-  // CHECK: addf {{.*}} : vector<4xf32>
+  // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} : vector<4x4xf32> to vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
   %c0 = arith.constant 0.0 : f32
   %init = linalg.init_tensor [4] : tensor<4xf32>
@@ -1075,10 +1060,9 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
   //      CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
   // CHECK-SAME:   : tensor<32xf32>, vector<32xf32>
   //      CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
-  //      CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]] [0]
+  //      CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[f0]] [0]
   // CHECK-SAME:   : vector<32xf32> to f32
-  //      CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32
-  //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<f32>
+  //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
   //      CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
   // CHECK-SAME:   : vector<f32>, tensor<f32>
   %2 = linalg.generic {

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 84b5a45f19e65..702670095c8d5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1281,9 +1281,9 @@ func.func @do_not_swap_extract_slice_transfer_write(%arg0 : vector<8xf32>,
 // -----
 
 // CHECK-LABEL: func @vector_multi_reduction_single_parallel(
-//  CHECK-SAME:     %[[v:.*]]: vector<2xf32>
-func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {
-    %0 = vector.multi_reduction <mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
+//  CHECK-SAME:     %[[v:.*]]: vector<2xf32>,
+func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [] : vector<2xf32> to vector<2xf32>
 
 //       CHECK:     return %[[v]] : vector<2xf32>
     return %0 : vector<2xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 87e5f9443807c..d50315970d744 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1138,9 +1138,9 @@ func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
 
 // -----
 
-func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>) -> f32 {
-  // expected-error at +1 {{'vector.multi_reduction' op inferred type(s) 'vector<4xf32>' are incompatible with return type(s) of operation 'vector<16xf32>'}}
-  %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<4x16xf32> to vector<16xf32>
+func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>, %acc: vector<16xf32>) -> f32 {
+  // expected-error at +1 {{'vector.multi_reduction' op destination type 'vector<16xf32>' is incompatible with source type 'vector<4x16xf32>'}}
+  %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<4x16xf32> to vector<16xf32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index e42c94f252c46..dc69bb0a78a6f 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -705,10 +705,13 @@ func.func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
 }
 
 // CHECK-LABEL: @multi_reduction
-func.func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
-  %1 = vector.multi_reduction <add>, %0 [1, 3] :
+func.func @multi_reduction(%0: vector<4x8x16x32xf32>, %acc0: vector<4x16xf32>,
+                           %acc1: f32) -> f32 {
+  // CHECK: vector.multi_reduction <add>, %{{.*}}, %{{.*}} [1, 3] : vector<4x8x16x32xf32> to vector<4x16xf32>
+  %1 = vector.multi_reduction <add>, %0, %acc0 [1, 3] :
     vector<4x8x16x32xf32> to vector<4x16xf32>
-  %2 = vector.multi_reduction <add>, %1 [0, 1] :
+  // CHECK: vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x16xf32> to f32
+  %2 = vector.multi_reduction <add>, %1, %acc1 [0, 1] :
     vector<4x16xf32> to f32
   return %2 : f32
 }

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index a39ac990f8281..6b372c3ef1c3e 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -1,40 +1,42 @@
 // RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
 
-func.func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
-    %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
     return %0 : vector<2xf32>
 }
 // CHECK-LABEL: func @vector_multi_reduction
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
 //       CHECK:       %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
 //       CHECK:       %[[C0:.+]] = arith.constant 0 : index
 //       CHECK:       %[[C1:.+]] = arith.constant 1 : index
 //       CHECK:       %[[V0:.+]] = vector.extract %[[INPUT]][0]
-//       CHECK:       %[[RV0:.+]] = vector.reduction <mul>, %[[V0]] : vector<4xf32> into f32
+//       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+//       CHECK:       %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
 //       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
 //       CHECK:       %[[V1:.+]] = vector.extract %[[INPUT]][1]
-//       CHECK:       %[[RV1:.+]] = vector.reduction <mul>, %[[V1]] : vector<4xf32> into f32
+//       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+//       CHECK:       %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
 //       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
 //       CHECK:       return %[[RESULT_VEC]]
 
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
-    %0 = vector.multi_reduction <mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
     return %0 : f32
 }
 // CHECK-LABEL: func @vector_multi_reduction_to_scalar
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
 //       CHECK:   %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
-//       CHECK:   %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]] : vector<8xf32> into f32
+//       CHECK:   %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
 //       CHECK:   %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
 //       CHECK:   %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32>
 //       CHECK:   return %[[RES]]
 
-func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
-    %0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
+    %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
     return %0 : vector<2x3xi32>
 }
 // CHECK-LABEL: func @vector_reduction_inner
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x3x4x5xi32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
 //       CHECK:       %[[FLAT_RESULT_VEC_0:.+]] = arith.constant dense<0> : vector<6xi32>
 //   CHECK-DAG:       %[[C0:.+]] = arith.constant 0 : index
 //   CHECK-DAG:       %[[C1:.+]] = arith.constant 1 : index
@@ -44,29 +46,35 @@ func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32>
 //   CHECK-DAG:       %[[C5:.+]] = arith.constant 5 : index
 //       CHECK:       %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
 //       CHECK:       %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32>
-//       CHECK:       %[[V0R:.+]] = vector.reduction <add>, %[[V0]] : vector<20xi32> into i32
+//       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : vector<2x3xi32>
+//       CHECK:       %[[V0R:.+]] = vector.reduction <add>, %[[V0]], %[[ACC0]] : vector<20xi32> into i32
 //       CHECK:       %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : index] : vector<6xi32>
 //       CHECK:       %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32>
-//       CHECK:       %[[V1R:.+]] = vector.reduction <add>, %[[V1]] : vector<20xi32> into i32
+//       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : vector<2x3xi32>
+//       CHECK:       %[[V1R:.+]] = vector.reduction <add>, %[[V1]], %[[ACC1]] : vector<20xi32> into i32
 //       CHECK:       %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : index] : vector<6xi32>
 //       CHECK:       %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32>
-//       CHECK:       %[[V2R:.+]] = vector.reduction <add>, %[[V2]] : vector<20xi32> into i32
+//       CHECK:       %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : vector<2x3xi32>
+//       CHECK:       %[[V2R:.+]] = vector.reduction <add>, %[[V2]], %[[ACC2]] : vector<20xi32> into i32
 //       CHECK:       %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : index] : vector<6xi32>
 //       CHECK:       %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32>
-//       CHECK:       %[[V3R:.+]] = vector.reduction <add>, %[[V3]] : vector<20xi32> into i32
+//       CHECK:       %[[ACC3:.+]] = vector.extract %[[ACC]][1, 0] : vector<2x3xi32>
+//       CHECK:       %[[V3R:.+]] = vector.reduction <add>, %[[V3]], %[[ACC3]] : vector<20xi32> into i32
 //       CHECK:       %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : index] : vector<6xi32>
 //       CHECK:       %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32>
-//       CHECK:       %[[V4R:.+]] = vector.reduction <add>, %[[V4]] : vector<20xi32> into i32
+//       CHECK:       %[[ACC4:.+]] = vector.extract %[[ACC]][1, 1] : vector<2x3xi32>
+//       CHECK:       %[[V4R:.+]] = vector.reduction <add>, %[[V4]], %[[ACC4]] : vector<20xi32> into i32
 //       CHECK:       %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : index] : vector<6xi32>
 ///       CHECK:      %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32>
-//       CHECK:       %[[V5R:.+]] = vector.reduction <add>, %[[V5]] : vector<20xi32> into i32
+//       CHECK:       %[[ACC5:.+]] = vector.extract %[[ACC]][1, 2] : vector<2x3xi32>
+//       CHECK:       %[[V5R:.+]] = vector.reduction <add>, %[[V5]], %[[ACC5]] : vector<20xi32> into i32
 //       CHECK:       %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : index] : vector<6xi32>
 //       CHECK:       %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
 //       CHECK:       return %[[RESULT]]
 
 
-func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
-    %0 = vector.multi_reduction <add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
+func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> {
+    %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
     return %0 : vector<2x5xf32>
 }
 
@@ -77,12 +85,12 @@ func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vect
 //       CHECK:     %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
 //       CHECK:       return %[[RESULT]]
 
-func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
-    %0 = vector.multi_reduction <mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
+func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32>
     return %0 : vector<2x4xf32>
 }
 // CHECK-LABEL: func @vector_multi_reduction_ordering
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<3x2x4xf32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<3x2x4xf32>, %[[ACC:.*]]: vector<2x4xf32>)
 //       CHECK:       %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<8xf32>
 //       CHECK:       %[[C0:.+]] = arith.constant 0 : index
 //       CHECK:       %[[C1:.+]] = arith.constant 1 : index
@@ -94,28 +102,36 @@ func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2
 //       CHECK:       %[[C7:.+]] = arith.constant 7 : index
 //       CHECK:       %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32>
 //       CHECK:       %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0]
-//       CHECK:       %[[RV0:.+]] = vector.reduction <mul>, %[[V0]] : vector<3xf32> into f32
+//       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : vector<2x4xf32>
+//       CHECK:       %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<3xf32> into f32
 //       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<8xf32>
 //       CHECK:       %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1]
-//       CHECK:       %[[RV1:.+]] = vector.reduction <mul>, %[[V1]] : vector<3xf32> into f32
+//       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : vector<2x4xf32>
+//       CHECK:       %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<3xf32> into f32
 //       CHECK:       %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<8xf32>
 //       CHECK:       %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2]
-//       CHECK:       %[[RV2:.+]] = vector.reduction <mul>, %[[V2]] : vector<3xf32> into f32
+//       CHECK:       %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : vector<2x4xf32>
+//       CHECK:       %[[RV2:.+]] = vector.reduction <mul>, %[[V2]], %[[ACC2]] : vector<3xf32> into f32
 //       CHECK:       %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : index] : vector<8xf32>
 //       CHECK:       %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3]
-//       CHECK:       %[[RV3:.+]] = vector.reduction <mul>, %[[V3]] : vector<3xf32> into f32
+//       CHECK:       %[[ACC3:.+]] = vector.extract %[[ACC]][0, 3] : vector<2x4xf32>
+//       CHECK:       %[[RV3:.+]] = vector.reduction <mul>, %[[V3]], %[[ACC3]] : vector<3xf32> into f32
 //       CHECK:       %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : index] : vector<8xf32>
 //       CHECK:       %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0]
-//       CHECK:       %[[RV4:.+]] = vector.reduction <mul>, %[[V4]] : vector<3xf32> into f32
+//       CHECK:       %[[ACC4:.+]] = vector.extract %[[ACC]][1, 0] : vector<2x4xf32>
+//       CHECK:       %[[RV4:.+]] = vector.reduction <mul>, %[[V4]], %[[ACC4]] : vector<3xf32> into f32
 //       CHECK:       %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : index] : vector<8xf32>
 //       CHECK:       %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1]
-//       CHECK:       %[[RV5:.+]] = vector.reduction <mul>, %[[V5]] : vector<3xf32> into f32
+//       CHECK:       %[[ACC5:.+]] = vector.extract %[[ACC]][1, 1] : vector<2x4xf32>
+//       CHECK:       %[[RV5:.+]] = vector.reduction <mul>, %[[V5]], %[[ACC5]] : vector<3xf32> into f32
 //       CHECK:       %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : index] : vector<8xf32>
 //       CHECK:       %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2]
-//       CHECK:       %[[RV6:.+]] = vector.reduction <mul>, %[[V6]] : vector<3xf32> into f32
+//       CHECK:       %[[ACC6:.+]] = vector.extract %[[ACC]][1, 2] : vector<2x4xf32>
+//       CHECK:       %[[RV6:.+]] = vector.reduction <mul>, %[[V6]], %[[ACC6]] : vector<3xf32> into f32
 //       CHECK:       %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : index] : vector<8xf32>
 //       CHECK:       %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3]
-//       CHECK:       %[[RV7:.+]] = vector.reduction <mul>, %[[V7]] : vector<3xf32> into f32
+//       CHECK:       %[[ACC7:.+]] = vector.extract %[[ACC]][1, 3] : vector<2x4xf32>
+//       CHECK:       %[[RV7:.+]] = vector.reduction <mul>, %[[V7]], %[[ACC7]] : vector<3xf32> into f32
 //       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32>
 //       CHECK:       %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
 //       CHECK:       return %[[RESHAPED_VEC]]

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index bd06e48f48231..8a8bf86bfd38b 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -1,101 +1,107 @@
 // RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s
 
-func.func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
-    %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
     return %0 : vector<2xf32>
 }
 
 // CHECK-LABEL: func @vector_multi_reduction
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.mulf %[[V1]], %[[V0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
 //       CHECK:   %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
 //       CHECK:   %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
-func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
-    %0 = vector.multi_reduction <minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <minf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
     return %0 : vector<2xf32>
 }
 
 // CHECK-LABEL: func @vector_multi_reduction_min
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.minf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.minf %[[V1]], %[[V0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.minf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
 //       CHECK:   %[[RV012:.+]] = arith.minf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
 //       CHECK:   %[[RESULT_VEC:.+]] = arith.minf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
-func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
-    %0 = vector.multi_reduction <maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <maxf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
     return %0 : vector<2xf32>
 }
 
 // CHECK-LABEL: func @vector_multi_reduction_max
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.maxf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.maxf %[[V1]], %[[V0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.maxf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
 //       CHECK:   %[[RV012:.+]] = arith.maxf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
 //       CHECK:   %[[RESULT_VEC:.+]] = arith.maxf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
-func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
-    %0 = vector.multi_reduction <and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+    %0 = vector.multi_reduction <and>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
     return %0 : vector<2xi32>
 }
 
 // CHECK-LABEL: func @vector_multi_reduction_and
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
+//       CHECK:   %[[RV0:.+]] = arith.andi %[[V0]], %[[ACC]] : vector<2xi32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
-//       CHECK:   %[[RV01:.+]] = arith.andi %[[V1]], %[[V0]] : vector<2xi32>
+//       CHECK:   %[[RV01:.+]] = arith.andi %[[V1]], %[[RV0]] : vector<2xi32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
 //       CHECK:   %[[RV012:.+]] = arith.andi %[[V2]], %[[RV01]] : vector<2xi32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
 //       CHECK:   %[[RESULT_VEC:.+]] = arith.andi %[[V3]], %[[RV012]] : vector<2xi32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
 
-func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
-    %0 = vector.multi_reduction <or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+    %0 = vector.multi_reduction <or>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
     return %0 : vector<2xi32>
 }
 
 // CHECK-LABEL: func @vector_multi_reduction_or
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
+//       CHECK:   %[[RV0:.+]] = arith.ori %[[V0]], %[[ACC]] : vector<2xi32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
-//       CHECK:   %[[RV01:.+]] = arith.ori %[[V1]], %[[V0]] : vector<2xi32>
+//       CHECK:   %[[RV01:.+]] = arith.ori %[[V1]], %[[RV0]] : vector<2xi32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
 //       CHECK:   %[[RV012:.+]] = arith.ori %[[V2]], %[[RV01]] : vector<2xi32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
 //       CHECK:   %[[RESULT_VEC:.+]] = arith.ori %[[V3]], %[[RV012]] : vector<2xi32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
 
-func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
-    %0 = vector.multi_reduction <xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+    %0 = vector.multi_reduction <xor>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
     return %0 : vector<2xi32>
 }
 
 // CHECK-LABEL: func @vector_multi_reduction_xor
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
+//       CHECK:   %[[RV0:.+]] = arith.xori %[[V0]], %[[ACC]] : vector<2xi32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
-//       CHECK:   %[[RV01:.+]] = arith.xori %[[V1]], %[[V0]] : vector<2xi32>
+//       CHECK:   %[[RV01:.+]] = arith.xori %[[V1]], %[[RV0]] : vector<2xi32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
 //       CHECK:   %[[RV012:.+]] = arith.xori %[[V2]], %[[RV01]] : vector<2xi32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
@@ -103,18 +109,20 @@ func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
 
 
-func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
-    %0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
+    %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
     return %0 : vector<2x3xi32>
 }
 
 // CHECK-LABEL: func @vector_reduction_outer
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x3x4x5xi32>
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [2, 3, 0, 1] : vector<2x3x4x5xi32> to vector<4x5x2x3xi32>
 //       CHECK:   %[[RESHAPED:.+]] = vector.shape_cast %[[TRANSPOSED]] : vector<4x5x2x3xi32> to vector<20x6xi32>
+//       CHECK:   %[[FACC:.+]] = vector.shape_cast %[[ACC]] : vector<2x3xi32> to vector<6xi32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[RESHAPED]][0] : vector<20x6xi32>
+//       CHECK:   %[[R:.+]] = arith.addi %[[V0]], %[[FACC]] : vector<6xi32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[RESHAPED]][1] : vector<20x6xi32>
-//       CHECK:   %[[R0:.+]] = arith.addi %[[V1]], %[[V0]] : vector<6xi32>
+//       CHECK:   %[[R0:.+]] = arith.addi %[[V1]], %[[R]] : vector<6xi32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[RESHAPED]][2] : vector<20x6xi32>
 //       CHECK:   %[[R1:.+]] = arith.addi %[[V2]], %[[R0]] : vector<6xi32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[RESHAPED]][3] : vector<20x6xi32>
@@ -157,15 +165,15 @@ func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32>
 // This test is mainly to catch a bug that running
 // `InnerOuterDimReductionConversion` on this function results in an
 // infinite loop. So just check that some value is returned.
-func.func @vector_reduction_1D(%arg0 : vector<2xf32>) -> f32 {
-  %0 = vector.multi_reduction #vector.kind<maxf>, %arg0 [0] : vector<2xf32> to f32
+func.func @vector_reduction_1D(%arg0 : vector<2xf32>, %acc: f32) -> f32 {
+  %0 = vector.multi_reduction #vector.kind<maxf>, %arg0, %acc [0] : vector<2xf32> to f32
   return %0 : f32
 }
 // CHECK-LABEL: func @vector_reduction_1D
 //       CHECK:   return %{{.+}}
 
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>) -> f32 {
-  %0 = vector.multi_reduction <add>, %arg0 [0, 1] : vector<2x3xf32> to f32
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 {
+  %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32
   return %0 : f32
 }
 // CHECK-LABEL: func @vector_multi_reduction_to_scalar

diff  --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index ade539e278226..f1587c2e2f3d6 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -4,15 +4,15 @@
 // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 
 // CHECK-LABEL: multidimreduction_contract
-//  CHECK-NEXT:   %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xf32>, %[[ARG1:.*]]: vector<8x32x16xf32>, %[[ARG2:.*]]: vector<8x16xf32>)
 //  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
-//  CHECK-SAME:   %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
+//  CHECK-SAME:   %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
 //  CHECK-NEXT:   return %[[R]] : vector<8x16xf32>
 func.func @multidimreduction_contract(
-  %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
+  %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>, %acc: vector<8x16xf32>) -> vector<8x16xf32> {
   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
-  %1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
+  %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<8x32x16xf32> to vector<8x16xf32>
   return %1 : vector<8x16xf32>
 }
 
@@ -22,15 +22,15 @@ func.func @multidimreduction_contract(
 // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 
 // CHECK-LABEL: multidimreduction_contract_int
-//  CHECK-NEXT:   %[[C0:.+]] = arith.constant dense<0> : vector<8x16xi32>
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xi32>, %[[ARG1:.*]]: vector<8x32x16xi32>, %[[ARG2:.*]]: vector<8x16xi32>)
 //  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
-//  CHECK-SAME:   %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
+//  CHECK-SAME:   %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
 //  CHECK-NEXT:   return %[[R]] : vector<8x16xi32>
 func.func @multidimreduction_contract_int(
-  %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
+  %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>, %acc: vector<8x16xi32>) -> vector<8x16xi32> {
   %0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
-  %1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
+  %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<8x32x16xi32> to vector<8x16xi32>
   return %1 : vector<8x16xi32>
 }
 

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index d0d10887d6a2d..db6a40d489d6c 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,30 +188,28 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
 //   CHECK-LABEL: func @vector_fma
 // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
 
-func.func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
-  %0 = vector.multi_reduction #vector.kind<add>, %v [1] : vector<4x6xf32> to vector<4xf32>
+func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
   return %0 : vector<4xf32>
 }
 // CHECK-LABEL: func @vector_multi_reduction
 //       CHECK:   %[[V0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
 //       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-//       CHECK:   %[[R0:.*]] = vector.multi_reduction <add>, %[[E0]] [1] : vector<2x2xf32> to vector<2xf32>
+//       CHECK:   %[[ACC0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK:   %[[R0:.*]] = vector.multi_reduction <add>, %[[E0]], %[[ACC0]] [1] : vector<2x2xf32> to vector<2xf32>
 //       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-//       CHECK:   %[[R1:.*]] = vector.multi_reduction <add>, %[[E1]] [1] : vector<2x2xf32> to vector<2xf32>
-//       CHECK:   %[[A0:.*]] = arith.addf %[[R1]], %[[R0]] : vector<2xf32>
+//       CHECK:   %[[R1:.*]] = vector.multi_reduction <add>, %[[E1]], %[[R0]] [1] : vector<2x2xf32> to vector<2xf32>
 //       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-//       CHECK:   %[[R2:.*]] = vector.multi_reduction <add>, %5 [1] : vector<2x2xf32> to vector<2xf32>
-//       CHECK:   %[[A1:.*]] = arith.addf %[[R2]], %[[A0]] : vector<2xf32>
+//       CHECK:   %[[R2:.*]] = vector.multi_reduction <add>, %[[E2]], %[[R1]] [1] : vector<2x2xf32> to vector<2xf32>
 //       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-//       CHECK:   %[[R3:.*]] = vector.multi_reduction <add>, %[[E3]] [1] : vector<2x2xf32> to vector<2xf32>
+//       CHECK:   %[[ACC1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK:   %[[R3:.*]] = vector.multi_reduction <add>, %[[E3]], %[[ACC1]] [1] : vector<2x2xf32> to vector<2xf32>
 //       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-//       CHECK:   %[[R4:.*]] = vector.multi_reduction <add>, %[[E4]] [1] : vector<2x2xf32> to vector<2xf32>
-//       CHECK:   %[[A2:.*]] = arith.addf %[[R4]], %[[R3]] : vector<2xf32>
+//       CHECK:   %[[R4:.*]] = vector.multi_reduction <add>, %[[E4]], %[[R3]] [1] : vector<2x2xf32> to vector<2xf32>
 //       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
-//       CHECK:   %[[R5:.*]] = vector.multi_reduction <add>, %[[E5]] [1] : vector<2x2xf32> to vector<2xf32>
-//       CHECK:   %[[A3:.*]] = arith.addf %[[R5]], %[[A2]] : vector<2xf32>
-//       CHECK:   %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-//       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+//       CHECK:   %[[R5:.*]] = vector.multi_reduction <add>, %[[E5]], %[[R4]] [1] : vector<2x2xf32> to vector<2xf32>
+//       CHECK:   %[[V1:.*]] = vector.insert_strided_slice %[[R2]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+//       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
 //       CHECK:   return %[[V2]] : vector<4xf32>
 
 


        


More information about the Mlir-commits mailing list