[Mlir-commits] [mlir] 45b995c - [mlir][bufferize][NFC] Change signature of allocateTensorForShapedValue
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 27 07:03:37 PDT 2022
Author: Matthias Springer
Date: 2022-06-27T16:00:06+02:00
New Revision: 45b995cda4611d6c0f1b4c2e8b7903303ce5d49c
URL: https://github.com/llvm/llvm-project/commit/45b995cda4611d6c0f1b4c2e8b7903303ce5d49c
DIFF: https://github.com/llvm/llvm-project/commit/45b995cda4611d6c0f1b4c2e8b7903303ce5d49c.diff
LOG: [mlir][bufferize][NFC] Change signature of allocateTensorForShapedValue
Add a failure return value and bufferization options argument. This is to keep a subsequent change smaller.
Differential Revision: https://reviews.llvm.org/D128278
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index b609a7fd78fb..ff8db00f7644 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -472,9 +472,10 @@ class AnalysisState {
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
-Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
- Value shapedValue, bool escape,
- bool copy = true);
+FailureOr<Value>
+allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
+ bool escape, const BufferizationOptions &options,
+ bool copy = true);
/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 7e5ccd031abe..6073c931e53a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -46,9 +46,9 @@ constexpr const ::llvm::StringLiteral
/// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
/// shaped value is copied. Otherwise, a tensor with undefined contents is
/// allocated.
-Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
- Value shapedValue,
- bool escape, bool copy) {
+FailureOr<Value> bufferization::allocateTensorForShapedValue(
+ OpBuilder &b, Location loc, Value shapedValue, bool escape,
+ const BufferizationOptions &options, bool copy) {
Value tensor;
if (shapedValue.getType().isa<RankedTensorType>()) {
tensor = shapedValue;
@@ -88,7 +88,7 @@ Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
copy ? tensor : Value());
allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
b.getBoolArrayAttr({escape}));
- return allocTensorOp;
+ return allocTensorOp.getResult();
}
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
@@ -147,26 +147,30 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// Insert copies of OpOperands.
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
- Value copy = allocateTensorForShapedValue(
+ FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), opOperand->get(),
- escapingOpOperandCopies.contains(opOperand),
+ escapingOpOperandCopies.contains(opOperand), state.getOptions(),
copiedOpOperands.contains(opOperand));
- rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
+ if (failed(copy))
+ return failure();
+ rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
}
// Insert copies of OpResults.
rewriter.setInsertionPointAfter(op);
for (OpResult opResult : outOfPlaceOpResults) {
- Value copy =
- allocateTensorForShapedValue(rewriter, op->getLoc(), opResult,
- escapingOpResultCopies.contains(opResult),
- copiedOpResults.count(opResult));
+ FailureOr<Value> copy = allocateTensorForShapedValue(
+ rewriter, op->getLoc(), opResult,
+ escapingOpResultCopies.contains(opResult), state.getOptions(),
+ copiedOpResults.count(opResult));
+ if (failed(copy))
+ return failure();
SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
opResult.getUses(), [](OpOperand &use) { return &use; }));
for (OpOperand *use : uses) {
// Do not update the alloc_tensor op that we just created.
- if (use->getOwner() != copy.getDefiningOp())
- rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
+ if (use->getOwner() != copy->getDefiningOp())
+ rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
}
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 36d0f0cefbae..0bf4fd3405dd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -458,9 +458,12 @@ struct ForOpInterface
yieldValues.push_back(value);
continue;
}
- Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
- value, /*escape=*/true);
- yieldValues.push_back(alloc);
+ FailureOr<Value> alloc =
+ allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
+ /*escape=*/true, state.getOptions());
+ if (failed(alloc))
+ return failure();
+ yieldValues.push_back(*alloc);
}
rewriter.updateRootInPlace(
@@ -669,9 +672,12 @@ struct WhileOpInterface
beforeYieldValues.push_back(value);
continue;
}
- Value alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(),
- value, /*escape=*/true);
- beforeYieldValues.push_back(alloc);
+ FailureOr<Value> alloc =
+ allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
+ /*escape=*/true, state.getOptions());
+ if (failed(alloc))
+ return failure();
+ beforeYieldValues.push_back(*alloc);
}
rewriter.updateRootInPlace(conditionOp, [&]() {
conditionOp.getArgsMutable().assign(beforeYieldValues);
@@ -687,9 +693,12 @@ struct WhileOpInterface
afterYieldValues.push_back(value);
continue;
}
- Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
- value, /*escape=*/true);
- afterYieldValues.push_back(alloc);
+ FailureOr<Value> alloc =
+ allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
+ /*escape=*/true, state.getOptions());
+ if (failed(alloc))
+ return failure();
+ afterYieldValues.push_back(*alloc);
}
rewriter.updateRootInPlace(yieldOp, [&]() {
yieldOp.getResultsMutable().assign(afterYieldValues);
@@ -972,13 +981,15 @@ struct ForeachThreadOpInterface
// Insert tensor allocation.
bool isYielded = state.isTensorYielded(opResult);
- Value alloc = allocateTensorForShapedValue(rewriter, op->getLoc(),
- destOperands.front()->get(),
- /*escape=*/isYielded);
+ FailureOr<Value> alloc = allocateTensorForShapedValue(
+ rewriter, op->getLoc(), destOperands.front()->get(),
+ /*escape=*/isYielded, state.getOptions());
+ if (failed(alloc))
+ return failure();
// Update terminator operand.
rewriter.updateRootInPlace(destOperands.front()->getOwner(),
- [&]() { destOperands.front()->set(alloc); });
+ [&]() { destOperands.front()->set(*alloc); });
}
return success();
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index e7e31dcd42f5..6f2484352210 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -154,15 +154,17 @@ struct CollapseShapeOpInterface
if (!canBeCollapsed) {
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
- Value tensorAlloc = allocateTensorForShapedValue(
+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), collapseShapeOp.getSrc(),
- analysisState.isTensorYielded(collapseShapeOp.getResult()));
+ analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
+ if (failed(tensorAlloc))
+ return failure();
auto memrefType =
MemRefType::get(collapseShapeOp.getSrcType().getShape(),
collapseShapeOp.getSrcType().getElementType(),
AffineMap(), bufferType.getMemorySpaceAsInt());
buffer = rewriter.create<bufferization::ToMemrefOp>(
- op->getLoc(), memrefType, tensorAlloc);
+ op->getLoc(), memrefType, *tensorAlloc);
}
// Result type is inferred by the builder.
@@ -383,14 +385,16 @@ struct FromElementsOpInterface
auto shape = tensorType.getShape();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
- Value tensorAlloc = allocateTensorForShapedValue(
+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, fromElementsOp.getResult(),
- analysisState.isTensorYielded(fromElementsOp.getResult()),
+ analysisState.isTensorYielded(fromElementsOp.getResult()), options,
/*copy=*/false);
+ if (failed(tensorAlloc))
+ return failure();
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
- op->getLoc(), memrefType, tensorAlloc);
+ op->getLoc(), memrefType, *tensorAlloc);
// Case: tensor<0xelem_type>.
if (fromElementsOp.getElements().empty()) {
@@ -436,14 +440,16 @@ struct GenerateOpInterface
Location loc = op->getLoc();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
- Value tensorAlloc = allocateTensorForShapedValue(
+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, generateOp.getResult(),
- analysisState.isTensorYielded(generateOp.getResult()),
+ analysisState.isTensorYielded(generateOp.getResult()), options,
/*copy=*/false);
+ if (failed(tensorAlloc))
+ return failure();
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
- op->getLoc(), memrefType, tensorAlloc);
+ op->getLoc(), memrefType, *tensorAlloc);
// Collect loop bounds.
int64_t rank = memrefType.getRank();
More information about the Mlir-commits
mailing list