[Mlir-commits] [mlir] 3ff93f8 - [mlir][SCF][bufferize][NFC] Bufferize scf.for terminator separately
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 27 04:38:36 PDT 2022
Author: Matthias Springer
Date: 2022-06-27T13:26:32+02:00
New Revision: 3ff93f838ebd04a910d76b7f3242140d60a2e56d
URL: https://github.com/llvm/llvm-project/commit/3ff93f838ebd04a910d76b7f3242140d60a2e56d
DIFF: https://github.com/llvm/llvm-project/commit/3ff93f838ebd04a910d76b7f3242140d60a2e56d.diff
LOG: [mlir][SCF][bufferize][NFC] Bufferize scf.for terminator separately
This allows for better type inference during bufferization and is in preparation of supporting memory spaces.
Differential Revision: https://reviews.llvm.org/D128422
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index b9236b3573dd..d55e2784c199 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -445,6 +445,13 @@ struct ForOpInterface
return success();
}
+ BaseMemRefType getBufferType(Operation *op, BlockArgument bbArg,
+ const BufferizationOptions &options) const {
+ auto forOp = cast<scf::ForOp>(op);
+ return bufferization::getBufferType(
+ forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
+ }
+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto forOp = cast<scf::ForOp>(op);
@@ -474,20 +481,9 @@ struct ForOpInterface
getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
- // Erase terminator if present.
- if (iterArgs.size() == 1)
- rewriter.eraseOp(loopBody->getTerminator());
-
// Move loop body to new loop.
rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
- // Update scf.yield of new loop.
- auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
- rewriter.setInsertionPoint(yieldOp);
- SmallVector<Value> yieldValues = getYieldedValues(
- rewriter, yieldOp.getResults(), initArgsTypes, indices, options);
- yieldOp.getResultsMutable().assign(yieldValues);
-
// Replace loop results.
replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
@@ -844,9 +840,9 @@ struct YieldOpInterface
yieldOp->getParentOp()))
return yieldOp->emitError("unsupported scf::YieldOp parent");
- // TODO: Bufferize scf.yield inside scf.while/scf.for here.
- // (Currently bufferized together with scf.while/scf.for.)
- if (isa<scf::ForOp, scf::WhileOp>(yieldOp->getParentOp()))
+ // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
+ // together with scf.while.)
+ if (isa<scf::WhileOp>(yieldOp->getParentOp()))
return success();
SmallVector<Value> newResults;
@@ -854,6 +850,13 @@ struct YieldOpInterface
Value value = it.value();
if (value.getType().isa<TensorType>()) {
Value buffer = getBuffer(rewriter, value, options);
+ if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
+ BaseMemRefType resultType =
+ cast<BufferizableOpInterface>(forOp.getOperation())
+ .getBufferType(forOp.getRegionIterArgs()[it.index()],
+ options);
+ buffer = castBuffer(rewriter, buffer, resultType);
+ }
newResults.push_back(buffer);
} else {
newResults.push_back(value);
More information about the Mlir-commits
mailing list