[Mlir-commits] [mlir] 31270eb - [mlir][Vector] Let vector.multi_reduction reduce down to a scalar.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Oct 12 04:03:59 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-12T11:03:54Z
New Revision: 31270eb16501cca73fb3fbac254fe9965a3f3fc1
URL: https://github.com/llvm/llvm-project/commit/31270eb16501cca73fb3fbac254fe9965a3f3fc1
DIFF: https://github.com/llvm/llvm-project/commit/31270eb16501cca73fb3fbac254fe9965a3f3fc1.diff
LOG: [mlir][Vector] Let vector.multi_reduction reduce down to a scalar.
vector.multi_reduction currently does not allow reducing down to a scalar.
This creates corner cases that are hard to handle during vectorization.
This revision extends the semantics and adds the proper transforms, lowerings and canonicalizations to allow lowering out of vector.multi_reduction to other abstractions all the way to LLVM.
In a future, where we will also allow 0-d vectors, scalars will still be relevant: 0-d vector and scalars are not equivalent on all hardware.
In the process, splice out the implementation patterns related to vector.multi_reduce into a new file.
Reviewed By: pifon2a
Differential Revision: https://reviews.llvm.org/D111442
Added:
mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/CMakeLists.txt
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index a98ca36025e82..a6fbf93f29a08 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -79,8 +79,28 @@ void populateVectorTransferPermutationMapLoweringPatterns(
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool enableIndexOptimizations);
-// Collect a set of patterns to convert vector.multi_reduction op into
-// a sequence of vector.reduction ops.
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
+/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
+/// that all reduction dimensions are either innermost or outermost, by adding
+/// the proper vector.transpose operations.
+/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
+/// form, with an **outermost** reduction dimension, unroll the outer dimension
+/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
+/// tree-reduction (in the future).
+/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
+/// with an **innermost** reduction dimension, unroll the outer dimension to
+/// obtain a sequence of extract + vector.reduction + insert. This can further
+/// lower to horizontal reduction ops.
+/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
+/// reduction (and are thus missing either a parallel or a reduction), we lift
+/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
+/// the other patterns can kick in, thus fully exiting out of the
+/// vector.multi_reduction abstraction.
void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, bool useInnerDimsForReduction = false);
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index ea86336fd787e..c334773d6654e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -301,14 +301,17 @@ def Vector_MultiDimReductionOp :
Results<(outs AnyType:$dest)> {
let summary = "Multi-dimensional reduction operation";
let description = [{
- Reduces an n-D vector into an (n-k)-D vector using the given operation
- (add/mul/min/max for int/fp and and/or/xor for int only).
+ Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n)
+ using the given operation (add/mul/min/max for int/fp and and/or/xor for
+ int only).
Example:
```mlir
%1 = vector.multi_reduction "add", %0 [1, 3] :
vector<4x8x16x32xf32> into vector<4x16xf32>
+ %2 = vector.multi_reduction "add", %1 [0, 1] :
+ vector<4x16xf32> into f32
```
}];
let builders = [
@@ -322,8 +325,14 @@ def Vector_MultiDimReductionOp :
VectorType getSourceVectorType() {
return source().getType().cast<VectorType>();
}
- VectorType getDestVectorType() {
- return dest().getType().cast<VectorType>();
+ Type getDestType() {
+ return dest().getType();
+ }
+
+ bool isReducedDim(int64_t d) {
+ assert(d >= 0 && d < static_cast<int64_t>(getReductionMask().size()) &&
+ "d overflows the number of dims");
+ return getReductionMask()[d];
}
SmallVector<bool> getReductionMask() {
@@ -341,18 +350,28 @@ def Vector_MultiDimReductionOp :
}
static SmallVector<int64_t> inferDestShape(
- ArrayRef<int64_t> shape, ArrayRef<bool> reducedDimsMask) {
- assert(shape.size() == reducedDimsMask.size() &&
- "shape and maks of
diff erent sizes");
+ ArrayRef<int64_t> sourceShape, ArrayRef<bool> reducedDimsMask) {
+ assert(sourceShape.size() == reducedDimsMask.size() &&
+ "sourceShape and maks of
diff erent sizes");
SmallVector<int64_t> res;
- for (auto it : llvm::zip(reducedDimsMask, shape))
+ for (auto it : llvm::zip(reducedDimsMask, sourceShape))
if (!std::get<0>(it))
res.push_back(std::get<1>(it));
return res;
}
+
+ static Type inferDestType(
+ ArrayRef<int64_t> sourceShape, ArrayRef<bool> reducedDimsMask, Type elementType) {
+ auto targetShape = inferDestShape(sourceShape, reducedDimsMask);
+ // TODO: update to also allow 0-d vectors when available.
+ if (targetShape.empty())
+ return elementType;
+ return VectorType::get(targetShape, elementType);
+ }
}];
let assemblyFormat =
"$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)";
+ let hasFolder = 1;
}
def Vector_BroadcastOp :
diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 9ea8aabb698de..f0c3d9eeb2a06 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVector
VectorOps.cpp
+ VectorMultiDimReductionTransforms.cpp
VectorTransferOpTransforms.cpp
VectorTransforms.cpp
VectorUtils.cpp
diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
new file mode 100644
index 0000000000000..6eba54226171d
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
@@ -0,0 +1,409 @@
+//===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
+//
+/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+/// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// This file implements target-independent rewrites of MultiDimReductionOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#define DEBUG_TYPE "vector-multi-reduction"
+
+using namespace mlir;
+
+/// This file implements the following transformations as composable atomic
+/// patterns.
+
+/// Converts vector.multi_reduction into inner-most/outer-most reduction form
+/// by using vector.transpose
+class InnerOuterDimReductionConversion
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+public:
+ using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+ explicit InnerOuterDimReductionConversion(MLIRContext *context,
+ bool useInnerDimsForReduction)
+ : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
+ useInnerDimsForReduction(useInnerDimsForReduction) {}
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto src = multiReductionOp.source();
+ auto loc = multiReductionOp.getLoc();
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+ // Separate reduction and parallel dims
+ auto reductionDimsRange =
+ multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
+ auto reductionDims = llvm::to_vector<4>(llvm::map_range(
+ reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
+ llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
+ reductionDims.end());
+ int64_t reductionSize = reductionDims.size();
+ SmallVector<int64_t, 4> parallelDims;
+ for (int64_t i = 0; i < srcRank; ++i)
+ if (!reductionDimsSet.contains(i))
+ parallelDims.push_back(i);
+
+ // Add transpose only if inner-most/outer-most dimensions are not parallel
+ if (useInnerDimsForReduction &&
+ (parallelDims ==
+ llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
+ return failure();
+
+ if (!useInnerDimsForReduction &&
+ (parallelDims !=
+ llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
+ return failure();
+
+ SmallVector<int64_t, 4> indices;
+ if (useInnerDimsForReduction) {
+ indices.append(parallelDims.begin(), parallelDims.end());
+ indices.append(reductionDims.begin(), reductionDims.end());
+ } else {
+ indices.append(reductionDims.begin(), reductionDims.end());
+ indices.append(parallelDims.begin(), parallelDims.end());
+ }
+ auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
+ SmallVector<bool> reductionMask(srcRank, false);
+ for (int i = 0; i < reductionSize; ++i) {
+ if (useInnerDimsForReduction)
+ reductionMask[srcRank - i - 1] = true;
+ else
+ reductionMask[i] = true;
+ }
+ rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
+ multiReductionOp, transposeOp.result(), reductionMask,
+ multiReductionOp.kind());
+ return success();
+ }
+
+private:
+ const bool useInnerDimsForReduction;
+};
+
+/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
+/// dimensions are either inner most or outer most.
+class ReduceMultiDimReductionRank
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+public:
+ using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+ explicit ReduceMultiDimReductionRank(MLIRContext *context,
+ bool useInnerDimsForReduction)
+ : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
+ useInnerDimsForReduction(useInnerDimsForReduction) {}
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ auto srcShape = multiReductionOp.getSourceVectorType().getShape();
+ auto loc = multiReductionOp.getLoc();
+
+ // If rank less than 2, nothing to do.
+ if (srcRank < 2)
+ return failure();
+
+ // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
+ SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
+ if (srcRank == 2 && reductionMask.front() != reductionMask.back())
+ return failure();
+
+ // 1. Separate reduction and parallel dims.
+ SmallVector<int64_t, 4> parallelDims, parallelShapes;
+ SmallVector<int64_t, 4> reductionDims, reductionShapes;
+ for (auto it : llvm::enumerate(reductionMask)) {
+ int64_t i = it.index();
+ bool isReduction = it.value();
+ if (isReduction) {
+ reductionDims.push_back(i);
+ reductionShapes.push_back(srcShape[i]);
+ } else {
+ parallelDims.push_back(i);
+ parallelShapes.push_back(srcShape[i]);
+ }
+ }
+
+ // 2. Compute flattened parallel and reduction sizes.
+ int flattenedParallelDim = 0;
+ int flattenedReductionDim = 0;
+ if (parallelShapes.size() > 0) {
+ flattenedParallelDim = 1;
+ for (auto d : parallelShapes)
+ flattenedParallelDim *= d;
+ }
+ if (reductionShapes.size() > 0) {
+ flattenedReductionDim = 1;
+ for (auto d : reductionShapes)
+ flattenedReductionDim *= d;
+ }
+ // We must at least have some parallel or some reduction.
+ assert((flattenedParallelDim || flattenedReductionDim) &&
+ "expected at least one parallel or reduction dim");
+
+ // 3. Fail if reduction/parallel dims are not contiguous.
+ // Check parallelDims are exactly [0 .. size).
+ int64_t counter = 0;
+ if (useInnerDimsForReduction &&
+ llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
+ return failure();
+ // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
+ counter = reductionDims.size();
+ if (!useInnerDimsForReduction &&
+ llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
+ return failure();
+
+ // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
+ // a single parallel (resp. reduction) dim.
+ SmallVector<bool, 2> mask;
+ SmallVector<int64_t, 2> vectorShape;
+ if (flattenedParallelDim) {
+ mask.push_back(false);
+ vectorShape.push_back(flattenedParallelDim);
+ }
+ if (flattenedReductionDim) {
+ mask.push_back(true);
+ vectorShape.push_back(flattenedReductionDim);
+ }
+ if (!useInnerDimsForReduction && vectorShape.size() == 2) {
+ std::swap(mask.front(), mask.back());
+ std::swap(vectorShape.front(), vectorShape.back());
+ }
+ auto castedType = VectorType::get(
+ vectorShape, multiReductionOp.getSourceVectorType().getElementType());
+ Value cast = rewriter.create<vector::ShapeCastOp>(
+ loc, castedType, multiReductionOp.source());
+
+ // 5. Creates the flattened form of vector.multi_reduction with inner/outer
+ // most dim as reduction.
+ auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+ loc, cast, mask, multiReductionOp.kind());
+
+ // 6. If there are no parallel shapes, the result is a scalar.
+ // TODO: support 0-d vectors when available.
+ if (parallelShapes.empty()) {
+ rewriter.replaceOp(multiReductionOp, newOp.dest());
+ return success();
+ }
+
+ // 7. Creates shape cast for the output n-D -> 2-D
+ VectorType outputCastedType = VectorType::get(
+ parallelShapes,
+ multiReductionOp.getSourceVectorType().getElementType());
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ multiReductionOp, outputCastedType, newOp.dest());
+ return success();
+ }
+
+private:
+ const bool useInnerDimsForReduction;
+};
+
+/// Unrolls vector.multi_reduction with outermost reductions
+/// and combines results
+struct TwoDimMultiReductionToElementWise
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ // Rank-2 ["parallel", "reduce"] or bail.
+ if (srcRank != 2)
+ return failure();
+
+ if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
+ return failure();
+
+ auto loc = multiReductionOp.getLoc();
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return failure();
+
+ Value condition;
+ Value result =
+ rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
+ .getResult();
+ for (int64_t i = 1; i < srcShape[0]; i++) {
+ auto operand =
+ rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
+ switch (multiReductionOp.kind()) {
+ case vector::CombiningKind::ADD:
+ if (elementType.isIntOrIndex())
+ result = rewriter.create<AddIOp>(loc, operand, result);
+ else
+ result = rewriter.create<AddFOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MUL:
+ if (elementType.isIntOrIndex())
+ result = rewriter.create<MulIOp>(loc, operand, result);
+ else
+ result = rewriter.create<MulFOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MINUI:
+ result = rewriter.create<MinUIOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MINSI:
+ result = rewriter.create<MinSIOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MINF:
+ result = rewriter.create<MinFOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MAXUI:
+ result = rewriter.create<MaxUIOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MAXSI:
+ result = rewriter.create<MaxSIOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MAXF:
+ result = rewriter.create<MaxFOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::AND:
+ result = rewriter.create<AndOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::OR:
+ result = rewriter.create<OrOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::XOR:
+ result = rewriter.create<XOrOp>(loc, operand, result);
+ break;
+ }
+ }
+
+ rewriter.replaceOp(multiReductionOp, result);
+ return success();
+ }
+};
+
+/// Converts 2d vector.multi_reduction with inner most reduction dimension into
+/// a sequence of vector.reduction ops.
+struct TwoDimMultiReductionToReduction
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (srcRank != 2)
+ return failure();
+
+ if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
+ return failure();
+
+ auto loc = multiReductionOp.getLoc();
+ Value result = rewriter.create<ConstantOp>(
+ loc, multiReductionOp.getDestType(),
+ rewriter.getZeroAttr(multiReductionOp.getDestType()));
+ int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
+
+ // TODO: Add vector::CombiningKind attribute instead of string to
+ // vector.reduction.
+ auto getKindStr = [](vector::CombiningKind kind) {
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ return "add";
+ case vector::CombiningKind::MUL:
+ return "mul";
+ case vector::CombiningKind::MINUI:
+ return "minui";
+ case vector::CombiningKind::MINSI:
+ return "minsi";
+ case vector::CombiningKind::MINF:
+ return "minf";
+ case vector::CombiningKind::MAXUI:
+ return "maxui";
+ case vector::CombiningKind::MAXSI:
+ return "maxsi";
+ case vector::CombiningKind::MAXF:
+ return "maxf";
+ case vector::CombiningKind::AND:
+ return "and";
+ case vector::CombiningKind::OR:
+ return "or";
+ case vector::CombiningKind::XOR:
+ return "xor";
+ }
+ llvm_unreachable("unknown combining kind");
+ };
+
+ for (int i = 0; i < outerDim; ++i) {
+ auto v = rewriter.create<vector::ExtractOp>(
+ loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
+ auto reducedValue = rewriter.create<vector::ReductionOp>(
+ loc, getElementTypeOrSelf(multiReductionOp.getDestType()),
+ rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
+ ValueRange{});
+ result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
+ result, i);
+ }
+ rewriter.replaceOp(multiReductionOp, result);
+ return success();
+ }
+};
+
+/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
+/// form with both a single parallel and reduction dimension.
+/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
+/// The case with a single parallel dimension is a noop and folds away
+/// separately.
+struct OneDimMultiReductionToTwoDim
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ // Rank-1 or bail.
+ if (srcRank != 1)
+ return failure();
+
+ auto loc = multiReductionOp.getLoc();
+ auto srcVectorType = multiReductionOp.getSourceVectorType();
+ auto srcShape = srcVectorType.getShape();
+ auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
+ srcVectorType.getElementType());
+ assert(!multiReductionOp.getDestType().isa<VectorType>() &&
+ "multi_reduction with a single dimension expects a scalar result");
+
+ // If the unique dim is reduced and we insert a parallel in front, we need a
+ // {false, true} mask.
+ SmallVector<bool, 2> mask{false, true};
+
+ /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
+ Value cast = rewriter.create<vector::ShapeCastOp>(
+ loc, castedType, multiReductionOp.source());
+ Value reduced = rewriter.create<vector::MultiDimReductionOp>(
+ loc, cast, mask, multiReductionOp.kind());
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
+ ArrayRef<int64_t>{0});
+ return success();
+ }
+};
+
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
+ RewritePatternSet &patterns, bool useInnerDimsForReduction) {
+ patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank,
+ OneDimMultiReductionToTwoDim>(patterns.getContext(),
+ useInnerDimsForReduction);
+ if (useInnerDimsForReduction)
+ patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
+ else
+ patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 757e1a3362f0d..36898a44bf273 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -260,11 +260,10 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
CombiningKind kind) {
result.addOperands(source);
auto sourceVectorType = source.getType().cast<VectorType>();
- auto targetShape = MultiDimReductionOp::inferDestShape(
- sourceVectorType.getShape(), reductionMask);
- auto targetVectorType =
- VectorType::get(targetShape, sourceVectorType.getElementType());
- result.addTypes(targetVectorType);
+ auto targetType = MultiDimReductionOp::inferDestType(
+ sourceVectorType.getShape(), reductionMask,
+ sourceVectorType.getElementType());
+ result.addTypes(targetType);
SmallVector<int64_t> reductionDims;
for (auto en : llvm::enumerate(reductionMask))
@@ -278,17 +277,23 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
static LogicalResult verify(MultiDimReductionOp op) {
auto reductionMask = op.getReductionMask();
- auto targetShape = MultiDimReductionOp::inferDestShape(
- op.getSourceVectorType().getShape(), reductionMask);
- auto targetVectorType =
- VectorType::get(targetShape, op.getSourceVectorType().getElementType());
- if (targetVectorType != op.getDestVectorType())
+ auto targetType = MultiDimReductionOp::inferDestType(
+ op.getSourceVectorType().getShape(), reductionMask,
+ op.getSourceVectorType().getElementType());
+ // TODO: update to support 0-d vectors when available.
+ if (targetType != op.getDestType())
return op.emitError("invalid output vector type: ")
- << op.getDestVectorType() << " (expected: " << targetVectorType
- << ")";
+ << op.getDestType() << " (expected: " << targetType << ")";
return success();
}
+OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
+ // Single parallel dim, this is a noop.
+ if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
+ return source();
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 999f37fd9dfea..c76c43afbed3f 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -875,14 +875,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
case CombiningKind::MAXF:
combinedResult = rewriter.create<MaxFOp>(loc, mul, acc);
break;
- case CombiningKind::ADD: // Already handled this special case above.
- case CombiningKind::AND: // Only valid for integer types.
+ case CombiningKind::ADD: // Already handled this special case above.
+ case CombiningKind::AND: // Only valid for integer types.
case CombiningKind::MINUI: // Only valid for integer types.
case CombiningKind::MINSI: // Only valid for integer types.
case CombiningKind::MAXUI: // Only valid for integer types.
case CombiningKind::MAXSI: // Only valid for integer types.
- case CombiningKind::OR: // Only valid for integer types.
- case CombiningKind::XOR: // Only valid for integer types.
+ case CombiningKind::OR: // Only valid for integer types.
+ case CombiningKind::XOR: // Only valid for integer types.
return Optional<Value>();
}
return Optional<Value>(combinedResult);
@@ -3504,315 +3504,6 @@ class VectorCreateMaskOpConversion
const bool enableIndexOptimizations;
};
-// Converts vector.multi_reduction into inner-most/outer-most reduction form
-// by using vector.tranpose
-class InnerOuterDimReductionConversion
- : public OpRewritePattern<vector::MultiDimReductionOp> {
-public:
- using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
-
- explicit InnerOuterDimReductionConversion(MLIRContext *context,
- bool useInnerDimsForReduction)
- : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
- useInnerDimsForReduction(useInnerDimsForReduction) {}
-
- LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
- PatternRewriter &rewriter) const override {
- auto src = multiReductionOp.source();
- auto loc = multiReductionOp.getLoc();
- auto srcRank = multiReductionOp.getSourceVectorType().getRank();
-
- // Separate reduction and parallel dims
- auto reductionDimsRange =
- multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
- auto reductionDims = llvm::to_vector<4>(llvm::map_range(
- reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
- llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
- reductionDims.end());
- int64_t reductionSize = reductionDims.size();
- SmallVector<int64_t, 4> parallelDims;
- for (int64_t i = 0; i < srcRank; i++) {
- if (!reductionDimsSet.contains(i))
- parallelDims.push_back(i);
- }
-
- // Add transpose only if inner-most/outer-most dimensions are not parallel
- if (useInnerDimsForReduction &&
- (parallelDims ==
- llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
- return failure();
-
- if (!useInnerDimsForReduction &&
- (parallelDims !=
- llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
- return failure();
-
- SmallVector<int64_t, 4> indices;
- if (useInnerDimsForReduction) {
- indices.append(parallelDims.begin(), parallelDims.end());
- indices.append(reductionDims.begin(), reductionDims.end());
- } else {
- indices.append(reductionDims.begin(), reductionDims.end());
- indices.append(parallelDims.begin(), parallelDims.end());
- }
- auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
- SmallVector<bool> reductionMask(srcRank, false);
- for (int i = 0; i < reductionSize; ++i) {
- if (useInnerDimsForReduction)
- reductionMask[srcRank - i - 1] = true;
- else
- reductionMask[i] = true;
- }
- rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
- multiReductionOp, transposeOp.result(), reductionMask,
- multiReductionOp.kind());
- return success();
- }
-
-private:
- const bool useInnerDimsForReduction;
-};
-
-// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction
-// dimensions are either inner most or outer most.
-class ReduceMultiDimReductionRank
- : public OpRewritePattern<vector::MultiDimReductionOp> {
-public:
- using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
-
- explicit ReduceMultiDimReductionRank(MLIRContext *context,
- bool useInnerDimsForReduction)
- : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
- useInnerDimsForReduction(useInnerDimsForReduction) {}
-
- LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
- PatternRewriter &rewriter) const override {
- auto srcRank = multiReductionOp.getSourceVectorType().getRank();
- auto srcShape = multiReductionOp.getSourceVectorType().getShape();
- auto loc = multiReductionOp.getLoc();
- if (srcRank == 2)
- return failure();
-
- // Separate reduction and parallel dims
- auto reductionDimsRange =
- multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
- auto reductionDims = llvm::to_vector<4>(llvm::map_range(
- reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
- llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
- reductionDims.end());
- SmallVector<int64_t, 4> parallelDims, parallelShapes;
- int canonicalReductionDim = 1;
- int canonicalParallelDim = 1;
- for (int64_t i = 0; i < srcRank; i++) {
- if (!reductionDimsSet.contains(i)) {
- parallelDims.push_back(i);
- parallelShapes.push_back(srcShape[i]);
- canonicalParallelDim *= srcShape[i];
- } else {
- canonicalReductionDim *= srcShape[i];
- }
- }
-
- // Fail if reduction dims are not either inner-most or outer-most
- if (useInnerDimsForReduction &&
- (parallelDims !=
- llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
- return failure();
-
- if (!useInnerDimsForReduction &&
- (parallelDims ==
- llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
- return failure();
-
- // Creates shape cast for the inputs n_d -> 2d
- int64_t outerDim =
- useInnerDimsForReduction ? canonicalParallelDim : canonicalReductionDim;
- int64_t innerDim =
- useInnerDimsForReduction ? canonicalReductionDim : canonicalParallelDim;
-
- auto castedType = VectorType::get(
- ArrayRef<int64_t>{outerDim, innerDim},
- multiReductionOp.getSourceVectorType().getElementType());
- auto castedOp = rewriter.create<vector::ShapeCastOp>(
- loc, castedType, multiReductionOp.source());
-
- // Creates the canonical form of 2d vector.multi_reduction with inner/outer
- // most dim as reduction.
- SmallVector<bool, 2> mask{!useInnerDimsForReduction,
- useInnerDimsForReduction};
- auto newOp = rewriter.create<vector::MultiDimReductionOp>(
- loc, castedOp.result(), mask, multiReductionOp.kind());
-
- // Creates shape cast for the output 2d -> nd
- VectorType outputCastedType = VectorType::get(
- parallelShapes,
- multiReductionOp.getSourceVectorType().getElementType());
- Value castedOutputOp = rewriter.create<vector::ShapeCastOp>(
- loc, outputCastedType, newOp.dest());
-
- rewriter.replaceOp(multiReductionOp, castedOutputOp);
- return success();
- }
-
-private:
- const bool useInnerDimsForReduction;
-};
-
-// Unrolls vector.multi_reduction with outermost reductions
-// and combines results
-struct UnrollOuterMultiReduction
- : public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
- PatternRewriter &rewriter) const override {
- auto srcRank = multiReductionOp.getSourceVectorType().getRank();
- if (srcRank != 2)
- return failure();
-
- if (multiReductionOp.getReductionMask()[1] ||
- !multiReductionOp.getReductionMask()[0])
- return failure();
-
- auto loc = multiReductionOp.getLoc();
- ArrayRef<int64_t> srcShape =
- multiReductionOp.getSourceVectorType().getShape();
-
- Type elementType = multiReductionOp.getDestVectorType().getElementType();
- if (!elementType.isIntOrIndexOrFloat())
- return failure();
-
- Value condition;
- Value result =
- rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
- .getResult();
- for (int64_t i = 1; i < srcShape[0]; i++) {
- auto operand =
- rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
- switch (multiReductionOp.kind()) {
- case vector::CombiningKind::ADD:
- if (elementType.isIntOrIndex())
- result = rewriter.create<AddIOp>(loc, operand, result);
- else
- result = rewriter.create<AddFOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MUL:
- if (elementType.isIntOrIndex())
- result = rewriter.create<MulIOp>(loc, operand, result);
- else
- result = rewriter.create<MulFOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MINUI:
- result = rewriter.create<MinUIOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MINSI:
- result = rewriter.create<MinSIOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MINF:
- result = rewriter.create<MinFOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MAXUI:
- result = rewriter.create<MaxUIOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MAXSI:
- result = rewriter.create<MaxSIOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MAXF:
- result = rewriter.create<MaxFOp>(loc, operand, result);
- break;
- case vector::CombiningKind::AND:
- result = rewriter.create<AndOp>(loc, operand, result);
- break;
- case vector::CombiningKind::OR:
- result = rewriter.create<OrOp>(loc, operand, result);
- break;
- case vector::CombiningKind::XOR:
- result = rewriter.create<XOrOp>(loc, operand, result);
- break;
- }
- }
-
- rewriter.replaceOp(multiReductionOp, result);
- return success();
- }
-};
-
-// Converts 2d vector.multi_reduction with inner most reduction dimension into a
-// sequence of vector.reduction ops.
-struct TwoDimMultiReductionToReduction
- : public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
- PatternRewriter &rewriter) const override {
- auto srcRank = multiReductionOp.getSourceVectorType().getRank();
- if (srcRank != 2)
- return failure();
-
- if (multiReductionOp.getReductionMask()[0] ||
- !multiReductionOp.getReductionMask()[1])
- return failure();
-
- auto loc = multiReductionOp.getLoc();
-
- Value result =
- multiReductionOp.getDestVectorType().getElementType().isIntOrIndex()
- ? rewriter.create<ConstantOp>(
- loc, multiReductionOp.getDestVectorType(),
- DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
- 0))
- : rewriter.create<ConstantOp>(
- loc, multiReductionOp.getDestVectorType(),
- DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
- 0.0f));
-
- int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
-
- // TODO: Add vector::CombiningKind attribute instead of string to
- // vector.reduction.
- auto getKindStr = [](vector::CombiningKind kind) {
- switch (kind) {
- case vector::CombiningKind::ADD:
- return "add";
- case vector::CombiningKind::MUL:
- return "mul";
- case vector::CombiningKind::MINUI:
- return "minui";
- case vector::CombiningKind::MINSI:
- return "minsi";
- case vector::CombiningKind::MINF:
- return "minf";
- case vector::CombiningKind::MAXUI:
- return "maxui";
- case vector::CombiningKind::MAXSI:
- return "maxsi";
- case vector::CombiningKind::MAXF:
- return "maxf";
- case vector::CombiningKind::AND:
- return "and";
- case vector::CombiningKind::OR:
- return "or";
- case vector::CombiningKind::XOR:
- return "xor";
- }
- llvm_unreachable("unknown combining kind");
- };
-
- for (int i = 0; i < outerDim; ++i) {
- auto v = rewriter.create<vector::ExtractOp>(
- loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
- auto reducedValue = rewriter.create<vector::ReductionOp>(
- loc, multiReductionOp.getDestVectorType().getElementType(),
- rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
- ValueRange{});
- result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
- result, i);
- }
- rewriter.replaceOp(multiReductionOp, result);
- return success();
- }
-};
-
void mlir::vector::populateVectorMaskMaterializationPatterns(
RewritePatternSet &patterns, bool enableIndexOptimizations) {
patterns.add<VectorCreateMaskOpConversion,
@@ -3893,16 +3584,6 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
patterns.add<VectorLoadToMemrefLoadLowering>(patterns.getContext());
}
-void mlir::vector::populateVectorMultiReductionLoweringPatterns(
- RewritePatternSet &patterns, bool useInnerDimsForReduction) {
- patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
- patterns.getContext(), useInnerDimsForReduction);
- if (useInnerDimsForReduction)
- patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
- else
- patterns.add<UnrollOuterMultiReduction>(patterns.getContext());
-}
-
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8b3674e59b7f4..f713ac38ce761 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1026,3 +1026,14 @@ func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v
%1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
return %1 : tensor<?x?x12xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @vector_multi_reduction_single_parallel(
+// CHECK-SAME: %[[v:.*]]: vector<2xf32>
+func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
+
+// CHECK: return %[[v]] : vector<2xf32>
+ return %0 : vector<2xf32>
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index d5afa674274d8..6f715ce95ba2f 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -621,3 +621,11 @@ func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
return %r, %r2 : vector<32xf32>, vector<16x32xf32>
}
+// CHECK-LABEL: @multi_reduction
+func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
+ %1 = vector.multi_reduction #vector.kind<add>, %0 [1, 3] :
+ vector<4x8x16x32xf32> to vector<4x16xf32>
+ %2 = vector.multi_reduction #vector.kind<add>, %1 [0, 1] :
+ vector<4x16xf32> to f32
+ return %2 : f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 4121262722e34..192f66c047091 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -17,6 +17,18 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]]
+func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
+ %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
+ return %0 : f32
+}
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
+// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+// CHECK: %[[REDUCED:.*]] = vector.reduction "mul", %[[CASTED]] : vector<8xf32> into f32
+// CHECK: %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32>
+// CHECK: return %[[RES]]
+
func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
return %0 : vector<2x3xi32>
@@ -50,7 +62,7 @@ func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
// CHECK: %[[V5R:.+]] = vector.reduction "add", %[[V5]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : i32] : vector<6xi32>
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
-// CHECK: return %[[RESULT]]
+// CHECK: return %[[RESULT]]
func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
@@ -63,7 +75,7 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x
// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32>
// CHECK: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
-// CHECK: return %[[RESULT]]
+// CHECK: return %[[RESULT]]
func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
More information about the Mlir-commits
mailing list