[Mlir-commits] [mlir] 0e5f258 - [mlir][linalg][bufferize][NFC] Simplify InsertSliceOp bufferization
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 6 00:40:51 PST 2022
Author: Matthias Springer
Date: 2022-01-06T17:35:45+09:00
New Revision: 0e5f258452b053cc3374754efaeabe3c30f42482
URL: https://github.com/llvm/llvm-project/commit/0e5f258452b053cc3374754efaeabe3c30f42482
DIFF: https://github.com/llvm/llvm-project/commit/0e5f258452b053cc3374754efaeabe3c30f42482.diff
LOG: [mlir][linalg][bufferize][NFC] Simplify InsertSliceOp bufferization
No need to keep track of equivalent extract_slice / insert_slice tensors during bufferization. Just emit a copy, it will fold away.
Note: The analysis still keeps track of equivalent tensors to make the correct inplace bufferization decisions.
Differential Revision: https://reviews.llvm.org/D116684
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
index ca620138d6433..29355ef338f3a 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
@@ -9,8 +9,6 @@
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
-
namespace mlir {
class DialectRegistry;
@@ -19,12 +17,6 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace tensor_ext {
-struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep {
- LogicalResult run(Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
- SmallVector<Operation *> &newOps) override;
-};
-
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
} // namespace tensor_ext
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 0f91e52a5227e..9ee1d23d5d8af 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -23,20 +23,6 @@ namespace tensor_ext {
using tensor::ExtractSliceOp;
using tensor::InsertSliceOp;
-namespace {
-/// Extra bufferization state that is required for bufferization of tensor ops.
-struct TensorBufferizationState : public DialectBufferizationState {
- /// InsertSliceOps that bufferize inplace and do not require a copy.
- DenseSet<Operation *> insertSliceOpsWithoutCopy;
-};
-} // namespace
-
-static TensorBufferizationState &
-getTensorBufferizationState(BufferizationState &state) {
- return state.getDialectState<TensorBufferizationState>(
- tensor::TensorDialect::getDialectNamespace());
-}
-
struct CastOpInterface
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
tensor::CastOp> {
@@ -274,23 +260,6 @@ areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
return true;
}
-/// Return true if the source of a `insertSliceOp` bufferizes to an
-/// equivalent ExtractSliceOp that bufferizes inplace.
-static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
- const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
- bool foundOp = false;
- aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
- auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
- if (extractSliceOp &&
- areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
- insertSliceOp) &&
- aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
- foundOp = true;
- }
- });
- return foundOp;
-}
-
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
@@ -419,7 +388,6 @@ struct InsertSliceOpInterface
// TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
Location loc = insertSliceOp.getLoc();
- TensorBufferizationState &tensorState = getTensorBufferizationState(state);
// When bufferizing out-of-place, `getResultBuffer` allocates.
Value dstMemref =
@@ -427,24 +395,22 @@ struct InsertSliceOpInterface
if (!dstMemref)
return failure();
- bool needCopy =
- !tensorState.insertSliceOpsWithoutCopy.contains(insertSliceOp);
- if (needCopy) {
- // Take a subview of the dst.
- auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
- auto subviewMemRefType =
- memref::SubViewOp::inferRankReducedResultType(
- insertSliceOp.getSourceType().getRank(), dstMemrefType,
- insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
- insertSliceOp.getMixedStrides())
- .cast<MemRefType>();
- Value subView = rewriter.create<memref::SubViewOp>(
- loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
- insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
- // Copy tensor.
- Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
- state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
- }
+ // Take a subview of the dst.
+ auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+ auto subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ insertSliceOp.getSourceType().getRank(), dstMemrefType,
+ insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+ insertSliceOp.getMixedStrides())
+ .cast<MemRefType>();
+ Value subView = rewriter.create<memref::SubViewOp>(
+ loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
+ insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+
+ // Copy tensor. If this tensor.insert_slice has a matching
+ // tensor.extract_slice, the copy operation will eventually fold away.
+ Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
+ state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
state.replaceOp(rewriter, op, dstMemref);
return success();
@@ -456,25 +422,6 @@ struct InsertSliceOpInterface
} // namespace linalg
} // namespace mlir
-LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext::
- InplaceInsertSliceOpAnalysis::run(Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
- SmallVector<Operation *> &newOps) {
- auto &tensorState = getTensorBufferizationState(state);
- op->walk([&](InsertSliceOp insertSliceOp) {
- // A copy of the source buffer is needed if either:
- // - The producer of `source` is not inplace. This is the case where a
- // slice is computed out of place into the inplace full tensor.
- // - The result is not inplace. This is the case where the whole tensor is
- // cloned and the clone needs to be updated.
- if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
- insertSliceOp) &&
- state.isInPlace(insertSliceOp->getResult(0)))
- tensorState.insertSliceOpsWithoutCopy.insert(insertSliceOp);
- });
- return success();
-}
-
void mlir::linalg::comprehensive_bufferize::tensor_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index c5fdf402d9412..13e18001d82ee 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -94,9 +94,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
// Enable InitTensorOp elimination.
options->addPostAnalysisStep<
linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
- // TODO: Find a way to enable this step automatically when bufferizing tensor
- // dialect ops.
- options->addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
if (!allowReturnMemref)
options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index f4a43aab1ebb9..59e53eaba56a9 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -99,9 +99,6 @@ void TestComprehensiveFunctionBufferize::runOnFunction() {
// Enable InitTensorOp elimination.
options->addPostAnalysisStep<
linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
- // TODO: Find a way to enable this step automatically when bufferizing
- // tensor dialect ops.
- options->addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
if (!allowReturnMemref)
options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
More information about the Mlir-commits
mailing list