[Mlir-commits] [mlir] 417e1c7 - [mlir][scf][bufferize][NFC] Split ForOp bufferization into smaller functions
Matthias Springer
llvmlistbot at llvm.org
Thu May 5 00:56:25 PDT 2022
Author: Matthias Springer
Date: 2022-05-05T16:55:44+09:00
New Revision: 417e1c7d520c5fb1868794341e8926f8037ef2a0
URL: https://github.com/llvm/llvm-project/commit/417e1c7d520c5fb1868794341e8926f8037ef2a0
DIFF: https://github.com/llvm/llvm-project/commit/417e1c7d520c5fb1868794341e8926f8037ef2a0.diff
LOG: [mlir][scf][bufferize][NFC] Split ForOp bufferization into smaller functions
This is in preparation of WhileOp bufferization, which reuses these functions.
Differential Revision: https://reviews.llvm.org/D124933
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index cc4c7288654bc..7dade29cc32d0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -259,6 +259,150 @@ struct IfOpInterface
}
};
+/// Helper function for loop bufferization. Return the indices of all values
+/// that have a tensor type.
+static DenseSet<int64_t> getTensorIndices(ValueRange values) {
+ DenseSet<int64_t> result;
+ for (const auto &it : llvm::enumerate(values))
+ if (it.value().getType().isa<TensorType>())
+ result.insert(it.index());
+ return result;
+}
+
+/// Helper function for loop bufferization. Return the indices of all
+/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
+DenseSet<int64_t> getEquivalentBuffers(ValueRange bbArgs,
+ ValueRange yieldedValues,
+ const AnalysisState &state) {
+ DenseSet<int64_t> result;
+ int64_t counter = 0;
+ for (const auto &it : llvm::zip(bbArgs, yieldedValues)) {
+ if (!std::get<0>(it).getType().isa<TensorType>())
+ continue;
+ if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it)))
+ result.insert(counter);
+ counter++;
+ }
+ return result;
+}
+
+/// Helper function for loop bufferization. Cast the given buffer to the given
+/// memref type.
+static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
+ assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
+ assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
+ // If the buffer already has the correct type, no cast is needed.
+ if (buffer.getType() == type)
+ return buffer;
+ // TODO: In case `type` has a layout map that is not the fully dynamic
+ // one, we may not be able to cast the buffer. In that case, the loop
+ // iter_arg's layout map must be changed (see uses of `castBuffer`).
+ assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
+ "scf.while op bufferization: cast incompatible");
+ return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
+}
+
+/// Helper function for loop bufferization. Return the bufferized values of the
+/// given OpOperands. If an operand is not a tensor, return the original value.
+static SmallVector<Value> getBuffers(RewriterBase &rewriter,
+ MutableArrayRef<OpOperand> operands,
+ BufferizationState &state) {
+ SmallVector<Value> result;
+ for (OpOperand &opOperand : operands) {
+ if (opOperand.get().getType().isa<TensorType>()) {
+ FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
+ if (failed(resultBuffer))
+ return {};
+ result.push_back(*resultBuffer);
+ } else {
+ result.push_back(opOperand.get());
+ }
+ }
+ return result;
+}
+
+/// Helper function for loop bufferization. Compute the buffer that should be
+/// yielded from a loop block (loop body or loop condition). If the given tensor
+/// is equivalent to the corresponding block argument (as indicated by
+/// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer
+/// copy must be yielded.
+///
+/// According to the `BufferizableOpInterface` implementation of scf loops, a
+/// a bufferized OpResult may alias only with the corresponding bufferized
+/// init_arg and with no other buffers. I.e., the i-th OpResult may alias with
+/// the i-th init_arg; but not with any other OpOperand. If a corresponding
+/// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by
+/// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we
+/// cannot be sure and must yield a new buffer copy. (New buffer copies do not
+/// alias with any buffer.)
+static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
+ BaseMemRefType type, bool isEquivalent,
+ BufferizationState &state) {
+ assert(tensor.getType().isa<TensorType>() && "expected tensor");
+ ensureToMemrefOpIsValid(tensor, type);
+ Value yieldedVal =
+ bufferization::lookupBuffer(rewriter, tensor, state.getOptions());
+
+ if (isEquivalent)
+ // Yielded value is equivalent to the corresponding iter_arg bbArg.
+ // Yield the value directly. Most IR should be like that. Everything
+ // else must be resolved with copies and is potentially inefficient.
+ // By default, such problematic IR would already have been rejected
+ // during `verifyAnalysis`, unless `allow-return-allocs`.
+ return castBuffer(rewriter, yieldedVal, type);
+
+ // It is not certain that the yielded value and the iter_arg bbArg
+ // have the same buffer. Allocate a new buffer and copy. The yielded
+ // buffer will get deallocated by `deallocateBuffers`.
+
+ // TODO: There are cases in which it is not neccessary to return a new
+ // buffer allocation. E.g., when equivalent values are yielded in a
+ //
diff erent order. This could be resolved with copies.
+ Optional<Value> yieldedAlloc = state.createAlloc(
+ rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false);
+ // TODO: We should rollback, but for now just assume that this always
+ // succeeds.
+ assert(yieldedAlloc.hasValue() && "could not create alloc");
+ LogicalResult copyStatus = bufferization::createMemCpy(
+ rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions());
+ (void)copyStatus;
+ assert(succeeded(copyStatus) && "could not create memcpy");
+
+ // The iter_arg memref type may have a layout map. Cast the new buffer
+ // to the same type if needed.
+ return castBuffer(rewriter, *yieldedAlloc, type);
+}
+
+/// Helper function for loop bufferization. Given a range of values, apply
+/// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
+/// value in the result vector.
+static SmallVector<Value>
+convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
+ llvm::function_ref<Value(Value, int64_t)> func) {
+ SmallVector<Value> result;
+ for (const auto &it : llvm::enumerate(values)) {
+ size_t idx = it.index();
+ Value val = it.value();
+ result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
+ }
+ return result;
+}
+
+/// Helper function for loop bufferization. Given a list of pre-bufferization
+/// yielded values, compute the list of bufferized yielded values.
+SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
+ TypeRange bufferizedTypes,
+ const DenseSet<int64_t> &tensorIndices,
+ const DenseSet<int64_t> &equivalentTensors,
+ BufferizationState &state) {
+ return convertTensorValues(
+ values, tensorIndices, [&](Value val, int64_t index) {
+ return getYieldedBuffer(rewriter, val,
+ bufferizedTypes[index].cast<BaseMemRefType>(),
+ equivalentTensors.contains(index), state);
+ });
+}
+
/// Bufferization of scf.for. Replace with a new scf.for that operates on
/// memrefs.
struct ForOpInterface
@@ -312,78 +456,38 @@ struct ForOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto forOp = cast<scf::ForOp>(op);
- auto bufferizableOp = cast<BufferizableOpInterface>(op);
+ auto oldYieldOp =
+ cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
Block *oldLoopBody = &forOp.getLoopBody().front();
- // Helper function for casting MemRef buffers.
- auto castBuffer = [&](Value buffer, Type type) {
- assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
- assert(buffer.getType().isa<BaseMemRefType>() &&
- "expected BaseMemRefType");
- // If the buffer already has the correct type, no cast is needed.
- if (buffer.getType() == type)
- return buffer;
- // TODO: In case `type` has a layout map that is not the fully dynamic
- // one, we may not be able to cast the buffer. In that case, the loop
- // iter_arg's layout map must be changed (see uses of `castBuffer`).
- assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
- "scf.for op bufferization: cast incompatible");
- return rewriter.create<memref::CastOp>(buffer.getLoc(), type, buffer)
- .getResult();
- };
-
// Indices of all iter_args that have tensor type. These are the ones that
// are bufferized.
- DenseSet<int64_t> indices;
+ DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
// For every yielded value, is the value equivalent to its corresponding
// bbArg?
- SmallVector<bool> equivalentYields;
- for (const auto &it : llvm::enumerate(forOp.getInitArgs())) {
- if (it.value().getType().isa<TensorType>()) {
- indices.insert(it.index());
- BufferRelation relation = bufferizableOp.bufferRelation(
- forOp->getResult(it.index()), state.getAnalysisState());
- equivalentYields.push_back(relation == BufferRelation::Equivalent);
- } else {
- equivalentYields.push_back(false);
- }
- }
+ DenseSet<int64_t> equivalentYields =
+ getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(),
+ state.getAnalysisState());
- // Given a range of values, apply `func` to those marked in `indices`.
- // Otherwise, store the unmodified value in the result vector.
- auto convert = [&](ValueRange values,
- llvm::function_ref<Value(Value, int64_t)> func) {
- SmallVector<Value> result;
- for (const auto &it : llvm::enumerate(values)) {
- size_t idx = it.index();
- Value val = it.value();
- result.push_back(indices.contains(idx) ? func(val, idx) : val);
- }
- return result;
- };
+ // The new memref init_args of the loop.
+ SmallVector<Value> initArgs =
+ getBuffers(rewriter, forOp.getIterOpOperands(), state);
+ if (initArgs.size() != indices.size())
+ return failure();
// Construct a new scf.for op with memref instead of tensor values.
- SmallVector<Value> initArgs;
- for (OpOperand &opOperand : forOp.getIterOpOperands()) {
- if (opOperand.get().getType().isa<TensorType>()) {
- FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
- if (failed(resultBuffer))
- return failure();
- initArgs.push_back(*resultBuffer);
- } else {
- initArgs.push_back(opOperand.get());
- }
- }
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), initArgs);
+ ValueRange initArgsRange(initArgs);
+ TypeRange initArgsTypes(initArgsRange);
Block *loopBody = &newForOp.getLoopBody().front();
// Set up new iter_args. The loop body uses tensors, so wrap the (memref)
// iter_args of the new loop in ToTensorOps.
rewriter.setInsertionPointToStart(loopBody);
- SmallVector<Value> iterArgs =
- convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) {
+ SmallVector<Value> iterArgs = convertTensorValues(
+ newForOp.getRegionIterArgs(), indices, [&](Value val, int64_t index) {
return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
});
iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
@@ -399,42 +503,8 @@ struct ForOpInterface
auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
rewriter.setInsertionPoint(yieldOp);
SmallVector<Value> yieldValues =
- convert(yieldOp.getResults(), [&](Value val, int64_t index) {
- Type initArgType = initArgs[index].getType();
- ensureToMemrefOpIsValid(val, initArgType);
- Value yieldedVal =
- bufferization::lookupBuffer(rewriter, val, state.getOptions());
-
- if (equivalentYields[index])
- // Yielded value is equivalent to the corresponding iter_arg bbArg.
- // Yield the value directly. Most IR should be like that. Everything
- // else must be resolved with copies and is potentially inefficient.
- // By default, such problematic IR would already have been rejected
- // during `verifyAnalysis`, unless `allow-return-allocs`.
- return castBuffer(yieldedVal, initArgType);
-
- // It is not certain that the yielded value and the iter_arg bbArg
- // have the same buffer. Allocate a new buffer and copy. The yielded
- // buffer will get deallocated by `deallocateBuffers`.
-
- // TODO: There are cases in which it is not neccessary to return a new
- // buffer allocation. E.g., when equivalent values are yielded in a
- //
diff erent order. This could be resolved with copies.
- Optional<Value> yieldedAlloc = state.createAlloc(
- rewriter, val.getLoc(), yieldedVal, /*deallocMemref=*/false);
- // TODO: We should rollback, but for now just assume that this always
- // succeeds.
- assert(yieldedAlloc.hasValue() && "could not create alloc");
- LogicalResult copyStatus =
- bufferization::createMemCpy(rewriter, val.getLoc(), yieldedVal,
- *yieldedAlloc, state.getOptions());
- (void)copyStatus;
- assert(succeeded(copyStatus) && "could not create memcpy");
-
- // The iter_arg memref type may have a layout map. Cast the new buffer
- // to the same type if needed.
- return castBuffer(*yieldedAlloc, initArgType);
- });
+ getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices,
+ equivalentYields, state);
yieldOp.getResultsMutable().assign(yieldValues);
// Replace loop results.
More information about the Mlir-commits
mailing list