[Mlir-commits] [mlir] 86974e3 - [mlir][SCF][bufferize] Support different iter_arg/init_arg types (scf.while)
Matthias Springer
llvmlistbot at llvm.org
Tue Aug 30 07:58:38 PDT 2022
Author: Matthias Springer
Date: 2022-08-30T16:58:21+02:00
New Revision: 86974e32a4f7c80fd6503c89b065cccf3a91300a
URL: https://github.com/llvm/llvm-project/commit/86974e32a4f7c80fd6503c89b065cccf3a91300a
DIFF: https://github.com/llvm/llvm-project/commit/86974e32a4f7c80fd6503c89b065cccf3a91300a.diff
LOG: [mlir][SCF][bufferize] Support different iter_arg/init_arg types (scf.while)
This change implements the same functionality as D132860, but for scf.while.
Differential Revision: https://reviews.llvm.org/D132927
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 7e458a1f3762..a13badaba7cd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -27,16 +27,78 @@ namespace mlir {
namespace scf {
namespace {
-// bufferization.to_memref is not allowed to change the rank.
-static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
-#ifndef NDEBUG
- auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
- assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
- rankedTensorType.getRank())) &&
- "to_memref would be invalid: mismatching ranks");
-#endif
+/// Helper function for loop bufferization. Cast the given buffer to the given
+/// memref type.
+static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
+ assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
+ assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
+ // If the buffer already has the correct type, no cast is needed.
+ if (buffer.getType() == type)
+ return buffer;
+ // TODO: In case `type` has a layout map that is not the fully dynamic
+ // one, we may not be able to cast the buffer. In that case, the loop
+ // iter_arg's layout map must be changed (see uses of `castBuffer`).
+ assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
+ "scf.while op bufferization: cast incompatible");
+ return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
}
+/// Bufferization of scf.condition.
+struct ConditionOpInterface
+ : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
+ scf::ConditionOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return false;
+ }
+
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return {};
+ }
+
+ bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ // Condition operands always bufferize inplace. Otherwise, an alloc + copy
+ // may be generated inside the block. We should not return/yield allocations
+ // when possible.
+ return true;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ auto conditionOp = cast<scf::ConditionOp>(op);
+ auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
+
+ SmallVector<Value> newArgs;
+ for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
+ Value value = it.value();
+ if (value.getType().isa<TensorType>()) {
+ FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
+ if (failed(maybeBuffer))
+ return failure();
+ FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+ whileOp.getAfterArguments()[it.index()], options);
+ if (failed(resultType))
+ return failure();
+ Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
+ newArgs.push_back(buffer);
+ } else {
+ newArgs.push_back(value);
+ }
+ }
+
+ replaceOpWithNewBufferizedOp<scf::ConditionOp>(
+ rewriter, op, conditionOp.getCondition(), newArgs);
+ return success();
+ }
+};
+
/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
/// fully implemented at the moment.
struct ExecuteRegionOpInterface
@@ -283,22 +345,6 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
return result;
}
-/// Helper function for loop bufferization. Cast the given buffer to the given
-/// memref type.
-static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
- assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
- assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
- // If the buffer already has the correct type, no cast is needed.
- if (buffer.getType() == type)
- return buffer;
- // TODO: In case `type` has a layout map that is not the fully dynamic
- // one, we may not be able to cast the buffer. In that case, the loop
- // iter_arg's layout map must be changed (see uses of `castBuffer`).
- assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
- "scf.while op bufferization: cast incompatible");
- return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
-}
-
/// Helper function for loop bufferization. Return the bufferized values of the
/// given OpOperands. If an operand is not a tensor, return the original value.
static FailureOr<SmallVector<Value>>
@@ -319,60 +365,10 @@ getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
return result;
}
-/// Helper function for loop bufferization. Compute the buffer that should be
-/// yielded from a loop block (loop body or loop condition).
-static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
- BaseMemRefType type,
- const BufferizationOptions &options) {
- assert(tensor.getType().isa<TensorType>() && "expected tensor");
- ensureToMemrefOpIsValid(tensor, type);
- FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
- if (failed(yieldedVal))
- return failure();
- return castBuffer(rewriter, *yieldedVal, type);
-}
-
-/// Helper function for loop bufferization. Given a range of values, apply
-/// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
-/// value in the result vector.
-static FailureOr<SmallVector<Value>>
-convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
- llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
- SmallVector<Value> result;
- for (const auto &it : llvm::enumerate(values)) {
- size_t idx = it.index();
- Value val = it.value();
- if (tensorIndices.contains(idx)) {
- FailureOr<Value> maybeVal = func(val, idx);
- if (failed(maybeVal))
- return failure();
- result.push_back(*maybeVal);
- } else {
- result.push_back(val);
- }
- }
- return result;
-}
-
-/// Helper function for loop bufferization. Given a list of pre-bufferization
-/// yielded values, compute the list of bufferized yielded values.
-FailureOr<SmallVector<Value>>
-getYieldedValues(RewriterBase &rewriter, ValueRange values,
- TypeRange bufferizedTypes,
- const DenseSet<int64_t> &tensorIndices,
- const BufferizationOptions &options) {
- return convertTensorValues(
- values, tensorIndices, [&](Value val, int64_t index) {
- return getYieldedBuffer(rewriter, val,
- bufferizedTypes[index].cast<BaseMemRefType>(),
- options);
- });
-}
-
/// Helper function for loop bufferization. Given a list of bbArgs of the new
/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
/// ToTensorOps, so that the block body can be moved over to the new op.
-SmallVector<Value>
+static SmallVector<Value>
getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
const DenseSet<int64_t> &tensorIndices) {
SmallVector<Value> result;
@@ -390,6 +386,74 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
return result;
}
+/// Compute the bufferized type of a loop iter_arg. This type must be equal to
+/// the bufferized type of the corresponding init_arg and the bufferized type
+/// of the corresponding yielded value.
+///
+/// This function uses bufferization::getBufferType to compute the bufferized
+/// type of the init_arg and of the yielded value. (The computation of the
+/// usually requires computing the bufferized type of the corresponding
+/// iter_arg; the implementation of getBufferType traces back the use-def chain
+/// of the given value and computes a buffer type along the way.) If both buffer
+/// types are equal, no casts are needed the computed buffer type can be used
+/// directly. Otherwise, the buffer types can only
diff er in their layout map
+/// and a cast must be inserted.
+static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+ BlockArgument iterArg, Value initArg, Value yieldedValue,
+ const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+ // Determine the buffer type of the init_arg.
+ auto initArgBufferType =
+ bufferization::getBufferType(initArg, options, fixedTypes);
+ if (failed(initArgBufferType))
+ return failure();
+
+ // Fix the iter_arg type, so that recursive lookups return the buffer type
+ // of the init_arg. This is to avoid infinite loops when calculating the
+ // buffer type of the yielded value.
+ //
+ // Note: For more precise layout map computation, a fixpoint iteration could
+ // be done (i.e., re-computing the yielded buffer type until the bufferized
+ // iter_arg type no longer changes). This current implementation immediately
+ // switches to a fully dynamic layout map when a mismatch between bufferized
+ // init_arg type and bufferized yield value type is detected.
+ DenseMap<Value, BaseMemRefType> newFixedTypes(fixedTypes);
+ newFixedTypes[iterArg] = *initArgBufferType;
+
+ // Compute the buffer type of the yielded value.
+ BaseMemRefType yieldedValueBufferType;
+ if (yieldedValue.getType().isa<BaseMemRefType>()) {
+ // scf.yield was already bufferized.
+ yieldedValueBufferType = yieldedValue.getType().cast<BaseMemRefType>();
+ } else {
+ auto maybeBufferType =
+ bufferization::getBufferType(yieldedValue, options, newFixedTypes);
+ if (failed(maybeBufferType))
+ return failure();
+ yieldedValueBufferType = *maybeBufferType;
+ }
+
+ // If yielded type and init_arg type are the same, use that type directly.
+ if (*initArgBufferType == yieldedValueBufferType)
+ return yieldedValueBufferType;
+
+ // If there is a mismatch between the yielded buffer type and the iter_arg
+ // buffer type, the buffer type must be promoted to a fully dynamic layout
+ // map.
+ auto yieldedRanked = yieldedValueBufferType.cast<MemRefType>();
+#ifndef NDEBUG
+ auto iterRanked = initArgBufferType->cast<MemRefType>();
+ assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
+ "expected same shape");
+ assert(yieldedRanked.getMemorySpaceAsInt() ==
+ iterRanked.getMemorySpaceAsInt() &&
+ "expected same memory space");
+#endif // NDEBUG
+ return getMemRefTypeWithFullyDynamicLayout(
+ iterArg.getType().cast<RankedTensorType>(),
+ yieldedRanked.getMemorySpaceAsInt());
+}
+
/// Bufferization of scf.for. Replace with a new scf.for that operates on
/// memrefs.
struct ForOpInterface
@@ -507,60 +571,14 @@ struct ForOpInterface
resultNum = value.cast<OpResult>().getResultNumber();
}
- // Determine the buffer type of the init_arg.
- Value initArg = forOp.getInitArgs()[resultNum];
- auto initArgBufferType =
- bufferization::getBufferType(initArg, options, fixedTypes);
- if (failed(initArgBufferType))
- return failure();
-
- // Fix the iter_arg type, so that recursive lookups return the buffer type
- // of the init_arg. This is to avoid infinite loops when calculating the
- // buffer type of the yielded value.
- //
- // Note: For more precise layout map computation, a fixpoint iteration could
- // be done (i.e., re-computing the yielded buffer type until the bufferized
- // iter_arg type no longer changes). This current implementation immediately
- // switches to a fully dynamic layout map when a mismatch between bufferized
- // init_arg type and bufferized yield value type is detected.
- DenseMap<Value, BaseMemRefType> newFixedTypes(fixedTypes);
- newFixedTypes[forOp.getRegionIterArgs()[resultNum]] = *initArgBufferType;
-
- // Compute the buffer type of the yielded value.
+ // Compute the bufferized type.
auto yieldOp =
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
Value yieldedValue = yieldOp.getOperand(resultNum);
- BaseMemRefType yieldedValueBufferType;
- if (yieldedValue.getType().isa<BaseMemRefType>()) {
- // scf.yield was already bufferized.
- yieldedValueBufferType = yieldedValue.getType().cast<BaseMemRefType>();
- } else {
- auto maybeBufferType =
- bufferization::getBufferType(yieldedValue, options, newFixedTypes);
- if (failed(maybeBufferType))
- return failure();
- yieldedValueBufferType = *maybeBufferType;
- }
-
- // If yielded type and init_arg type are the same, use that type directly.
- if (*initArgBufferType == yieldedValueBufferType)
- return yieldedValueBufferType;
-
- // If there is a mismatch between the yielded buffer type and the iter_arg
- // buffer type, the buffer type must be promoted to a fully dynamic layout
- // map.
- auto yieldedRanked = yieldedValueBufferType.cast<MemRefType>();
-#ifndef NDEBUG
- auto iterRanked = initArgBufferType->cast<MemRefType>();
- assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
- "expected same shape");
- assert(yieldedRanked.getMemorySpaceAsInt() ==
- iterRanked.getMemorySpaceAsInt() &&
- "expected same memory space");
-#endif // NDEBUG
- return getMemRefTypeWithFullyDynamicLayout(
- forOp.getRegionIterArgs()[resultNum].getType().cast<RankedTensorType>(),
- yieldedRanked.getMemorySpaceAsInt());
+ BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
+ Value initArg = forOp.getInitArgs()[resultNum];
+ return computeLoopRegionIterArgBufferType(iterArg, initArg, yieldedValue,
+ options, fixedTypes);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -800,8 +818,6 @@ struct WhileOpInterface
return success();
}
- // TODO: Implement getBufferType interface method and infer buffer types.
-
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto whileOp = cast<scf::WhileOp>(op);
@@ -826,6 +842,17 @@ struct WhileOpInterface
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
+ // Cast init_args if necessary.
+ SmallVector<Value> castedInitArgs;
+ for (const auto &it : llvm::enumerate(initArgs)) {
+ Value initArg = it.value();
+ auto targetType = bufferization::getBufferType(
+ whileOp.getBeforeArguments()[it.index()], options);
+ if (failed(targetType))
+ return failure();
+ castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
+ }
+
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
@@ -834,13 +861,14 @@ struct WhileOpInterface
}));
// Construct a new scf.while op with memref instead of tensor values.
- ValueRange argsRangeBefore(initArgs);
+ ValueRange argsRangeBefore(castedInitArgs);
TypeRange argsTypesBefore(argsRangeBefore);
- auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
- argsTypesAfter, initArgs);
+ auto newWhileOp = rewriter.create<scf::WhileOp>(
+ whileOp.getLoc(), argsTypesAfter, castedInitArgs);
// Add before/after regions to the new op.
- SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
+ SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
+ whileOp.getLoc());
SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
whileOp.getLoc());
Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
@@ -856,19 +884,6 @@ struct WhileOpInterface
rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
- // Update scf.condition of new loop.
- auto newConditionOp = newWhileOp.getConditionOp();
- rewriter.setInsertionPoint(newConditionOp);
- // Only equivalent buffers or new buffer allocations may be yielded to the
- // "after" region.
- // TODO: This could be relaxed for better bufferization results.
- FailureOr<SmallVector<Value>> newConditionArgs =
- getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
- indicesAfter, options);
- if (failed(newConditionArgs))
- return failure();
- newConditionOp.getArgsMutable().assign(*newConditionArgs);
-
// Set up new iter_args and move the loop body block to the new op.
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
// in ToTensorOps.
@@ -877,25 +892,51 @@ struct WhileOpInterface
rewriter, newWhileOp.getAfterArguments(), indicesAfter);
rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
- // Update scf.yield of the new loop.
- auto newYieldOp = newWhileOp.getYieldOp();
- rewriter.setInsertionPoint(newYieldOp);
- // Only equivalent buffers or new buffer allocations may be yielded to the
- // "before" region.
- // TODO: This could be relaxed for better bufferization results.
- FailureOr<SmallVector<Value>> newYieldValues =
- getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
- indicesBefore, options);
- if (failed(newYieldValues))
- return failure();
- newYieldOp.getResultsMutable().assign(*newYieldValues);
-
// Replace loop results.
replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
return success();
}
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ auto whileOp = cast<scf::WhileOp>(op);
+ assert(getOwnerOfValue(value) == op && "invalid value");
+ assert(value.getType().isa<TensorType>() && "expected tensor type");
+
+ // Case 1: Block argument of the "before" region.
+ if (auto bbArg = value.cast<BlockArgument>()) {
+ if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
+ Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
+ auto yieldOp = whileOp.getYieldOp();
+ Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
+ return computeLoopRegionIterArgBufferType(bbArg, initArg, yieldedValue,
+ options, fixedTypes);
+ }
+ }
+
+ // Case 2: OpResult of the loop or block argument of the "after" region.
+ // The bufferized "after" bbArg type can be directly computed from the
+ // bufferized "before" bbArg type.
+ unsigned resultNum;
+ if (auto opResult = value.dyn_cast<OpResult>()) {
+ resultNum = opResult.getResultNumber();
+ } else if (value.cast<BlockArgument>().getOwner()->getParent() ==
+ &whileOp.getAfter()) {
+ resultNum = value.cast<BlockArgument>().getArgNumber();
+ } else {
+ llvm_unreachable("invalid value");
+ }
+ Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
+ if (!conditionYieldedVal.getType().isa<TensorType>()) {
+ // scf.condition was already bufferized.
+ return conditionYieldedVal.getType().cast<BaseMemRefType>();
+ }
+ return bufferization::getBufferType(conditionYieldedVal, options,
+ fixedTypes);
+ }
+
/// Assert that yielded values of an scf.while op are equivalent to their
/// corresponding bbArgs. In that case, the buffer relations of the
/// corresponding OpResults are "Equivalent".
@@ -979,11 +1020,6 @@ struct YieldOpInterface
yieldOp->getParentOp()))
return yieldOp->emitError("unsupported scf::YieldOp parent");
- // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
- // together with scf.while.)
- if (isa<scf::WhileOp>(yieldOp->getParentOp()))
- return success();
-
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Value value = it.value();
@@ -992,15 +1028,20 @@ struct YieldOpInterface
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
- // In case of scf::ForOp / scf::IfOp, we may have to cast the value
- // before yielding it.
- // TODO: Do the same for scf::WhileOp.
+ // We may have to cast the value before yielding it.
if (isa<scf::ForOp, scf::IfOp>(yieldOp->getParentOp())) {
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
yieldOp->getParentOp()->getResult(it.index()), options);
if (failed(resultType))
return failure();
buffer = castBuffer(rewriter, buffer, *resultType);
+ } else if (auto whileOp =
+ dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
+ FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+ whileOp.getBeforeArguments()[it.index()], options);
+ if (failed(resultType))
+ return failure();
+ buffer = castBuffer(rewriter, buffer, *resultType);
}
newResults.push_back(buffer);
} else {
@@ -1103,6 +1144,7 @@ struct PerformConcurrentlyOpInterface
void mlir::scf::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
+ ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
ForOp::attachInterface<ForOpInterface>(*ctx);
IfOp::attachInterface<IfOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index c8bf6125e252..61ec6faae86a 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -388,23 +388,23 @@ func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
// CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w0]], %[[a0]]
// CHECK: memref.dealloc %[[w0]]
- // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
- // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
- // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
- // CHECK: memref.dealloc %[[a0]]
- // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]]
+ // CHECK: %[[cloned1:.*]] = bufferization.clone %[[a1]]
// CHECK: memref.dealloc %[[a1]]
+ // CHECK: %[[cloned0:.*]] = bufferization.clone %[[a0]]
+ // CHECK: memref.dealloc %[[a0]]
// CHECK: scf.condition(%[[condition]]) %[[cloned1]], %[[cloned0]]
%condition = tensor.extract %w0[%idx] : tensor<5xi1>
scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
} do {
^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
// CHECK: } do {
- // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}):
+ // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>):
// CHECK: memref.store %{{.*}}, %[[b0]]
- // CHECK: %[[cloned2:.*]] = bufferization.clone %[[b1]]
+ // CHECK: %[[casted0:.*]] = memref.cast %[[b0]] : memref<5xi1> to memref<5xi1, #{{.*}}>
+ // CHECK: %[[casted1:.*]] = memref.cast %[[b1]] : memref<5xi1> to memref<5xi1, #{{.*}}>
+ // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted1]]
// CHECK: memref.dealloc %[[b1]]
- // CHECK: %[[cloned3:.*]] = bufferization.clone %[[b0]]
+ // CHECK: %[[cloned3:.*]] = bufferization.clone %[[casted0]]
// CHECK: memref.dealloc %[[b0]]
// CHECK: scf.yield %[[cloned3]], %[[cloned2]]
// CHECK: }
@@ -441,25 +441,24 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
// CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w0]], %[[a0]]
// CHECK: memref.dealloc %[[w0]]
- // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
- // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
- // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
- // CHECK: memref.dealloc %[[a0]]
- // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]]
+ // CHECK: %[[cloned1:.*]] = bufferization.clone %[[a1]]
// CHECK: memref.dealloc %[[a1]]
+ // CHECK: %[[cloned0:.*]] = bufferization.clone %[[a0]]
+ // CHECK: memref.dealloc %[[a0]]
// CHECK: scf.condition(%[[condition]]) %[[cloned1]], %[[cloned0]]
%condition = tensor.extract %w0[%idx] : tensor<5xi1>
scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
} do {
^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
// CHECK: } do {
- // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}):
+ // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>):
// CHECK: memref.store %{{.*}}, %[[b0]]
// CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[b1]], %[[a3]]
// CHECK: memref.dealloc %[[b1]]
// CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[b0]], %[[a2]]
+ // CHECK: memref.dealloc %[[b0]]
// CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
// CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
// CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]]
@@ -764,3 +763,32 @@ func.func @scf_for_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
%x = tensor.extract %r[%c1] : tensor<?xf32>
return %x : f32
}
+
+// -----
+
+// We just check that this example bufferizes to valid IR.
+
+// CHECK-LABEL: func @scf_while_buffer_type_mismatch
+func.func @scf_while_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %cst = arith.constant 5.5 : f32
+ %0 = bufferization.alloc_tensor(%sz) : tensor<?xf32>
+ %e2 = tensor.extract_slice %0[1][%sz2][1] : tensor<?xf32> to tensor<?xf32>
+ // init_arg and iter_arg have
diff erent buffer types. This must be resolved
+ // with casts.
+ %r = scf.while (%t = %e2) : (tensor<?xf32>) -> (tensor<?xf32>) {
+ %c = "test.condition"() : () -> (i1)
+ %s = "test.dummy"() : () -> (index)
+ %e = tensor.extract_slice %t[1][%s][1] : tensor<?xf32> to tensor<?xf32>
+ scf.condition(%c) %e : tensor<?xf32>
+ } do {
+ ^bb0(%b0: tensor<?xf32>):
+ %s2 = "test.dummy"() : () -> (index)
+ %n = tensor.insert %cst into %b0[%s2] : tensor<?xf32>
+ scf.yield %n : tensor<?xf32>
+ }
+ %x = tensor.extract %r[%c1] : tensor<?xf32>
+ return %x : f32
+}
More information about the Mlir-commits
mailing list