[Mlir-commits] [mlir] fa087b4 - [mlir][scf][bufferize][NFC] Lookup buffer using helper function
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 12 02:15:41 PDT 2022
Author: Matthias Springer
Date: 2022-04-12T18:09:30+09:00
New Revision: fa087b43529cfac223432cc5545923fb9c6544af
URL: https://github.com/llvm/llvm-project/commit/fa087b43529cfac223432cc5545923fb9c6544af
DIFF: https://github.com/llvm/llvm-project/commit/fa087b43529cfac223432cc5545923fb9c6544af.diff
LOG: [mlir][scf][bufferize][NFC] Lookup buffer using helper function
Lookup iter_arg buffers using `lookupBuffer` instead of always creating a new `ToMemrefOp`. Also cast all yielded buffers (if necessary), regardless of whether they are an equivalent buffer or a new allocation.
Note: This should have been part of D123369.
Differential Revision: https://reviews.llvm.org/D123383
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 7c6bbcd414d71..33f1af28fdf1b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -314,6 +314,23 @@ struct ForOpInterface
auto bufferizableOp = cast<BufferizableOpInterface>(op);
Block *oldLoopBody = &forOp.getLoopBody().front();
+ // Helper function for casting MemRef buffers.
+ auto castBuffer = [&](Value buffer, Type type) {
+ assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
+ assert(buffer.getType().isa<BaseMemRefType>() &&
+ "expected BaseMemRefType");
+ // If the buffer already has the correct type, no cast is needed.
+ if (buffer.getType() == type)
+ return buffer;
+ // TODO: In case `type` has a layout map that is not the fully dynamic
+ // one, we may not be able to cast the buffer. In that case, the loop
+ // iter_arg's layout map must be changed (see uses of `castBuffer`).
+ assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
+ "scf.for op bufferization: cast incompatible");
+ return rewriter.create<memref::CastOp>(buffer.getLoc(), type, buffer)
+ .getResult();
+ };
+
// Indices of all iter_args that have tensor type. These are the ones that
// are bufferized.
DenseSet<int64_t> indices;
@@ -382,9 +399,10 @@ struct ForOpInterface
rewriter.setInsertionPoint(yieldOp);
SmallVector<Value> yieldValues =
convert(yieldOp.getResults(), [&](Value val, int64_t index) {
- ensureToMemrefOpIsValid(val, initArgs[index].getType());
- Value yieldedVal = rewriter.create<bufferization::ToMemrefOp>(
- val.getLoc(), initArgs[index].getType(), val);
+ Type initArgType = initArgs[index].getType();
+ ensureToMemrefOpIsValid(val, initArgType);
+ Value yieldedVal =
+ bufferization::lookupBuffer(rewriter, val, state.getOptions());
if (equivalentYields[index])
// Yielded value is equivalent to the corresponding iter_arg bbArg.
@@ -392,7 +410,7 @@ struct ForOpInterface
// else must be resolved with copies and is potentially inefficient.
// By default, such problematic IR would already have been rejected
// during `verifyAnalysis`, unless `allow-return-allocs`.
- return yieldedVal;
+ return castBuffer(yieldedVal, initArgType);
// It is not certain that the yielded value and the iter_arg bbArg
// have the same buffer. Allocate a new buffer and copy. The yielded
@@ -412,21 +430,9 @@ struct ForOpInterface
(void)copyStatus;
assert(succeeded(copyStatus) && "could not create memcpy");
- if (yieldedVal.getType() == yieldedAlloc->getType())
- return *yieldedAlloc;
-
- // The iter_arg memref type has a layout map. Cast the new buffer to
- // the same type.
- // TODO: In case the iter_arg has a layout map that is not the fully
- // dynamic one, we cannot cast the new buffer. In that case, the
- // iter_arg must be changed to the fully dynamic layout map. (And then
- // the new buffer can be casted.)
- assert(memref::CastOp::areCastCompatible(yieldedAlloc->getType(),
- yieldedVal.getType()) &&
- "scf.for op bufferization: cast incompatible");
- Value casted = rewriter.create<memref::CastOp>(
- val.getLoc(), yieldedVal.getType(), *yieldedAlloc);
- return casted;
+ // The iter_arg memref type may have a layout map. Cast the new buffer
+ // to the same type if needed.
+ return castBuffer(*yieldedAlloc, initArgType);
});
yieldOp.getResultsMutable().assign(yieldValues);
More information about the Mlir-commits
mailing list