[Mlir-commits] [mlir] 1b2bda8 - [mlir][linalg][bufferize] Add PostAnalysisStep
Matthias Springer
llvmlistbot at llvm.org
Thu Nov 11 16:52:34 PST 2021
Author: Matthias Springer
Date: 2021-11-12T09:51:06+09:00
New Revision: 1b2bda8d1a82b0ffc49eb485824fca93db1b0aac
URL: https://github.com/llvm/llvm-project/commit/1b2bda8d1a82b0ffc49eb485824fca93db1b0aac
DIFF: https://github.com/llvm/llvm-project/commit/1b2bda8d1a82b0ffc49eb485824fca93db1b0aac.diff
LOG: [mlir][linalg][bufferize] Add PostAnalysisStep
This helper struct allows users of ComprehensiveBufferize to inject "post analysis" steps that are implemented after the analysis but before the bufferization.
Differential Revision: https://reviews.llvm.org/D113458
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 3c3f601a9385b..07f065d909e56 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -18,6 +18,8 @@
namespace mlir {
class BlockAndValueMapping;
+class DominanceInfo;
+class FuncOp;
namespace linalg {
namespace comprehensive_bufferize {
@@ -266,6 +268,20 @@ struct BufferizationState {
/// bufferization is necessary.
Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
+/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
+/// executed after the analysis, but before bufferization. They can be used
+/// implement custom dialect-specific optimizations.
+struct PostAnalysisStep {
+ virtual ~PostAnalysisStep() {}
+
+ /// Run the post analysis step. This function may modify the IR, but must keep
+ /// `aliasInfo` consistent. Newly created operations and operations that
+ /// should be re-analyzed must be stored in `newOps`.
+ virtual LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+ DominanceInfo &domInfo,
+ SmallVector<Operation *> &newOps) = 0;
+};
+
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index 72a5c700c10a4..5df3b63aa60aa 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -63,40 +63,59 @@ bufferizeOp(Operation *op, BufferizationState &state,
/// Register external models implemented for the `BufferizableOpInterface`.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
-/// Try to eliminate InitTensorOps inside `funcOp`.
-///
-/// * `rewriteFunc` generates the replacement for the InitTensorOp.
-/// * Only InitTensorOps that are anchored on a matching OpOperand as per
-/// `anchorMatchFunc` are considered. "Anchored" means that there is a path on
-/// the reverse SSA use-def chain, starting from the OpOperand and always
-/// following the aliasing OpOperand, that eventually ends at a single
-/// InitTensorOp.
-/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
-/// This analysis can be skipped with `skipAnalysis`.
-LogicalResult initTensorElimination(
- FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
- std::function<bool(OpOperand &)> anchorMatchFunc,
- std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
- bool skipAnalysis = false);
-
-/// Try to eliminate InitTensorOps inside funcOp that are anchored on an
-/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
-/// (and some other conditions are met).
-LogicalResult eliminateInsertSliceAnchoredInitTensorOps(
- FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo);
-
struct BufferizationOptions {
BufferizationOptions();
+ /// Register a "post analysis" step. Such steps are executed after the
+ /// analysis, but before bufferization.
+ template <typename Step, typename... Args>
+ void addPostAnalysisStep(Args... args) {
+ postAnalysisSteps.emplace_back(
+ std::make_unique<Step>(std::forward<Args>(args)...));
+ }
+
std::unique_ptr<AllocationCallbacks> allocationFns;
bool allowReturnMemref = false;
unsigned analysisFuzzerSeed = 0;
bool testAnalysisOnly = false;
+ std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
};
LogicalResult runComprehensiveBufferize(ModuleOp moduleOp,
const BufferizationOptions &options);
+namespace linalg_ext {
+
+struct InitTensorEliminationStep : public PostAnalysisStep {
+ /// Try to eliminate InitTensorOps inside `funcOp`.
+ ///
+ /// * `rewriteFunc` generates the replacement for the InitTensorOp.
+ /// * Only InitTensorOps that are anchored on a matching OpOperand as per
+ /// `anchorMatchFunc` are considered. "Anchored" means that there is a path
+ /// on the reverse SSA use-def chain, starting from the OpOperand and always
+ /// following the aliasing OpOperand, that eventually ends at a single
+ /// InitTensorOp.
+ /// * The result of `rewriteFunc` must usually be analyzed for inplacability.
+ /// This analysis can be skipped with `skipAnalysis`.
+ LogicalResult eliminateInitTensors(
+ FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
+ std::function<bool(OpOperand &)> anchorMatchFunc,
+ std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
+ SmallVector<Operation *> &newOps);
+};
+
+/// Try to eliminate InitTensorOps inside funcOp that are anchored on an
+/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
+/// (and some other conditions are met).
+struct InsertSliceAnchoredInitTensorEliminationStep
+ : public InitTensorEliminationStep {
+ LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+ DominanceInfo &domInfo,
+ SmallVector<Operation *> &newOps) override;
+};
+
+} // namespace linalg_ext
+
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index cbfbef467b6f3..bd6babfc456f3 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -1595,11 +1595,13 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
/// chain, starting from the OpOperand and always following the aliasing
/// OpOperand, that eventually ends at a single InitTensorOp.
-LogicalResult mlir::linalg::comprehensive_bufferize::initTensorElimination(
- FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
- std::function<bool(OpOperand &)> anchorMatchFunc,
- std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
- bool skipAnalysis) {
+LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
+ InitTensorEliminationStep::eliminateInitTensors(
+ FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+ DominanceInfo &domInfo,
+ std::function<bool(OpOperand &)> anchorMatchFunc,
+ std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
+ SmallVector<Operation *> &newOps) {
OpBuilder b(funcOp->getContext());
WalkResult status = funcOp->walk([&](Operation *op) {
@@ -1647,14 +1649,9 @@ LogicalResult mlir::linalg::comprehensive_bufferize::initTensorElimination(
aliasInfo.unionAliasSets(initTensor, replacement);
aliasInfo.unionEquivalenceClasses(initTensor, replacement);
- // Run analysis on the newly created op.
- if (auto opResult = replacement.dyn_cast<OpResult>()) {
- if (!skipAnalysis) {
- SmallVector<Operation *> ops(1, replacement.getDefiningOp());
- if (failed(inPlaceAnalysis(ops, aliasInfo, domInfo)))
- return WalkResult::interrupt();
- }
- }
+ // Register replacement ops.
+ if (Operation *newOp = replacement.getDefiningOp())
+ newOps.push_back(newOp);
}
// Advance to the next operation.
@@ -1692,11 +1689,11 @@ LogicalResult mlir::linalg::comprehensive_bufferize::initTensorElimination(
///
/// Note that the newly inserted ExtractSliceOp may have to bufferize
/// out-of-place due to RaW conflicts.
-LogicalResult mlir::linalg::comprehensive_bufferize::
- eliminateInsertSliceAnchoredInitTensorOps(FuncOp funcOp,
- BufferizationAliasInfo &aliasInfo,
- DominanceInfo &domInfo) {
- return initTensorElimination(
+LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
+ InsertSliceAnchoredInitTensorEliminationStep::run(
+ FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+ DominanceInfo &domInfo, SmallVector<Operation *> &newOps) {
+ return eliminateInitTensors(
funcOp, aliasInfo, domInfo,
[&](OpOperand &operand) {
auto insertSliceOp = dyn_cast<InsertSliceOp>(operand.getOwner());
@@ -1713,7 +1710,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::
loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
return extractOp.result();
- });
+ },
+ newOps);
}
#ifndef NDEBUG
@@ -1793,11 +1791,15 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
options.analysisFuzzerSeed)))
return failure();
- // Try to eliminate InitTensorOps to avoid new allocations during the
- // bufferization phase.
- if (failed(eliminateInsertSliceAnchoredInitTensorOps(funcOp, aliasInfo,
- domInfo)))
- return failure();
+ for (const std::unique_ptr<PostAnalysisStep> &step :
+ options.postAnalysisSteps) {
+ SmallVector<Operation *> newOps;
+ if (failed(step->run(funcOp, aliasInfo, domInfo, newOps)))
+ return failure();
+ // Analyze ops that were created by the PostAnalysisStep.
+ if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
+ return failure();
+ }
// Bufferization phase.
if (!options.testAnalysisOnly) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index f14ab5af50a6a..20b1894755fda 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -71,6 +71,10 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
options.analysisFuzzerSeed = analysisFuzzerSeed;
options.testAnalysisOnly = testAnalysisOnly;
+ // Enable InitTensorOp elimination.
+ options.addPostAnalysisStep<
+ linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
+
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);
More information about the Mlir-commits
mailing list