[Mlir-commits] [mlir] bb83520 - [mlir][linalg][bufferize] Generalize InitTensorOp elimination
Matthias Springer
llvmlistbot at llvm.org
Wed Nov 3 21:59:05 PDT 2021
Author: Matthias Springer
Date: 2021-11-04T13:53:12+09:00
New Revision: bb83520dce13b9804b3282f29a8ff0886384a362
URL: https://github.com/llvm/llvm-project/commit/bb83520dce13b9804b3282f29a8ff0886384a362
DIFF: https://github.com/llvm/llvm-project/commit/bb83520dce13b9804b3282f29a8ff0886384a362.diff
LOG: [mlir][linalg][bufferize] Generalize InitTensorOp elimination
This allows for external users of Comprehensive Bufferize to specify their own InitTensorOp elimination procedures.
Differential Revision: https://reviews.llvm.org/D112686
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index b7ce96d9fedf3..e3b59d5daa60d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -195,6 +195,29 @@ bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
/// Register external models implemented for the `BufferizableOpInterface`.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
+
+/// Try to eliminate InitTensorOps inside `funcOp`.
+///
+/// * `rewriteFunc` generates the replacement for the InitTensorOp.
+/// * Only InitTensorOps that are anchored on a matching OpOperand as per
+/// `anchorMatchFunc` are considered. "Anchored" means that there is a path on
+/// the reverse SSA use-def chain, starting from the OpOperand and always
+/// following the aliasing OpOperand, that eventually ends at a single
+/// InitTensorOp.
+/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
+/// This analysis can be skipped with `skipAnalysis`.
+LogicalResult initTensorElimination(
+ FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
+ std::function<bool(OpOperand &)> anchorMatchFunc,
+ std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
+ bool skipAnalysis = false);
+
+/// Try to eliminate InitTensorOps inside funcOp that are anchored on an
+/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
+/// (and some other conditions are met).
+LogicalResult eliminateInsertSliceAnchoredInitTensorOps(
+ FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 1fe5835d282bf..4fbde63c4d108 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -2150,6 +2150,78 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
}
}
+/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp is replaced
+/// with the the result of `rewriteFunc` if it is anchored on a matching
+/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
+/// chain, starting from the OpOperand and always following the aliasing
+/// OpOperand, that eventually ends at a single InitTensorOp.
+LogicalResult mlir::linalg::initTensorElimination(
+ FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
+ std::function<bool(OpOperand &)> anchorMatchFunc,
+ std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
+ bool skipAnalysis) {
+ OpBuilder b(funcOp->getContext());
+
+ WalkResult status = funcOp->walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ // Is this a matching OpOperand?
+ if (!anchorMatchFunc(operand))
+ continue;
+
+ SetVector<Value> maybeInitTensor =
+ findValueInReverseUseDefChain(operand.get(), [](Value val) {
+ // Continue traversal until this function returns true.
+ OpResult opResult = val.dyn_cast<OpResult>();
+ if (!opResult)
+ return true;
+ if (getInPlace(opResult) != InPlaceSpec::True)
+ return true;
+ // Only equivalent tensors are supported at the moment.
+ // TODO: Support cases such as extract_slice(init_tensor).
+ SmallVector<OpOperand *> opOperands =
+ getAliasingOpOperand(opResult);
+ if (!llvm::all_of(opOperands, [](OpOperand *operand) {
+ return bufferRelation(*operand) == BufferRelation::Equivalent;
+ }))
+ return true;
+ return false;
+ });
+
+ // Replace only if the reverse use-def chain ends at exactly one
+ // InitTensorOp.
+ if (maybeInitTensor.size() != 1 ||
+ !maybeInitTensor.front().getDefiningOp<InitTensorOp>())
+ return WalkResult::skip();
+ Value initTensor = maybeInitTensor.front();
+
+ // Create a replacement for the InitTensorOp.
+ b.setInsertionPoint(initTensor.getDefiningOp());
+ Value replacement = rewriteFunc(b, initTensor.getLoc(), operand);
+ if (!replacement)
+ continue;
+
+ // Uses of the InitTensorOp are replaced here, but the op is not deleted.
+ // InitTensorOps without uses are ignored by the bufferization.
+ initTensor.replaceAllUsesWith(replacement);
+ aliasInfo.createAliasInfoEntry(replacement);
+
+ // Run analysis on the newly created op.
+ if (auto opResult = replacement.dyn_cast<OpResult>()) {
+ if (!skipAnalysis) {
+ SmallVector<Operation *> ops(1, replacement.getDefiningOp());
+ if (failed(inPlaceAnalysis(ops, aliasInfo, domInfo)))
+ return WalkResult::interrupt();
+ }
+ }
+ }
+
+ // Advance to the next operation.
+ return WalkResult::advance();
+ });
+
+ return failure(status.wasInterrupted());
+}
+
/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp can be
/// eliminated if it is eventually inserted into another tensor (and some other
/// conditions are met).
@@ -2178,60 +2250,26 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
///
/// Note that the newly inserted ExtractSliceOp may have to bufferize
/// out-of-place due to RaW conflicts.
-static LogicalResult runInitTensorElimination(FuncOp funcOp,
- BufferizationAliasInfo &aliasInfo,
- DominanceInfo &domInfo) {
- OpBuilder b(funcOp->getContext());
-
- WalkResult status = funcOp->walk([&](tensor::InsertSliceOp insertOp) {
- // Only inplace bufferized InsertSliceOps are eligible.
- if (getInPlace(insertOp->getOpResult(0)) != InPlaceSpec::True)
- return WalkResult::skip();
-
- SetVector<Value> maybeInitTensor =
- findValueInReverseUseDefChain(insertOp.source(), [](Value val) {
- // Continue traversal until this function returns true.
- OpResult opResult = val.dyn_cast<OpResult>();
- if (!opResult)
- return true;
- if (getInPlace(opResult) != InPlaceSpec::True)
- return true;
- // Only equivalent tensors are supported at the moment. E.g., when
- // taking a tensor.extract_slice of an init_tensor, we can currently
- // not eliminate the init_tensor.
- SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
- if (!llvm::all_of(opOperands, [](OpOperand *operand) {
- return bufferRelation(*operand) == BufferRelation::Equivalent;
- }))
- return true;
+LogicalResult mlir::linalg::eliminateInsertSliceAnchoredInitTensorOps(
+ FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo) {
+ return initTensorElimination(
+ funcOp, aliasInfo, domInfo,
+ [](OpOperand &operand) {
+ auto insertSliceOp = dyn_cast<InsertSliceOp>(operand.getOwner());
+ if (!insertSliceOp)
return false;
- });
- // Replace only if the InsertSliceOp source originates from exactly one
- // InitTensorOp.
- if (maybeInitTensor.size() != 1 ||
- !maybeInitTensor.front().getDefiningOp<InitTensorOp>())
- return WalkResult::skip();
- Value initTensor = maybeInitTensor.front();
-
- b.setInsertionPoint(initTensor.getDefiningOp());
- auto extractOp = b.create<tensor::ExtractSliceOp>(
- initTensor.getLoc(), insertOp.dest(), insertOp.getMixedOffsets(),
- insertOp.getMixedSizes(), insertOp.getMixedStrides());
- // Uses of the InitTensorOp are replaced here, but the op is not deleted.
- // InitTensorOps without uses are ignored by the bufferization.
- initTensor.replaceAllUsesWith(extractOp.result());
- aliasInfo.createAliasInfoEntry(extractOp.result());
-
- // Run analysis on the ExtractSliceOp.
- if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(
- extractOp->getOpOperand(0), aliasInfo, domInfo)))
- return WalkResult::interrupt();
-
- // Advance to the next operation.
- return WalkResult::advance();
- });
-
- return failure(status.wasInterrupted());
+ // Only inplace bufferized InsertSliceOps are eligible.
+ if (getInPlace(insertSliceOp->getOpResult(0)) != InPlaceSpec::True)
+ return false;
+ return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
+ },
+ [](OpBuilder &b, Location loc, OpOperand &operand) {
+ auto insertSliceOp = cast<InsertSliceOp>(operand.getOwner());
+ auto extractOp = b.create<tensor::ExtractSliceOp>(
+ loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(),
+ insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+ return extractOp.result();
+ });
}
void LinalgComprehensiveModuleBufferize::runOnOperation() {
@@ -2291,7 +2329,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
// Try to eliminate InitTensorOps to avoid new allocations during the
// bufferization phase.
- if (failed(runInitTensorElimination(funcOp, aliasInfo, domInfo))) {
+ if (failed(eliminateInsertSliceAnchoredInitTensorOps(funcOp, aliasInfo,
+ domInfo))) {
signalPassFailure();
return;
}
More information about the Mlir-commits
mailing list