[Mlir-commits] [mlir] aeb1c8d - [mlir][linalg][bufferize] Group helpers in BufferizationState
Matthias Springer
llvmlistbot at llvm.org
Thu Nov 11 01:28:40 PST 2021
Author: Matthias Springer
Date: 2021-11-11T18:24:13+09:00
New Revision: aeb1c8d0cae8e13cc97481e54ef11867808fe5b8
URL: https://github.com/llvm/llvm-project/commit/aeb1c8d0cae8e13cc97481e54ef11867808fe5b8
DIFF: https://github.com/llvm/llvm-project/commit/aeb1c8d0cae8e13cc97481e54ef11867808fe5b8.diff
LOG: [mlir][linalg][bufferize] Group helpers in BufferizationState
This simplifies the signature of `bufferize`.
Differential Revision: https://reviews.llvm.org/D113388
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 42908bc6c5c2..3c3f601a9385 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -197,6 +197,8 @@ findValueInReverseUseDefChain(Value value,
/// is returned regardless of whether it is a memory write or not.
Value findLastPrecedingWrite(Value value);
+struct BufferizationState;
+
/// Callback functions that are used to allocate/deallocate/copy memory buffers.
/// Comprehensive Bufferize provides default implementations of these functions.
// TODO: Could be replaced with a "bufferization strategy" object with virtual
@@ -207,8 +209,7 @@ struct AllocationCallbacks {
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
using CreateAllocDeallocFn =
- std::function<Value(OpBuilder &, Location, Value,
- BufferizationAliasInfo &, AllocationCallbacks &)>;
+ std::function<Value(OpBuilder &, Location, Value, BufferizationState &)>;
AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn)
@@ -230,13 +231,40 @@ struct AllocationCallbacks {
CreateAllocDeallocFn createAllocDeallocFn;
};
+/// BufferizationState keeps track of bufferization state and provides access to
+/// the results of the analysis.
+struct BufferizationState {
+ BufferizationState(BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks &allocationFns,
+ BlockAndValueMapping &tensorToBufferMap)
+ : aliasInfo(aliasInfo), allocationFns(allocationFns),
+ tensorToBufferMap(tensorToBufferMap) {}
+
+ /// Map tensor values to memref buffers.
+ void mapBuffer(ValueRange tensors, ValueRange buffers);
+
+ /// Map a tensor value to a memref buffer.
+ void mapBuffer(Value tensor, Value buffer);
+
+ /// Lookup the memref buffer that is associated to the given tensor value.
+ /// Asserts if no buffer is associated.
+ Value lookupBuffer(Value tensor) const;
+
+ /// `aliasInfo` keeps track of aliasing and equivalent values.
+ BufferizationAliasInfo &aliasInfo;
+
+ /// `allocationFns` contains helper functions for creating alloc ops, dealloc
+ /// ops and memcpy ops.
+ AllocationCallbacks &allocationFns;
+
+ /// The mapping of tensors to buffers.
+ BlockAndValueMapping &tensorToBufferMap;
+};
+
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
-Value getResultBuffer(OpBuilder &b, OpResult result,
- const BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks allocationFns);
+Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
} // namespace comprehensive_bufferize
} // namespace linalg
@@ -280,9 +308,7 @@ struct AllocationHoistingBarrierOnly
bool isWritable(Operation *op, Value value) const { return false; }
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
if (any_of(op->getOperandTypes(), isaTensor) ||
any_of(op->getResultTypes(), isaTensor))
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index 8c8846753304..66e26b3e6e61 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -160,8 +160,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
llvm_unreachable("bufferRelation not implemented");
}]
>,
- // TODO: Simplify method signature: Pass an OpBuilder and a
- // BufferizationState object.
InterfaceMethod<
/*desc=*/[{
Bufferize this op, i.e., rewrite it into a memref-based equivalent.
@@ -171,9 +169,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "OpBuilder &":$b,
- "BlockAndValueMapping &":$bvm,
- "BufferizationAliasInfo &":$aliasInfo,
- "AllocationCallbacks &":$allocationFn),
+ "BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index 35b6e6f2abe2..72a5c700c10a 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -27,6 +27,8 @@ namespace comprehensive_bufferize {
// TODO: from some HW description.
static constexpr int64_t kBufferAlignments = 128;
+struct BufferizationState;
+
/// Analyze the `ops` to determine which OpResults are inplaceable.
LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
BufferizationAliasInfo &aliasInfo,
@@ -55,9 +57,7 @@ std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
/// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
/// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
LogicalResult
-bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks allocationFns,
+bufferizeOp(Operation *op, BufferizationState &state,
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
/// Register external models implemented for the `BufferizableOpInterface`.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 3c5c95e87a04..bada67c7c1b7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -7,8 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/IR/AsmState.h"
#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
#include "llvm/Support/Debug.h"
namespace mlir {
@@ -319,30 +322,28 @@ Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
- OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo, AllocationCallbacks allocationFns) {
+ OpBuilder &b, OpResult result, BufferizationState &state) {
OpBuilder::InsertionGuard guard(b);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
OpOperand *opOperand = aliasingOperands.front();
Value operand = opOperand->get();
- Value operandBuffer = bvm.lookupOrNull(operand);
- assert(operandBuffer && "operand buffer not found");
+ Value operandBuffer = state.lookupBuffer(operand);
// Make sure that all OpOperands are the same buffer. If this is not the case,
// we would have to materialize a memref value.
// TODO: Should be looking for checking for "equivalent buffers" instead of
// operator== here, but equivalent buffers for scf.if yield values are not
// set up yet.
if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
- return bvm.lookup(o->get()) == operandBuffer;
+ return state.lookupBuffer(o->get()) == operandBuffer;
})) {
op->emitError("result buffer is ambiguous");
return Value();
}
// If bufferizing out-of-place, allocate a new buffer.
- if (!aliasInfo.isInPlace(result)) {
+ if (!state.aliasInfo.isInPlace(result)) {
// Ops with multiple aliasing operands can currently not bufferize
// out-of-place.
assert(
@@ -350,8 +351,8 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
Location loc = op->getLoc();
// Allocate the result buffer.
- Value resultBuffer = allocationFns.createAllocDeallocFn(
- b, loc, operand, aliasInfo, allocationFns);
+ Value resultBuffer =
+ state.allocationFns.createAllocDeallocFn(b, loc, operand, state);
bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -373,7 +374,7 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
if (!skipCopy) {
// Set insertion point now that potential alloc/dealloc are introduced.
b.setInsertionPoint(op);
- allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
+ state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
}
return resultBuffer;
}
@@ -381,3 +382,39 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
// Bufferizing in-place. No need to allocate a new buffer.
return operandBuffer;
}
+
+//===----------------------------------------------------------------------===//
+// Bufferization-specific BlockAndValueMapping support with debugging.
+//===----------------------------------------------------------------------===//
+
+/// Wrapper for better debugging.
+void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
+ ValueRange tensors, ValueRange buffers) {
+ assert(!tensors.empty() && "unexpected empty tensors");
+ return tensorToBufferMap.map(tensors, buffers);
+}
+
+/// Wrapper for better debugging.
+void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
+ Value tensor, Value buffer) {
+ assert(tensor && "unexpected empty tensor");
+ assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
+ return tensorToBufferMap.map(tensor, buffer);
+}
+
+/// Wrapper for better debugging.
+Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
+ Value tensor) const {
+ // TODO: if key comes from bbArg, forward.
+ assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
+ Value v = tensorToBufferMap.lookupOrNull(tensor);
+
+ if (!v) {
+ // Dump tensor for easier debugging.
+ tensor.dump();
+ llvm_unreachable("tensor is not mapped");
+ return Value();
+ }
+
+ return v;
+}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 233c896648b9..60b0e86d8628 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -172,47 +172,6 @@ static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
return returnOp;
}
-//===----------------------------------------------------------------------===//
-// Bufferization-specific BlockAndValueMapping support with debugging.
-//===----------------------------------------------------------------------===//
-
-/// Wrapper for better debugging.
-static void map(BlockAndValueMapping &bvm, ValueRange keys, ValueRange values) {
- assert(!keys.empty() && "Unexpected empty keys");
- LDBG("\n\tMap: " << printValueInfo(keys.front())
- << "\n\tto: " << printValueInfo(values.front()) << '\n');
- return bvm.map(keys, values);
-}
-
-/// Wrapper for better debugging.
-static void map(BlockAndValueMapping &bvm, Value key, Value value) {
- LDBG("\n\tMap: " << printValueInfo(key) << "\n\tto: " << printValueInfo(value)
- << '\n');
- return bvm.map(key, value);
-}
-
-/// Wrapper for better debugging.
-static Value lookup(const BlockAndValueMapping &bvm, Value key) {
- // TODO: if key comes from bbArg, forward.
- assert(key.getType().isa<TensorType>());
- Value v = bvm.lookupOrNull(key);
- if (v)
- return v;
-
- Operation *parentOp;
- if (auto bbArg = key.dyn_cast<BlockArgument>()) {
- if (isa<FuncOp>(key.getParentBlock()->getParentOp()))
- parentOp = key.getParentBlock()->getParentOp();
- else
- parentOp = key.getParentBlock()->getParentOp()->getParentOfType<FuncOp>();
- } else {
- parentOp = key.getDefiningOp()->getParentOfType<FuncOp>();
- }
- LDBG("In func:\n" << *parentOp << "\nNO VALUE FOR KEY: " << key << '\n');
- (void)parentOp;
- return Value();
-}
-
//===----------------------------------------------------------------------===//
// Bufferization-specific attribute manipulation.
// These are for testing and debugging only. Bufferization information is
@@ -878,8 +837,7 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
static Value createNewAllocDeallocPairForShapedValue(
- OpBuilder &b, Location loc, Value shapedValue,
- BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
+ OpBuilder &b, Location loc, Value shapedValue, BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -891,19 +849,19 @@ static Value createNewAllocDeallocPairForShapedValue(
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
Optional<Value> allocated =
- allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
+ state.allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
// TODO: For now just assert the value is returned. Eventually need to
// error-propagate.
assert(allocated && "allocation failed");
Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
- aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
+ state.aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
}
// 2. Create memory deallocation.
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
- allocationFns.deallocationFn(b, loc, allocated.getValue());
+ state.allocationFns.deallocationFn(b, loc, allocated.getValue());
return casted;
}
@@ -915,8 +873,7 @@ static Value createNewAllocDeallocPairForShapedValue(
/// inplaceable. For now, it is the responsibility of the `callOp` bufferization
/// to allow FuncOp that are inplaceable to write inPlace.
static LogicalResult
-bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns,
+bufferize(OpBuilder &b, CallOpInterface callOp, BufferizationState &state,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
@@ -962,14 +919,13 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
// If return operand is equivalent to some bbArg, no need to return it.
Value returnVal = returnOperand.get();
if (BlockArgument bbArg =
- getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) {
+ getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
int64_t idx = bbArg.getArgNumber();
- Value buffer = lookup(bvm, callOp->getOperand(idx));
- assert(buffer && "expected bufferized value");
+ Value buffer = state.lookupBuffer(callOp->getOperand(idx));
// Add CallOp operand/result equivalence: this is interprocedural info.
- aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
- map(bvm, oldRes, buffer);
+ state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
+ state.mapBuffer(oldRes, buffer);
// Add a TensorLoadOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
// This TensorLoadOp must fold/DCE away or bufferization should be
@@ -978,13 +934,13 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
b.create<memref::TensorLoadOp>(callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(tensorLoad);
// Add new op equivalence info.
- aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
- map(bvm, tensorLoad, buffer);
+ state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
+ state.mapBuffer(tensorLoad, buffer);
continue;
}
// TODO: Need to hoist above function boundary.
- if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) {
+ if (Operation *allocOp = getEquivalentAlloc(returnVal, state.aliasInfo)) {
hoistedArguments.push_back(allocOp->getResult(0));
continue;
}
@@ -1023,8 +979,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
// Tensor operands are guaranteed to have been buferized.
int64_t idx = opOperand.getOperandNumber();
- Value buffer = lookup(bvm, tensorOperand);
- assert(buffer && "expected bufferized value");
+ Value buffer = state.lookupBuffer(tensorOperand);
// Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx);
@@ -1037,8 +992,8 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
Value castBuffer =
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
// Add new op equivalence info.
- aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
- map(bvm, tensorOperand, castBuffer);
+ state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
+ state.mapBuffer(tensorOperand, castBuffer);
buffer = castBuffer;
}
newOperands.push_back(buffer);
@@ -1054,9 +1009,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
/// FuncOp always creates TensorToMemRef ops.
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) {
+ BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&funcOp.body().front());
@@ -1072,8 +1025,8 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
: getContiguousOrUnrankedMemRefType(tensorType);
Value bufferCast =
b.create<memref::BufferCastOp>(funcOp.getLoc(), memRefType, bbArg);
- aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
- map(bvm, bbArg, bufferCast);
+ state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
+ state.mapBuffer(bbArg, bufferCast);
}
return success();
}
@@ -1230,8 +1183,7 @@ void mlir::linalg::comprehensive_bufferize::defaultMemCpyFn(OpBuilder &b,
}
LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
- Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks allocationFns,
+ Operation *op, BufferizationState &state,
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
OpBuilder b(op->getContext());
@@ -1241,8 +1193,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
if (!bufferizedFunctionTypes)
llvm_unreachable(
"null bufferizedFunctionTypes when bufferizing CallOpInterface");
- return bufferize(b, callOp, bvm, aliasInfo, allocationFns,
- *bufferizedFunctionTypes);
+ return bufferize(b, callOp, state, *bufferizedFunctionTypes);
}
// Skip BufferCast and TensorLoad ops.
@@ -1251,7 +1202,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
// Bufferize using `BufferizableOpInterface`.
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
- return bufferizableOp.bufferize(b, bvm, aliasInfo, allocationFns);
+ return bufferizableOp.bufferize(b, state);
// Other op with tensors. No bufferization method specified.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -1262,23 +1213,21 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
}
static LogicalResult bufferizeFuncOpInternals(
- FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFns,
+ FuncOp funcOp, BufferizationState &state,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
OpBuilder b(funcOp->getContext());
// Start by bufferizing `funcOp` arguments.
- if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns)))
+ if (failed(bufferize(b, funcOp, state)))
return failure();
// Cannot erase ops during the traversal. Do that afterwards.
SmallVector<Operation *> toErase;
auto walkFunc = [&](Operation *op) -> WalkResult {
- if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
- &bufferizedFunctionTypes)))
+ if (failed(bufferizeOp(op, state, &bufferizedFunctionTypes)))
return failure();
// Register post-walk erasure, if necessary.
@@ -1852,9 +1801,10 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Bufferization phase.
if (!options.testAnalysisOnly) {
BlockAndValueMapping tensorToBufferMap;
- if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
- *options.allocationFns,
- bufferizedFunctionTypes)))
+ BufferizationState state(aliasInfo, *options.allocationFns,
+ tensorToBufferMap);
+ if (failed(
+ bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes)))
return failure();
}
}
@@ -1926,9 +1876,7 @@ struct ConstantOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
if (!isaTensor(constantOp.getResult().getType()))
return success();
@@ -1948,8 +1896,8 @@ struct ConstantOpInterface
auto globalMemref = globalCreator.getGlobalFor(constantOp);
Value memref = b.create<memref::GetGlobalOp>(
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
- aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
- map(bvm, constantOp, memref);
+ state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
+ state.mapBuffer(constantOp, memref);
return success();
}
@@ -1969,10 +1917,10 @@ namespace linalg_ext {
/// Helper function for LinalgOp bufferization.
/// When allocating a new buffer, analyze whether `op` wants to read form that
/// buffer. Only in that case, a copy of the result buffer may be needed.
-static LogicalResult allocateBuffersForResults(
- OpBuilder &b, Location loc, LinalgOp op,
- SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
+static LogicalResult
+allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
+ SmallVectorImpl<Value> &resultBuffers,
+ BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@@ -1983,24 +1931,21 @@ static LogicalResult allocateBuffersForResults(
OpResult opResult = cast<BufferizableOpInterface>(op.getOperation())
.getAliasingOpResult(*opOperand);
assert(opResult && "could not find correspond OpResult");
- Value resultBuffer =
- getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns);
+ Value resultBuffer = getResultBuffer(b, opResult, state);
if (!resultBuffer)
return failure();
resultBuffers.push_back(resultBuffer);
}
if (op->getNumResults())
- map(bvm, op->getResults(), resultBuffers);
+ state.mapBuffer(op->getResults(), resultBuffers);
return success();
}
/// Generic conversion for any LinalgOp on tensors.
static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFns) {
+ BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -2017,13 +1962,11 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
newInputBuffers.push_back(opOperand->get());
continue;
}
- newInputBuffers.push_back(lookup(bvm, opOperand->get()));
- assert(newInputBuffers.back() && "missing buffer");
+ newInputBuffers.push_back(state.lookupBuffer(opOperand->get()));
}
SmallVector<Value> newOutputBuffers;
// Try to allocate new buffers depending on op's inplace semantics.
- if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
- aliasInfo, allocationFns)))
+ if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, state)))
return failure();
// Clone the newly bufferized op.
@@ -2036,7 +1979,7 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
// Replace the results of the old op with the new output buffers.
if (op->getNumResults())
- map(bvm, op->getResults(), newOutputBuffers);
+ state.mapBuffer(op->getResults(), newOutputBuffers);
// The original op will be DCE'd away later.
@@ -2087,11 +2030,8 @@ struct LinalgOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
- return bufferizeLinalgOp(b, cast<LinalgOp>(op), bvm, aliasInfo,
- allocationFn);
+ BufferizationState &state) const {
+ return bufferizeLinalgOp(b, cast<LinalgOp>(op), state);
}
};
@@ -2109,9 +2049,7 @@ struct InitTensorOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto initTensorOp = cast<linalg::InitTensorOp>(op);
// The InitTensorOp may have been eliminated.
@@ -2123,9 +2061,8 @@ struct InitTensorOpInterface
b.setInsertionPoint(initTensorOp);
Value alloc = createNewAllocDeallocPairForShapedValue(
- b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo,
- allocationFn);
- map(bvm, initTensorOp.result(), alloc);
+ b, initTensorOp->getLoc(), initTensorOp.result(), state);
+ state.mapBuffer(initTensorOp.result(), alloc);
return success();
}
};
@@ -2178,9 +2115,7 @@ struct TiledLoopOpInterface
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
// Take a guard before anything else.
@@ -2222,15 +2157,14 @@ struct TiledLoopOpInterface
const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
- Value resultBuffer =
- getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
+ Value resultBuffer = getResultBuffer(b, opResult, state);
if (!resultBuffer)
return failure();
// Insert mapping and aliasing info.
- aliasInfo.createAliasInfoEntry(resultBuffer);
- aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
- map(bvm, opResult, resultBuffer);
+ state.aliasInfo.createAliasInfoEntry(resultBuffer);
+ state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
+ state.mapBuffer(opResult, resultBuffer);
// Insert new operand and bbArg.
tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer);
@@ -2238,9 +2172,10 @@ struct TiledLoopOpInterface
body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
// Insert mapping and aliasing info.
- aliasInfo.createAliasInfoEntry(newBufferBBArg);
- aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg);
- map(bvm, oldTensorBBArg, newBufferBBArg);
+ state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
+ state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
+ newBufferBBArg);
+ state.mapBuffer(oldTensorBBArg, newBufferBBArg);
// Set operand of `linalg.yield` to the bbArg so it just canonicalizes
// away later.
@@ -2268,8 +2203,7 @@ struct TiledLoopOpInterface
continue;
}
- Value inputBuffer = lookup(bvm, oldInputTensor);
- assert(inputBuffer && " missing buffer for operand");
+ Value inputBuffer = state.lookupBuffer(oldInputTensor);
// Insert new operand and bbArg.
tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer);
@@ -2278,9 +2212,10 @@ struct TiledLoopOpInterface
BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
// Insert mapping and aliasing info.
- aliasInfo.createAliasInfoEntry(newBufferBBArg);
- aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg);
- map(bvm, oldTensorBBArg, newBufferBBArg);
+ state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
+ state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
+ newBufferBBArg);
+ state.mapBuffer(oldTensorBBArg, newBufferBBArg);
// Increment indices.
++numNewInputBuffers;
@@ -2318,9 +2253,7 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto yieldOp = cast<linalg::YieldOp>(op);
// Take a guard before anything else.
@@ -2394,9 +2327,7 @@ struct IfOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
// scf::IfOp is bufferized after scf::YieldOp in the else branch.
return success();
}
@@ -2405,9 +2336,7 @@ struct IfOpInterface
/// Bufferize the scf::IfOp. This function is called after the YieldOp was
/// bufferized.
static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) {
+ BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(ifOp);
@@ -2420,13 +2349,12 @@ static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b,
assert(opResult.getType().isa<RankedTensorType>() &&
"unsupported unranked tensor");
- Value resultBuffer =
- getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
+ Value resultBuffer = getResultBuffer(b, opResult, state);
if (!resultBuffer)
return failure();
- aliasInfo.createAliasInfoEntry(resultBuffer);
- map(bvm, opResult, resultBuffer);
+ state.aliasInfo.createAliasInfoEntry(resultBuffer);
+ state.mapBuffer(opResult, resultBuffer);
}
return success();
@@ -2477,9 +2405,7 @@ struct ForOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
// Note: This method is just setting up the mappings for the block arguments
// and the result buffer. The op is bufferized after the scf::YieldOp.
@@ -2497,17 +2423,16 @@ struct ForOpInterface
"unsupported unranked tensor");
// TODO: More general: Matching bbArg does not bufferize to a read.
- Value resultBuffer =
- getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
+ Value resultBuffer = getResultBuffer(b, opResult, state);
if (!resultBuffer)
return failure();
OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
- aliasInfo.createAliasInfoEntry(resultBuffer);
- aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
- map(bvm, bbArg, resultBuffer);
- map(bvm, opResult, resultBuffer);
+ state.aliasInfo.createAliasInfoEntry(resultBuffer);
+ state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
+ state.mapBuffer(bbArg, resultBuffer);
+ state.mapBuffer(opResult, resultBuffer);
}
return success();
@@ -2517,9 +2442,7 @@ struct ForOpInterface
/// Bufferize the scf::ForOp. This function is called after the YieldOp was
/// bufferized.
static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) {
+ BufferizationState &state) {
auto yieldOp = cast<scf::YieldOp>(&forOp.region().front().back());
for (OpOperand &operand : yieldOp->getOpOperands()) {
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
@@ -2529,9 +2452,10 @@ static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b,
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
- Value yieldedBuffer = lookup(bvm, operand.get());
- Value bbArgBuffer = lookup(bvm, bbArg);
- if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) {
+ Value yieldedBuffer = state.lookupBuffer(operand.get());
+ Value bbArgBuffer = state.lookupBuffer(bbArg);
+ if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
+ bbArgBuffer)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
return yieldOp->emitError()
@@ -2567,9 +2491,7 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto yieldOp = cast<scf::YieldOp>(op);
if (auto execOp = dyn_cast<scf::ExecuteRegionOp>(yieldOp->getParentOp())) {
@@ -2584,12 +2506,12 @@ struct YieldOpInterface
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
if (ifOp.elseYield() != yieldOp)
return success();
- return bufferizeIfOp(ifOp, b, bvm, aliasInfo, allocationFn);
+ return bufferizeIfOp(ifOp, b, state);
}
// Bufferize scf::ForOp after bufferizing the scf::YieldOp.
if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
- return bufferizeForOp(forOp, b, bvm, aliasInfo, allocationFn);
+ return bufferizeForOp(forOp, b, state);
return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp");
}
@@ -2635,9 +2557,7 @@ struct CallOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
llvm_unreachable("CallOps are handled separately");
return failure();
}
@@ -2659,9 +2579,7 @@ struct ReturnOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto returnOp = cast<ReturnOp>(op);
// Take a guard before anything else.
@@ -2675,12 +2593,11 @@ struct ReturnOpInterface
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
- Value v = lookup(bvm, operand.get());
- assert(v && "missing buffer for result");
+ Value v = state.lookupBuffer(operand.get());
Value returnTensor = b.create<memref::TensorLoadOp>(returnOp.getLoc(), v);
operand.set(returnTensor);
- aliasInfo.insertNewBufferEquivalence(returnTensor, v);
- map(bvm, returnTensor, v);
+ state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
+ state.mapBuffer(returnTensor, v);
}
return success();
}
@@ -2715,17 +2632,14 @@ struct CastOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
- Value resultBuffer =
- getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn);
+ Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
if (!resultBuffer)
return failure();
Type sourceType = resultBuffer.getType();
@@ -2744,8 +2658,8 @@ struct CastOpInterface
castOp.getResult().getType(), layout, memorySpace);
Value res =
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
- aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
- map(bvm, castOp.getResult(), res);
+ state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
+ state.mapBuffer(castOp.getResult(), res);
return success();
}
};
@@ -2766,9 +2680,7 @@ struct DimOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
// Take a guard before anything else.
@@ -2776,8 +2688,7 @@ struct DimOpInterface
b.setInsertionPoint(dimOp);
if (dimOp.source().getType().isa<RankedTensorType>()) {
- Value v = lookup(bvm, dimOp.source());
- assert(v && "missing buffer");
+ Value v = state.lookupBuffer(dimOp.source());
dimOp.result().replaceAllUsesWith(
b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
}
@@ -2812,9 +2723,7 @@ struct ExtractSliceOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
// Take a guard before anything else.
@@ -2824,18 +2733,16 @@ struct ExtractSliceOpInterface
Location loc = extractSliceOp.getLoc();
// Bail if source was not bufferized.
- Value srcMemref = lookup(bvm, extractSliceOp.source());
- if (!srcMemref)
- return failure();
+ Value srcMemref = state.lookupBuffer(extractSliceOp.source());
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
// If not inplaceable, alloc.
Value alloc;
- if (!aliasInfo.isInPlace(extractSliceOp->getResult(0)))
+ if (!state.aliasInfo.isInPlace(extractSliceOp->getResult(0)))
alloc = createNewAllocDeallocPairForShapedValue(
- b, loc, extractSliceOp.result(), aliasInfo, allocationFn);
+ b, loc, extractSliceOp.result(), state);
// Set insertion point now that potential alloc/dealloc are introduced.
b.setInsertionPoint(extractSliceOp);
@@ -2851,17 +2758,18 @@ struct ExtractSliceOpInterface
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
// Insert new alias.
- aliasInfo.insertNewBufferAlias(subView, srcMemref);
+ state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
/// If not inplaceable, copy.
if (alloc) {
// Do not copy if the copied data is never read.
if (isValueRead(extractSliceOp.result()))
- allocationFn.memCpyFn(b, extractSliceOp.getLoc(), subView, alloc);
+ state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
+ alloc);
subView = alloc;
}
- map(bvm, extractSliceOp.result(), subView);
+ state.mapBuffer(extractSliceOp.result(), subView);
return success();
}
};
@@ -2882,9 +2790,7 @@ struct ExtractOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
// Take a guard before anything else.
@@ -2892,7 +2798,7 @@ struct ExtractOpInterface
b.setInsertionPoint(extractOp);
Location loc = extractOp.getLoc();
- Value srcMemref = lookup(bvm, extractOp.tensor());
+ Value srcMemref = state.lookupBuffer(extractOp.tensor());
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
extractOp.replaceAllUsesWith(l);
return success();
@@ -2950,9 +2856,7 @@ struct InsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
// Take a guard before anything else.
@@ -2969,15 +2873,12 @@ struct InsertSliceOpInterface
// TODO: be very loud about it or even consider failing the pass.
// Alloc a copy for `insertSliceOp.dest()`, it will become the result
// buffer.
- Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm,
- aliasInfo, allocationFn);
+ Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
if (!dstMemref)
return failure();
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
- Value srcMemref = lookup(bvm, insertSliceOp.source());
- if (!srcMemref)
- return failure();
+ Value srcMemref = state.lookupBuffer(insertSliceOp.source());
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getRank(), dstMemrefType,
@@ -2991,9 +2892,9 @@ struct InsertSliceOpInterface
// - The result is not inplace. This is the case where the whole tensor is
// cloned and the clone needs to be updated.
// TODO: Is this necessary?
- if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
+ if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo,
insertSliceOp) ||
- !aliasInfo.isInPlace(insertSliceOp->getResult(0))) {
+ !state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) {
LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
<< " -> copy\n");
// Take a subview of the dst.
@@ -3001,11 +2902,12 @@ struct InsertSliceOpInterface
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Insert new alias.
- aliasInfo.insertNewBufferAlias(subView, dstMemref);
- allocationFn.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView);
+ state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
+ state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
+ subView);
}
- map(bvm, insertSliceOp.result(), dstMemref);
+ state.mapBuffer(insertSliceOp.result(), dstMemref);
return success();
}
@@ -3035,9 +2937,7 @@ struct TransferReadOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto transferReadOp = cast<vector::TransferReadOp>(op);
// Take a guard before anything else.
@@ -3048,8 +2948,7 @@ struct TransferReadOpInterface
return failure();
// TransferReadOp always reads from the bufferized op.source().
- Value v = lookup(bvm, transferReadOp.source());
- assert(v && "missing buffer");
+ Value v = state.lookupBuffer(transferReadOp.source());
transferReadOp.sourceMutable().assign(v);
return success();
}
@@ -3086,9 +2985,7 @@ struct TransferWriteOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFn) const {
+ BufferizationState &state) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
// Take a guard before anything else.
@@ -3101,15 +2998,14 @@ struct TransferWriteOpInterface
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
- Value resultBuffer =
- getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn);
+ Value resultBuffer = getResultBuffer(b, op->getResult(0), state);
if (!resultBuffer)
return failure();
b.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_map(),
writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
- map(bvm, op->getResult(0), resultBuffer);
+ state.mapBuffer(op->getResult(0), resultBuffer);
return success();
}
More information about the Mlir-commits
mailing list