[Mlir-commits] [mlir] be8e280 - [mlir][vector][NFC] split TransposeOp lowerning out of contractLowering

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 3 10:24:25 PDT 2021


Author: thomasraoux
Date: 2021-05-03T10:23:45-07:00
New Revision: be8e2801a4f305213fef99ae3f7e80cfeb2a7c08

URL: https://github.com/llvm/llvm-project/commit/be8e2801a4f305213fef99ae3f7e80cfeb2a7c08
DIFF: https://github.com/llvm/llvm-project/commit/be8e2801a4f305213fef99ae3f7e80cfeb2a7c08.diff

LOG: [mlir][vector][NFC] split TransposeOp lowerning out of contractLowering

Move TransposeOp lowering in its own populate function as in some cases
it is better to keep it during ContractOp lowering to better
canonicalize it rather than emiting scalar insert/extract.

Differential Revision: https://reviews.llvm.org/D101647

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Transforms/TestConvVectorization.cpp
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 399a90645df22..20c6e9601166c 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -173,7 +173,6 @@ struct VectorTransformsOptions {
 ///   ShapeCastOp2DDownCastRewritePattern,
 ///   ShapeCastOp2DUpCastRewritePattern
 ///   BroadcastOpLowering,
-///   TransposeOpLowering
 ///   OuterproductOpLowering
 /// These transformation express higher level vector ops in terms of more
 /// elementary extraction, insertion, reduction, product, and broadcast ops.
@@ -181,6 +180,11 @@ void populateVectorContractLoweringPatterns(
     RewritePatternSet &patterns,
     VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
 
+/// Insert TransposeLowering patterns into extraction/insertion.
+void populateVectorTransposeLoweringPatterns(
+    RewritePatternSet &patterns,
+    VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
+
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);
 

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index ec6c707a36565..85cd065162590 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -64,6 +64,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateVectorToVectorCanonicalizationPatterns(patterns);
     populateVectorSlicesLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns);
+    populateVectorTransposeLoweringPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index dab1c14e51979..75017973d30f3 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3823,13 +3823,19 @@ void mlir::vector::populateVectorContractLoweringPatterns(
                   ShapeCastOp2DDownCastRewritePattern,
                   ShapeCastOp2DUpCastRewritePattern,
                   ShapeCastOpRewritePattern>(patterns.getContext());
-  patterns.add<TransposeOpLowering,
-                  ContractionOpLowering,
+  patterns.add<ContractionOpLowering,
                   ContractionOpToMatmulOpLowering,
                   ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
   // clang-format on
 }
 
+void mlir::vector::populateVectorTransposeLoweringPatterns(
+    RewritePatternSet &patterns,
+    VectorTransformsOptions vectorTransformOptions) {
+  patterns.add<TransposeOpLowering>(vectorTransformOptions,
+                                    patterns.getContext());
+}
+
 void mlir::vector::populateVectorTransferLoweringPatterns(
     RewritePatternSet &patterns) {
   patterns

diff  --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp
index 3ccfa93e0a0d3..74cef197f7adb 100644
--- a/mlir/test/lib/Transforms/TestConvVectorization.cpp
+++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp
@@ -109,6 +109,8 @@ void TestConvVectorization::runOnOperation() {
   RewritePatternSet vectorContractLoweringPatterns(context);
   populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
                                          vectorTransformsOptions);
+  populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
+                                          vectorTransformsOptions);
   (void)applyPatternsAndFoldGreedily(module,
                                      std::move(vectorContractLoweringPatterns));
 

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index d78ce3db4bca5..0e62d2c09be76 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -140,6 +140,7 @@ struct TestVectorContractionConversion
       transposeLowering = VectorTransposeLowering::Flat;
     VectorTransformsOptions options{contractLowering, transposeLowering};
     populateVectorContractLoweringPatterns(patterns, options);
+    populateVectorTransposeLoweringPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list