[Mlir-commits] [mlir] 5f60c48 - [mlir][linalg][bufferize][NFC] Make init_tensor elimination a separate pre-processing pass
Matthias Springer
llvmlistbot at llvm.org
Wed May 4 01:19:28 PDT 2022
Author: Matthias Springer
Date: 2022-05-04T17:17:27+09:00
New Revision: 5f60c4825b351189379f147ad636f19fa5060e5c
URL: https://github.com/llvm/llvm-project/commit/5f60c4825b351189379f147ad636f19fa5060e5c
DIFF: https://github.com/llvm/llvm-project/commit/5f60c4825b351189379f147ad636f19fa5060e5c.diff
LOG: [mlir][linalg][bufferize][NFC] Make init_tensor elimination a separate pre-processing pass
This commit decouples init_tensor elimination from the rest of the bufferization.
Differential Revision: https://reviews.llvm.org/D124853
Added:
mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir
mlir/test/Dialect/Linalg/one-shot-bufferize-init-tensor-elimination.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
Removed:
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 73226b2587a54..3510b2f1f984a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -73,6 +73,10 @@ std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass();
std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass(
const bufferization::OneShotBufferizationOptions &options);
+/// Create a pass that tries to eliminate init_tensor ops that are anchored on
+/// insert_slice ops.
+std::unique_ptr<Pass> createLinalgInitTensorEliminationPass();
+
/// Create a pass to convert Linalg operations which work on tensors to use
/// buffers instead.
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgBufferizePass();
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 06f0e217986d7..2c0287de0fcac 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -62,10 +62,6 @@ def LinalgComprehensiveModuleBufferize :
Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned",
/*default=*/"0",
"Analyze ops in random order with a given seed (fuzzer)">,
- Option<"initTensorElimination", "init-tensor-elimination", "bool",
- /*default=*/"false",
- "(Experimental) Try to eliminate init_tensor operations that are "
- "anchored at an insert_slice op">,
Option<"createDeallocs", "create-deallocs", "bool", /*default=*/"true",
"Specify if buffers should be deallocated. For compatibility with "
"core bufferization passes.">,
@@ -73,6 +69,18 @@ def LinalgComprehensiveModuleBufferize :
let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
}
+def LinalgInitTensorElimination : Pass<"linalg-eliminate-init-tensors"> {
+ let summary = "Try to eliminate all init_tensor ops.";
+ let description = [{
+ This pass tries to eliminate all insert_slice op-anchored init_tensor ops.
+ I.e., when a value that is aliasing with an init_tensor op is inserted into
+ another tensor, this pass tries to rewrite the IR in such a way that the
+ destination tensor of the insert_slice op is used directly instead of the
+ init_tensor result.
+ }];
+ let constructor = "mlir::createLinalgInitTensorEliminationPass()";
+}
+
def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
index 64c3232e86372..8fc406ceef80c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
@@ -33,21 +33,16 @@ using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
/// on the reverse SSA use-def chain, starting from the OpOperand and always
/// following the aliasing OpOperand, that eventually ends at a single
/// InitTensorOp.
-/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
-/// This analysis can be skipped with `skipAnalysis`.
-LogicalResult
-eliminateInitTensors(Operation *op, bufferization::AnalysisState &state,
- bufferization::BufferizationAliasInfo &aliasInfo,
- AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
- SmallVector<Operation *> &newOps);
+LogicalResult eliminateInitTensors(RewriterBase &rewriter, Operation *op,
+ bufferization::AnalysisState &state,
+ AnchorMatchFn anchorMatchFunc,
+ RewriteFn rewriteFunc);
/// Try to eliminate InitTensorOps inside `op` that are anchored on an
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
/// (and some other conditions are met).
LogicalResult insertSliceAnchoredInitTensorEliminationStep(
- Operation *op, bufferization::AnalysisState &state,
- bufferization::BufferizationAliasInfo &aliasInfo,
- SmallVector<Operation *> &newOps);
+ RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state);
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 6146175debb7d..11fea3adb76c0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -334,16 +334,17 @@ findValidInsertionPoint(Operation *initTensorOp,
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
/// chain, starting from the OpOperand and always following the aliasing
/// OpOperand, that eventually ends at a single InitTensorOp.
-LogicalResult mlir::linalg::eliminateInitTensors(
- Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo,
- AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
- SmallVector<Operation *> &newOps) {
- OpBuilder b(op->getContext());
+LogicalResult mlir::linalg::eliminateInitTensors(RewriterBase &rewriter,
+ Operation *op,
+ AnalysisState &state,
+ AnchorMatchFn anchorMatchFunc,
+ RewriteFn rewriteFunc) {
+ OpBuilder::InsertionGuard g(rewriter);
WalkResult status = op->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
// Skip operands that do not bufferize inplace.
- if (!aliasInfo.isInPlace(operand))
+ if (!state.isInPlace(operand))
continue;
// All values that are needed to create the replacement op.
SmallVector<Value> neededValues;
@@ -359,14 +360,14 @@ LogicalResult mlir::linalg::eliminateInitTensors(
SmallVector<OpOperand *> opOperands =
state.getAliasingOpOperand(opResult);
if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
- return aliasInfo.isInPlace(*operand);
+ return state.isInPlace(*operand);
}))
return true;
// Only equivalent tensors are supported at the moment.
// TODO: Support cases such as extract_slice(init_tensor)
return !llvm::all_of(opOperands, [&](OpOperand *operand) {
- return aliasInfo.areEquivalentBufferizedValues(operand->get(),
- opResult);
+ return state.areEquivalentBufferizedValues(operand->get(),
+ opResult);
});
});
@@ -384,21 +385,13 @@ LogicalResult mlir::linalg::eliminateInitTensors(
continue;
// Create a replacement for the InitTensorOp.
- b.setInsertionPoint(insertionPoint);
- Value replacement = rewriteFunc(b, initTensor.getLoc(), operand);
+ rewriter.setInsertionPoint(insertionPoint);
+ Value replacement = rewriteFunc(rewriter, initTensor.getLoc(), operand);
if (!replacement)
continue;
- // Uses of the InitTensorOp are replaced here, but the op is not deleted.
- // InitTensorOps without uses are ignored by the bufferization.
- initTensor.replaceAllUsesWith(replacement);
- aliasInfo.createAliasInfoEntry(replacement);
- aliasInfo.unionAliasSets(initTensor, replacement);
- aliasInfo.unionEquivalenceClasses(initTensor, replacement);
-
- // Register replacement ops.
- if (Operation *newOp = replacement.getDefiningOp())
- newOps.push_back(newOp);
+ // Replace the InitTensorOp.
+ rewriter.replaceOp(initTensor.getDefiningOp(), replacement);
}
// Advance to the next operation.
@@ -428,28 +421,20 @@ LogicalResult mlir::linalg::eliminateInitTensors(
///
/// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert
/// source's reverse use-def chain is eliminated if:
-/// * The InsertSliceOp was decided to bufferize inplace.
/// * On the reverse use-def chain path from the InsertSliceOp to the
/// InitTensorOp, all ops were decided to bufferize inplace and the buffer
/// relation is "equivalent" (TODO: can be relaxed if needed).
/// * The reverse use-def chain has exactly one end, which is the InitTensorOp.
-///
-/// Note that the newly inserted ExtractSliceOp may have to bufferize
-/// out-of-place due to RaW conflicts.
LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep(
- Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo,
- SmallVector<Operation *> &newOps) {
+ RewriterBase &rewriter, Operation *op, AnalysisState &state) {
return eliminateInitTensors(
- op, state, aliasInfo,
+ rewriter, op, state,
/*anchorMatchFunc=*/
[&](OpOperand &operand, SmallVector<Value> &neededValues) {
auto insertSliceOp =
dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
if (!insertSliceOp)
return false;
- // Only inplace bufferized InsertSliceOps are eligible.
- if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/))
- return false;
if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
return false;
@@ -487,8 +472,7 @@ LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep(
auto extractOp = b.create<tensor::ExtractSliceOp>(
loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides);
return extractOp.result();
- },
- newOps);
+ });
}
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 13b760c6dd959..bbb013d955332 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -63,6 +63,17 @@ struct LinalgComprehensiveModuleBufferize
private:
llvm::Optional<OneShotBufferizationOptions> options;
};
+
+struct LinalgInitTensorElimination
+ : public LinalgInitTensorEliminationBase<LinalgInitTensorElimination> {
+ LinalgInitTensorElimination() = default;
+
+ void runOnOperation() override;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
+ }
+};
} // namespace
static void applyEnablingTransformations(ModuleOp moduleOp) {
@@ -100,9 +111,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
opt.testAnalysisOnly = testAnalysisOnly;
opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
opt.bufferizeFunctionBoundaries = true;
- if (initTensorElimination) {
- opt.addPostAnalysisStep(insertSliceAnchoredInitTensorEliminationStep);
- }
} else {
opt = *options;
}
@@ -125,6 +133,20 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
(void)runPipeline(cleanupPipeline, moduleOp);
}
+void LinalgInitTensorElimination::runOnOperation() {
+ Operation *op = getOperation();
+ OneShotBufferizationOptions options;
+ OneShotAnalysisState state(op, options);
+ if (failed(analyzeOp(op, state))) {
+ signalPassFailure();
+ return;
+ }
+
+ IRRewriter rewriter(op->getContext());
+ if (failed(insertSliceAnchoredInitTensorEliminationStep(rewriter, op, state)))
+ signalPassFailure();
+}
+
std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
return std::make_unique<LinalgComprehensiveModuleBufferize>();
}
@@ -133,3 +155,7 @@ std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass(
const OneShotBufferizationOptions &options) {
return std::make_unique<LinalgComprehensiveModuleBufferize>(options);
}
+
+std::unique_ptr<Pass> mlir::createLinalgInitTensorEliminationPass() {
+ return std::make_unique<LinalgInitTensorElimination>();
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir
similarity index 92%
rename from mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir
rename to mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir
index dcad535385bb5..512407907aa3c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir
@@ -1,6 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-allocs init-tensor-elimination" -split-input-file | FileCheck %s
-
-// -----
+// RUN: mlir-opt %s -linalg-eliminate-init-tensors -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs" -split-input-file | FileCheck %s
//===----------------------------------------------------------------------===//
// InitTensorOp elimination
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-init-tensor-elimination.mlir
similarity index 96%
rename from mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
rename to mlir/test/Dialect/Linalg/one-shot-bufferize-init-tensor-elimination.mlir
index 72c350fea5cb8..053efe8728837 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-init-tensor-elimination.mlir
@@ -1,6 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs init-tensor-elimination" -canonicalize -split-input-file | FileCheck %s
-
-// -----
+// RUN: mlir-opt %s -linalg-eliminate-init-tensors -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs" -canonicalize -split-input-file | FileCheck %s
// CHECK: func @buffer_forwarding_conflict(
// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
More information about the Mlir-commits
mailing list