[Mlir-commits] [mlir] 3bbc869 - [mlir][linalg][bufferize] Support scf::IfOp
Matthias Springer
llvmlistbot at llvm.org
Thu Oct 21 18:13:07 PDT 2021
Author: Matthias Springer
Date: 2021-10-22T10:12:55+09:00
New Revision: 3bbc869e2ef26f3bc296d5b4e23ee8678a20fc0b
URL: https://github.com/llvm/llvm-project/commit/3bbc869e2ef26f3bc296d5b4e23ee8678a20fc0b
DIFF: https://github.com/llvm/llvm-project/commit/3bbc869e2ef26f3bc296d5b4e23ee8678a20fc0b.diff
LOG: [mlir][linalg][bufferize] Support scf::IfOp
This commit adds support for scf::IfOp to comprehensive bufferization. Support is currently limited to cases where both branches yield tensors that bufferize to the same buffer.
To keep the analysis simple, scf::IfOp are treated as memory writes for analysis purposes, even if no op inside any branch is writing. (scf::ForOps are handled in the same way.)
Differential Revision: https://reviews.llvm.org/D111929
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 4134cc042ebd..b4af0fd82903 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -442,6 +442,7 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
ConstantOp,
tensor::DimOp,
ExtractSliceOp,
+ scf::IfOp,
scf::ForOp,
InsertSliceOp,
InitTensorOp,
@@ -550,6 +551,16 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
// clang-format on
}
+/// Either one of the corresponding yield values from the then/else branches
+/// may alias with the result.
+static void populateAliasingOpOperands(scf::IfOp op, OpResult result,
+ SmallVector<OpOperand *> &operands) {
+ size_t resultNum = std::distance(op->getOpResults().begin(),
+ llvm::find(op->getOpResults(), result));
+ operands.push_back(&op.thenYield()->getOpOperand(resultNum));
+ operands.push_back(&op.elseYield()->getOpOperand(resultNum));
+}
+
/// Determine which OpOperand* will alias with `result` if the op is bufferized
/// in place. Note that multiple OpOperands can may potentially alias with an
/// OpResult. E.g.: std.select in the future.
@@ -561,6 +572,7 @@ static SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) {
TypeSwitch<Operation *>(result.getDefiningOp())
.Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); })
.Case([&](ExtractSliceOp op) { r.push_back(&op->getOpOperand(0)); })
+ .Case([&](scf::IfOp op) { populateAliasingOpOperands(op, result, r); })
// In the case of scf::ForOp, this currently assumes the iter_args / yield
// are 1-1. This may fail and is verified at the end.
// TODO: update this.
@@ -730,6 +742,19 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
if (bbArg.getType().isa<TensorType>())
createAliasInfoEntry(bbArg);
});
+
+ // The return value of an scf::IfOp aliases with both yield values.
+ rootOp->walk([&](scf::IfOp ifOp) {
+ if (ifOp->getNumResults() > 0) {
+ for (auto it : llvm::zip(ifOp.thenYield().results(),
+ ifOp.elseYield().results(), ifOp.results())) {
+ aliasInfo.unionSets(std::get<0>(it), std::get<1>(it));
+ aliasInfo.unionSets(std::get<0>(it), std::get<2>(it));
+ equivalentInfo.unionSets(std::get<0>(it), std::get<1>(it));
+ equivalentInfo.unionSets(std::get<0>(it), std::get<2>(it));
+ }
+ }
+ });
}
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
@@ -834,13 +859,28 @@ void BufferizationAliasInfo::bufferizeOutOfPlace(OpResult result) {
}
/// Starting from `value`, follow the use-def chain in reverse, always selecting
-/// the corresponding aliasing OpOperand. Try to find and return a Value for
-/// which `condition` evaluates to true.
+/// the aliasing OpOperands. Find and return Values for which `condition`
+/// evaluates to true. OpOperands of such matching Values are not traversed any
+/// further.
///
-/// When reaching the end of the chain (BlockArgument or Value without aliasing
-/// OpOperands), return the last Value of the chain.
+/// When reaching the end of a chain (BlockArgument or Value without aliasing
+/// OpOperands), also return the last Value of that chain.
+///
+/// Example:
///
-/// Note: The returned SetVector contains exactly one element.
+/// 8
+/// |
+/// 6* 7* +-----+----+
+/// | | | |
+/// 2* 3 4* 5
+/// | | | |
+/// +----------+----------+----------+
+/// |
+/// 1
+///
+/// In the above example, Values with a star satisfy the condition. When
+/// starting the traversal from Value 1, the resulting SetVector is:
+/// { 2, 7, 8, 5 }
static llvm::SetVector<Value>
findValueInReverseUseDefChain(Value value,
std::function<bool(Value)> condition) {
@@ -861,18 +901,22 @@ findValueInReverseUseDefChain(Value value,
continue;
}
- assert(opOperands.size() == 1 && "multiple OpOperands not supported yet");
- workingSet.insert(opOperands.front()->get());
+ for (OpOperand *o : opOperands)
+ workingSet.insert(o->get());
}
return result;
}
-/// Find the Value (result) of the last preceding write of a given Value.
+/// Find the Value of the last preceding write of a given Value.
///
/// Note: Unknown ops are handled conservatively and assumed to be writes.
/// Furthermore, BlockArguments are also assumed to be writes. There is no
/// analysis across block boundaries.
+///
+/// Note: To simplify the analysis, scf.if ops are considered writes. Treating
+/// a non-writing op as a writing op may introduce unnecessary out-of-place
+/// bufferizations, but is always safe from a correctness point of view.
static Value findLastPrecedingWrite(Value value) {
SetVector<Value> result =
findValueInReverseUseDefChain(value, [](Value value) {
@@ -881,6 +925,8 @@ static Value findLastPrecedingWrite(Value value) {
return true;
if (!hasKnownBufferizationAliasingBehavior(op))
return true;
+ if (isa<scf::IfOp>(op))
+ return true;
SmallVector<OpOperand *> opOperands =
getAliasingOpOperand(value.cast<OpResult>());
@@ -911,6 +957,21 @@ bool BufferizationAliasInfo::hasMatchingExtractSliceOp(
condition);
}
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b,
+ const DominanceInfo &domInfo) {
+ do {
+ // TODO: Instead of isProperAncestor + properlyDominates, we should use
+ // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
+ if (a->isProperAncestor(b))
+ return false;
+ if (domInfo.properlyDominates(a, b))
+ return true;
+ } while ((a = a->getParentOp()));
+ return false;
+}
+
/// Given sets of uses and writes, return true if there is a RaW conflict under
/// the assumption that all given reads/writes alias the same buffer and that
/// all given writes bufferize inplace.
@@ -935,7 +996,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
// In the above example, if uRead is the OpOperand of reading_op, lastWrite
// is %0. Note that operations that create an alias but do not write (such
// as ExtractSliceOp) are skipped.
- // TODO: With branches this should probably be a list of Values.
Value lastWrite = findLastPrecedingWrite(uRead->get());
// Look for conflicting memory writes. Potential conflicts are writes to an
@@ -949,21 +1009,35 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
LDBG("Found potential conflict:\n");
LDBG("READ = #" << uRead->getOperandNumber() << " of "
<< printOperationInfo(readingOp) << "\n");
- LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
LDBG("CONFLICTING WRITE = #"
<< uConflictingWrite->getOperandNumber() << " of "
<< printOperationInfo(conflictingWritingOp) << "\n");
// No conflict if the readingOp dominates conflictingWritingOp, i.e., the
// write is not visible when reading.
- if (domInfo.properlyDominates(readingOp, conflictingWritingOp))
+ if (happensBefore(readingOp, conflictingWritingOp, domInfo))
+ continue;
+
+ // No conflict if the reading use equals the use of the conflicting write.
+ // A use cannot conflict with itself. Note: Just being the same op is not
+ // enough. It has to be the same use.
+ if (uConflictingWrite == uRead)
+ continue;
+
+ if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp))
continue;
- // No conflict if the conflicting write happens before the last write.
+ LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
+
+ // No conflict if the conflicting write happens before the last
+ // write.
if (Operation *writingOp = lastWrite.getDefiningOp()) {
- if (domInfo.properlyDominates(conflictingWritingOp, writingOp))
+ if (happensBefore(conflictingWritingOp, writingOp, domInfo))
// conflictingWritingOp happens before writingOp. No conflict.
continue;
+ // No conflict if conflictingWritingOp is contained in writingOp.
+ if (writingOp->isProperAncestor(conflictingWritingOp))
+ continue;
} else {
auto bbArg = lastWrite.cast<BlockArgument>();
Block *block = bbArg.getOwner();
@@ -978,11 +1052,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
if (getAliasingOpResult(*uConflictingWrite) == lastWrite)
continue;
- // No conflict is the same use is the read and the conflicting write. A
- // use cannot conflict with itself.
- if (uConflictingWrite == uRead)
- continue;
-
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
// uRead is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
@@ -1423,15 +1492,27 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
OpBuilder::InsertionGuard guard(b);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
- // TODO: Support multiple OpOperands.
- assert(aliasingOperands.size() == 1 &&
- "more than 1 OpOperand not supported yet");
+ assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
Value operand = aliasingOperands.front()->get();
Value operandBuffer = lookup(bvm, operand);
assert(operandBuffer && "operand buffer not found");
+ // Make sure that all OpOperands are the same buffer. If this is not the case,
+ // we would have to materialize a memref value.
+ if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
+ return lookup(bvm, o->get()) == operandBuffer;
+ })) {
+ op->emitError("result buffer is ambiguous");
+ return Value();
+ }
// If bufferizing out-of-place, allocate a new buffer.
- if (getInPlace(result) != InPlaceSpec::True) {
+ bool needCopy =
+ getInPlace(result) != InPlaceSpec::True && !isa<scf::IfOp>(op);
+ if (needCopy) {
+ // Ops such as scf::IfOp can currently not bufferize out-of-place.
+ assert(
+ aliasingOperands.size() == 1 &&
+ "ops with multiple aliasing OpOperands cannot bufferize out-of-place");
Location loc = op->getLoc();
// Allocate the result buffer.
Value resultBuffer =
@@ -1771,6 +1852,31 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
return success();
}
+static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp,
+ BlockAndValueMapping &bvm,
+ BufferizationAliasInfo &aliasInfo) {
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+
+ for (OpResult opResult : ifOp->getResults()) {
+ if (!opResult.getType().isa<TensorType>())
+ continue;
+ // TODO: Atm we bail on unranked TensorType because we don't know how to
+ // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
+ assert(opResult.getType().isa<RankedTensorType>() &&
+ "unsupported unranked tensor");
+
+ Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+ if (!resultBuffer)
+ return failure();
+
+ aliasInfo.createAliasInfoEntry(resultBuffer);
+ map(bvm, opResult, resultBuffer);
+ }
+
+ return success();
+}
+
/// FuncOp always creates TensorToMemRef ops.
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
BlockAndValueMapping &bvm,
@@ -2038,7 +2144,6 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
if (!dstMemref)
return failure();
-
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
Value srcMemref = lookup(bvm, insertSliceOp.source());
@@ -2127,6 +2232,9 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
return success();
}
+ if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp()))
+ return success();
+
scf::ForOp forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
if (!forOp)
return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp");
@@ -2344,6 +2452,13 @@ LogicalResult mlir::linalg::bufferizeOp(
LDBG("Begin bufferize:\n" << op << '\n');
return bufferize(b, op, bvm, aliasInfo);
})
+ .Case<tensor::CastOp, tensor::DimOp, ExtractSliceOp, InitTensorOp,
+ InsertSliceOp, tensor::ExtractOp, LinalgOp, ReturnOp,
+ VectorTransferOpInterface, linalg::YieldOp, scf::YieldOp,
+ scf::IfOp>([&](auto op) {
+ LDBG("Begin bufferize:\n" << op << '\n');
+ return bufferize(b, op, bvm, aliasInfo);
+ })
.Case([&](CallOpInterface op) {
LDBG("Begin bufferize:\n" << op << '\n');
if (!bufferizedFunctionTypes)
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 1283525ae33c..12897e2b4faa 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1087,3 +1087,291 @@ func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = t
%2 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
return %2, %2 : tensor<?xf32>, tensor<?xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// scf.if cases
+//===----------------------------------------------------------------------===//
+
+// This example passes analysis, but it fails when bufferizing.
+// CHECK-LABEL: func @scf_if_inplace1
+func @scf_if_inplace1(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %t2: tensor<?xf32> {linalg.inplaceable = true},
+ %cond: i1) -> tensor<?xf32> {
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ scf.yield %t1 : tensor<?xf32>
+ } else {
+ scf.yield %t2 : tensor<?xf32>
+ }
+ return %r : tensor<?xf32>
+}
+
+// CHECK-LABEL: func @scf_if_inplace2
+func @scf_if_inplace2(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %v: vector<5xf32>, %idx: index,
+ %cond: i1) -> tensor<?xf32> {
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ scf.yield %t1 : tensor<?xf32>
+ } else {
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t2 : tensor<?xf32>
+ }
+ return %r : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_inplace3
+func @scf_if_inplace3(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index,
+ %cond: i1) -> tensor<?xf32> {
+ // CHECK: tensor.extract_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %t2 = vector.transfer_write %v1, %e[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t2 : tensor<?xf32>
+ } else {
+ // Writing the same tensor through an alias. This is OK.
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t3 : tensor<?xf32>
+ }
+ return %r : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_in_place4
+func @scf_if_in_place4(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %v: vector<5xf32>, %idx: index,
+ %cond: i1, %cond2: i1) -> (tensor<?xf32>, vector<10xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ scf.yield %t1 : tensor<?xf32>
+ } else {
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t2 : tensor<?xf32>
+ }
+ %r_alias = scf.if %cond2 -> (tensor<?xf32>) {
+ // Reading %r is OK. No conflict.
+ scf.yield %r : tensor<?xf32>
+ } else {
+ scf.yield %r : tensor<?xf32>
+ }
+ %v2 = vector.transfer_read %r_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
+ return %r_alias, %v2 : tensor<?xf32>, vector<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_inplace5
+func @scf_if_inplace5(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %idx: index, %cond: i1) -> tensor<?xf32> {
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ // CHECK: tensor.extract_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+ scf.yield %e : tensor<?xf32>
+ } else {
+ // CHECK: tensor.extract_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %f = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+ scf.yield %f : tensor<?xf32>
+ }
+
+ // Inserting into an equivalent tensor at the same offset. This bufferizes
+ // inplace.
+ // CHECK: tensor.insert_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %r2 = tensor.insert_slice %r into %t1[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
+ return %r2 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_inplace6
+func @scf_if_inplace6(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %v1: vector<5xf32>, %v2: vector<5xf32>,
+ %v3: vector<5xf32>, %idx: index,
+ %cond: i1, %cond2: i1) -> tensor<?xf32> {
+ // Test nested scf.if ops.
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ %t2 = scf.if %cond2 -> (tensor<?xf32>) {
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %t3 = vector.transfer_write %v1, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t3 : tensor<?xf32>
+ } else {
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %t4 = vector.transfer_write %v3, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t4 : tensor<?xf32>
+ }
+ scf.yield %t2 : tensor<?xf32>
+ } else {
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t3 : tensor<?xf32>
+ }
+ return %r : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_inplace7
+func @scf_if_inplace7(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index,
+ %idx2: index, %cond: i1) -> (tensor<?xf32>, vector<5xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %r, %v_r2 = scf.if %cond -> (tensor<?xf32>, vector<5xf32>) {
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %t2 = vector.transfer_write %v1, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t2, %v1 : tensor<?xf32>, vector<5xf32>
+ } else {
+ // Writing the same tensor through an alias.
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+ %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ // Read the original value of %t1. This requires the write in this branch
+ // to be out-of-place. But the write in the other branch can still be
+ // inplace.
+ %v_r = vector.transfer_read %t1[%idx2], %cst : tensor<?xf32>, vector<5xf32>
+ scf.yield %t3, %v_r : tensor<?xf32>, vector<5xf32>
+ }
+ return %r, %v_r2 : tensor<?xf32>, vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_out_of_place1a
+func @scf_if_out_of_place1a(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %idx: index, %idx2: index,
+ %cond: i1) -> tensor<?xf32> {
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ // CHECK: tensor.extract_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+ scf.yield %e : tensor<?xf32>
+ } else {
+ scf.yield %t1 : tensor<?xf32>
+ }
+
+ // Reading from and writing to the same tensor via
diff erent args. This is a
+ // conflict.
+ // CHECK: tensor.insert_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+ %r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
+ return %r2 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_out_of_place1b
+func @scf_if_out_of_place1b(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %idx: index, %idx2: index, %idx3: index,
+ %cond: i1) -> tensor<?xf32> {
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ // CHECK: tensor.extract_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+ %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+ scf.yield %e : tensor<?xf32>
+ } else {
+ // CHECK: tensor.extract_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+ %f = tensor.extract_slice %t1[%idx2][%idx2][1] : tensor<?xf32> to tensor<?xf32>
+ scf.yield %f : tensor<?xf32>
+ }
+
+ // Reading from and writing to the same tensor via
diff erent args. This is a
+ // conflict. In contrast to scf_if_out_of_place1a, the fact that %r aliases
+ // with %t1 is only detected when analyzing the tensor.extract_slices. That's
+ // why the tensor.insert_slice is inplace and the two extract_slices are
+ // out-of-place.
+ // CHECK: tensor.insert_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %r2 = tensor.insert_slice %r into %t1[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
+ return %r2 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_out_of_place1c
+func @scf_if_out_of_place1c(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %idx: index, %idx2: index, %cond: i1) -> tensor<?xf32> {
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ // CHECK: tensor.extract_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+ %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+ scf.yield %e : tensor<?xf32>
+ } else {
+ // TODO: This one could bufferize inplace, but the analysis is too restrictive.
+ // CHECK: tensor.extract_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+ %f = tensor.extract_slice %t1[%idx2][%idx2][1] : tensor<?xf32> to tensor<?xf32>
+ scf.yield %f : tensor<?xf32>
+ }
+
+ // CHECK: tensor.insert_slice
+ // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+ %r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
+ return %r2 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_out_of_place2
+func @scf_if_out_of_place2(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %v: vector<5xf32>, %idx: index,
+ %cond: i1) -> (tensor<?xf32>, vector<10xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ scf.yield %t1 : tensor<?xf32>
+ } else {
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+ %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t2 : tensor<?xf32>
+ }
+
+ // Read the old value of %t1. Forces the transfer_write to bufferize
+ // out-of-place.
+ %v2 = vector.transfer_read %t1[%idx], %cst : tensor<?xf32>, vector<10xf32>
+ return %r, %v2 : tensor<?xf32>, vector<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_out_of_place3
+func @scf_if_out_of_place3(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %v: vector<5xf32>, %idx: index,
+ %cond: i1, %cond2: i1) -> (tensor<?xf32>, vector<10xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ scf.yield %t1 : tensor<?xf32>
+ } else {
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+ %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t2 : tensor<?xf32>
+ }
+ %t1_alias = scf.if %cond2 -> (tensor<?xf32>) {
+ // scf.yield bufferizes to a read. That is a conflict in this example.
+ scf.yield %t1 : tensor<?xf32>
+ } else {
+ scf.yield %t1 : tensor<?xf32>
+ }
+ %v2 = vector.transfer_read %t1_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
+ return %r, %v2 : tensor<?xf32>, vector<10xf32>
+}
+
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 88cfbb4d68b7..0584ebde985c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -113,8 +113,8 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
{
+ // expected-error @+1 {{result buffer is ambiguous}}
%r = scf.if %b -> (tensor<4xf32>) {
- // expected-error @+1 {{expected scf::ForOp parent for scf::YieldOp}}
scf.yield %A : tensor<4xf32>
} else {
scf.yield %B : tensor<4xf32>
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index b012409fd873..9d6227462c49 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -861,3 +861,25 @@ func @buffer_forwarding_no_conflict(
return %r1: tensor<?xf32>
}
+// -----
+
+// CHECK-LABEL: func @scf_if_inplace(
+// CHECK-SAME: %[[cond:.*]]: i1, %[[t1:.*]]: memref<?xf32{{.*}}>, %[[v:.*]]: vector
+func @scf_if_inplace(%cond: i1,
+ %t1: tensor<?xf32> {linalg.inplaceable = true},
+ %v: vector<5xf32>, %idx: index) -> tensor<?xf32> {
+
+ // CHECK: scf.if %[[cond]] {
+ // CHECK-NEXT: } else {
+ // CHECK-NEXT: vector.transfer_write %[[v]], %[[t1]]
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ %r = scf.if %cond -> (tensor<?xf32>) {
+ scf.yield %t1 : tensor<?xf32>
+ } else {
+ %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+ scf.yield %t2 : tensor<?xf32>
+ }
+ return %r : tensor<?xf32>
+}
+
More information about the Mlir-commits
mailing list