[Mlir-commits] [mlir] 89d55d3 - [mlir][Linalg] Retire CodegenStrategy::transform
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 22 13:30:00 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-22T20:27:14Z
New Revision: 89d55d3c86f07178d20be36d3724d50a2e9322b7
URL: https://github.com/llvm/llvm-project/commit/89d55d3c86f07178d20be36d3724d50a2e9322b7
DIFF: https://github.com/llvm/llvm-project/commit/89d55d3c86f07178d20be36d3724d50a2e9322b7.diff
LOG: [mlir][Linalg] Retire CodegenStrategy::transform
Instead each pass should constructed a nested OpPassManager and runPipeline on that.
Differential Revision: https://reviews.llvm.org/D112308
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 3633fba85393..75783d654f89 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -127,6 +127,10 @@ createLinalgStrategyLowerVectorsPass(linalg::LinalgVectorLoweringOptions opt =
linalg::LinalgVectorLoweringOptions(),
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyRemoveMarkersPass.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyRemoveMarkersPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index f70441975f6c..d0744a891328 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -315,4 +315,15 @@ def LinalgStrategyLowerVectorsPass
];
}
+def LinalgStrategyRemoveMarkersPass
+ : FunctionPass<"linalg-strategy-remove-markers-pass"> {
+ let summary = "Cleanup pass that drops markers.";
+ let constructor = "mlir::createLinalgStrategyRemoveMarkersPass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ ];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index 7454c91886c7..322ac15fe282 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -247,7 +247,6 @@ struct CodegenStrategy {
/// Apply the transformation patterns in sequence with cleanup
/// transformations interleaved.
- LogicalResult transform(FuncOp func) const;
void configurePassPipeline(OpPassManager &pm, MLIRContext *context) const;
private:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index 1770cd9fce86..c4218ed89d35 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -60,15 +60,5 @@ void mlir::linalg::CodegenStrategy::configurePassPipeline(
vectorLoweringOptions.vectorTransformOptions = vectorTransformOptions;
vectorLoweringOptions.vectorTransferToSCFOptions = vectorToSCFOptions;
pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions));
-}
-
-LogicalResult mlir::linalg::CodegenStrategy::transform(FuncOp funcOp) const {
- PassManager pm(funcOp.getContext(), funcOp.getOperationName());
- configurePassPipeline(pm, funcOp.getContext());
- LogicalResult res = pm.run(funcOp);
- // Ensure we drop the marker in the end.
- funcOp.walk([](LinalgOp op) {
- op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
- });
- return res;
+ pm.addPass(createLinalgStrategyRemoveMarkersPass());
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 391c21bcd138..069838cddf28 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -287,6 +287,21 @@ struct LinalgStrategyLowerVectorsPass
LinalgVectorLoweringOptions options;
LinalgTransformationFilter filter;
};
+
+/// Configurable pass to lower vector operations.
+struct LinalgStrategyRemoveMarkersPass
+ : public LinalgStrategyRemoveMarkersPassBase<
+ LinalgStrategyRemoveMarkersPass> {
+
+ void runOnFunction() override {
+ auto funcOp = getFunction();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+ funcOp.walk([](LinalgOp op) {
+ op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
+ });
+ }
+};
} // namespace
/// Create a LinalgStrategyTilePass.
@@ -340,3 +355,9 @@ mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
LinalgTransformationFilter filter) {
return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
}
+
+/// Create a LinalgStrategyRemoveMarkersPass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyRemoveMarkersPass() {
+ return std::make_unique<LinalgStrategyRemoveMarkersPass>();
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
index ff213d32ddfc..7ad0ccb0341d 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
@@ -162,7 +162,13 @@ void TestLinalgCodegenStrategy::runStrategy(
.setVectorTransferSplit(vectorTransferSplit))
.setVectorTransferToSCFOptions(
VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
- (void)strategy.transform(getFunction());
+
+ // Created a nested OpPassManager and run.
+ FuncOp funcOp = getFunction();
+ OpPassManager dynamicPM("builtin.func");
+ strategy.configurePassPipeline(dynamicPM, funcOp.getContext());
+ if (failed(runPipeline(dynamicPM, funcOp)))
+ return signalPassFailure();
}
} // end anonymous namespace
More information about the Mlir-commits
mailing list