[Mlir-commits] [mlir] 5d50f51 - [mlir][bufferization][NFC] Add error handling to getBuffer
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 27 04:51:37 PDT 2022
Author: Matthias Springer
Date: 2022-06-27T13:48:01+02:00
New Revision: 5d50f51c970f8bc7cb76c785f81cb13bab94d14e
URL: https://github.com/llvm/llvm-project/commit/5d50f51c970f8bc7cb76c785f81cb13bab94d14e
DIFF: https://github.com/llvm/llvm-project/commit/5d50f51c970f8bc7cb76c785f81cb13bab94d14e.diff
LOG: [mlir][bufferization][NFC] Add error handling to getBuffer
This is in preparation of adding memory space support.
Differential Revision: https://reviews.llvm.org/D128277
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f852e7bcfe961..b609a7fd78fb6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -479,14 +479,15 @@ Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
-Value getBuffer(RewriterBase &rewriter, Value value,
- const BufferizationOptions &options);
+FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
+ const BufferizationOptions &options);
/// Return the buffer type for a given Value (tensor) after bufferization.
///
/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
/// This function should only be used if `getBuffer` cannot be used.
-BaseMemRefType getBufferType(Value value, const BufferizationOptions &options);
+FailureOr<BaseMemRefType> getBufferType(Value value,
+ const BufferizationOptions &options);
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 4c56a8196d455..8f739d758a91d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -343,7 +343,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Return the bufferized type of the given tensor block argument. The
block argument is guaranteed to belong to a block of this op.
}],
- /*retType=*/"BaseMemRefType",
+ /*retType=*/"FailureOr<BaseMemRefType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "BlockArgument":$bbArg,
"const BufferizationOptions &":$options),
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index bd4b9d7d4a6be..24657bf3558bf 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -84,8 +84,10 @@ struct IndexCastOpInterface
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = castOp.getType().cast<TensorType>();
- Value source = getBuffer(rewriter, castOp.getIn(), options);
- auto sourceType = source.getType().cast<BaseMemRefType>();
+ FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
+ if (failed(source))
+ return failure();
+ auto sourceType = source->getType().cast<BaseMemRefType>();
// Result type should have same layout and address space as the source type.
BaseMemRefType resultType;
@@ -100,7 +102,7 @@ struct IndexCastOpInterface
}
replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
- source);
+ *source);
return success();
}
};
@@ -140,8 +142,14 @@ struct SelectOpInterface
// instead of its OpOperands. In the worst case, 2 copies are inserted at
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
- Value trueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options);
- Value falseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options);
+ FailureOr<Value> maybeTrueBuffer =
+ getBuffer(rewriter, selectOp.getTrueValue(), options);
+ FailureOr<Value> maybeFalseBuffer =
+ getBuffer(rewriter, selectOp.getFalseValue(), options);
+ if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
+ return failure();
+ Value trueBuffer = *maybeTrueBuffer;
+ Value falseBuffer = *maybeFalseBuffer;
// The "true" and the "false" operands must have the same type. If the
// buffers have
diff erent types, they
diff er only in their layout map. Cast
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8a01acaf9374a..7e5ccd031abeb 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -480,8 +480,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
#endif
}
-Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
- const BufferizationOptions &options) {
+FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
+ const BufferizationOptions &options) {
#ifndef NDEBUG
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
@@ -494,14 +494,17 @@ Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- Type memrefType = getBufferType(value, options);
- ensureToMemrefOpIsValid(value, memrefType);
- return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
- value);
+ FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+ if (failed(memrefType))
+ return failure();
+ ensureToMemrefOpIsValid(value, *memrefType);
+ return rewriter
+ .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
+ .getResult();
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-BaseMemRefType
+FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 158f6edd1fa37..188c08b674c0e 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -165,8 +165,12 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
// Get "copy" buffer.
Value copyBuffer;
- if (getCopy())
- copyBuffer = getBuffer(rewriter, getCopy(), options);
+ if (getCopy()) {
+ FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
+ if (failed(maybeCopyBuffer))
+ return failure();
+ copyBuffer = *maybeCopyBuffer;
+ }
// Compute memory space of this allocation.
unsigned memorySpace;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 3689522dd065e..b5d5cf3c52c23 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -305,8 +305,13 @@ struct CallOpInterface
// Retrieve buffers for tensor operands.
Value buffer = newOperands[idx];
- if (!buffer)
- buffer = getBuffer(rewriter, opOperand.get(), options);
+ if (!buffer) {
+ FailureOr<Value> maybeBuffer =
+ getBuffer(rewriter, opOperand.get(), options);
+ if (failed(maybeBuffer))
+ return failure();
+ buffer = *maybeBuffer;
+ }
// Caller / callee type mismatch is handled with a CastOp.
auto memRefType = funcType.getInput(idx);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index cc27b4403d898..a904d7c2e1d6d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -44,15 +44,21 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
newInputBuffers.push_back(opOperand->get());
continue;
}
- newInputBuffers.push_back(getBuffer(rewriter, opOperand->get(), options));
+ FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
+ if (failed(buffer))
+ return failure();
+ newInputBuffers.push_back(*buffer);
}
// New output operands for the cloned op.
SmallVector<Value> newOutputBuffers;
for (OpResult opResult : op->getOpResults()) {
OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
- Value resultBuffer = getBuffer(rewriter, opOperand->get(), options);
- newOutputBuffers.push_back(resultBuffer);
+ FailureOr<Value> resultBuffer =
+ getBuffer(rewriter, opOperand->get(), options);
+ if (failed(resultBuffer))
+ return failure();
+ newOutputBuffers.push_back(*resultBuffer);
}
// Merge input/output operands.
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index d55e2784c1994..16ef7b5cad130 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -281,14 +281,17 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
/// Helper function for loop bufferization. Return the bufferized values of the
/// given OpOperands. If an operand is not a tensor, return the original value.
-static SmallVector<Value> getBuffers(RewriterBase &rewriter,
- MutableArrayRef<OpOperand> operands,
- const BufferizationOptions &options) {
+static FailureOr<SmallVector<Value>>
+getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
+ const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
if (opOperand.get().getType().isa<TensorType>()) {
- Value resultBuffer = getBuffer(rewriter, opOperand.get(), options);
- result.push_back(resultBuffer);
+ FailureOr<Value> resultBuffer =
+ getBuffer(rewriter, opOperand.get(), options);
+ if (failed(resultBuffer))
+ return failure();
+ result.push_back(*resultBuffer);
} else {
result.push_back(opOperand.get());
}
@@ -298,36 +301,46 @@ static SmallVector<Value> getBuffers(RewriterBase &rewriter,
/// Helper function for loop bufferization. Compute the buffer that should be
/// yielded from a loop block (loop body or loop condition).
-static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
- BaseMemRefType type,
- const BufferizationOptions &options) {
+static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
+ BaseMemRefType type,
+ const BufferizationOptions &options) {
assert(tensor.getType().isa<TensorType>() && "expected tensor");
ensureToMemrefOpIsValid(tensor, type);
- Value yieldedVal = getBuffer(rewriter, tensor, options);
- return castBuffer(rewriter, yieldedVal, type);
+ FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
+ if (failed(yieldedVal))
+ return failure();
+ return castBuffer(rewriter, *yieldedVal, type);
}
/// Helper function for loop bufferization. Given a range of values, apply
/// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
/// value in the result vector.
-static SmallVector<Value>
+static FailureOr<SmallVector<Value>>
convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
- llvm::function_ref<Value(Value, int64_t)> func) {
+ llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
SmallVector<Value> result;
for (const auto &it : llvm::enumerate(values)) {
size_t idx = it.index();
Value val = it.value();
- result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
+ if (tensorIndices.contains(idx)) {
+ FailureOr<Value> maybeVal = func(val, idx);
+ if (failed(maybeVal))
+ return failure();
+ result.push_back(*maybeVal);
+ } else {
+ result.push_back(val);
+ }
}
return result;
}
/// Helper function for loop bufferization. Given a list of pre-bufferization
/// yielded values, compute the list of bufferized yielded values.
-SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
- TypeRange bufferizedTypes,
- const DenseSet<int64_t> &tensorIndices,
- const BufferizationOptions &options) {
+FailureOr<SmallVector<Value>>
+getYieldedValues(RewriterBase &rewriter, ValueRange values,
+ TypeRange bufferizedTypes,
+ const DenseSet<int64_t> &tensorIndices,
+ const BufferizationOptions &options) {
return convertTensorValues(
values, tensorIndices, [&](Value val, int64_t index) {
return getYieldedBuffer(rewriter, val,
@@ -342,10 +355,19 @@ SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
SmallVector<Value>
getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
const DenseSet<int64_t> &tensorIndices) {
- return convertTensorValues(
- bbArgs, tensorIndices, [&](Value val, int64_t index) {
- return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
- });
+ SmallVector<Value> result;
+ for (const auto &it : llvm::enumerate(bbArgs)) {
+ size_t idx = it.index();
+ Value val = it.value();
+ if (tensorIndices.contains(idx)) {
+ result.push_back(
+ rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
+ .getResult());
+ } else {
+ result.push_back(val);
+ }
+ }
+ return result;
}
/// Bufferization of scf.for. Replace with a new scf.for that operates on
@@ -445,8 +467,9 @@ struct ForOpInterface
return success();
}
- BaseMemRefType getBufferType(Operation *op, BlockArgument bbArg,
- const BufferizationOptions &options) const {
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, BlockArgument bbArg,
+ const BufferizationOptions &options) const {
auto forOp = cast<scf::ForOp>(op);
return bufferization::getBufferType(
forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
@@ -462,8 +485,11 @@ struct ForOpInterface
DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
// The new memref init_args of the loop.
- SmallVector<Value> initArgs =
+ FailureOr<SmallVector<Value>> maybeInitArgs =
getBuffers(rewriter, forOp.getIterOpOperands(), options);
+ if (failed(maybeInitArgs))
+ return failure();
+ SmallVector<Value> initArgs = *maybeInitArgs;
// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = rewriter.create<scf::ForOp>(
@@ -689,13 +715,17 @@ struct WhileOpInterface
getTensorIndices(whileOp.getAfterArguments());
// The new memref init_args of the loop.
- SmallVector<Value> initArgs =
+ FailureOr<SmallVector<Value>> maybeInitArgs =
getBuffers(rewriter, whileOp->getOpOperands(), options);
+ if (failed(maybeInitArgs))
+ return failure();
+ SmallVector<Value> initArgs = *maybeInitArgs;
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
- return bufferization::getBufferType(bbArg, options).cast<Type>();
+ // TODO: error handling
+ return bufferization::getBufferType(bbArg, options)->cast<Type>();
}));
// Construct a new scf.while op with memref instead of tensor values.
@@ -727,10 +757,12 @@ struct WhileOpInterface
// Only equivalent buffers or new buffer allocations may be yielded to the
// "after" region.
// TODO: This could be relaxed for better bufferization results.
- SmallVector<Value> newConditionArgs =
+ FailureOr<SmallVector<Value>> newConditionArgs =
getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
indicesAfter, options);
- newConditionOp.getArgsMutable().assign(newConditionArgs);
+ if (failed(newConditionArgs))
+ return failure();
+ newConditionOp.getArgsMutable().assign(*newConditionArgs);
// Set up new iter_args and move the loop body block to the new op.
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
@@ -746,10 +778,12 @@ struct WhileOpInterface
// Only equivalent buffers or new buffer allocations may be yielded to the
// "before" region.
// TODO: This could be relaxed for better bufferization results.
- SmallVector<Value> newYieldValues =
+ FailureOr<SmallVector<Value>> newYieldValues =
getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
indicesBefore, options);
- newYieldOp.getResultsMutable().assign(newYieldValues);
+ if (failed(newYieldValues))
+ return failure();
+ newYieldOp.getResultsMutable().assign(*newYieldValues);
// Replace loop results.
replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
@@ -849,13 +883,18 @@ struct YieldOpInterface
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Value value = it.value();
if (value.getType().isa<TensorType>()) {
- Value buffer = getBuffer(rewriter, value, options);
+ FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
+ if (failed(maybeBuffer))
+ return failure();
+ Value buffer = *maybeBuffer;
if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
- BaseMemRefType resultType =
+ FailureOr<BaseMemRefType> resultType =
cast<BufferizableOpInterface>(forOp.getOperation())
.getBufferType(forOp.getRegionIterArgs()[it.index()],
options);
- buffer = castBuffer(rewriter, buffer, resultType);
+ if (failed(resultType))
+ return failure();
+ buffer = castBuffer(rewriter, buffer, *resultType);
}
newResults.push_back(buffer);
} else {
@@ -1078,16 +1117,22 @@ struct ParallelInsertSliceOpInterface
// If the op bufferizes out-of-place, allocate the copy before the
// ForeachThreadOp.
rewriter.setInsertionPoint(foreachThreadOp);
- Value destBuffer = getBuffer(rewriter, insertOp.getDest(), options);
+ FailureOr<Value> destBuffer =
+ getBuffer(rewriter, insertOp.getDest(), options);
+ if (failed(destBuffer))
+ return failure();
// Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
rewriter.setInsertionPoint(performConcurrentlyOp);
- Value srcBuffer = getBuffer(rewriter, insertOp.getSource(), options);
+ FailureOr<Value> srcBuffer =
+ getBuffer(rewriter, insertOp.getSource(), options);
+ if (failed(srcBuffer))
+ return failure();
Value subview = rewriter.create<memref::SubViewOp>(
- insertOp.getLoc(), destBuffer, insertOp.getMixedOffsets(),
+ insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
// This memcpy will fold away if everything bufferizes in-place.
- if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), srcBuffer,
+ if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
subview)))
return failure();
rewriter.eraseOp(op);
@@ -1095,7 +1140,7 @@ struct ParallelInsertSliceOpInterface
// Replace all uses of ForeachThreadOp (just the corresponding result).
rewriter.setInsertionPointAfter(foreachThreadOp);
Value toTensorOp =
- rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), destBuffer);
+ rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
unsigned resultNum = 0;
for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
if (&nextOp == op)
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 9a4f8e187a8b7..68580b8680afc 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -130,10 +130,16 @@ struct AssumingYieldOpInterface
const BufferizationOptions &options) const {
auto yieldOp = cast<shape::AssumingYieldOp>(op);
SmallVector<Value> newResults;
- for (Value value : yieldOp.operands())
- newResults.push_back(value.getType().isa<TensorType>()
- ? getBuffer(rewriter, value, options)
- : value);
+ for (Value value : yieldOp.operands()) {
+ if (value.getType().isa<TensorType>()) {
+ FailureOr<Value> buffer = getBuffer(rewriter, value, options);
+ if (failed(buffer))
+ return failure();
+ newResults.push_back(*buffer);
+ } else {
+ newResults.push_back(value);
+ }
+ }
replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
newResults);
return success();
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 7f8b6a35491d5..e7e31dcd42f54 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -52,8 +52,11 @@ struct CastOpInterface
auto castOp = cast<tensor::CastOp>(op);
// The result buffer still has the old (pre-cast) type.
- Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options);
- auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
+ FailureOr<Value> resultBuffer =
+ getBuffer(rewriter, castOp.getSource(), options);
+ if (failed(resultBuffer))
+ return failure();
+ auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
TensorType resultTensorType =
castOp.getResult().getType().cast<TensorType>();
MemRefLayoutAttrInterface layout;
@@ -68,11 +71,11 @@ struct CastOpInterface
sourceMemRefType.getMemorySpaceAsInt());
// Replace the op with a memref.cast.
- assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
+ assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
resultMemRefType) &&
"CallOp::bufferize: cast incompatible");
replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
- resultBuffer);
+ *resultBuffer);
return success();
}
@@ -108,7 +111,11 @@ struct CollapseShapeOpInterface
const BufferizationOptions &options) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
- Value buffer = getBuffer(rewriter, collapseShapeOp.getSrc(), options);
+ FailureOr<Value> maybeBuffer =
+ getBuffer(rewriter, collapseShapeOp.getSrc(), options);
+ if (failed(maybeBuffer))
+ return failure();
+ Value buffer = *maybeBuffer;
auto bufferType = buffer.getType().cast<MemRefType>();
if (tensorResultType.getRank() == 0) {
@@ -187,9 +194,11 @@ struct DimOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto dimOp = cast<tensor::DimOp>(op);
- auto v = getBuffer(rewriter, dimOp.getSource(), options);
- replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v,
- dimOp.getIndex());
+ FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
+ if (failed(v))
+ return failure();
+ replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
+ dimOp.index());
return success();
}
};
@@ -224,12 +233,15 @@ struct ExpandShapeOpInterface
const BufferizationOptions &options) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto tensorResultType = expandShapeOp.getResultType();
- auto buffer = getBuffer(rewriter, expandShapeOp.getSrc(), options);
+ FailureOr<Value> buffer =
+ getBuffer(rewriter, expandShapeOp.getSrc(), options);
+ if (failed(buffer))
+ return failure();
// Memref result type is inferred by the builder based on reassociation
// indices and result shape.
replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
- rewriter, op, tensorResultType.getShape(), buffer,
+ rewriter, op, tensorResultType.getShape(), *buffer,
expandShapeOp.getReassociationIndices());
return success();
}
@@ -268,8 +280,11 @@ struct ExtractSliceOpInterface
// Even if this op was decided to bufferize out-of-place, do not insert the
// buffer copy yet. This is done later in this function.
- auto srcMemref = getBuffer(rewriter, extractSliceOp.getSource(), options);
- auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
+ FailureOr<Value> srcMemref =
+ getBuffer(rewriter, extractSliceOp.getSource(), options);
+ if (failed(srcMemref))
+ return failure();
+ auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.getResult().getType().cast<RankedTensorType>();
@@ -279,7 +294,7 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
OffsetSizeAndStrideOpInterface::expandToRank(
- srcMemref, mixedOffsets, mixedSizes, mixedStrides,
+ *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
[&](Value target, int64_t dim) -> OpFoldResult {
auto shapedType = target.getType().cast<ShapedType>();
if (shapedType.isDynamicDim(dim))
@@ -292,7 +307,7 @@ struct ExtractSliceOpInterface
mixedOffsets, mixedSizes, mixedStrides)
.cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
+ loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
mixedStrides);
replaceOpWithBufferizedValues(rewriter, op, subView);
@@ -322,9 +337,12 @@ struct ExtractOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto extractOp = cast<tensor::ExtractOp>(op);
- Value srcMemref = getBuffer(rewriter, extractOp.getTensor(), options);
- replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
- extractOp.getIndices());
+ FailureOr<Value> srcMemref =
+ getBuffer(rewriter, extractOp.getTensor(), options);
+ if (failed(srcMemref))
+ return failure();
+ replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
+ extractOp.indices());
return success();
}
};
@@ -497,10 +515,13 @@ struct InsertOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto insertOp = cast<tensor::InsertOp>(op);
- Value destMemref = getBuffer(rewriter, insertOp.getDest(), options);
+ FailureOr<Value> destMemref =
+ getBuffer(rewriter, insertOp.getDest(), options);
+ if (failed(destMemref))
+ return failure();
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
- destMemref, insertOp.getIndices());
- replaceOpWithBufferizedValues(rewriter, op, destMemref);
+ *destMemref, insertOp.getIndices());
+ replaceOpWithBufferizedValues(rewriter, op, *destMemref);
return success();
}
@@ -655,7 +676,10 @@ struct InsertSliceOpInterface
// TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
Location loc = insertSliceOp.getLoc();
- Value dstMemref = getBuffer(rewriter, insertSliceOp.getDest(), options);
+ FailureOr<Value> dstMemref =
+ getBuffer(rewriter, insertSliceOp.getDest(), options);
+ if (failed(dstMemref))
+ return failure();
// Expand offsets, sizes and strides to the full rank to handle the
// rank-reducing case.
@@ -663,7 +687,7 @@ struct InsertSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
OffsetSizeAndStrideOpInterface::expandToRank(
- dstMemref, mixedOffsets, mixedSizes, mixedStrides,
+ *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
[&](Value target, int64_t dim) -> OpFoldResult {
auto shapedType = target.getType().cast<ShapedType>();
if (shapedType.isDynamicDim(dim))
@@ -671,23 +695,26 @@ struct InsertSliceOpInterface
return rewriter.getIndexAttr(shapedType.getDimSize(dim));
});
// Take a subview of the dst.
- auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+ auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getRank(), dstMemrefType,
mixedOffsets, mixedSizes, mixedStrides)
.cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, subviewMemRefType, dstMemref, mixedOffsets, mixedSizes,
+ loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
mixedStrides);
// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
- auto srcMemref = getBuffer(rewriter, insertSliceOp.getSource(), options);
- if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView)))
+ FailureOr<Value> srcMemref =
+ getBuffer(rewriter, insertSliceOp.getSource(), options);
+ if (failed(srcMemref))
+ return failure();
+ if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
return failure();
- replaceOpWithBufferizedValues(rewriter, op, dstMemref);
+ replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
return success();
}
};
@@ -714,9 +741,11 @@ struct RankOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto rankOp = cast<tensor::RankOp>(op);
- auto v = getBuffer(rewriter, rankOp.getTensor(), options);
+ FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
+ if (failed(v))
+ return failure();
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
- v);
+ *v);
return success();
}
};
@@ -750,12 +779,16 @@ struct ReshapeOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
- Value srcBuffer = getBuffer(rewriter, reshapeOp.getSource(), options);
- Value shapeBuffer = getBuffer(rewriter, reshapeOp.getShape(), options);
+ FailureOr<Value> srcBuffer =
+ getBuffer(rewriter, reshapeOp.getSource(), options);
+ FailureOr<Value> shapeBuffer =
+ getBuffer(rewriter, reshapeOp.getShape(), options);
+ if (failed(srcBuffer) || failed(shapeBuffer))
+ return failure();
auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
auto resultMemRefType = getMemRefType(resultTensorType, options);
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
- rewriter, op, resultMemRefType, srcBuffer, shapeBuffer);
+ rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 142e09bbb3da5..77f895a929c33 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -50,9 +50,11 @@ struct TransferReadOpInterface
auto readOp = cast<vector::TransferReadOp>(op);
assert(readOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
- Value buffer = getBuffer(rewriter, readOp.getSource(), options);
+ FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
+ if (failed(buffer))
+ return failure();
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
- rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(),
+ rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
return success();
@@ -97,12 +99,15 @@ struct TransferWriteOpInterface
"only tensor types expected");
// Create a new transfer_write on buffer that doesn't have a return value.
- Value resultBuffer = getBuffer(rewriter, writeOp.getSource(), options);
+ FailureOr<Value> resultBuffer =
+ getBuffer(rewriter, writeOp.getSource(), options);
+ if (failed(resultBuffer))
+ return failure();
rewriter.create<vector::TransferWriteOp>(
- writeOp.getLoc(), writeOp.getVector(), resultBuffer,
+ writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
writeOp.getIndices(), writeOp.getPermutationMapAttr(),
writeOp.getInBoundsAttr());
- replaceOpWithBufferizedValues(rewriter, op, resultBuffer);
+ replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
return success();
}
More information about the Mlir-commits
mailing list