[Mlir-commits] [mlir] 499e89f - Add patterns to lower vector.multi_reduction into a sequence of vector.reduction
Ahmed Taei
llvmlistbot at llvm.org
Fri Apr 30 10:53:20 PDT 2021
Author: Ahmed Taei
Date: 2021-04-30T10:52:21-07:00
New Revision: 499e89fc9119d901132bcc8ab460b1c161c22acc
URL: https://github.com/llvm/llvm-project/commit/499e89fc9119d901132bcc8ab460b1c161c22acc
DIFF: https://github.com/llvm/llvm-project/commit/499e89fc9119d901132bcc8ab460b1c161c22acc.diff
LOG: Add patterns to lower vector.multi_reduction into a sequence of vector.reduction
Three patterns are added to convert into vector.multi_reduction into a
sequence of vector.reduction as the following:
- Transpose the inputs so inner most dimensions are always reduction.
- Reduce rank of vector.multi_reduction into 2d with inner most
reduction dim (get the 2d canical form)
- 2D canonical form is converted into a sequence of vector.reduction.
There are two things we might worth in a follow up diff:
- An scf.for (maybe optionally) around vector.reduction instead of unrolling it.
- Breakdown the vector.reduction into a sequence of vector.reduction
(e.g tree-based reduction) instead of relying on how downstream dialects
handle it.
Note: this will requires passing target-vector-length
Differential Revision: https://reviews.llvm.org/D101570
Added:
mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index c11e8112b2e8e..399a90645df22 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -92,6 +92,10 @@ void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool enableIndexOptimizations);
+// Collect a set of patterns to convert vector.multi_reduction op into
+// a sequence of vector.reduction ops.
+void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
+
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
class CombiningKindAttr
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 58e2c1d3a83dc..dab1c14e51979 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3575,6 +3575,198 @@ class VectorCreateMaskOpConversion
const bool enableIndexOptimizations;
};
+// Converts vector.multi_reduction into inner-most reduction form by inserting
+// vector.transpose
+struct InnerDimReductionConversion
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto src = multiReductionOp.source();
+ auto loc = multiReductionOp.getLoc();
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+ auto reductionDims = llvm::to_vector<4>(
+ llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
+ [](Attribute attr) -> int64_t {
+ return attr.cast<IntegerAttr>().getInt();
+ }));
+ llvm::sort(reductionDims);
+
+ int64_t reductionSize = multiReductionOp.reduction_dims().size();
+
+ // Fails if already inner most reduction.
+ bool innerMostReduction = true;
+ for (int i = 0; i < reductionSize; ++i) {
+ if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
+ innerMostReduction = false;
+ }
+ }
+ if (innerMostReduction)
+ return failure();
+
+ // Permutes the indices so reduction dims are inner most dims.
+ SmallVector<int64_t> indices;
+ for (int i = 0; i < srcRank; ++i) {
+ indices.push_back(i);
+ }
+ int ir = reductionSize - 1;
+ int id = srcRank - 1;
+ while (ir >= 0) {
+ std::swap(indices[reductionDims[ir--]], indices[id--]);
+ }
+
+ // Sets inner most dims as reduction.
+ SmallVector<bool> reductionMask(srcRank, false);
+ for (int i = 0; i < reductionSize; ++i) {
+ reductionMask[srcRank - i - 1] = true;
+ }
+ auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
+ rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
+ multiReductionOp, transposeOp.result(), reductionMask,
+ multiReductionOp.kind());
+ return success();
+ }
+};
+
+// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction
+// dimensions are inner most.
+struct ReduceMultiDimReductionRank
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ auto srcShape = multiReductionOp.getSourceVectorType().getShape();
+ if (srcRank == 2)
+ return failure();
+
+ auto loc = multiReductionOp.getLoc();
+ auto reductionDims = llvm::to_vector<4>(
+ llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
+ [](Attribute attr) -> int64_t {
+ return attr.cast<IntegerAttr>().getInt();
+ }));
+ llvm::sort(reductionDims);
+
+ // Fails if not inner most reduction.
+ int64_t reductionSize = reductionDims.size();
+ bool innerMostReduction = true;
+ for (int i = 0; i < reductionSize; ++i) {
+ if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
+ innerMostReduction = false;
+ }
+ }
+ if (!innerMostReduction)
+ return failure();
+
+ // Extracts 2d rank reduction shape.
+ int innerDims = 1;
+ int outterDims = 1;
+ SmallVector<int64_t> innerDimsShape;
+ for (int i = 0; i < srcRank; ++i) {
+ if (i < (srcRank - reductionSize)) {
+ innerDims *= srcShape[i];
+ innerDimsShape.push_back(srcShape[i]);
+ } else {
+ outterDims *= srcShape[i];
+ }
+ }
+
+ // Creates shape cast for the inputs n_d -> 2d
+ auto castedType = VectorType::get(
+ {innerDims, outterDims},
+ multiReductionOp.getSourceVectorType().getElementType());
+ auto castedOp = rewriter.create<vector::ShapeCastOp>(
+ loc, castedType, multiReductionOp.source());
+
+ // Creates the canonical form of 2d vector.multi_reduction with inner most
+ // dim as reduction.
+ auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+ loc, castedOp.result(), ArrayRef<bool>{false, true},
+ multiReductionOp.kind());
+
+ // Creates shape cast for the output 2d -> nd
+ auto outputCastedType = VectorType::get(
+ innerDimsShape,
+ multiReductionOp.getSourceVectorType().getElementType());
+ Value castedOutputOp = rewriter.create<vector::ShapeCastOp>(
+ loc, outputCastedType, newOp.dest());
+
+ rewriter.replaceOp(multiReductionOp, castedOutputOp);
+ return success();
+ }
+};
+
+// Converts 2d vector.multi_reduction with inner most reduction dimension into a
+// sequence of vector.reduction ops.
+struct TwoDimMultiReductionToReduction
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (srcRank != 2)
+ return failure();
+
+ if (multiReductionOp.getReductionMask()[0] ||
+ !multiReductionOp.getReductionMask()[1])
+ return failure();
+
+ auto loc = multiReductionOp.getLoc();
+
+ Value result =
+ multiReductionOp.getDestVectorType().getElementType().isIntOrIndex()
+ ? rewriter.create<ConstantOp>(
+ loc, multiReductionOp.getDestVectorType(),
+ DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
+ 0))
+ : rewriter.create<ConstantOp>(
+ loc, multiReductionOp.getDestVectorType(),
+ DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
+ 0.0f));
+
+ int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
+
+ // TODO: Add vector::CombiningKind attribute instead of string to
+ // vector.reduction.
+ auto getKindStr = [](vector::CombiningKind kind) {
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ return "add";
+ case vector::CombiningKind::MUL:
+ return "mul";
+ case vector::CombiningKind::MIN:
+ return "min";
+ case vector::CombiningKind::MAX:
+ return "max";
+ case vector::CombiningKind::AND:
+ return "and";
+ case vector::CombiningKind::OR:
+ return "or";
+ case vector::CombiningKind::XOR:
+ return "xor";
+ }
+ };
+
+ for (int i = 0; i < outerDim; ++i) {
+ auto v = rewriter.create<vector::ExtractOp>(
+ loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
+ auto reducedValue = rewriter.create<vector::ReductionOp>(
+ loc, multiReductionOp.getDestVectorType().getElementType(),
+ rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
+ ValueRange{});
+ result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
+ result, i);
+ }
+ rewriter.replaceOp(multiReductionOp, result);
+ return success();
+ }
+};
+
void mlir::vector::populateVectorMaskMaterializationPatterns(
RewritePatternSet &patterns, bool enableIndexOptimizations) {
patterns.add<VectorCreateMaskOpConversion,
@@ -3645,3 +3837,9 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
TransferReadPermutationLowering, TransferOpReduceRank>(
patterns.getContext());
}
+
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<InnerDimReductionConversion, ReduceMultiDimReductionRank,
+ TwoDimMultiReductionToReduction>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
new file mode 100644
index 0000000000000..6cfc4e035719d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
+
+func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+// CHECK-LABEL: func @vector_multi_reduction
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
+// CHECK: %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<2xf32>
+// CHECK: %[[C0:.+]] = constant 0 : i32
+// CHECK: %[[C1:.+]] = constant 1 : i32
+// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0]
+// CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<4xf32> into f32
+// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<2xf32>
+// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+// CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<4xf32> into f32
+// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<2xf32>
+// CHECK: return %[[RESULT_VEC]]
+
+func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
+ %0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+ return %0 : vector<2x3xi32>
+}
+// CHECK-LABEL: func @vector_reduction_inner
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>
+// CHECK: %[[FLAT_RESULT_VEC_0:.+]] = constant dense<0> : vector<6xi32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : i32
+// CHECK-DAG: %[[C1:.+]] = constant 1 : i32
+// CHECK-DAG: %[[C2:.+]] = constant 2 : i32
+// CHECK-DAG: %[[C3:.+]] = constant 3 : i32
+// CHECK-DAG: %[[C4:.+]] = constant 4 : i32
+// CHECK-DAG: %[[C5:.+]] = constant 5 : i32
+// CHECK: %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
+// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32>
+// CHECK: %[[V0R:.+]] = vector.reduction "add", %[[V0]] : vector<20xi32> into i32
+// CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : i32] : vector<6xi32>
+// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32>
+// CHECK: %[[V1R:.+]] = vector.reduction "add", %[[V1]] : vector<20xi32> into i32
+// CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : i32] : vector<6xi32>
+// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32>
+// CHECK: %[[V2R:.+]] = vector.reduction "add", %[[V2]] : vector<20xi32> into i32
+// CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : i32] : vector<6xi32>
+// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32>
+// CHECK: %[[V3R:.+]] = vector.reduction "add", %[[V3]] : vector<20xi32> into i32
+// CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : i32] : vector<6xi32>
+// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32>
+// CHECK: %[[V4R:.+]] = vector.reduction "add", %[[V4]] : vector<20xi32> into i32
+// CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : i32] : vector<6xi32>
+/// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32>
+// CHECK: %[[V5R:.+]] = vector.reduction "add", %[[V5]] : vector<20xi32> into i32
+// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : i32] : vector<6xi32>
+// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
+// CHECK: return %[[RESULT]]
+
+
+func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
+ %0 = vector.multi_reduction #vector.kind<add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
+ return %0 : vector<2x5xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_transposed
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xf32>
+// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32>
+// CHEKC: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
+// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
+// CHECK: return %[[RESULT]]
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index b125680efa70c..d78ce3db4bca5 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -376,6 +376,19 @@ struct TestVectorTransferLoweringPatterns
}
};
+struct TestVectorMultiReductionLoweringPatterns
+ : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
+ FunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<memref::MemRefDialect>();
+ }
+ void runOnFunction() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorMultiReductionLoweringPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+ }
+};
+
struct TestProgressiveVectorToSCFLoweringPatterns
: public PassWrapper<TestProgressiveVectorToSCFLoweringPatterns,
FunctionPass> {
@@ -439,6 +452,12 @@ void registerTestVectorConversions() {
PassRegistration<TestProgressiveVectorToSCFLoweringPatterns> transferOpToSCF(
"test-progressive-convert-vector-to-scf",
"Test conversion patterns to progressively lower transfer ops to SCF");
+
+ PassRegistration<TestVectorMultiReductionLoweringPatterns>
+ multiDimReductionOpLoweringPass(
+ "test-vector-multi-reduction-lowering-patterns",
+ "Test conversion patterns to lower vector.multi_reduction to other "
+ "vector ops");
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list