[Mlir-commits] [mlir] 176a0ea - [mlr][Linalg] NFC - Add option to hook vector.multi_reduction lowering to strategies.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Oct 25 04:35:49 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-25T11:31:39Z
New Revision: 176a0ea535d4403b3a8482153f366e55141c7e50
URL: https://github.com/llvm/llvm-project/commit/176a0ea535d4403b3a8482153f366e55141c7e50
DIFF: https://github.com/llvm/llvm-project/commit/176a0ea535d4403b3a8482153f366e55141c7e50.diff
LOG: [mlr][Linalg] NFC - Add option to hook vector.multi_reduction lowering to strategies.
Differential Revision: https://reviews.llvm.org/D112414
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4c2f004b5361e..640e1221aeb53 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -899,6 +899,12 @@ struct LinalgVectorLoweringOptions {
contractionLowering = val;
return *this;
}
+ /// Enable lowering of vector.multi_reduce.
+ bool multiReductionLowering = false;
+ LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
+ multiReductionLowering = val;
+ return *this;
+ }
/// Enable lowering of vector.transfer to scf.
bool transferToSCFConversion = false;
LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) {
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 5a5eadf6ad7d1..dd56cd1ea1926 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -40,6 +40,76 @@ namespace detail {
struct BitmaskEnumStorage;
} // namespace detail
+/// Enum to control the lowering of `vector.contract` operations.
+enum class VectorContractLowering {
+ /// Progressively lower to finer grained `vector.contract` and dot-products.
+ Dot = 0,
+ /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
+ Matmul = 1,
+ /// Lower to `vector.outerproduct`.
+ OuterProduct = 2,
+};
+/// Enum to control the lowering of `vector.multi_reduction` operations.
+enum class VectorMultiReductionLowering {
+ /// Lower multi_reduction into outer-reduction and inner-parallel ops.
+ InnerParallel = 0,
+ /// Lower multi_reduction into outer-parallel and inner-reduction ops.
+ InnerReduction = 1,
+};
+/// Enum to control the lowering of `vector.transpose` operations.
+enum class VectorTransposeLowering {
+ /// Lower transpose into element-wise extract and inserts.
+ EltWise = 0,
+ /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
+ /// intrinsics.
+ Flat = 1,
+};
+/// Enum to control the splitting of `vector.transfer` operations into
+/// in-bounds and out-of-bounds variants.
+enum class VectorTransferSplit {
+ /// Do not split vector transfer operations.
+ None = 0,
+ /// Split using in-bounds + out-of-bounds vector.transfer operations.
+ VectorTransfer = 1,
+ /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
+ /// operations.
+ LinalgCopy = 2,
+ /// Do not split vector transfer operation but instead mark it as "in-bounds".
+ ForceInBounds = 3
+};
+/// Structure to control the behavior of vector transform patterns.
+struct VectorTransformsOptions {
+ /// Option to control the lowering of vector.contract.
+ VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
+ VectorTransformsOptions &
+ setVectorTransformsOptions(VectorContractLowering opt) {
+ vectorContractLowering = opt;
+ return *this;
+ }
+ /// Option to control the lowering of vector.multi_reduction.
+ VectorMultiReductionLowering vectorMultiReductionLowering =
+ VectorMultiReductionLowering::InnerParallel;
+ VectorTransformsOptions &
+ setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
+ vectorMultiReductionLowering = opt;
+ return *this;
+ }
+ /// Option to control the lowering of vector.transpose.
+ VectorTransposeLowering vectorTransposeLowering =
+ VectorTransposeLowering::EltWise;
+ VectorTransformsOptions &
+ setVectorTransposeLowering(VectorTransposeLowering opt) {
+ vectorTransposeLowering = opt;
+ return *this;
+ }
+ /// Option to control the splitting of vector transfers.
+ VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
+ VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
+ vectorTransferSplit = opt;
+ return *this;
+ }
+};
+
/// Return whether `srcType` can be broadcast to `dstVectorType` under the
/// semantics of the `vector.broadcast` op.
enum class BroadcastableToResult {
@@ -114,7 +184,9 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
/// the other patterns can kick in, thus fully exiting out of the
/// vector.multi_reduction abstraction.
void populateVectorMultiReductionLoweringPatterns(
- RewritePatternSet &patterns, bool useInnerDimsForReduction = false);
+ RewritePatternSet &patterns,
+ VectorMultiReductionLowering options =
+ vector::VectorMultiReductionLowering::InnerParallel);
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
/// chain.
@@ -136,61 +208,6 @@ class CombiningKindAttr
static Attribute parse(DialectAsmParser &parser);
};
-/// Enum to control the lowering of `vector.contract` operations.
-enum class VectorContractLowering {
- /// Progressively lower to finer grained `vector.contract` and dot-products.
- Dot = 0,
- /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
- Matmul = 1,
- /// Lower to `vector.outerproduct`.
- OuterProduct = 2,
-};
-/// Enum to control the lowering of `vector.transpose` operations.
-enum class VectorTransposeLowering {
- /// Lower transpose into element-wise extract and inserts.
- EltWise = 0,
- /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
- /// intrinsics.
- Flat = 1,
-};
-/// Enum to control the splitting of `vector.transfer` operations into
-/// in-bounds and out-of-bounds variants.
-enum class VectorTransferSplit {
- /// Do not split vector transfer operations.
- None = 0,
- /// Split using in-bounds + out-of-bounds vector.transfer operations.
- VectorTransfer = 1,
- /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
- /// operations.
- LinalgCopy = 2,
- /// Do not split vector transfer operation but instead mark it as "in-bounds".
- ForceInBounds = 3
-};
-/// Structure to control the behavior of vector transform patterns.
-struct VectorTransformsOptions {
- /// Option to control the lowering of vector.contract.
- VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
- VectorTransformsOptions &
- setVectorTransformsOptions(VectorContractLowering opt) {
- vectorContractLowering = opt;
- return *this;
- }
- /// Option to control the lowering of vector.transpose.
- VectorTransposeLowering vectorTransposeLowering =
- VectorTransposeLowering::EltWise;
- VectorTransformsOptions &
- setVectorTransposeLowering(VectorTransposeLowering opt) {
- vectorTransposeLowering = opt;
- return *this;
- }
- /// Option to control the splitting of vector transfers.
- VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
- VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
- vectorTransferSplit = opt;
- return *this;
- }
-};
-
/// Collects patterns to progressively lower vector.broadcast ops on high-D
/// vectors to low-D vector ops.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 8cc06d486371b..b7506eb91aa9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -263,6 +263,7 @@ struct LinalgStrategyLowerVectorsPass
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
+ vector::populateVectorToVectorCanonicalizationPatterns(patterns);
if (options.transferLowering) {
vector::populateVectorTransferLoweringPatterns(patterns,
options.maxTransferRank);
@@ -277,6 +278,11 @@ struct LinalgStrategyLowerVectorsPass
options.vectorTransformOptions, context);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
+ if (options.multiReductionLowering) {
+ vector::populateVectorMultiReductionLoweringPatterns(
+ patterns,
+ options.vectorTransformOptions.vectorMultiReductionLowering);
+ }
if (options.transferToSCFConversion) {
populateVectorToSCFConversionPatterns(patterns,
options.vectorTransferToSCFOptions);
diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
index f738364529776..67d0db4d2cd45 100644
--- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
@@ -35,10 +35,11 @@ class InnerOuterDimReductionConversion
public:
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
- explicit InnerOuterDimReductionConversion(MLIRContext *context,
- bool useInnerDimsForReduction)
+ explicit InnerOuterDimReductionConversion(
+ MLIRContext *context, vector::VectorMultiReductionLowering options)
: mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
- useInnerDimsForReduction(useInnerDimsForReduction) {}
+ useInnerDimsForReduction(
+ options == vector::VectorMultiReductionLowering::InnerReduction) {}
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
@@ -103,10 +104,11 @@ class ReduceMultiDimReductionRank
public:
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
- explicit ReduceMultiDimReductionRank(MLIRContext *context,
- bool useInnerDimsForReduction)
+ explicit ReduceMultiDimReductionRank(
+ MLIRContext *context, vector::VectorMultiReductionLowering options)
: mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
- useInnerDimsForReduction(useInnerDimsForReduction) {}
+ useInnerDimsForReduction(
+ options == vector::VectorMultiReductionLowering::InnerReduction) {}
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
@@ -398,11 +400,11 @@ struct OneDimMultiReductionToTwoDim
};
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
- RewritePatternSet &patterns, bool useInnerDimsForReduction) {
- patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank,
- OneDimMultiReductionToTwoDim>(patterns.getContext(),
- useInnerDimsForReduction);
- if (useInnerDimsForReduction)
+ RewritePatternSet &patterns, VectorMultiReductionLowering options) {
+ patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
+ patterns.getContext(), options);
+ patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
+ if (options == VectorMultiReductionLowering ::InnerReduction)
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
else
patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
index 45c985bbef1ea..f73cd6c0a1eb0 100644
--- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
@@ -94,7 +94,8 @@ void TestConvVectorization::runOnOperation() {
//===--------------------------------------------------------------------===//
VectorTransformsOptions vectorTransformOptions{
- VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
+ VectorContractLowering::Dot, VectorMultiReductionLowering::InnerParallel,
+ VectorTransposeLowering::EltWise};
RewritePatternSet vectorTransferPatterns(context);
// Pattern is not applied because rank-reducing vector transfer is not yet
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index d5538b09dadda..b95c45d21633a 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -159,11 +159,14 @@ struct TestVectorContractionConversion
VectorContractLowering contractLowering = VectorContractLowering::Dot;
if (lowerToFlatMatrix)
contractLowering = VectorContractLowering::Matmul;
+ VectorMultiReductionLowering vectorMultiReductionLowering =
+ VectorMultiReductionLowering::InnerParallel;
VectorTransposeLowering transposeLowering =
VectorTransposeLowering::EltWise;
if (lowerToFlatTranspose)
transposeLowering = VectorTransposeLowering::Flat;
- VectorTransformsOptions options{contractLowering, transposeLowering};
+ VectorTransformsOptions options{
+ contractLowering, vectorMultiReductionLowering, transposeLowering};
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
populateVectorMaskOpLoweringPatterns(patterns);
@@ -461,7 +464,10 @@ struct TestVectorMultiReductionLoweringPatterns
llvm::cl::init(false)};
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
- populateVectorMultiReductionLoweringPatterns(patterns, !useOuterReductions);
+ populateVectorMultiReductionLoweringPatterns(
+ patterns, useOuterReductions
+ ? vector::VectorMultiReductionLowering::InnerParallel
+ : vector::VectorMultiReductionLowering::InnerReduction);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
More information about the Mlir-commits
mailing list