[Mlir-commits] [mlir] 4170141 - [mlir][linalg][bufferize] Replace remaining bvm usage with new API
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 15 06:25:01 PST 2021
Author: Matthias Springer
Date: 2021-12-15T23:21:39+09:00
New Revision: 417014170bd581186277c5e899a2338c482db69d
URL: https://github.com/llvm/llvm-project/commit/417014170bd581186277c5e899a2338c482db69d
DIFF: https://github.com/llvm/llvm-project/commit/417014170bd581186277c5e899a2338c482db69d.diff
LOG: [mlir][linalg][bufferize] Replace remaining bvm usage with new API
* Call `replaceOp` instead of `mapBuffer`.
* Remove bvm and all helper functions around bvm.
* Simplify FuncOp bufferization and rely on existing functionality to generate ToMemrefOps for function BlockArguments.
Differential Revision: https://reviews.llvm.org/D115515
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 53c1840b18c6d..35955af49efaa 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -268,6 +268,9 @@ class BufferizationAliasInfo {
llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
};
+/// Return `true` if the given value is a BlockArgument of a FuncOp.
+bool isFunctionArgument(Value value);
+
/// Determine which OpOperand* will alias with `result` if the op is bufferized
/// in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);
@@ -342,18 +345,18 @@ struct DialectBufferizationState {
DialectBufferizationState(const DialectBufferizationState &) = delete;
};
-/// BufferizationState keeps track of memory buffers and provides a variety of
-/// helper functions for dealing with them. In particular,
+/// BufferizationState provides a variety of helper functions for dealing with
+/// tensor values and memref buffers. In particular,
/// `BufferizableOpInterface::bufferize` implementation should utilize the
/// following helper functions.
///
/// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops
/// that allocate and/or deallocate memref buffers.
-/// * `mapBuffer` maps a tensor value to a memref buffer during bufferization.
-/// * `lookupBuffer` returns the mapped memref buffer of a given tensor value.
+/// * `lookupBuffer` returns the memref buffer of a given tensor value.
/// * `getResultBuffer` returns the memref buffer for a given tensor OpResult.
/// Based on inplace bufferization decisions of the analysis, it may either
/// directly return a mapped buffer or allocate a new brand new buffer.
+/// * `replaceOp` replaces an op with new values.
class BufferizationState {
public:
BufferizationState(Operation *op, const BufferizationOptions &options)
@@ -378,16 +381,19 @@ class BufferizationState {
/// Creates a memcpy between two given buffers.
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to);
- /// Replace an op with replacement values. The op is deleted.
+ /// Replace an op with replacement values. The op is deleted. Tensor OpResults
+ /// must be replaced with memref values.
void replaceOp(Operation *op, ValueRange values);
- /// Map tensor values to memref buffers.
- // TODO: Deprecated. Remove all uses of this op. Use `replaceOp` instead.
- void mapBuffer(ValueRange tensors, ValueRange buffers);
-
- /// Map a tensor value to a memref buffer.
- // TODO: Deprecated. Remove all uses of this op. Use `replaceOp` instead.
- void mapBuffer(Value tensor, Value buffer);
+ /// Replace an op with a new op. Tensor OpResults must be replaced with memref
+ /// values.
+ template <typename OpTy, typename... Args>
+ OpTy replaceOpWithNewOp(OpBuilder &b, Operation *op, Args &&...args) {
+ Operation *newOp =
+ b.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+ replaceOp(op, newOp->getResults());
+ return cast<OpTy>(newOp);
+ }
/// Lookup the memref buffer that is associated to the given tensor value.
/// Asserts if no buffer is associated.
@@ -396,23 +402,11 @@ class BufferizationState {
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpResult opResult) const;
- /// Return `true` if the given value is mapped.
- // TODO: Deprecated. Remove all uses of this op.
- bool isMapped(Value value) const;
-
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
Value getResultBuffer(OpResult result);
- /// Mark `op` as obsolete, so that it is deleted after bufferization.
- // TODO: Deprecated. Remove all uses of this op.
- void markOpObsolete(Operation *op);
-
- /// Erase all ops that were marked obsolete.
- // TODO: Deprecated. Remove all uses of this op.
- void eraseObsoleteOps();
-
/// Return dialect-specific bufferization state.
template <typename StateT> StateT &getDialectState(StringRef name) {
// Create state if it does not exist yet.
@@ -441,12 +435,6 @@ class BufferizationState {
/// functions and `runComprehensiveBufferize` may access this object.
BufferizationAliasInfo aliasInfo;
- /// The mapping of tensors to buffers.
- BlockAndValueMapping mapping;
-
- /// Obsolete ops that should be deleted after bufferization.
- SmallVector<Operation *> obsoleteOps;
-
/// Dialect-specific bufferization state.
DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index 1f3eaab91d2c5..e370d3f430421 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -35,10 +35,8 @@ struct ConstantOpInterface
GlobalCreator globalCreator(moduleOp);
auto globalMemref = globalCreator.getGlobalFor(constantOp);
- Value memref = b.create<memref::GetGlobalOp>(
- constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
- state.mapBuffer(constantOp, memref);
-
+ state.replaceOpWithNewOp<memref::GetGlobalOp>(b, op, globalMemref.type(),
+ globalMemref.getName());
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index e33562f3b6f9f..d6d3d28a1022f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -498,19 +498,6 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
if (!state.getOptions().allowUnknownOps)
return op->emitError() << "unsupported op with tensors";
- // Replace all OpOperands with "to-tensor casted" bufferized values.
- for (OpOperand &operand : op->getOpOperands()) {
- if (operand.get().getType().isa<TensorType>() &&
- state.isMapped(operand.get())) {
- assert(state.getOptions().allowUnknownOps &&
- "unsupported op error should have been emitted earlier");
- b.setInsertionPoint(op);
- Value toTensorOp = b.create<bufferization::ToTensorOp>(
- op->getLoc(), state.lookupBuffer(operand.get()));
- operand.set(toTensorOp);
- }
- }
-
// Bufferize all regions.
for (Region ®ion : op->getRegions())
if (failed(bufferize(®ion, state)))
@@ -654,38 +641,13 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
-/// Wrapper for better debugging.
-void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
- ValueRange tensors, ValueRange buffers) {
- assert(!tensors.empty() && "unexpected empty tensors");
-#ifndef NDEBUG
- for (Value tensor : tensors) {
- assert(tensor && "unexpected empty tensor");
- assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
- }
- for (Value buffer : buffers) {
- assert(buffer && "unexpected empty buffer");
- assert((buffer.getType().isa<MemRefType>() ||
- buffer.getType().isa<UnrankedMemRefType>()) &&
- "expected that tensor is mapped to memref");
- }
-#endif // NDEBUG
- return mapping.map(tensors, buffers);
-}
-
-/// Wrapper for better debugging.
-void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
- Value tensor, Value buffer) {
- assert(tensor && "unexpected empty tensor");
- assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
- assert(buffer && "unexpected empty buffer");
- assert((buffer.getType().isa<MemRefType>() ||
- buffer.getType().isa<UnrankedMemRefType>()) &&
- "expected that tensor is mapped to memref");
- return mapping.map(tensor, buffer);
+bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
+ auto bbArg = value.dyn_cast<BlockArgument>();
+ if (!bbArg)
+ return false;
+ return isa<FuncOp>(bbArg.getOwner()->getParentOp());
}
-/// Wrapper for better debugging.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
Value tensor) {
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
@@ -694,37 +656,29 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.memref();
- Value buffer = mapping.lookupOrNull(tensor);
- if (!buffer) {
- if (options.allowUnknownOps) {
- // `tensor` was not bufferized yet. This should never happen with
- // bufferizable ops.
- assert(!options.dynCastBufferizableOp(tensor) && "tensor is not mapped");
- // Insert to_memref op.
- OpBuilder b(tensor.getContext());
- setInsertionPointAfter(b, tensor);
- return b.create<bufferization::ToMemrefOp>(
- tensor.getLoc(),
- getDynamicMemRefType(tensor.getType().cast<RankedTensorType>()),
- tensor);
+ if (!isFunctionArgument(tensor)) {
+ if (static_cast<bool>(options.dynCastBufferizableOp(tensor))) {
+ // Dump tensor for easier debugging.
+ tensor.dump();
+ llvm_unreachable("op is known, but has not been bufferized yet");
+ return Value();
+ }
+ if (!options.allowUnknownOps) {
+ // Dump tensor for easier debugging.
+ tensor.dump();
+ // Note: An assertion should already have failed earlier.
+ llvm_unreachable("unknown ops are not allowed");
+ return Value();
}
-
- // Dump tensor for easier debugging.
- tensor.dump();
- llvm_unreachable("tensor is not mapped");
- return Value();
}
- assert((buffer.getType().isa<MemRefType>() ||
- buffer.getType().isa<UnrankedMemRefType>()) &&
- "expected that tensor is mapped to memref");
- return buffer;
-}
-
-bool mlir::linalg::comprehensive_bufferize::BufferizationState::isMapped(
- Value value) const {
- assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
- return mapping.contains(value);
+ // Insert to_memref op.
+ OpBuilder &b = getBuilder();
+ OpBuilder::InsertionGuard g(b);
+ setInsertionPointAfter(b, tensor);
+ return b.create<bufferization::ToMemrefOp>(
+ tensor.getLoc(),
+ getDynamicMemRefType(tensor.getType().cast<RankedTensorType>()), tensor);
}
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
@@ -732,18 +686,6 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
return aliasInfo.isInPlace(opResult);
}
-void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete(
- Operation *op) {
- obsoleteOps.push_back(op);
-}
-
-void mlir::linalg::comprehensive_bufferize::BufferizationState::
- eraseObsoleteOps() {
- for (Operation *op : obsoleteOps)
- op->erase();
- obsoleteOps.clear();
-}
-
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
ShapedType shapedType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index 6836a02dfd3db..97345350835a9 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -63,21 +63,9 @@ struct ToMemrefOpInterface
// If a ToMemrefOp's tensor operand has not been bufferized yet, the op
// remains unchanged. All IR up to this ToMemrefOp has already been
// bufferized, unless there were unknown ops that could be bufferized.
- if (!state.isMapped(toMemrefOp.tensor())) {
- assert(state.getOptions().allowUnknownOps &&
- "expected that tensor is mapped");
- return success();
- }
-
- // If a ToMemrefOp's tensor operand has been bufferized, the op can be
- // removed.
- Value memref = state.lookupBuffer(toMemrefOp.tensor());
- // Do not replace a ToMemrefOp with itself. E.g., when bufferizing a
- // function body, ToMemrefOps were inserted before starting bufferization of
- // the function body. Such ToMemrefOps are replaced in a separate step after
- // the function body has been bufferized.
- if (toMemrefOp.getResult() != memref)
- toMemrefOp.replaceAllUsesWith(memref);
+ assert((isFunctionArgument(toMemrefOp.tensor()) ||
+ state.getOptions().allowUnknownOps) &&
+ "expected that tensor is mapped");
return success();
}
@@ -98,8 +86,6 @@ struct ToTensorOpInterface
bufferization::ToTensorOp> {
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
- auto tensorLoadOp = cast<bufferization::ToTensorOp>(op);
- state.mapBuffer(tensorLoadOp.result(), tensorLoadOp.memref());
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 028be806236c4..b9dde90e63ee2 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -699,8 +699,5 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
if (failed(bufferize(op, state)))
return failure();
- // Erase all obsolete ops.
- state.eraseObsoleteOps();
-
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 451923d768fbb..9984ae1ad122c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -56,7 +56,6 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
if (!resultBuffer)
return failure();
newOutputBuffers.push_back(resultBuffer);
- state.mapBuffer(opResult, resultBuffer);
}
// Clone the newly bufferized op.
@@ -68,7 +67,9 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
auto bufferizedOp = cast<LinalgOp>(
op.clone(b, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
- // The original op will be DCE'd away later.
+ // Replace the results of the old op with the new output buffers.
+ state.replaceOp(op, newOutputBuffers);
+
return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state);
}
@@ -194,7 +195,7 @@ struct InitTensorOpInterface
Value alloc = state.createAllocDeallocPair(b, initTensorOp->getLoc(),
initTensorOp.result());
- state.mapBuffer(initTensorOp.result(), alloc);
+ state.replaceOp(op, alloc);
return success();
}
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index cb0d75cb366f7..0e391a9a4f04a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -551,9 +551,6 @@ struct CallOpInterface
.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
- // Add CallOp operand/result equivalence: this is interprocedural
- // info.
- state.mapBuffer(oldRes, buffer);
// Add a ToTensorOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
// This ToTensorOp must fold/DCE away or bufferization should be
@@ -561,8 +558,6 @@ struct CallOpInterface
Value toTensorOp =
b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(toTensorOp);
- // Add new op equivalence info.
- state.mapBuffer(toTensorOp, buffer);
continue;
}
@@ -603,8 +598,6 @@ struct CallOpInterface
if (buffer.getType() != memRefType) {
Value castBuffer =
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
- // Add new op equivalence info.
- state.mapBuffer(tensorOperand, castBuffer);
buffer = castBuffer;
}
newOperands.push_back(buffer);
@@ -616,7 +609,7 @@ struct CallOpInterface
newCallOp->setAttrs(callOp->getAttrs());
// 5. Delete the op at the end of bufferization.
- state.markOpObsolete(callOp);
+ callOp->erase();
return success();
}
@@ -651,7 +644,6 @@ struct ReturnOpInterface
Value returnTensor = b.create<bufferization::ToTensorOp>(
returnOp.getLoc(), v);
operand.set(returnTensor);
- state.mapBuffer(returnTensor, v);
}
return success();
}
@@ -662,23 +654,6 @@ struct FuncOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto funcOp = cast<FuncOp>(op);
- b.setInsertionPointToStart(&funcOp.body().front());
-
- // Create BufferCastOps for function args.
- for (auto bbArg : funcOp.getArguments()) {
- auto tensorType = bbArg.getType().dyn_cast<TensorType>();
- if (!tensorType)
- continue;
- auto rankedTensorType = tensorType.dyn_cast<RankedTensorType>();
- // Cast the tensor to the most dynamic buffer possible. Further
- // canonicalizations will clean up.
- Type memRefType = rankedTensorType
- ? getDynamicMemRefType(rankedTensorType)
- : getContiguousOrUnrankedMemRefType(tensorType);
- Value bufferCast = b.create<bufferization::ToMemrefOp>(funcOp.getLoc(),
- memRefType, bbArg);
- state.mapBuffer(bbArg, bufferCast);
- }
// Bufferize function body.
return comprehensive_bufferize::bufferize(&funcOp.body(), state);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index cfc04be793b12..7558d792facf6 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -78,9 +78,7 @@ struct CastOpInterface
: MemRefLayoutAttrInterface();
Type memRefType = getContiguousOrUnrankedMemRefType(
castOp.getResult().getType(), layout, memorySpace);
- Value res =
- b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
- state.mapBuffer(castOp.getResult(), res);
+ state.replaceOpWithNewOp<memref::CastOp>(b, op, memRefType, resultBuffer);
return success();
}
};
@@ -103,11 +101,10 @@ struct DimOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
- if (dimOp.source().getType().isa<RankedTensorType>()) {
- Value v = state.lookupBuffer(dimOp.source());
- dimOp.result().replaceAllUsesWith(
- b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
- }
+ if (!dimOp.source().getType().isa<RankedTensorType>())
+ return dimOp.emitError("unranked tensor not supported");
+ Value v = state.lookupBuffer(dimOp.source());
+ state.replaceOpWithNewOp<memref::DimOp>(b, op, v, dimOp.index());
return success();
}
};
@@ -168,7 +165,7 @@ struct ExtractSliceOpInterface
subView = alloc;
}
- state.mapBuffer(extractSliceOp.result(), subView);
+ state.replaceOp(op, subView);
return success();
}
};
@@ -191,10 +188,9 @@ struct ExtractOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
- Location loc = extractOp.getLoc();
Value srcMemref = state.lookupBuffer(extractOp.tensor());
- Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
- extractOp.replaceAllUsesWith(l);
+ state.replaceOpWithNewOp<memref::LoadOp>(b, op, srcMemref,
+ extractOp.indices());
return success();
}
};
@@ -228,7 +224,7 @@ struct InsertOpInterface
Value destMemref = state.getResultBuffer(insertOp->getOpResult(0));
b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
insertOp.indices());
- state.mapBuffer(insertOp, destMemref);
+ state.replaceOp(op, destMemref);
return success();
}
@@ -423,7 +419,7 @@ struct InsertSliceOpInterface
state.createMemCpy(b, insertSliceOp.getLoc(), srcMemref, subView);
}
- state.mapBuffer(insertSliceOp.result(), dstMemref);
+ state.replaceOp(op, dstMemref);
return success();
}
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index c2f33b876fff3..3ccfb5065ed21 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -38,13 +38,17 @@ struct TransferReadOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
- auto transferReadOp = cast<vector::TransferReadOp>(op);
- assert(transferReadOp.getShapedType().isa<TensorType>() &&
+ auto readOp = cast<vector::TransferReadOp>(op);
+ assert(readOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
// TransferReadOp always reads from the bufferized op.source().
- Value v = state.lookupBuffer(transferReadOp.source());
- transferReadOp.sourceMutable().assign(v);
+ Value buffer = state.lookupBuffer(readOp.source());
+ Value read = b.create<vector::TransferReadOp>(
+ readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(),
+ readOp.permutation_map(), readOp.padding(), readOp.mask(),
+ readOp.in_boundsAttr());
+ state.replaceOp(op, read);
return success();
}
};
@@ -90,7 +94,7 @@ struct TransferWriteOpInterface
b.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
- state.mapBuffer(op->getResult(0), resultBuffer);
+ state.replaceOp(op, resultBuffer);
return success();
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
index 24f8de7aac0e6..971d6c6e88a2f 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
@@ -30,11 +30,11 @@ func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
// CHECK: %[[dim:.*]] = tensor.dim %[[A]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+ // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
// CHECK: memref.copy %[[A_memref]], %[[casted]]
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
- // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
// CHECK: return %[[res_tensor]]
return %0 : tensor<?xf32>
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
index de8717f22d72d..967b231d73134 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
@@ -52,10 +52,10 @@ func @use_of_unknown_op_3(%t1: tensor<?xf32> {linalg.inplaceable = true})
-> (vector<5xf32>, vector<5xf32>) {
%idx = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
+ // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
// CHECK: %[[v1:.*]] = vector.transfer_read %[[m1]]
%1 = vector.transfer_read %t1[%idx], %cst : tensor<?xf32>, vector<5xf32>
- // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
// CHECK: %[[dummy:.*]] = "test.dummy_op"(%[[m1_tensor]])
%0 = "test.dummy_op"(%t1) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]]
@@ -114,11 +114,11 @@ func @use_of_bufferizable_op_in_unbufferizable_op(
func @unused_unknown_op(%t1 : tensor<?xf32>) -> vector<5xf32> {
%idx = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
+ // ToTensorOp is inserted to pass in the result of the above bufferized op.
+ // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
// CHECK: vector.transfer_read %[[m1]]
%1 = vector.transfer_read %t1[%idx], %cst : tensor<?xf32>, vector<5xf32>
- // ToTensorOp is inserted to pass in the result of the above bufferized op.
- // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
// CHECK: "test.dummy_op"(%[[m1_tensor]])
"test.dummy_op"(%t1) : (tensor<?xf32>) -> ()
@@ -158,10 +158,10 @@ func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
// CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
// CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
+ // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
// CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]]
// CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
%0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
- // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
// CHECK-TENSOR: return %[[casted_tensor]]
return %0 : tensor<?xf32>
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 1094c21ed0537..f55f3008f2bd2 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -168,25 +168,25 @@ func @insert_slice_fun(%A0 : tensor<?xf32>,
-> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
{
// Hoisted allocs.
- // CHECK: %[[REALLOC_A0_2:.*]] = memref.alloc
- // CHECK: %[[REALLOC_A0:.*]] = memref.alloc
- // CHECK: %[[REALLOC_A1:.*]] = memref.alloc
+ // CHECK: %[[REALLOC1:.*]] = memref.alloc
+ // CHECK: %[[REALLOC2:.*]] = memref.alloc
+ // CHECK: %[[REALLOC3:.*]] = memref.alloc
// Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
- // CHECK: linalg.copy(%[[A0]], %[[REALLOC_A0]]
- // CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC_A0]]
+ // CHECK: linalg.copy(%[[A0]], %[[REALLOC3]]
+ // CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC3]]
// CHECK: linalg.copy(%[[t0]], %[[SV_A0]])
%r0 = tensor.insert_slice %t0 into %A0[0][4][1] : tensor<4xf32> into tensor<?xf32>
// Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
// CHECK: linalg.copy(%[[A0]]
- // CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC_A0_2]]
+ // CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC2]]
// CHECK: linalg.copy(%[[t1]], %[[SV_A0_2]])
%r1 = tensor.insert_slice %t1 into %A0[0][4][1] : tensor<4xf32> into tensor<?xf32>
// Still alloc the large tensor because %A1 is read after. Copy the tensor.extract_slice.
// CHECK: linalg.copy(%[[A1]]
- // CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC_A1]]
+ // CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC1]]
// CHECK: linalg.copy(%[[t0]], %[[SV_A1]])
%r2 = tensor.insert_slice %t0 into %A1[0][4][1] : tensor<4xf32> into tensor<?xf32>
@@ -196,7 +196,7 @@ func @insert_slice_fun(%A0 : tensor<?xf32>,
// CHECK: linalg.copy(%[[t1]], %[[SV_A1_2]])
%r3 = tensor.insert_slice %t1 into %A1[0][4][1] : tensor<4xf32> into tensor<?xf32>
- // CHECK: return %[[REALLOC_A0]], %[[REALLOC_A0_2]], %[[REALLOC_A1]] :
+ // CHECK: return %[[REALLOC3]], %[[REALLOC2]], %[[REALLOC1]] :
// CHECK-SAME: memref<?xf32>, memref<?xf32>, memref<?xf32>
return %r0, %r1, %r2, %r3: tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>
}
More information about the Mlir-commits
mailing list