[Mlir-commits] [mlir] 31270eb - [mlir][Vector] Let vector.multi_reduction reduce down to a scalar.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Oct 12 04:03:59 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-12T11:03:54Z
New Revision: 31270eb16501cca73fb3fbac254fe9965a3f3fc1

URL: https://github.com/llvm/llvm-project/commit/31270eb16501cca73fb3fbac254fe9965a3f3fc1
DIFF: https://github.com/llvm/llvm-project/commit/31270eb16501cca73fb3fbac254fe9965a3f3fc1.diff

LOG: [mlir][Vector] Let vector.multi_reduction reduce down to a scalar.

vector.multi_reduction currently does not allow reducing down to a scalar.
This creates corner cases that are hard to handle during vectorization.
This revision extends the semantics and adds the proper transforms, lowerings and canonicalizations to allow lowering out of vector.multi_reduction to other abstractions all the way to LLVM.

In a future, where we will also allow 0-d vectors, scalars will still be relevant: 0-d vector and scalars are not equivalent on all hardware.

In the process, splice out the implementation patterns related to vector.multi_reduce into a new file.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D111442

Added: 
    mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index a98ca36025e82..a6fbf93f29a08 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -79,8 +79,28 @@ void populateVectorTransferPermutationMapLoweringPatterns(
 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
                                                bool enableIndexOptimizations);
 
-// Collect a set of patterns to convert vector.multi_reduction op into
-// a sequence of vector.reduction ops.
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
+/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
+/// that all reduction dimensions are either innermost or outermost, by adding
+/// the proper vector.transpose operations.
+/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
+/// form, with an **outermost** reduction dimension, unroll the outer dimension
+/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
+/// tree-reduction (in the future).
+/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
+/// with an **innermost** reduction dimension, unroll the outer dimension to
+/// obtain a sequence of extract + vector.reduction + insert. This can further
+/// lower to horizontal reduction ops.
+/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
+/// reduction (and are thus missing either a parallel or a reduction), we lift
+/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
+/// the other patterns can kick in, thus fully exiting out of the
+/// vector.multi_reduction abstraction.
 void populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, bool useInnerDimsForReduction = false);
 

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index ea86336fd787e..c334773d6654e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -301,14 +301,17 @@ def Vector_MultiDimReductionOp :
     Results<(outs AnyType:$dest)> {
   let summary = "Multi-dimensional reduction operation";
   let description = [{
-    Reduces an n-D vector into an (n-k)-D vector using the given operation
-    (add/mul/min/max for int/fp and and/or/xor for int only).
+    Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n) 
+    using the given operation (add/mul/min/max for int/fp and and/or/xor for
+    int only).
 
     Example:
 
     ```mlir
     %1 = vector.multi_reduction "add", %0 [1, 3] :
       vector<4x8x16x32xf32> into vector<4x16xf32>
+    %2 = vector.multi_reduction "add", %1 [0, 1] :
+      vector<4x16xf32> into f32
     ```
   }];
   let builders = [
@@ -322,8 +325,14 @@ def Vector_MultiDimReductionOp :
     VectorType getSourceVectorType() {
       return source().getType().cast<VectorType>();
     }
-    VectorType getDestVectorType() {
-      return dest().getType().cast<VectorType>();
+    Type getDestType() {
+      return dest().getType();
+    }
+
+    bool isReducedDim(int64_t d) {
+      assert(d >= 0 && d < static_cast<int64_t>(getReductionMask().size()) &&
+        "d overflows the number of dims");
+      return getReductionMask()[d];
     }
 
     SmallVector<bool> getReductionMask() {
@@ -341,18 +350,28 @@ def Vector_MultiDimReductionOp :
     }
 
     static SmallVector<int64_t> inferDestShape(
-      ArrayRef<int64_t> shape, ArrayRef<bool> reducedDimsMask) {
-      assert(shape.size() == reducedDimsMask.size() &&
-             "shape and maks of 
diff erent sizes");
+      ArrayRef<int64_t> sourceShape, ArrayRef<bool> reducedDimsMask) {
+      assert(sourceShape.size() == reducedDimsMask.size() &&
+             "sourceShape and maks of 
diff erent sizes");
       SmallVector<int64_t> res;
-      for (auto it : llvm::zip(reducedDimsMask, shape))
+      for (auto it : llvm::zip(reducedDimsMask, sourceShape))
         if (!std::get<0>(it))
           res.push_back(std::get<1>(it));
       return res;
     }
+
+    static Type inferDestType(
+      ArrayRef<int64_t> sourceShape, ArrayRef<bool> reducedDimsMask, Type elementType) {
+      auto targetShape = inferDestShape(sourceShape, reducedDimsMask);
+      // TODO: update to also allow 0-d vectors when available.
+      if (targetShape.empty())
+        return elementType;
+      return VectorType::get(targetShape, elementType);
+    }
   }];
   let assemblyFormat =
     "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)";
+  let hasFolder = 1;
 }
 
 def Vector_BroadcastOp :

diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 9ea8aabb698de..f0c3d9eeb2a06 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVector
   VectorOps.cpp
+  VectorMultiDimReductionTransforms.cpp
   VectorTransferOpTransforms.cpp
   VectorTransforms.cpp
   VectorUtils.cpp

diff  --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
new file mode 100644
index 0000000000000..6eba54226171d
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
@@ -0,0 +1,409 @@
+//===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
+//
+/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+/// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// This file implements target-independent rewrites of MultiDimReductionOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#define DEBUG_TYPE "vector-multi-reduction"
+
+using namespace mlir;
+
+/// This file implements the following transformations as composable atomic
+/// patterns.
+
+/// Converts vector.multi_reduction into inner-most/outer-most reduction form
+/// by using vector.transpose
+class InnerOuterDimReductionConversion
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+public:
+  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+  explicit InnerOuterDimReductionConversion(MLIRContext *context,
+                                            bool useInnerDimsForReduction)
+      : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
+        useInnerDimsForReduction(useInnerDimsForReduction) {}
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto src = multiReductionOp.source();
+    auto loc = multiReductionOp.getLoc();
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+    // Separate reduction and parallel dims
+    auto reductionDimsRange =
+        multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
+    auto reductionDims = llvm::to_vector<4>(llvm::map_range(
+        reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
+    llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
+                                                  reductionDims.end());
+    int64_t reductionSize = reductionDims.size();
+    SmallVector<int64_t, 4> parallelDims;
+    for (int64_t i = 0; i < srcRank; ++i)
+      if (!reductionDimsSet.contains(i))
+        parallelDims.push_back(i);
+
+    // Add transpose only if inner-most/outer-most dimensions are not parallel
+    if (useInnerDimsForReduction &&
+        (parallelDims ==
+         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
+      return failure();
+
+    if (!useInnerDimsForReduction &&
+        (parallelDims !=
+         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
+      return failure();
+
+    SmallVector<int64_t, 4> indices;
+    if (useInnerDimsForReduction) {
+      indices.append(parallelDims.begin(), parallelDims.end());
+      indices.append(reductionDims.begin(), reductionDims.end());
+    } else {
+      indices.append(reductionDims.begin(), reductionDims.end());
+      indices.append(parallelDims.begin(), parallelDims.end());
+    }
+    auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
+    SmallVector<bool> reductionMask(srcRank, false);
+    for (int i = 0; i < reductionSize; ++i) {
+      if (useInnerDimsForReduction)
+        reductionMask[srcRank - i - 1] = true;
+      else
+        reductionMask[i] = true;
+    }
+    rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
+        multiReductionOp, transposeOp.result(), reductionMask,
+        multiReductionOp.kind());
+    return success();
+  }
+
+private:
+  const bool useInnerDimsForReduction;
+};
+
+/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
+/// dimensions are either inner most or outer most.
+class ReduceMultiDimReductionRank
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+public:
+  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+  explicit ReduceMultiDimReductionRank(MLIRContext *context,
+                                       bool useInnerDimsForReduction)
+      : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
+        useInnerDimsForReduction(useInnerDimsForReduction) {}
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    auto srcShape = multiReductionOp.getSourceVectorType().getShape();
+    auto loc = multiReductionOp.getLoc();
+
+    // If rank less than 2, nothing to do.
+    if (srcRank < 2)
+      return failure();
+
+    // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
+    SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
+    if (srcRank == 2 && reductionMask.front() != reductionMask.back())
+      return failure();
+
+    // 1. Separate reduction and parallel dims.
+    SmallVector<int64_t, 4> parallelDims, parallelShapes;
+    SmallVector<int64_t, 4> reductionDims, reductionShapes;
+    for (auto it : llvm::enumerate(reductionMask)) {
+      int64_t i = it.index();
+      bool isReduction = it.value();
+      if (isReduction) {
+        reductionDims.push_back(i);
+        reductionShapes.push_back(srcShape[i]);
+      } else {
+        parallelDims.push_back(i);
+        parallelShapes.push_back(srcShape[i]);
+      }
+    }
+
+    // 2. Compute flattened parallel and reduction sizes.
+    int flattenedParallelDim = 0;
+    int flattenedReductionDim = 0;
+    if (parallelShapes.size() > 0) {
+      flattenedParallelDim = 1;
+      for (auto d : parallelShapes)
+        flattenedParallelDim *= d;
+    }
+    if (reductionShapes.size() > 0) {
+      flattenedReductionDim = 1;
+      for (auto d : reductionShapes)
+        flattenedReductionDim *= d;
+    }
+    // We must at least have some parallel or some reduction.
+    assert((flattenedParallelDim || flattenedReductionDim) &&
+           "expected at least one parallel or reduction dim");
+
+    // 3. Fail if reduction/parallel dims are not contiguous.
+    // Check parallelDims are exactly [0 .. size).
+    int64_t counter = 0;
+    if (useInnerDimsForReduction &&
+        llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
+      return failure();
+    // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
+    counter = reductionDims.size();
+    if (!useInnerDimsForReduction &&
+        llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
+      return failure();
+
+    // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
+    // a single parallel (resp. reduction) dim.
+    SmallVector<bool, 2> mask;
+    SmallVector<int64_t, 2> vectorShape;
+    if (flattenedParallelDim) {
+      mask.push_back(false);
+      vectorShape.push_back(flattenedParallelDim);
+    }
+    if (flattenedReductionDim) {
+      mask.push_back(true);
+      vectorShape.push_back(flattenedReductionDim);
+    }
+    if (!useInnerDimsForReduction && vectorShape.size() == 2) {
+      std::swap(mask.front(), mask.back());
+      std::swap(vectorShape.front(), vectorShape.back());
+    }
+    auto castedType = VectorType::get(
+        vectorShape, multiReductionOp.getSourceVectorType().getElementType());
+    Value cast = rewriter.create<vector::ShapeCastOp>(
+        loc, castedType, multiReductionOp.source());
+
+    // 5. Creates the flattened form of vector.multi_reduction with inner/outer
+    // most dim as reduction.
+    auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+        loc, cast, mask, multiReductionOp.kind());
+
+    // 6. If there are no parallel shapes, the result is a scalar.
+    // TODO: support 0-d vectors when available.
+    if (parallelShapes.empty()) {
+      rewriter.replaceOp(multiReductionOp, newOp.dest());
+      return success();
+    }
+
+    // 7. Creates shape cast for the output n-D -> 2-D
+    VectorType outputCastedType = VectorType::get(
+        parallelShapes,
+        multiReductionOp.getSourceVectorType().getElementType());
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        multiReductionOp, outputCastedType, newOp.dest());
+    return success();
+  }
+
+private:
+  const bool useInnerDimsForReduction;
+};
+
+/// Unrolls vector.multi_reduction with outermost reductions
+/// and combines results
+struct TwoDimMultiReductionToElementWise
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    // Rank-2 ["parallel", "reduce"] or bail.
+    if (srcRank != 2)
+      return failure();
+
+    if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
+      return failure();
+
+    auto loc = multiReductionOp.getLoc();
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return failure();
+
+    Value condition;
+    Value result =
+        rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
+            .getResult();
+    for (int64_t i = 1; i < srcShape[0]; i++) {
+      auto operand =
+          rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
+      switch (multiReductionOp.kind()) {
+      case vector::CombiningKind::ADD:
+        if (elementType.isIntOrIndex())
+          result = rewriter.create<AddIOp>(loc, operand, result);
+        else
+          result = rewriter.create<AddFOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::MUL:
+        if (elementType.isIntOrIndex())
+          result = rewriter.create<MulIOp>(loc, operand, result);
+        else
+          result = rewriter.create<MulFOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::MINUI:
+        result = rewriter.create<MinUIOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::MINSI:
+        result = rewriter.create<MinSIOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::MINF:
+        result = rewriter.create<MinFOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::MAXUI:
+        result = rewriter.create<MaxUIOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::MAXSI:
+        result = rewriter.create<MaxSIOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::MAXF:
+        result = rewriter.create<MaxFOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::AND:
+        result = rewriter.create<AndOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::OR:
+        result = rewriter.create<OrOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::XOR:
+        result = rewriter.create<XOrOp>(loc, operand, result);
+        break;
+      }
+    }
+
+    rewriter.replaceOp(multiReductionOp, result);
+    return success();
+  }
+};
+
+/// Converts 2d vector.multi_reduction with inner most reduction dimension into
+/// a sequence of vector.reduction ops.
+struct TwoDimMultiReductionToReduction
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank != 2)
+      return failure();
+
+    if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
+      return failure();
+
+    auto loc = multiReductionOp.getLoc();
+    Value result = rewriter.create<ConstantOp>(
+        loc, multiReductionOp.getDestType(),
+        rewriter.getZeroAttr(multiReductionOp.getDestType()));
+    int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
+
+    // TODO: Add vector::CombiningKind attribute instead of string to
+    // vector.reduction.
+    auto getKindStr = [](vector::CombiningKind kind) {
+      switch (kind) {
+      case vector::CombiningKind::ADD:
+        return "add";
+      case vector::CombiningKind::MUL:
+        return "mul";
+      case vector::CombiningKind::MINUI:
+        return "minui";
+      case vector::CombiningKind::MINSI:
+        return "minsi";
+      case vector::CombiningKind::MINF:
+        return "minf";
+      case vector::CombiningKind::MAXUI:
+        return "maxui";
+      case vector::CombiningKind::MAXSI:
+        return "maxsi";
+      case vector::CombiningKind::MAXF:
+        return "maxf";
+      case vector::CombiningKind::AND:
+        return "and";
+      case vector::CombiningKind::OR:
+        return "or";
+      case vector::CombiningKind::XOR:
+        return "xor";
+      }
+      llvm_unreachable("unknown combining kind");
+    };
+
+    for (int i = 0; i < outerDim; ++i) {
+      auto v = rewriter.create<vector::ExtractOp>(
+          loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
+      auto reducedValue = rewriter.create<vector::ReductionOp>(
+          loc, getElementTypeOrSelf(multiReductionOp.getDestType()),
+          rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
+          ValueRange{});
+      result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
+                                                        result, i);
+    }
+    rewriter.replaceOp(multiReductionOp, result);
+    return success();
+  }
+};
+
+/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
+/// form with both a single parallel and reduction dimension.
+/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
+/// The case with a single parallel dimension is a noop and folds away
+/// separately.
+struct OneDimMultiReductionToTwoDim
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    // Rank-1 or bail.
+    if (srcRank != 1)
+      return failure();
+
+    auto loc = multiReductionOp.getLoc();
+    auto srcVectorType = multiReductionOp.getSourceVectorType();
+    auto srcShape = srcVectorType.getShape();
+    auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
+                                      srcVectorType.getElementType());
+    assert(!multiReductionOp.getDestType().isa<VectorType>() &&
+           "multi_reduction with a single dimension expects a scalar result");
+
+    // If the unique dim is reduced and we insert a parallel in front, we need a
+    // {false, true} mask.
+    SmallVector<bool, 2> mask{false, true};
+
+    /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
+    Value cast = rewriter.create<vector::ShapeCastOp>(
+        loc, castedType, multiReductionOp.source());
+    Value reduced = rewriter.create<vector::MultiDimReductionOp>(
+        loc, cast, mask, multiReductionOp.kind());
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
+                                                   ArrayRef<int64_t>{0});
+    return success();
+  }
+};
+
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
+    RewritePatternSet &patterns, bool useInnerDimsForReduction) {
+  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank,
+               OneDimMultiReductionToTwoDim>(patterns.getContext(),
+                                             useInnerDimsForReduction);
+  if (useInnerDimsForReduction)
+    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
+  else
+    patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 757e1a3362f0d..36898a44bf273 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -260,11 +260,10 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
                                         CombiningKind kind) {
   result.addOperands(source);
   auto sourceVectorType = source.getType().cast<VectorType>();
-  auto targetShape = MultiDimReductionOp::inferDestShape(
-      sourceVectorType.getShape(), reductionMask);
-  auto targetVectorType =
-      VectorType::get(targetShape, sourceVectorType.getElementType());
-  result.addTypes(targetVectorType);
+  auto targetType = MultiDimReductionOp::inferDestType(
+      sourceVectorType.getShape(), reductionMask,
+      sourceVectorType.getElementType());
+  result.addTypes(targetType);
 
   SmallVector<int64_t> reductionDims;
   for (auto en : llvm::enumerate(reductionMask))
@@ -278,17 +277,23 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
 
 static LogicalResult verify(MultiDimReductionOp op) {
   auto reductionMask = op.getReductionMask();
-  auto targetShape = MultiDimReductionOp::inferDestShape(
-      op.getSourceVectorType().getShape(), reductionMask);
-  auto targetVectorType =
-      VectorType::get(targetShape, op.getSourceVectorType().getElementType());
-  if (targetVectorType != op.getDestVectorType())
+  auto targetType = MultiDimReductionOp::inferDestType(
+      op.getSourceVectorType().getShape(), reductionMask,
+      op.getSourceVectorType().getElementType());
+  // TODO: update to support 0-d vectors when available.
+  if (targetType != op.getDestType())
     return op.emitError("invalid output vector type: ")
-           << op.getDestVectorType() << " (expected: " << targetVectorType
-           << ")";
+           << op.getDestType() << " (expected: " << targetType << ")";
   return success();
 }
 
+OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
+  // Single parallel dim, this is a noop.
+  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
+    return source();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // ReductionOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 999f37fd9dfea..c76c43afbed3f 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -875,14 +875,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
     case CombiningKind::MAXF:
       combinedResult = rewriter.create<MaxFOp>(loc, mul, acc);
       break;
-    case CombiningKind::ADD: // Already handled this special case above.
-    case CombiningKind::AND: // Only valid for integer types.
+    case CombiningKind::ADD:   // Already handled this special case above.
+    case CombiningKind::AND:   // Only valid for integer types.
     case CombiningKind::MINUI: // Only valid for integer types.
     case CombiningKind::MINSI: // Only valid for integer types.
     case CombiningKind::MAXUI: // Only valid for integer types.
     case CombiningKind::MAXSI: // Only valid for integer types.
-    case CombiningKind::OR:  // Only valid for integer types.
-    case CombiningKind::XOR: // Only valid for integer types.
+    case CombiningKind::OR:    // Only valid for integer types.
+    case CombiningKind::XOR:   // Only valid for integer types.
       return Optional<Value>();
     }
     return Optional<Value>(combinedResult);
@@ -3504,315 +3504,6 @@ class VectorCreateMaskOpConversion
   const bool enableIndexOptimizations;
 };
 
-// Converts vector.multi_reduction into inner-most/outer-most reduction form
-// by using vector.tranpose
-class InnerOuterDimReductionConversion
-    : public OpRewritePattern<vector::MultiDimReductionOp> {
-public:
-  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
-
-  explicit InnerOuterDimReductionConversion(MLIRContext *context,
-                                            bool useInnerDimsForReduction)
-      : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
-        useInnerDimsForReduction(useInnerDimsForReduction) {}
-
-  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
-                                PatternRewriter &rewriter) const override {
-    auto src = multiReductionOp.source();
-    auto loc = multiReductionOp.getLoc();
-    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
-
-    // Separate reduction and parallel dims
-    auto reductionDimsRange =
-        multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
-    auto reductionDims = llvm::to_vector<4>(llvm::map_range(
-        reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
-    llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
-                                                  reductionDims.end());
-    int64_t reductionSize = reductionDims.size();
-    SmallVector<int64_t, 4> parallelDims;
-    for (int64_t i = 0; i < srcRank; i++) {
-      if (!reductionDimsSet.contains(i))
-        parallelDims.push_back(i);
-    }
-
-    // Add transpose only if inner-most/outer-most dimensions are not parallel
-    if (useInnerDimsForReduction &&
-        (parallelDims ==
-         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
-      return failure();
-
-    if (!useInnerDimsForReduction &&
-        (parallelDims !=
-         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
-      return failure();
-
-    SmallVector<int64_t, 4> indices;
-    if (useInnerDimsForReduction) {
-      indices.append(parallelDims.begin(), parallelDims.end());
-      indices.append(reductionDims.begin(), reductionDims.end());
-    } else {
-      indices.append(reductionDims.begin(), reductionDims.end());
-      indices.append(parallelDims.begin(), parallelDims.end());
-    }
-    auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
-    SmallVector<bool> reductionMask(srcRank, false);
-    for (int i = 0; i < reductionSize; ++i) {
-      if (useInnerDimsForReduction)
-        reductionMask[srcRank - i - 1] = true;
-      else
-        reductionMask[i] = true;
-    }
-    rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
-        multiReductionOp, transposeOp.result(), reductionMask,
-        multiReductionOp.kind());
-    return success();
-  }
-
-private:
-  const bool useInnerDimsForReduction;
-};
-
-// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction
-// dimensions are either inner most or outer most.
-class ReduceMultiDimReductionRank
-    : public OpRewritePattern<vector::MultiDimReductionOp> {
-public:
-  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
-
-  explicit ReduceMultiDimReductionRank(MLIRContext *context,
-                                       bool useInnerDimsForReduction)
-      : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
-        useInnerDimsForReduction(useInnerDimsForReduction) {}
-
-  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
-    auto srcShape = multiReductionOp.getSourceVectorType().getShape();
-    auto loc = multiReductionOp.getLoc();
-    if (srcRank == 2)
-      return failure();
-
-    // Separate reduction and parallel dims
-    auto reductionDimsRange =
-        multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
-    auto reductionDims = llvm::to_vector<4>(llvm::map_range(
-        reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
-    llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
-                                                  reductionDims.end());
-    SmallVector<int64_t, 4> parallelDims, parallelShapes;
-    int canonicalReductionDim = 1;
-    int canonicalParallelDim = 1;
-    for (int64_t i = 0; i < srcRank; i++) {
-      if (!reductionDimsSet.contains(i)) {
-        parallelDims.push_back(i);
-        parallelShapes.push_back(srcShape[i]);
-        canonicalParallelDim *= srcShape[i];
-      } else {
-        canonicalReductionDim *= srcShape[i];
-      }
-    }
-
-    // Fail if reduction dims are not either inner-most or outer-most
-    if (useInnerDimsForReduction &&
-        (parallelDims !=
-         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
-      return failure();
-
-    if (!useInnerDimsForReduction &&
-        (parallelDims ==
-         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
-      return failure();
-
-    // Creates shape cast for the inputs n_d -> 2d
-    int64_t outerDim =
-        useInnerDimsForReduction ? canonicalParallelDim : canonicalReductionDim;
-    int64_t innerDim =
-        useInnerDimsForReduction ? canonicalReductionDim : canonicalParallelDim;
-
-    auto castedType = VectorType::get(
-        ArrayRef<int64_t>{outerDim, innerDim},
-        multiReductionOp.getSourceVectorType().getElementType());
-    auto castedOp = rewriter.create<vector::ShapeCastOp>(
-        loc, castedType, multiReductionOp.source());
-
-    // Creates the canonical form of 2d vector.multi_reduction with inner/outer
-    // most dim as reduction.
-    SmallVector<bool, 2> mask{!useInnerDimsForReduction,
-                              useInnerDimsForReduction};
-    auto newOp = rewriter.create<vector::MultiDimReductionOp>(
-        loc, castedOp.result(), mask, multiReductionOp.kind());
-
-    // Creates shape cast for the output 2d -> nd
-    VectorType outputCastedType = VectorType::get(
-        parallelShapes,
-        multiReductionOp.getSourceVectorType().getElementType());
-    Value castedOutputOp = rewriter.create<vector::ShapeCastOp>(
-        loc, outputCastedType, newOp.dest());
-
-    rewriter.replaceOp(multiReductionOp, castedOutputOp);
-    return success();
-  }
-
-private:
-  const bool useInnerDimsForReduction;
-};
-
-// Unrolls vector.multi_reduction with outermost reductions
-// and combines results
-struct UnrollOuterMultiReduction
-    : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
-    if (srcRank != 2)
-      return failure();
-
-    if (multiReductionOp.getReductionMask()[1] ||
-        !multiReductionOp.getReductionMask()[0])
-      return failure();
-
-    auto loc = multiReductionOp.getLoc();
-    ArrayRef<int64_t> srcShape =
-        multiReductionOp.getSourceVectorType().getShape();
-
-    Type elementType = multiReductionOp.getDestVectorType().getElementType();
-    if (!elementType.isIntOrIndexOrFloat())
-      return failure();
-
-    Value condition;
-    Value result =
-        rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
-            .getResult();
-    for (int64_t i = 1; i < srcShape[0]; i++) {
-      auto operand =
-          rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
-      switch (multiReductionOp.kind()) {
-      case vector::CombiningKind::ADD:
-        if (elementType.isIntOrIndex())
-          result = rewriter.create<AddIOp>(loc, operand, result);
-        else
-          result = rewriter.create<AddFOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MUL:
-        if (elementType.isIntOrIndex())
-          result = rewriter.create<MulIOp>(loc, operand, result);
-        else
-          result = rewriter.create<MulFOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MINUI:
-        result = rewriter.create<MinUIOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MINSI:
-        result = rewriter.create<MinSIOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MINF:
-        result = rewriter.create<MinFOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MAXUI:
-        result = rewriter.create<MaxUIOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MAXSI:
-        result = rewriter.create<MaxSIOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MAXF:
-        result = rewriter.create<MaxFOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::AND:
-        result = rewriter.create<AndOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::OR:
-        result = rewriter.create<OrOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::XOR:
-        result = rewriter.create<XOrOp>(loc, operand, result);
-        break;
-      }
-    }
-
-    rewriter.replaceOp(multiReductionOp, result);
-    return success();
-  }
-};
-
-// Converts 2d vector.multi_reduction with inner most reduction dimension into a
-// sequence of vector.reduction ops.
-struct TwoDimMultiReductionToReduction
-    : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
-    if (srcRank != 2)
-      return failure();
-
-    if (multiReductionOp.getReductionMask()[0] ||
-        !multiReductionOp.getReductionMask()[1])
-      return failure();
-
-    auto loc = multiReductionOp.getLoc();
-
-    Value result =
-        multiReductionOp.getDestVectorType().getElementType().isIntOrIndex()
-            ? rewriter.create<ConstantOp>(
-                  loc, multiReductionOp.getDestVectorType(),
-                  DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
-                                         0))
-            : rewriter.create<ConstantOp>(
-                  loc, multiReductionOp.getDestVectorType(),
-                  DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
-                                         0.0f));
-
-    int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
-
-    // TODO: Add vector::CombiningKind attribute instead of string to
-    // vector.reduction.
-    auto getKindStr = [](vector::CombiningKind kind) {
-      switch (kind) {
-      case vector::CombiningKind::ADD:
-        return "add";
-      case vector::CombiningKind::MUL:
-        return "mul";
-      case vector::CombiningKind::MINUI:
-        return "minui";
-      case vector::CombiningKind::MINSI:
-        return "minsi";
-      case vector::CombiningKind::MINF:
-        return "minf";
-      case vector::CombiningKind::MAXUI:
-        return "maxui";
-      case vector::CombiningKind::MAXSI:
-        return "maxsi";
-      case vector::CombiningKind::MAXF:
-        return "maxf";
-      case vector::CombiningKind::AND:
-        return "and";
-      case vector::CombiningKind::OR:
-        return "or";
-      case vector::CombiningKind::XOR:
-        return "xor";
-      }
-      llvm_unreachable("unknown combining kind");
-    };
-
-    for (int i = 0; i < outerDim; ++i) {
-      auto v = rewriter.create<vector::ExtractOp>(
-          loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
-      auto reducedValue = rewriter.create<vector::ReductionOp>(
-          loc, multiReductionOp.getDestVectorType().getElementType(),
-          rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
-          ValueRange{});
-      result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
-                                                        result, i);
-    }
-    rewriter.replaceOp(multiReductionOp, result);
-    return success();
-  }
-};
-
 void mlir::vector::populateVectorMaskMaterializationPatterns(
     RewritePatternSet &patterns, bool enableIndexOptimizations) {
   patterns.add<VectorCreateMaskOpConversion,
@@ -3893,16 +3584,6 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
   patterns.add<VectorLoadToMemrefLoadLowering>(patterns.getContext());
 }
 
-void mlir::vector::populateVectorMultiReductionLoweringPatterns(
-    RewritePatternSet &patterns, bool useInnerDimsForReduction) {
-  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
-      patterns.getContext(), useInnerDimsForReduction);
-  if (useInnerDimsForReduction)
-    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
-  else
-    patterns.add<UnrollOuterMultiReduction>(patterns.getContext());
-}
-
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8b3674e59b7f4..f713ac38ce761 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1026,3 +1026,14 @@ func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v
   %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
   return %1 : tensor<?x?x12xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @vector_multi_reduction_single_parallel(
+//  CHECK-SAME:     %[[v:.*]]: vector<2xf32>
+func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
+
+//       CHECK:     return %[[v]] : vector<2xf32>
+    return %0 : vector<2xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index d5afa674274d8..6f715ce95ba2f 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -621,3 +621,11 @@ func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
   return %r, %r2 : vector<32xf32>, vector<16x32xf32>
 }
 
+// CHECK-LABEL: @multi_reduction
+func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
+  %1 = vector.multi_reduction #vector.kind<add>, %0 [1, 3] :
+    vector<4x8x16x32xf32> to vector<4x16xf32>
+  %2 = vector.multi_reduction #vector.kind<add>, %1 [0, 1] :
+    vector<4x16xf32> to f32
+  return %2 : f32
+}

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 4121262722e34..192f66c047091 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -17,6 +17,18 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
 //       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<2xf32>
 //       CHECK:       return %[[RESULT_VEC]]
 
+func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
+    %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
+    return %0 : f32
+}
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
+//       CHECK:   %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+//       CHECK:   %[[REDUCED:.*]] = vector.reduction "mul", %[[CASTED]] : vector<8xf32> into f32
+//       CHECK:   %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
+//       CHECK:   %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32>
+//       CHECK:   return %[[RES]]
+
 func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
     %0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
     return %0 : vector<2x3xi32>
@@ -50,7 +62,7 @@ func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
 //       CHECK:       %[[V5R:.+]] = vector.reduction "add", %[[V5]] : vector<20xi32> into i32
 //       CHECK:       %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : i32] : vector<6xi32>
 //       CHECK:       %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
-//       CHECK:       return %[[RESULT]]     
+//       CHECK:       return %[[RESULT]]
 
 
 func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
@@ -63,7 +75,7 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x
 //       CHECK:     %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32>
 //       CHECK:     vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
 //       CHECK:     %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
-//       CHECK:       return %[[RESULT]]     
+//       CHECK:       return %[[RESULT]]
 
 func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
     %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>


        


More information about the Mlir-commits mailing list