[Mlir-commits] [mlir] 19efb84 - [mlir][shape][bufferize][NFC] Bufferize block terminators separately
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 27 04:08:24 PDT 2022
Author: Matthias Springer
Date: 2022-06-27T13:08:13+02:00
New Revision: 19efb84c7a03a951cdc54662a21e0fd71a33601c
URL: https://github.com/llvm/llvm-project/commit/19efb84c7a03a951cdc54662a21e0fd71a33601c
DIFF: https://github.com/llvm/llvm-project/commit/19efb84c7a03a951cdc54662a21e0fd71a33601c.diff
LOG: [mlir][shape][bufferize][NFC] Bufferize block terminators separately
This allows for better type inference during bufferization and is in preparation of supporting memory spaces.
Differential Revision: https://reviews.llvm.org/D128579
Added:
Modified:
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 02a073de45699..9a4f8e187a8b7 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -61,42 +61,17 @@ struct AssumingOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto assumingOp = cast<shape::AssumingOp>(op);
-
- // Compute new result types.
- SmallVector<Type> newResultTypes;
- for (Type type : assumingOp->getResultTypes()) {
- if (auto tensorType = type.dyn_cast<TensorType>()) {
- // TODO: Infer the result type instead of computing it.
- newResultTypes.push_back(getMemRefType(tensorType, options));
- } else {
- newResultTypes.push_back(type);
- }
- }
+ assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+ "only 1 block supported");
+ auto yieldOp = cast<shape::AssumingYieldOp>(
+ assumingOp.getDoRegion().front().getTerminator());
// Create new op and move over region.
+ TypeRange newResultTypes(yieldOp.operands());
auto newOp = rewriter.create<shape::AssumingOp>(
op->getLoc(), newResultTypes, assumingOp.getWitness());
newOp.getDoRegion().takeBody(assumingOp.getRegion());
- // Update terminator.
- assert(newOp.getDoRegion().getBlocks().size() == 1 &&
- "only 1 block supported");
- Block *newBlock = &newOp.getDoRegion().front();
- auto yieldOp = cast<shape::AssumingYieldOp>(newBlock->getTerminator());
- rewriter.setInsertionPoint(yieldOp);
- SmallVector<Value> newYieldValues;
- for (const auto &it : llvm::enumerate(yieldOp.operands())) {
- Value val = it.value();
- if (val.getType().isa<TensorType>()) {
- newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
- yieldOp.getLoc(), newResultTypes[it.index()], val));
- } else {
- newYieldValues.push_back(val);
- }
- }
- rewriter.replaceOpWithNewOp<shape::AssumingYieldOp>(yieldOp,
- newYieldValues);
-
// Update all uses of the old op.
rewriter.setInsertionPointAfter(newOp);
SmallVector<Value> newResults;
@@ -153,7 +128,14 @@ struct AssumingYieldOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
- // Op is bufferized as part of AssumingOp.
+ auto yieldOp = cast<shape::AssumingYieldOp>(op);
+ SmallVector<Value> newResults;
+ for (Value value : yieldOp.operands())
+ newResults.push_back(value.getType().isa<TensorType>()
+ ? getBuffer(rewriter, value, options)
+ : value);
+ replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
+ newResults);
return success();
}
};
More information about the Mlir-commits
mailing list