[Mlir-commits] [mlir] 7ce427e - [mlir][linalg][bufferize][NFC] Clean up BufferizationState
Matthias Springer
llvmlistbot at llvm.org
Mon Dec 6 17:08:40 PST 2021
Author: Matthias Springer
Date: 2021-12-07T10:05:39+09:00
New Revision: 7ce427e3bc0bcf50cb2f3b5944852219be03db9e
URL: https://github.com/llvm/llvm-project/commit/7ce427e3bc0bcf50cb2f3b5944852219be03db9e
DIFF: https://github.com/llvm/llvm-project/commit/7ce427e3bc0bcf50cb2f3b5944852219be03db9e.diff
LOG: [mlir][linalg][bufferize][NFC] Clean up BufferizationState
Make fields private and clean up the interface. In particular, BufferizableOpInterface::bufferize no longer has access to `aliasInfo`. This was potentially dangerous because some of the ops registered in BufferizationAliasInfo may have been deleted.
Differential Revision: https://reviews.llvm.org/D114931
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index a76007ad9a97a..1e0c96f0114a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -29,7 +29,10 @@ namespace comprehensive_bufferize {
// TODO: from some HW description.
static constexpr int64_t kBufferAlignments = 128;
-struct BufferizationState;
+class BufferizationAliasInfo;
+struct BufferizationOptions;
+class BufferizationState;
+struct PostAnalysisStep;
/// Callback functions that are used to allocate/deallocate/copy memory buffers.
/// Comprehensive Bufferize provides default implementations of these functions.
@@ -68,6 +71,7 @@ struct PostAnalysisStep {
/// `aliasInfo` (inside `state`) consistent. Newly created operations and
/// operations that should be re-analyzed must be stored in `newOps`.
virtual LogicalResult run(FuncOp funcOp, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) = 0;
};
@@ -281,9 +285,20 @@ struct DialectBufferizationState {
virtual ~DialectBufferizationState() = default;
};
-/// BufferizationState keeps track of bufferization state and provides access to
-/// the results of the analysis.
-struct BufferizationState {
+/// BufferizationState keeps track of memory buffers and provides a variety of
+/// helper functions for dealing with them. In particular,
+/// `BufferizableOpInterface::bufferize` implementation should utilize the
+/// following helper functions.
+///
+/// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops
+/// that allocate and/or deallocate memref buffers.
+/// * `mapBuffer` maps a tensor value to a memref buffer during bufferization.
+/// * `lookupBuffer` returns the mapped memref buffer of a given tensor value.
+/// * `getResultBuffer` returns the memref buffer for a given tensor OpResult.
+/// Based on inplace bufferization decisions of the analysis, it may either
+/// directly return a mapped buffer or allocate a new brand new buffer.
+class BufferizationState {
+public:
BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options)
: aliasInfo(moduleOp), options(options),
builder(moduleOp->getContext()) {}
@@ -291,11 +306,21 @@ struct BufferizationState {
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
- /// A function that creates an alloc-dealloc pair. This function may perform
- /// additional optimizations such as buffer allocation hoisting. This function
- /// calls `allocationFn` and `deallocationFn` to create (de)allocations.
- Value createAllocDeallocFn(OpBuilder &builder, Location loc,
- Value shapedValue);
+ /// Creates a memref allocation.
+ Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ArrayRef<Value> dynShape);
+
+ /// Creates an alloc-dealloc pair. This function may perform additional
+ /// optimizations such as buffer allocation hoisting.
+ Value createAllocDeallocPair(OpBuilder &builder, Location loc,
+ Value shapedValue);
+
+ /// Creates a memref deallocation. The given memref buffer must have been
+ /// allocated using `createAlloc`.
+ void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer);
+
+ /// Creates a memcpy between two given buffers.
+ void createMemCpy(OpBuilder &b, Location loc, Value from, Value to);
/// Map tensor values to memref buffers.
void mapBuffer(ValueRange tensors, ValueRange buffers);
@@ -307,6 +332,9 @@ struct BufferizationState {
/// Asserts if no buffer is associated.
Value lookupBuffer(Value tensor);
+ /// Return `true` if the given OpResult has been decided to bufferize inplace.
+ bool isInPlace(OpResult opResult) const;
+
/// Return `true` if the given value is mapped.
bool isMapped(Value value) const;
@@ -329,7 +357,24 @@ struct BufferizationState {
return static_cast<StateT &>(*dialectState[name]);
}
- /// `aliasInfo` keeps track of aliasing and equivalent values.
+ /// Return a reference to the BufferizationOptions.
+ const BufferizationOptions &getOptions() const { return options; }
+
+ /// Return a reference to the OpBuilder.
+ OpBuilder &getBuilder() { return builder; }
+
+private:
+ friend LogicalResult
+ runComprehensiveBufferize(FuncOp funcOp, const BufferizationOptions &options,
+ BufferizationState &state,
+ const PostAnalysisStepList &extraSteps);
+
+ friend LogicalResult
+ runComprehensiveBufferize(ModuleOp moduleOp,
+ const BufferizationOptions &options);
+
+ /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
+ /// functions and `runComprehensiveBufferize` may access this object.
BufferizationAliasInfo aliasInfo;
/// The mapping of tensors to buffers.
@@ -428,7 +473,7 @@ struct AllocationHoistingBarrierOnly
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
if (any_of(op->getOperandTypes(), isaTensor) ||
any_of(op->getResultTypes(), isaTensor))
- if (!state.options.allowUnknownOps)
+ if (!state.getOptions().allowUnknownOps)
return op->emitError() << "unsupported op with tensors";
for (Region ®ion : op->getRegions())
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
index 185878701d710..9b7cb9421dd98 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -35,6 +35,7 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
/// This analysis can be skipped with `skipAnalysis`.
LogicalResult eliminateInitTensors(
FuncOp funcOp, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
std::function<bool(OpOperand &)> anchorMatchFunc,
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
SmallVector<Operation *> &newOps);
@@ -46,6 +47,7 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
struct InsertSliceAnchoredInitTensorEliminationStep
: public InitTensorEliminationStep {
LogicalResult run(FuncOp funcOp, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) override;
};
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
index 3ab5cc3525fc3..3e557e962b2aa 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
@@ -23,6 +23,7 @@ namespace scf_ext {
/// equivalent to their corresponding loop yield values.
struct AssertDestinationPassingStyle : public PostAnalysisStep {
LogicalResult run(FuncOp funcOp, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) override;
};
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
index dbda53743d9b0..61b1f9356d545 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
@@ -21,6 +21,7 @@ namespace tensor_ext {
struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep {
LogicalResult run(FuncOp funcOp, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) override;
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 7682c4ae49393..4cb4c5bd38166 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -367,7 +367,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(builder, operandBuffer);
// Allocate the result buffer.
- Value resultBuffer = createAllocDeallocFn(builder, loc, operandBuffer);
+ Value resultBuffer = createAllocDeallocPair(builder, loc, operandBuffer);
bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -389,8 +389,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
if (!skipCopy) {
// The copy happens right before the op that is bufferized.
builder.setInsertionPoint(op);
- options.allocationFns->memCpyFn(builder, loc, operandBuffer,
- resultBuffer);
+ createMemCpy(builder, loc, operandBuffer, resultBuffer);
}
return resultBuffer;
}
@@ -420,7 +419,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
LogicalResult
mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
BufferizationState &state) {
- OpBuilder &b = state.builder;
+ OpBuilder &b = state.getBuilder();
// Check if op has tensor results or operands.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -443,7 +442,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
}
// `op` is an unbufferizable tensor op.
- if (!state.options.allowUnknownOps)
+ if (!state.getOptions().allowUnknownOps)
return op->emitError() << "unsupported op with tensors";
// Replace all OpOperands with "to-tensor casted" bufferized values.
@@ -550,7 +549,7 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
- createAllocDeallocFn(OpBuilder &b, Location loc, Value shapedValue) {
+ createAllocDeallocPair(OpBuilder &b, Location loc, Value shapedValue) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -561,8 +560,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
// Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
- Optional<Value> allocated =
- options.allocationFns->allocationFn(b, loc, allocMemRefType, dynShape);
+ Optional<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
// TODO: For now just assert the value is returned. Eventually need to
// error-propagate.
assert(allocated && "allocation failed");
@@ -573,10 +571,29 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
// 2. Create memory deallocation.
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
- options.allocationFns->deallocationFn(b, loc, allocated.getValue());
+ createDealloc(b, loc, allocated.getValue());
return casted;
}
+/// Create a memref allocation.
+Optional<Value>
+mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
+ OpBuilder &b, Location loc, MemRefType type, ArrayRef<Value> dynShape) {
+ return options.allocationFns->allocationFn(b, loc, type, dynShape);
+}
+
+/// Create a memref deallocation.
+void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc(
+ OpBuilder &b, Location loc, Value allocatedBuffer) {
+ return options.allocationFns->deallocationFn(b, loc, allocatedBuffer);
+}
+
+/// Create a memory copy between two memref buffers.
+void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
+ OpBuilder &b, Location loc, Value from, Value to) {
+ return options.allocationFns->memCpyFn(b, loc, from, to);
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
@@ -648,9 +665,15 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isMapped(
Value value) const {
+ assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
return mapping.contains(value);
}
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
+ OpResult opResult) const {
+ return aliasInfo.isInPlace(opResult);
+}
+
void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete(
Operation *op) {
obsoleteOps.push_back(op);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 6cbbda3f97146..19b6361027514 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -732,7 +732,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
SmallVector<Operation *> newOps;
- if (failed(step->run(funcOp, state, newOps)))
+ if (failed(step->run(funcOp, state, aliasInfo, newOps)))
return failure();
// Analyze ops that were created by the PostAnalysisStep.
if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index d9231c0445164..4d66f44724a3f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -159,8 +159,8 @@ struct InitTensorOpInterface
if (initTensorOp->getUses().empty())
return success();
- Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(),
- initTensorOp.result());
+ Value alloc = state.createAllocDeallocPair(b, initTensorOp->getLoc(),
+ initTensorOp.result());
state.mapBuffer(initTensorOp.result(), alloc);
return success();
}
@@ -379,11 +379,11 @@ struct LinalgOpInterfaceHelper<> {
LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
InitTensorEliminationStep::eliminateInitTensors(
FuncOp funcOp, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
std::function<bool(OpOperand &)> anchorMatchFunc,
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
SmallVector<Operation *> &newOps) {
OpBuilder b(funcOp->getContext());
- BufferizationAliasInfo &aliasInfo = state.aliasInfo;
WalkResult status = funcOp->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
@@ -474,16 +474,16 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
InsertSliceAnchoredInitTensorEliminationStep::run(
FuncOp funcOp, BufferizationState &state,
- SmallVector<Operation *> &newOps) {
+ BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
return eliminateInitTensors(
- funcOp, state,
+ funcOp, state, aliasInfo,
[&](OpOperand &operand) {
auto insertSliceOp =
dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
if (!insertSliceOp)
return false;
// Only inplace bufferized InsertSliceOps are eligible.
- if (!state.aliasInfo.isInPlace(insertSliceOp->getOpResult(0)))
+ if (!aliasInfo.isInPlace(insertSliceOp->getOpResult(0)))
return false;
return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
},
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index ebe7a5feb8d00..3a64ae711b9fc 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -88,6 +88,7 @@ struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
}
LogicalResult run(FuncOp funcOp, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) override {
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
@@ -99,12 +100,12 @@ struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
if (returnVal.get().getType().isa<RankedTensorType>())
for (BlockArgument bbArg : funcOp.getArguments())
if (bbArg.getType().isa<RankedTensorType>())
- if (state.aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
- bbArg)) {
+ if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
+ bbArg)) {
moduleState
.equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
bbArg.getArgNumber();
- if (state.options.testAnalysisOnly)
+ if (state.getOptions().testAnalysisOnly)
annotateReturnOp(returnVal, bbArg);
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 55156f949635d..db70cfd571523 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -265,6 +265,7 @@ struct ForOpInterface
LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
LogicalResult status = success();
funcOp->walk([&](scf::YieldOp yieldOp) {
@@ -280,8 +281,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
- if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
- bbArg)) {
+ if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
status =
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 7f1bdb703d18d..21872d8407bea 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -144,10 +144,10 @@ struct ExtractSliceOpInterface
extractSliceOp.result().getType().cast<RankedTensorType>();
// If not inplaceable, alloc.
- bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
+ bool inplace = state.isInPlace(extractSliceOp->getResult(0));
Value alloc;
if (!inplace)
- alloc = state.createAllocDeallocFn(b, loc, extractSliceOp.result());
+ alloc = state.createAllocDeallocPair(b, loc, extractSliceOp.result());
// Bufferize to subview.
auto subviewMemRefType =
@@ -159,15 +159,12 @@ struct ExtractSliceOpInterface
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
- // Insert new alias.
- state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
/// If not inplaceable, copy.
if (!inplace) {
// Do not copy if the copied data is never read.
if (isValueRead(extractSliceOp.result()))
- state.options.allocationFns->memCpyFn(b, extractSliceOp.getLoc(),
- subView, alloc);
+ state.createMemCpy(b, extractSliceOp.getLoc(), subView, alloc);
subView = alloc;
}
@@ -421,8 +418,7 @@ struct InsertSliceOpInterface
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Copy tensor.
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
- state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(),
- srcMemref, subView);
+ state.createMemCpy(b, insertSliceOp.getLoc(), srcMemref, subView);
}
state.mapBuffer(insertSliceOp.result(), dstMemref);
@@ -437,6 +433,7 @@ struct InsertSliceOpInterface
LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext::
InplaceInsertSliceOpAnalysis::run(FuncOp funcOp, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
auto &tensorState = getTensorBufferizationState(state);
funcOp.walk([&](InsertSliceOp insertSliceOp) {
@@ -445,9 +442,9 @@ LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext::
// slice is computed out of place into the inplace full tensor.
// - The result is not inplace. This is the case where the whole tensor is
// cloned and the clone needs to be updated.
- if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo,
+ if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
insertSliceOp) &&
- state.aliasInfo.isInPlace(insertSliceOp->getResult(0)))
+ state.isInPlace(insertSliceOp->getResult(0)))
tensorState.insertSliceOpsWithoutCopy.insert(insertSliceOp);
});
return success();
More information about the Mlir-commits
mailing list