[Mlir-commits] [mlir] b57e2f0 - [mlir][Linalg] Add pad vectorization patterns into LinalgStrategyVectorize passes.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 23 11:48:04 PST 2021


Author: MaheshRavishankar
Date: 2021-11-23T11:47:54-08:00
New Revision: b57e2f071a2e47147a57c52a6a8c6aa062230cd8

URL: https://github.com/llvm/llvm-project/commit/b57e2f071a2e47147a57c52a6a8c6aa062230cd8
DIFF: https://github.com/llvm/llvm-project/commit/b57e2f071a2e47147a57c52a6a8c6aa062230cd8.diff

LOG: [mlir][Linalg] Add pad vectorization patterns into LinalgStrategyVectorize passes.

Add an option to control whether these patterns are added to the
pattern list or not.

Differential Revision: https://reviews.llvm.org/D114290

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index df26b9537377a..d324256d43110 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -128,7 +128,8 @@ createLinalgStrategyVectorizePass(StringRef opName = "",
                                   linalg::LinalgVectorizationOptions opt =
                                       linalg::LinalgVectorizationOptions(),
                                   linalg::LinalgTransformationFilter filter =
-                                      linalg::LinalgTransformationFilter());
+                                      linalg::LinalgTransformationFilter(),
+                                  bool padVectorize = false);
 
 /// Create a LinalgStrategyEnablePass.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyEnablePass(

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index b781876918757..593fb89debae2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -319,6 +319,8 @@ def LinalgStrategyVectorizePass
       "Which func op is the anchor to latch on.">,
     Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
       "Which linalg op within the func is the anchor to latch on.">,
+    Option<"vectorizePadding", "vectorize-padding", "bool", "false",
+      "Enable vectorization of padding ops.">,
   ];
 }
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index ea2afa01889b0..96ed3178c6f1b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -128,21 +128,27 @@ struct Interchange : public Transformation {
 /// Represent one application of createLinalgStrategyVectorizePass.
 struct Vectorize : public Transformation {
   explicit Vectorize(linalg::LinalgVectorizationOptions options,
-                     LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(f), opName(), options(options) {}
+                     LinalgTransformationFilter::FilterFunction f = nullptr,
+                     bool padVectorize = false)
+      : Transformation(f), opName(), options(options),
+        vectorizePadding(padVectorize) {}
 
   Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
-            LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(f), opName(name), options(options) {}
+            LinalgTransformationFilter::FilterFunction f = nullptr,
+            bool padVectorize = false)
+      : Transformation(f), opName(name), options(options),
+        vectorizePadding(padVectorize) {}
 
   void addToPassPipeline(OpPassManager &pm,
                          LinalgTransformationFilter m) const override {
-    pm.addPass(createLinalgStrategyVectorizePass(opName, options, m));
+    pm.addPass(createLinalgStrategyVectorizePass(opName, options, m,
+                                                 vectorizePadding));
   }
 
 private:
   std::string opName;
   linalg::LinalgVectorizationOptions options;
+  bool vectorizePadding;
 };
 
 /// Represent one application of createLinalgStrategyLowerVectorsPass.
@@ -260,18 +266,20 @@ struct CodegenStrategy {
   /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
   CodegenStrategy &
   vectorize(StringRef opName,
-            LinalgTransformationFilter::FilterFunction f = nullptr) {
+            LinalgTransformationFilter::FilterFunction f = nullptr,
+            bool vectorizePadding = false) {
     assert(!opName.empty() && "expected an op name");
     transformationSequence.emplace_back(std::make_unique<Vectorize>(
-        opName, linalg::LinalgVectorizationOptions(), f));
+        opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding));
     return *this;
   }
   /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
   /// operation.
   CodegenStrategy &
   vectorizeIf(bool b, StringRef opName,
-              LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? vectorize(opName, f) : *this;
+              LinalgTransformationFilter::FilterFunction f = nullptr,
+              bool vectorizePadding = false) {
+    return b ? vectorize(opName, f, vectorizePadding) : *this;
     return *this;
   }
   /// Append a pattern to lower all vector operations.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index acbc30a93c6c1..8ed43c8b0dbae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -260,9 +260,11 @@ struct LinalgStrategyVectorizePass
   LinalgStrategyVectorizePass() = default;
 
   LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
-                              LinalgTransformationFilter filt)
+                              LinalgTransformationFilter filt,
+                              bool padVectorize = false)
       : options(opt), filter(filt) {
     this->anchorOpName.setValue(opName.str());
+    this->vectorizePadding.setValue(padVectorize);
   }
 
   void runOnFunction() override {
@@ -284,6 +286,9 @@ struct LinalgStrategyVectorizePass
     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
                               linalg::LinalgCopyVTWForwardingPattern>(
         funcOp.getContext(), /*benefit=*/2);
+    if (vectorizePadding) {
+      linalg::populatePadTensorOpVectorizationPatterns(vectorizationPatterns);
+    }
     (void)applyPatternsAndFoldGreedily(funcOp,
                                        std::move(vectorizationPatterns));
   }
@@ -471,11 +476,11 @@ mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
 }
 
 /// Create a LinalgStrategyVectorizePass.
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createLinalgStrategyVectorizePass(StringRef opName,
-                                        LinalgVectorizationOptions opt,
-                                        LinalgTransformationFilter filter) {
-  return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter);
+std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass(
+    StringRef opName, LinalgVectorizationOptions opt,
+    LinalgTransformationFilter filter, bool padVectorize) {
+  return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
+                                                       padVectorize);
 }
 
 /// Create a LinalgStrategyEnablePass.


        


More information about the Mlir-commits mailing list