[Mlir-commits] [mlir] 800694a - [mlir][Linalg] Make a LinalgStrategyDecomposePass available.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Nov 11 09:47:31 PST 2021
Author: Nicolas Vasilache
Date: 2021-11-11T17:47:27Z
New Revision: 800694a6977c02cfd9770ef3bd5530e6fb4ff2f7
URL: https://github.com/llvm/llvm-project/commit/800694a6977c02cfd9770ef3bd5530e6fb4ff2f7
DIFF: https://github.com/llvm/llvm-project/commit/800694a6977c02cfd9770ef3bd5530e6fb4ff2f7.diff
LOG: [mlir][Linalg] Make a LinalgStrategyDecomposePass available.
Differential Revision: https://reviews.llvm.org/D113684
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 1c9b7252942f3..c0173ec2f443a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -67,7 +67,7 @@ std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass();
/// buffers instead.
std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
-/// Create a pass to conver named Linalg operations to Linalg generic
+/// Create a pass to convert named Linalg operations to Linalg generic
/// operations.
std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
@@ -108,6 +108,11 @@ createLinalgStrategyGeneralizePass(StringRef opName = "",
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter());
+/// Create a LinalgStrategyDecomposePass.
+// TODO: atm this is applied to all supported ops. If/when we need finer control
+// this should be exposed with an opName + filter and a proper pattern.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyDecomposePass();
+
/// Create a LinalgStrategyInterchangePass.
std::unique_ptr<OperationPass<FuncOp>>
createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange = {},
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 5d05039b5ea59..c9bcfebecb022 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -287,6 +287,19 @@ def LinalgStrategyGeneralizePass
];
}
+// TODO: atm this is applied to all supported ops. If/when we need finer control
+// this should be exposed with an opName + filter and a proper pattern.
+def LinalgStrategyDecomposePass
+ : FunctionPass<"linalg-strategy-decompose-pass"> {
+ let summary = "Configurable pass to apply pattern-based generalization.";
+ let constructor = "mlir::createLinalgStrategyDecomposePass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ ];
+}
+
def LinalgStrategyInterchangePass
: FunctionPass<"linalg-strategy-interchange-pass"> {
let summary = "Configurable pass to apply pattern-based iterator interchange.";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 4a97b926655fa..24cec12cec62d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -134,6 +134,25 @@ struct LinalgStrategyGeneralizePass
LinalgTransformationFilter filter;
};
+/// Configurable pass to apply lowering of coarser-grained named linalg ops into
+/// finer-grained named versions.
+struct LinalgStrategyDecomposePass
+ : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
+
+ LinalgStrategyDecomposePass() = default;
+
+ void runOnFunction() override {
+ auto funcOp = getFunction();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+ RewritePatternSet decompositionPattern(funcOp.getContext());
+ populateDecomposeConvolutionPatterns(decompositionPattern);
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(decompositionPattern))))
+ signalPassFailure();
+ }
+};
+
/// Configurable pass to apply pattern-based linalg generalization.
struct LinalgStrategyInterchangePass
: public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
@@ -389,6 +408,13 @@ mlir::createLinalgStrategyGeneralizePass(StringRef opName,
LinalgTransformationFilter filter) {
return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
}
+/// Create a LinalgStrategyDecomposePass.
+// TODO: atm this is applied to all supported ops. If/when we need finer control
+// this should be exposed with an opName + filter and a proper pattern.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyDecomposePass() {
+ return std::make_unique<LinalgStrategyDecomposePass>();
+}
/// Create a LinalgStrategyInterchangePass.
std::unique_ptr<OperationPass<FuncOp>>
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index c3415c6e2a5d6..8f0b43afafcb1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -14,13 +14,14 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
@@ -554,12 +555,6 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
-static void applyDecomposeConvolutionPatterns(FuncOp funcOp) {
- RewritePatternSet patterns(funcOp.getContext());
- populateDecomposeConvolutionPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
@@ -726,8 +721,13 @@ void TestLinalgTransforms::runOnFunction() {
if (testTileScalarizeDynamicDims)
return applyTilePattern(getFunction(), loopType, tileSizes,
/*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
- if (testDecomposeConvolutionPattern)
- return applyDecomposeConvolutionPatterns(getFunction());
+ if (testDecomposeConvolutionPattern) {
+ // TODO: thread all tests through LinalgStrategy passes.
+ OpPassManager dynamicPM("builtin.func");
+ dynamicPM.addPass(createLinalgStrategyDecomposePass());
+ if (failed(runPipeline(dynamicPM, getFunction())))
+ return signalPassFailure();
+ }
}
namespace mlir {
More information about the Mlir-commits
mailing list