[Mlir-commits] [mlir] 3964c1d - [mlir][vector] Split populateVectorContractLoweringPatterns

Lei Zhang llvmlistbot at llvm.org
Thu Oct 7 06:40:11 PDT 2021


Author: Lei Zhang
Date: 2021-10-07T09:39:26-04:00
New Revision: 3964c1db915b00fffb77764892b890a3075e181e

URL: https://github.com/llvm/llvm-project/commit/3964c1db915b00fffb77764892b890a3075e181e
DIFF: https://github.com/llvm/llvm-project/commit/3964c1db915b00fffb77764892b890a3075e181e.diff

LOG: [mlir][vector] Split populateVectorContractLoweringPatterns

It was bundling quite a lot of patterns that convert high-D
vector ops into low-D elementary ops. It might not be good
for all of the patterns to happen for a particular downstream
user. For example, `ShapeCastOpRewritePattern` rewrites
`vector.shape_cast` into data movement extract/insert ops.

Instead, split the entry point into multiple ones so users
can pull in patterns on demand.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D111225

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 9bc2cd4e35acf..a98ca36025e82 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -159,23 +159,29 @@ struct VectorTransformsOptions {
   }
 };
 
-/// Collect a set of transformation patterns that are related to contracting
-/// or expanding vector operations:
-///   ContractionOpLowering,
-///   ShapeCastOp2DDownCastRewritePattern,
-///   ShapeCastOp2DUpCastRewritePattern
-///   BroadcastOpLowering,
-///   OuterproductOpLowering
-/// These transformation express higher level vector ops in terms of more
-/// elementary extraction, insertion, reduction, product, and broadcast ops.
+/// Collects patterns to progressively lower vector.broadcast ops on high-D
+/// vectors to low-D vector ops.
+void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
+
+/// Collects patterns to progressively lower vector contraction ops on high-D
+/// into low-D reduction and product ops.
 void populateVectorContractLoweringPatterns(
     RewritePatternSet &patterns,
-    VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
+    VectorTransformsOptions options = VectorTransformsOptions());
+
+/// Collects patterns to progressively lower vector mask ops into elementary
+/// selection and insertion ops.
+void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
+
+/// Collects patterns to progressively lower vector.shape_cast ops on high-D
+/// vectors into 1-D/2-D vector ops by generating data movement extract/insert
+/// ops.
+void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
 
 /// Insert TransposeLowering patterns into extraction/insertion.
 void populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
-    VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
+    VectorTransformsOptions options = VectorTransformsOptions());
 
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 1a708dc4da6cc..d920bb7b0f9a6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -62,7 +62,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
+    populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns);
+    populateVectorMaskOpLoweringPatterns(patterns);
+    populateVectorShapeCastLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns);
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 8f1d9fc08593e..999f37fd9dfea 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3847,27 +3847,35 @@ void mlir::vector::populateBubbleVectorBitCastOpPatterns(
                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
 }
 
+void mlir::vector::populateVectorBroadcastLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<BroadcastOpLowering>(patterns.getContext());
+}
+
+void mlir::vector::populateVectorMaskOpLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
+      patterns.getContext());
+}
+
+void mlir::vector::populateVectorShapeCastLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ShapeCastOp2DDownCastRewritePattern,
+               ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
+      patterns.getContext());
+}
+
 void mlir::vector::populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions parameters) {
-  // clang-format off
-  patterns.add<BroadcastOpLowering,
-                  CreateMaskOpLowering,
-                  ConstantMaskOpLowering,
-                  OuterProductOpLowering,
-                  ShapeCastOp2DDownCastRewritePattern,
-                  ShapeCastOp2DUpCastRewritePattern,
-                  ShapeCastOpRewritePattern>(patterns.getContext());
-  patterns.add<ContractionOpLowering,
-                  ContractionOpToMatmulOpLowering,
-                  ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
-  // clang-format on
+    RewritePatternSet &patterns, VectorTransformsOptions options) {
+  patterns.add<OuterProductOpLowering>(patterns.getContext());
+  patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
+               ContractionOpToOuterProductOpLowering>(options,
+                                                      patterns.getContext());
 }
 
 void mlir::vector::populateVectorTransposeLoweringPatterns(
-    RewritePatternSet &patterns,
-    VectorTransformsOptions vectorTransformOptions) {
-  patterns.add<TransposeOpLowering>(vectorTransformOptions,
-                                    patterns.getContext());
+    RewritePatternSet &patterns, VectorTransformsOptions options) {
+  patterns.add<TransposeOpLowering>(options, patterns.getContext());
 }
 
 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(

diff  --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
index bf5cf5a166f21..45c985bbef1ea 100644
--- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
@@ -112,8 +112,11 @@ void TestConvVectorization::runOnOperation() {
 
   // Programmatic controlled lowering of vector.contract only.
   RewritePatternSet vectorContractLoweringPatterns(context);
+  populateVectorBroadcastLoweringPatterns(vectorContractLoweringPatterns);
   populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
                                          vectorTransformOptions);
+  populateVectorMaskOpLoweringPatterns(vectorContractLoweringPatterns);
+  populateVectorShapeCastLoweringPatterns(vectorContractLoweringPatterns);
   populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
                                           vectorTransformOptions);
   (void)applyPatternsAndFoldGreedily(module,

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 907f9aedfdb17..d4182f5c1f649 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -164,7 +164,10 @@ struct TestVectorContractionConversion
     if (lowerToFlatTranspose)
       transposeLowering = VectorTransposeLowering::Flat;
     VectorTransformsOptions options{contractLowering, transposeLowering};
+    populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns, options);
+    populateVectorMaskOpLoweringPatterns(patterns);
+    populateVectorShapeCastLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }


        


More information about the Mlir-commits mailing list