[Mlir-commits] [mlir] 3bbc869 - [mlir][linalg][bufferize] Support scf::IfOp

Matthias Springer llvmlistbot at llvm.org
Thu Oct 21 18:13:07 PDT 2021


Author: Matthias Springer
Date: 2021-10-22T10:12:55+09:00
New Revision: 3bbc869e2ef26f3bc296d5b4e23ee8678a20fc0b

URL: https://github.com/llvm/llvm-project/commit/3bbc869e2ef26f3bc296d5b4e23ee8678a20fc0b
DIFF: https://github.com/llvm/llvm-project/commit/3bbc869e2ef26f3bc296d5b4e23ee8678a20fc0b.diff

LOG: [mlir][linalg][bufferize] Support scf::IfOp

This commit adds support for scf::IfOp to comprehensive bufferization. Support is currently limited to cases where both branches yield tensors that bufferize to the same buffer.

To keep the analysis simple, scf::IfOp are treated as memory writes for analysis purposes, even if no op inside any branch is writing. (scf::ForOps are handled in the same way.)

Differential Revision: https://reviews.llvm.org/D111929

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 4134cc042ebd..b4af0fd82903 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -442,6 +442,7 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
           ConstantOp,
           tensor::DimOp,
           ExtractSliceOp,
+          scf::IfOp,
           scf::ForOp,
           InsertSliceOp,
           InitTensorOp,
@@ -550,6 +551,16 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
   // clang-format on
 }
 
+/// Either one of the corresponding yield values from the then/else branches
+/// may alias with the result.
+static void populateAliasingOpOperands(scf::IfOp op, OpResult result,
+                                       SmallVector<OpOperand *> &operands) {
+  size_t resultNum = std::distance(op->getOpResults().begin(),
+                                   llvm::find(op->getOpResults(), result));
+  operands.push_back(&op.thenYield()->getOpOperand(resultNum));
+  operands.push_back(&op.elseYield()->getOpOperand(resultNum));
+}
+
 /// Determine which OpOperand* will alias with `result` if the op is bufferized
 /// in place. Note that multiple OpOperands can may potentially alias with an
 /// OpResult. E.g.: std.select in the future.
@@ -561,6 +572,7 @@ static SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) {
   TypeSwitch<Operation *>(result.getDefiningOp())
       .Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); })
       .Case([&](ExtractSliceOp op) { r.push_back(&op->getOpOperand(0)); })
+      .Case([&](scf::IfOp op) { populateAliasingOpOperands(op, result, r); })
       // In the case of scf::ForOp, this currently assumes the iter_args / yield
       // are 1-1. This may fail and is verified at the end.
       // TODO: update this.
@@ -730,6 +742,19 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
           if (bbArg.getType().isa<TensorType>())
             createAliasInfoEntry(bbArg);
   });
+
+  // The return value of an scf::IfOp aliases with both yield values.
+  rootOp->walk([&](scf::IfOp ifOp) {
+    if (ifOp->getNumResults() > 0) {
+      for (auto it : llvm::zip(ifOp.thenYield().results(),
+                               ifOp.elseYield().results(), ifOp.results())) {
+        aliasInfo.unionSets(std::get<0>(it), std::get<1>(it));
+        aliasInfo.unionSets(std::get<0>(it), std::get<2>(it));
+        equivalentInfo.unionSets(std::get<0>(it), std::get<1>(it));
+        equivalentInfo.unionSets(std::get<0>(it), std::get<2>(it));
+      }
+    }
+  });
 }
 
 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
@@ -834,13 +859,28 @@ void BufferizationAliasInfo::bufferizeOutOfPlace(OpResult result) {
 }
 
 /// Starting from `value`, follow the use-def chain in reverse, always selecting
-/// the corresponding aliasing OpOperand. Try to find and return a Value for
-/// which `condition` evaluates to true.
+/// the aliasing OpOperands. Find and return Values for which `condition`
+/// evaluates to true. OpOperands of such matching Values are not traversed any
+/// further.
 ///
-/// When reaching the end of the chain (BlockArgument or Value without aliasing
-/// OpOperands), return the last Value of the chain.
+/// When reaching the end of a chain (BlockArgument or Value without aliasing
+/// OpOperands), also return the last Value of that chain.
+///
+/// Example:
 ///
-/// Note: The returned SetVector contains exactly one element.
+///                               8
+///                               |
+///   6*         7*         +-----+----+
+///   |          |          |          |
+///   2*         3          4*         5
+///   |          |          |          |
+///   +----------+----------+----------+
+///              |
+///              1
+///
+/// In the above example, Values with a star satisfy the condition. When
+/// starting the traversal from Value 1, the resulting SetVector is:
+/// { 2, 7, 8, 5 }
 static llvm::SetVector<Value>
 findValueInReverseUseDefChain(Value value,
                               std::function<bool(Value)> condition) {
@@ -861,18 +901,22 @@ findValueInReverseUseDefChain(Value value,
       continue;
     }
 
-    assert(opOperands.size() == 1 && "multiple OpOperands not supported yet");
-    workingSet.insert(opOperands.front()->get());
+    for (OpOperand *o : opOperands)
+      workingSet.insert(o->get());
   }
 
   return result;
 }
 
-/// Find the Value (result) of the last preceding write of a given Value.
+/// Find the Value of the last preceding write of a given Value.
 ///
 /// Note: Unknown ops are handled conservatively and assumed to be writes.
 /// Furthermore, BlockArguments are also assumed to be writes. There is no
 /// analysis across block boundaries.
+///
+/// Note: To simplify the analysis, scf.if ops are considered writes. Treating
+/// a non-writing op as a writing op may introduce unnecessary out-of-place
+/// bufferizations, but is always safe from a correctness point of view.
 static Value findLastPrecedingWrite(Value value) {
   SetVector<Value> result =
       findValueInReverseUseDefChain(value, [](Value value) {
@@ -881,6 +925,8 @@ static Value findLastPrecedingWrite(Value value) {
           return true;
         if (!hasKnownBufferizationAliasingBehavior(op))
           return true;
+        if (isa<scf::IfOp>(op))
+          return true;
 
         SmallVector<OpOperand *> opOperands =
             getAliasingOpOperand(value.cast<OpResult>());
@@ -911,6 +957,21 @@ bool BufferizationAliasInfo::hasMatchingExtractSliceOp(
                       condition);
 }
 
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b,
+                          const DominanceInfo &domInfo) {
+  do {
+    // TODO: Instead of isProperAncestor + properlyDominates, we should use
+    // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
+    if (a->isProperAncestor(b))
+      return false;
+    if (domInfo.properlyDominates(a, b))
+      return true;
+  } while ((a = a->getParentOp()));
+  return false;
+}
+
 /// Given sets of uses and writes, return true if there is a RaW conflict under
 /// the assumption that all given reads/writes alias the same buffer and that
 /// all given writes bufferize inplace.
@@ -935,7 +996,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
     // is %0. Note that operations that create an alias but do not write (such
     // as ExtractSliceOp) are skipped.
-    // TODO: With branches this should probably be a list of Values.
     Value lastWrite = findLastPrecedingWrite(uRead->get());
 
     // Look for conflicting memory writes. Potential conflicts are writes to an
@@ -949,21 +1009,35 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
       LDBG("Found potential conflict:\n");
       LDBG("READ = #" << uRead->getOperandNumber() << " of "
                       << printOperationInfo(readingOp) << "\n");
-      LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
       LDBG("CONFLICTING WRITE = #"
            << uConflictingWrite->getOperandNumber() << " of "
            << printOperationInfo(conflictingWritingOp) << "\n");
 
       // No conflict if the readingOp dominates conflictingWritingOp, i.e., the
       // write is not visible when reading.
-      if (domInfo.properlyDominates(readingOp, conflictingWritingOp))
+      if (happensBefore(readingOp, conflictingWritingOp, domInfo))
+        continue;
+
+      // No conflict if the reading use equals the use of the conflicting write.
+      // A use cannot conflict with itself. Note: Just being the same op is not
+      // enough. It has to be the same use.
+      if (uConflictingWrite == uRead)
+        continue;
+
+      if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp))
         continue;
 
-      // No conflict if the conflicting write happens before the last write.
+      LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
+
+      // No conflict if the conflicting write happens before the last
+      // write.
       if (Operation *writingOp = lastWrite.getDefiningOp()) {
-        if (domInfo.properlyDominates(conflictingWritingOp, writingOp))
+        if (happensBefore(conflictingWritingOp, writingOp, domInfo))
           // conflictingWritingOp happens before writingOp. No conflict.
           continue;
+        // No conflict if conflictingWritingOp is contained in writingOp.
+        if (writingOp->isProperAncestor(conflictingWritingOp))
+          continue;
       } else {
         auto bbArg = lastWrite.cast<BlockArgument>();
         Block *block = bbArg.getOwner();
@@ -978,11 +1052,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
       if (getAliasingOpResult(*uConflictingWrite) == lastWrite)
         continue;
 
-      // No conflict is the same use is the read and the conflicting write. A
-      // use cannot conflict with itself.
-      if (uConflictingWrite == uRead)
-        continue;
-
       // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
       // uRead is an InsertSliceOp...
       if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
@@ -1423,15 +1492,27 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
   OpBuilder::InsertionGuard guard(b);
   Operation *op = result.getOwner();
   SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
-  // TODO: Support multiple OpOperands.
-  assert(aliasingOperands.size() == 1 &&
-         "more than 1 OpOperand not supported yet");
+  assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
   Value operand = aliasingOperands.front()->get();
   Value operandBuffer = lookup(bvm, operand);
   assert(operandBuffer && "operand buffer not found");
+  // Make sure that all OpOperands are the same buffer. If this is not the case,
+  // we would have to materialize a memref value.
+  if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
+        return lookup(bvm, o->get()) == operandBuffer;
+      })) {
+    op->emitError("result buffer is ambiguous");
+    return Value();
+  }
 
   // If bufferizing out-of-place, allocate a new buffer.
-  if (getInPlace(result) != InPlaceSpec::True) {
+  bool needCopy =
+      getInPlace(result) != InPlaceSpec::True && !isa<scf::IfOp>(op);
+  if (needCopy) {
+    // Ops such as scf::IfOp can currently not bufferize out-of-place.
+    assert(
+        aliasingOperands.size() == 1 &&
+        "ops with multiple aliasing OpOperands cannot bufferize out-of-place");
     Location loc = op->getLoc();
     // Allocate the result buffer.
     Value resultBuffer =
@@ -1771,6 +1852,31 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
   return success();
 }
 
+static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp,
+                               BlockAndValueMapping &bvm,
+                               BufferizationAliasInfo &aliasInfo) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+
+  for (OpResult opResult : ifOp->getResults()) {
+    if (!opResult.getType().isa<TensorType>())
+      continue;
+    // TODO: Atm we bail on unranked TensorType because we don't know how to
+    // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
+    assert(opResult.getType().isa<RankedTensorType>() &&
+           "unsupported unranked tensor");
+
+    Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+    if (!resultBuffer)
+      return failure();
+
+    aliasInfo.createAliasInfoEntry(resultBuffer);
+    map(bvm, opResult, resultBuffer);
+  }
+
+  return success();
+}
+
 /// FuncOp always creates TensorToMemRef ops.
 static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
                                BlockAndValueMapping &bvm,
@@ -2038,7 +2144,6 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
       getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
   if (!dstMemref)
     return failure();
-
   auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
 
   Value srcMemref = lookup(bvm, insertSliceOp.source());
@@ -2127,6 +2232,9 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
     return success();
   }
 
+  if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp()))
+    return success();
+
   scf::ForOp forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
   if (!forOp)
     return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp");
@@ -2344,6 +2452,13 @@ LogicalResult mlir::linalg::bufferizeOp(
         LDBG("Begin bufferize:\n" << op << '\n');
         return bufferize(b, op, bvm, aliasInfo);
       })
+      .Case<tensor::CastOp, tensor::DimOp, ExtractSliceOp, InitTensorOp,
+            InsertSliceOp, tensor::ExtractOp, LinalgOp, ReturnOp,
+            VectorTransferOpInterface, linalg::YieldOp, scf::YieldOp,
+            scf::IfOp>([&](auto op) {
+        LDBG("Begin bufferize:\n" << op << '\n');
+        return bufferize(b, op, bvm, aliasInfo);
+      })
       .Case([&](CallOpInterface op) {
         LDBG("Begin bufferize:\n" << op << '\n');
         if (!bufferizedFunctionTypes)

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 1283525ae33c..12897e2b4faa 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1087,3 +1087,291 @@ func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = t
   %2 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
   return %2, %2 : tensor<?xf32>, tensor<?xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// scf.if cases
+//===----------------------------------------------------------------------===//
+
+// This example passes analysis, but it fails when bufferizing.
+// CHECK-LABEL: func @scf_if_inplace1
+func @scf_if_inplace1(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                      %t2: tensor<?xf32> {linalg.inplaceable = true},
+                      %cond: i1) -> tensor<?xf32> {
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    scf.yield %t2 : tensor<?xf32>
+  }
+  return %r : tensor<?xf32>
+}
+
+// CHECK-LABEL: func @scf_if_inplace2
+func @scf_if_inplace2(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                      %v: vector<5xf32>, %idx: index,
+                      %cond: i1) -> tensor<?xf32> {
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    //      CHECK: vector.transfer_write
+    // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+    %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+    scf.yield %t2 : tensor<?xf32>
+  }
+  return %r : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_inplace3
+func @scf_if_inplace3(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                      %v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index,
+                      %cond: i1) -> tensor<?xf32> {
+  //      CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+  %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    //      CHECK: vector.transfer_write
+    // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+    %t2 = vector.transfer_write %v1, %e[%idx] : vector<5xf32>, tensor<?xf32>
+    scf.yield %t2 : tensor<?xf32>
+  } else {
+    // Writing the same tensor through an alias. This is OK.
+    //      CHECK: vector.transfer_write
+    // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+    %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+    scf.yield %t3 : tensor<?xf32>
+  }
+  return %r : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_in_place4
+func @scf_if_in_place4(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                       %v: vector<5xf32>, %idx: index,
+                       %cond: i1, %cond2: i1) -> (tensor<?xf32>, vector<10xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    //      CHECK: vector.transfer_write
+    // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+    %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+    scf.yield %t2 : tensor<?xf32>
+  }
+  %r_alias = scf.if %cond2 -> (tensor<?xf32>) {
+    // Reading %r is OK. No conflict.
+    scf.yield %r : tensor<?xf32>
+  } else {
+    scf.yield %r : tensor<?xf32>
+  }
+  %v2 = vector.transfer_read %r_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
+  return %r_alias, %v2 : tensor<?xf32>, vector<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_inplace5
+func @scf_if_inplace5(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                      %idx: index, %cond: i1) -> tensor<?xf32> {
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    //      CHECK: tensor.extract_slice
+    // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+    %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+    scf.yield %e : tensor<?xf32>
+  } else {
+    //      CHECK: tensor.extract_slice
+    // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+    %f = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+    scf.yield %f : tensor<?xf32>
+  }
+
+  // Inserting into an equivalent tensor at the same offset. This bufferizes
+  // inplace.
+  //      CHECK: tensor.insert_slice
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+  %r2 = tensor.insert_slice %r into %t1[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
+  return %r2 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_inplace6
+func @scf_if_inplace6(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                      %v1: vector<5xf32>, %v2: vector<5xf32>,
+                      %v3: vector<5xf32>, %idx: index,
+                      %cond: i1, %cond2: i1) -> tensor<?xf32> {
+  // Test nested scf.if ops.
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    %t2 = scf.if %cond2 -> (tensor<?xf32>) {
+      //      CHECK: vector.transfer_write
+      // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+      %t3 = vector.transfer_write %v1, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+      scf.yield %t3 : tensor<?xf32>
+    } else {
+      //      CHECK: vector.transfer_write
+      // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+      %t4 = vector.transfer_write %v3, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+      scf.yield %t4 : tensor<?xf32>
+    }
+    scf.yield %t2 : tensor<?xf32>
+  } else {
+    //      CHECK: vector.transfer_write
+    // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+    %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+    scf.yield %t3 : tensor<?xf32>
+  }
+  return %r : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_inplace7
+func @scf_if_inplace7(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                      %v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index,
+                      %idx2: index, %cond: i1) -> (tensor<?xf32>, vector<5xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %r, %v_r2 = scf.if %cond -> (tensor<?xf32>, vector<5xf32>) {
+    //      CHECK: vector.transfer_write
+    // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+    %t2 = vector.transfer_write %v1, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+    scf.yield %t2, %v1 : tensor<?xf32>, vector<5xf32>
+  } else {
+    // Writing the same tensor through an alias.
+    //      CHECK: vector.transfer_write
+    // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+    %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+    // Read the original value of %t1. This requires the write in this branch
+    // to be out-of-place. But the write in the other branch can still be
+    // inplace.
+    %v_r = vector.transfer_read %t1[%idx2], %cst : tensor<?xf32>, vector<5xf32>
+    scf.yield %t3, %v_r : tensor<?xf32>, vector<5xf32>
+  }
+  return %r, %v_r2 : tensor<?xf32>, vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_out_of_place1a
+func @scf_if_out_of_place1a(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                            %idx: index, %idx2: index,
+                            %cond: i1) -> tensor<?xf32> {
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    //      CHECK: tensor.extract_slice
+    // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+    %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+    scf.yield %e : tensor<?xf32>
+  } else {
+    scf.yield %t1 : tensor<?xf32>
+  }
+
+  // Reading from and writing to the same tensor via 
diff erent args. This is a
+  // conflict.
+  //      CHECK: tensor.insert_slice
+  // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+  %r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
+  return %r2 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_out_of_place1b
+func @scf_if_out_of_place1b(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                            %idx: index, %idx2: index, %idx3: index,
+                            %cond: i1) -> tensor<?xf32> {
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    //      CHECK: tensor.extract_slice
+    // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+    %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+    scf.yield %e : tensor<?xf32>
+  } else {
+    //      CHECK: tensor.extract_slice
+    // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+    %f = tensor.extract_slice %t1[%idx2][%idx2][1] : tensor<?xf32> to tensor<?xf32>
+    scf.yield %f : tensor<?xf32>
+  }
+
+  // Reading from and writing to the same tensor via 
diff erent args. This is a
+  // conflict. In contrast to scf_if_out_of_place1a, the fact that %r aliases
+  // with %t1 is only detected when analyzing the tensor.extract_slices. That's
+  // why the tensor.insert_slice is inplace and the two extract_slices are
+  // out-of-place.
+  //      CHECK: tensor.insert_slice
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+  %r2 = tensor.insert_slice %r into %t1[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
+  return %r2 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_out_of_place1c
+func @scf_if_out_of_place1c(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                            %idx: index, %idx2: index, %cond: i1) -> tensor<?xf32> {
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    //      CHECK: tensor.extract_slice
+    // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+    %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
+    scf.yield %e : tensor<?xf32>
+  } else {
+    // TODO: This one could bufferize inplace, but the analysis is too restrictive.
+    //      CHECK: tensor.extract_slice
+    // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+    %f = tensor.extract_slice %t1[%idx2][%idx2][1] : tensor<?xf32> to tensor<?xf32>
+    scf.yield %f : tensor<?xf32>
+  }
+
+  //      CHECK: tensor.insert_slice
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]
+  %r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
+  return %r2 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_out_of_place2
+func @scf_if_out_of_place2(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                           %v: vector<5xf32>, %idx: index,
+                           %cond: i1) -> (tensor<?xf32>, vector<10xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    //      CHECK: vector.transfer_write
+    // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+    %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+    scf.yield %t2 : tensor<?xf32>
+  }
+
+  // Read the old value of %t1. Forces the transfer_write to bufferize
+  // out-of-place.
+  %v2 = vector.transfer_read %t1[%idx], %cst : tensor<?xf32>, vector<10xf32>
+  return %r, %v2 : tensor<?xf32>, vector<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_out_of_place3
+func @scf_if_out_of_place3(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                           %v: vector<5xf32>, %idx: index,
+                           %cond: i1, %cond2: i1) -> (tensor<?xf32>, vector<10xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    //      CHECK: vector.transfer_write
+    // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+    %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+    scf.yield %t2 : tensor<?xf32>
+  }
+  %t1_alias = scf.if %cond2 -> (tensor<?xf32>) {
+    // scf.yield bufferizes to a read. That is a conflict in this example.
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    scf.yield %t1 : tensor<?xf32>
+  }
+  %v2 = vector.transfer_read %t1_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
+  return %r, %v2 : tensor<?xf32>, vector<10xf32>
+}
+

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 88cfbb4d68b7..0584ebde985c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -113,8 +113,8 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
 
 func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
 {
+  // expected-error @+1 {{result buffer is ambiguous}}
   %r = scf.if %b -> (tensor<4xf32>) {
-    // expected-error @+1 {{expected scf::ForOp parent for scf::YieldOp}}
     scf.yield %A : tensor<4xf32>
   } else {
     scf.yield %B : tensor<4xf32>

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index b012409fd873..9d6227462c49 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -861,3 +861,25 @@ func @buffer_forwarding_no_conflict(
   return %r1: tensor<?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @scf_if_inplace(
+//  CHECK-SAME:     %[[cond:.*]]: i1, %[[t1:.*]]: memref<?xf32{{.*}}>, %[[v:.*]]: vector
+func @scf_if_inplace(%cond: i1,
+                     %t1: tensor<?xf32> {linalg.inplaceable = true},
+                     %v: vector<5xf32>, %idx: index) -> tensor<?xf32> {
+
+  //      CHECK: scf.if %[[cond]] {
+  // CHECK-NEXT: } else {
+  // CHECK-NEXT:   vector.transfer_write %[[v]], %[[t1]]
+  // CHECK-NEXT: }
+  // CHECK-NEXT: return
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
+    scf.yield %t2 : tensor<?xf32>
+  }
+  return %r : tensor<?xf32>
+}
+


        


More information about the Mlir-commits mailing list