[Mlir-commits] [mlir] 3964c1d - [mlir][vector] Split populateVectorContractLoweringPatterns
Lei Zhang
llvmlistbot at llvm.org
Thu Oct 7 06:40:11 PDT 2021
Author: Lei Zhang
Date: 2021-10-07T09:39:26-04:00
New Revision: 3964c1db915b00fffb77764892b890a3075e181e
URL: https://github.com/llvm/llvm-project/commit/3964c1db915b00fffb77764892b890a3075e181e
DIFF: https://github.com/llvm/llvm-project/commit/3964c1db915b00fffb77764892b890a3075e181e.diff
LOG: [mlir][vector] Split populateVectorContractLoweringPatterns
It was bundling quite a lot of patterns that convert high-D
vector ops into low-D elementary ops. It might not be good
for all of the patterns to happen for a particular downstream
user. For example, `ShapeCastOpRewritePattern` rewrites
`vector.shape_cast` into data movement extract/insert ops.
Instead, split the entry point into multiple ones so users
can pull in patterns on demand.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D111225
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 9bc2cd4e35acf..a98ca36025e82 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -159,23 +159,29 @@ struct VectorTransformsOptions {
}
};
-/// Collect a set of transformation patterns that are related to contracting
-/// or expanding vector operations:
-/// ContractionOpLowering,
-/// ShapeCastOp2DDownCastRewritePattern,
-/// ShapeCastOp2DUpCastRewritePattern
-/// BroadcastOpLowering,
-/// OuterproductOpLowering
-/// These transformation express higher level vector ops in terms of more
-/// elementary extraction, insertion, reduction, product, and broadcast ops.
+/// Collects patterns to progressively lower vector.broadcast ops on high-D
+/// vectors to low-D vector ops.
+void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
+
+/// Collects patterns to progressively lower vector contraction ops on high-D
+/// into low-D reduction and product ops.
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns,
- VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
+ VectorTransformsOptions options = VectorTransformsOptions());
+
+/// Collects patterns to progressively lower vector mask ops into elementary
+/// selection and insertion ops.
+void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
+
+/// Collects patterns to progressively lower vector.shape_cast ops on high-D
+/// vectors into 1-D/2-D vector ops by generating data movement extract/insert
+/// ops.
+void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
/// Insert TransposeLowering patterns into extraction/insertion.
void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
- VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
+ VectorTransformsOptions options = VectorTransformsOptions());
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 1a708dc4da6cc..d920bb7b0f9a6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -62,7 +62,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns);
+ populateVectorMaskOpLoweringPatterns(patterns);
+ populateVectorShapeCastLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns);
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 8f1d9fc08593e..999f37fd9dfea 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3847,27 +3847,35 @@ void mlir::vector::populateBubbleVectorBitCastOpPatterns(
BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
}
+void mlir::vector::populateVectorBroadcastLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<BroadcastOpLowering>(patterns.getContext());
+}
+
+void mlir::vector::populateVectorMaskOpLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
+ patterns.getContext());
+}
+
+void mlir::vector::populateVectorShapeCastLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ShapeCastOp2DDownCastRewritePattern,
+ ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
+ patterns.getContext());
+}
+
void mlir::vector::populateVectorContractLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions parameters) {
- // clang-format off
- patterns.add<BroadcastOpLowering,
- CreateMaskOpLowering,
- ConstantMaskOpLowering,
- OuterProductOpLowering,
- ShapeCastOp2DDownCastRewritePattern,
- ShapeCastOp2DUpCastRewritePattern,
- ShapeCastOpRewritePattern>(patterns.getContext());
- patterns.add<ContractionOpLowering,
- ContractionOpToMatmulOpLowering,
- ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
- // clang-format on
+ RewritePatternSet &patterns, VectorTransformsOptions options) {
+ patterns.add<OuterProductOpLowering>(patterns.getContext());
+ patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
+ ContractionOpToOuterProductOpLowering>(options,
+ patterns.getContext());
}
void mlir::vector::populateVectorTransposeLoweringPatterns(
- RewritePatternSet &patterns,
- VectorTransformsOptions vectorTransformOptions) {
- patterns.add<TransposeOpLowering>(vectorTransformOptions,
- patterns.getContext());
+ RewritePatternSet &patterns, VectorTransformsOptions options) {
+ patterns.add<TransposeOpLowering>(options, patterns.getContext());
}
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
index bf5cf5a166f21..45c985bbef1ea 100644
--- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
@@ -112,8 +112,11 @@ void TestConvVectorization::runOnOperation() {
// Programmatic controlled lowering of vector.contract only.
RewritePatternSet vectorContractLoweringPatterns(context);
+ populateVectorBroadcastLoweringPatterns(vectorContractLoweringPatterns);
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
vectorTransformOptions);
+ populateVectorMaskOpLoweringPatterns(vectorContractLoweringPatterns);
+ populateVectorShapeCastLoweringPatterns(vectorContractLoweringPatterns);
populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
vectorTransformOptions);
(void)applyPatternsAndFoldGreedily(module,
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 907f9aedfdb17..d4182f5c1f649 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -164,7 +164,10 @@ struct TestVectorContractionConversion
if (lowerToFlatTranspose)
transposeLowering = VectorTransposeLowering::Flat;
VectorTransformsOptions options{contractLowering, transposeLowering};
+ populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
+ populateVectorMaskOpLoweringPatterns(patterns);
+ populateVectorShapeCastLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
More information about the Mlir-commits
mailing list