[Mlir-commits] [mlir] b57e2f0 - [mlir][Linalg] Add pad vectorization patterns into LinalgStrategyVectorize passes.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 23 11:48:04 PST 2021
Author: MaheshRavishankar
Date: 2021-11-23T11:47:54-08:00
New Revision: b57e2f071a2e47147a57c52a6a8c6aa062230cd8
URL: https://github.com/llvm/llvm-project/commit/b57e2f071a2e47147a57c52a6a8c6aa062230cd8
DIFF: https://github.com/llvm/llvm-project/commit/b57e2f071a2e47147a57c52a6a8c6aa062230cd8.diff
LOG: [mlir][Linalg] Add pad vectorization patterns into LinalgStrategyVectorize passes.
Add an option to control whether these patterns are added to the
pattern list or not.
Differential Revision: https://reviews.llvm.org/D114290
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index df26b9537377a..d324256d43110 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -128,7 +128,8 @@ createLinalgStrategyVectorizePass(StringRef opName = "",
linalg::LinalgVectorizationOptions opt =
linalg::LinalgVectorizationOptions(),
linalg::LinalgTransformationFilter filter =
- linalg::LinalgTransformationFilter());
+ linalg::LinalgTransformationFilter(),
+ bool padVectorize = false);
/// Create a LinalgStrategyEnablePass.
std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyEnablePass(
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index b781876918757..593fb89debae2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -319,6 +319,8 @@ def LinalgStrategyVectorizePass
"Which func op is the anchor to latch on.">,
Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
"Which linalg op within the func is the anchor to latch on.">,
+ Option<"vectorizePadding", "vectorize-padding", "bool", "false",
+ "Enable vectorization of padding ops.">,
];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index ea2afa01889b0..96ed3178c6f1b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -128,21 +128,27 @@ struct Interchange : public Transformation {
/// Represent one application of createLinalgStrategyVectorizePass.
struct Vectorize : public Transformation {
explicit Vectorize(linalg::LinalgVectorizationOptions options,
- LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(f), opName(), options(options) {}
+ LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool padVectorize = false)
+ : Transformation(f), opName(), options(options),
+ vectorizePadding(padVectorize) {}
Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
- LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(f), opName(name), options(options) {}
+ LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool padVectorize = false)
+ : Transformation(f), opName(name), options(options),
+ vectorizePadding(padVectorize) {}
void addToPassPipeline(OpPassManager &pm,
LinalgTransformationFilter m) const override {
- pm.addPass(createLinalgStrategyVectorizePass(opName, options, m));
+ pm.addPass(createLinalgStrategyVectorizePass(opName, options, m,
+ vectorizePadding));
}
private:
std::string opName;
linalg::LinalgVectorizationOptions options;
+ bool vectorizePadding;
};
/// Represent one application of createLinalgStrategyLowerVectorsPass.
@@ -260,18 +266,20 @@ struct CodegenStrategy {
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
CodegenStrategy &
vectorize(StringRef opName,
- LinalgTransformationFilter::FilterFunction f = nullptr) {
+ LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool vectorizePadding = false) {
assert(!opName.empty() && "expected an op name");
transformationSequence.emplace_back(std::make_unique<Vectorize>(
- opName, linalg::LinalgVectorizationOptions(), f));
+ opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding));
return *this;
}
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
/// operation.
CodegenStrategy &
vectorizeIf(bool b, StringRef opName,
- LinalgTransformationFilter::FilterFunction f = nullptr) {
- return b ? vectorize(opName, f) : *this;
+ LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool vectorizePadding = false) {
+ return b ? vectorize(opName, f, vectorizePadding) : *this;
return *this;
}
/// Append a pattern to lower all vector operations.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index acbc30a93c6c1..8ed43c8b0dbae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -260,9 +260,11 @@ struct LinalgStrategyVectorizePass
LinalgStrategyVectorizePass() = default;
LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
- LinalgTransformationFilter filt)
+ LinalgTransformationFilter filt,
+ bool padVectorize = false)
: options(opt), filter(filt) {
this->anchorOpName.setValue(opName.str());
+ this->vectorizePadding.setValue(padVectorize);
}
void runOnFunction() override {
@@ -284,6 +286,9 @@ struct LinalgStrategyVectorizePass
vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(
funcOp.getContext(), /*benefit=*/2);
+ if (vectorizePadding) {
+ linalg::populatePadTensorOpVectorizationPatterns(vectorizationPatterns);
+ }
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorizationPatterns));
}
@@ -471,11 +476,11 @@ mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
}
/// Create a LinalgStrategyVectorizePass.
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createLinalgStrategyVectorizePass(StringRef opName,
- LinalgVectorizationOptions opt,
- LinalgTransformationFilter filter) {
- return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter);
+std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass(
+ StringRef opName, LinalgVectorizationOptions opt,
+ LinalgTransformationFilter filter, bool padVectorize) {
+ return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
+ padVectorize);
}
/// Create a LinalgStrategyEnablePass.
More information about the Mlir-commits
mailing list