[Mlir-commits] [mlir] 57fe7fd - [mlir][Linalg] Add support for scf::ForOp in comprehensive bufferization (7/n)

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jun 24 08:04:34 PDT 2021


Author: Nicolas Vasilache
Date: 2021-06-24T15:03:28Z
New Revision: 57fe7fd37dcd1f144f600976b3f33d5d792e89fd

URL: https://github.com/llvm/llvm-project/commit/57fe7fd37dcd1f144f600976b3f33d5d792e89fd
DIFF: https://github.com/llvm/llvm-project/commit/57fe7fd37dcd1f144f600976b3f33d5d792e89fd.diff

LOG: [mlir][Linalg] Add support for scf::ForOp in comprehensive bufferization (7/n)

scf::ForOp bufferization analysis proceeds just like for any other op (including FuncOp) at its boundaries; i.e. if:

1. The tensor operand is inplaceable.
2. The matching result has no subsequent read (i.e. all reads dominate the scf::ForOp).
3. In  and does not create a RAW interference.

then it can bufferize inplace.

Still there are a few differences:

1. bbArgs for an scf::ForOp are always considered inplaceable when seen from ops inside the body. This is because a) either the matching tensor operand is not inplaceable and an alloc will be inserted (which makes bbArg itself inplaceable); or b) the tensor operand and bbArg are both already inplaceable.
2. Bufferization within the scf::ForOp body has implications to the outside world : the scf.yield terminator may well ping-pong values of the same type. This muddies the water for alias analysis and is not supported atm. Such cases result in a pass failure.

Differential revision: https://reviews.llvm.org/D104490

Added: 
    mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis-invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index e4915b3a6249a..a5584392aa610 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -261,6 +261,8 @@ def ForOp : SCF_Op<"for",
       return getOperation()->getNumOperands() - getNumControlOperands();
     }
     /// Get the region iter arg that corresponds to an OpOperand.
+    /// This helper prevents internal op implementation detail leakage to
+    /// clients by hiding the operand / block argument mapping.
     BlockArgument getRegionIterArgForOpOperand(OpOperand &opOperand) {
       assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
              "expected an iter args operand");
@@ -270,6 +272,8 @@ def ForOp : SCF_Op<"for",
         opOperand.getOperandNumber() - getNumControlOperands()];
     }
     /// Get the OpOperand& that corresponds to a region iter arg.
+    /// This helper prevents internal op implementation detail leakage to
+    /// clients by hiding the operand / block argument mapping.
     OpOperand &getOpOperandForRegionIterArg(BlockArgument bbArg) {
       assert(bbArg.getArgNumber() >= getNumInductionVars() &&
              "expected a bbArg that is not an induction variable");
@@ -278,6 +282,27 @@ def ForOp : SCF_Op<"for",
       return getOperation()->getOpOperand(
         getNumControlOperands() + bbArg.getArgNumber() - getNumInductionVars());
     }
+    /// Get the OpResult that corresponds to an OpOperand.
+    /// Assert that opOperand is an iterArg.
+    /// This helper prevents internal op implementation detail leakage to
+    /// clients by hiding the operand / block argument mapping.
+    OpResult getResultForOpOperand(OpOperand &opOperand) {
+      assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
+             "expected an iter args operand");
+      assert(opOperand.getOwner() == getOperation() &&
+             "opOperand does not belong to this scf::ForOp operation");
+      return getOperation()->getResult(
+        opOperand.getOperandNumber() - getNumControlOperands());
+    }
+    /// Get the OpOperand& that corresponds to an OpResultOpOperand.
+    /// This helper prevents internal op implementation detail leakage to
+    /// clients by hiding the operand / block argument mapping.
+    OpOperand &getOpOperandForResult(OpResult opResult) {
+      assert(opResult.getDefiningOp() == getOperation() &&
+             "opResult does not belong to the scf::ForOp operation");
+      return getOperation()->getOpOperand(
+        getNumControlOperands() + opResult.getResultNumber());
+    }
 
     /// Return operands used when entering the region at 'index'. These operands
     /// correspond to the loop iterator operands, i.e., those exclusing the

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 0fcae75624c34..d02570af3622b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -109,6 +109,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Pass/Pass.h"
@@ -206,10 +207,10 @@ static Optional<InPlaceSpec> symbolize(StringRef str) {
       .Default(None);
 }
 
-/// Mark whether OpResult can actually be bufferized inplace. If `inPlace` is
-/// `InPlaceSpec::True`, the use-def chain analysis has guaranteed that no
-/// subsequent write would occur to the bufferized tensor value (i.e. the result
-/// can be bufferized inPlace).
+/// Mark whether OpResult can actually be bufferized inplace.
+/// If `inPlace` is `InPlaceSpec::True`, the use-def chain analysis has
+/// guaranteed that no subsequent write would occur to the bufferized
+/// tensor value (i.e. the result can be bufferized inPlace).
 static void setInPlaceOpResult(OpResult opResult,
                                InPlaceSpec inPlace = InPlaceSpec::True) {
   if (!opResult)
@@ -252,16 +253,26 @@ static InPlaceSpec getInPlace(OpResult opResult) {
 }
 
 /// Get inPlace information for `bbArg`.
-/// If it does not come from a function, return InPlaceSpec::False.
+/// FuncOp allow argument attributes, we use those to encode the information.
+/// BlockArgument of other ops delegate to their owner's parent op.
 static InPlaceSpec getInPlace(BlockArgument bbArg) {
-  auto funcOp = dyn_cast<FuncOp>(bbArg.getOwner()->getParentOp());
-  if (!funcOp)
-    return InPlaceSpec::False;
-  auto attr = funcOp.getArgAttrOfType<BoolAttr>(
-      bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName);
-  if (!attr)
-    return InPlaceSpec::None;
-  return attr.getValue() ? InPlaceSpec::True : InPlaceSpec::False;
+  if (auto funcOp = dyn_cast<FuncOp>(bbArg.getOwner()->getParentOp())) {
+    BoolAttr inplaceAttr = funcOp.getArgAttrOfType<BoolAttr>(
+        bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName);
+    if (!inplaceAttr)
+      return InPlaceSpec::None;
+    return inplaceAttr.getValue() ? InPlaceSpec::True : InPlaceSpec::False;
+  }
+  // Interestingly, scf::ForOp's bbArg can **always** be viewed inplace from the
+  // perspective of ops nested under it:
+  //   1. Either the matching iter operand is not bufferized inplace and an
+  //      alloc + optional copy makes the bbArg itself inplaceable.
+  //   2. Or the matching iter operand is bufferized inplace and bbArg just
+  //      bufferizes to that too.
+  if (auto forOp = dyn_cast<scf::ForOp>(bbArg.getOwner()->getParentOp()))
+    return InPlaceSpec::True;
+  // Unknown cases.
+  return InPlaceSpec::None;
 }
 
 LLVM_ATTRIBUTE_UNUSED static InPlaceSpec getInPlace(Value v) {
@@ -293,11 +304,13 @@ LLVM_ATTRIBUTE_UNUSED static InPlaceSpec getInPlace(Value v) {
 static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
   return
       // clang-format off
-      isa<LinalgOp,
+      isa<scf::ForOp,
+          LinalgOp,
           ReturnOp,
           ExtractSliceOp,
           InsertSliceOp,
-          VectorTransferOpInterface>(op)
+          VectorTransferOpInterface,
+          scf::YieldOp>(op)
       // clang-format on
       || (none_of(op->getResultTypes(),
                   [](Type t) { return t.isa<TensorType>(); }) &&
@@ -305,6 +318,15 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
                   [](Type t) { return t.isa<TensorType>(); }));
 }
 
+/// Return the OpResult that may bufferize into the same buffer as `opOperand`
+/// when the op is bufferized inplace.
+/// Return null if no such result exists.
+static OpResult getInplaceableOpResult(scf::ForOp forOp, OpOperand &opOperand) {
+  if (!opOperand.get().getType().isa<RankedTensorType>())
+    return OpResult();
+  return forOp.getResultForOpOperand(opOperand);
+}
+
 /// Return the OpResult that may bufferize into the same buffer as `opOperand`
 /// when the op is bufferized inplace.
 /// Return null if no such result exists.
@@ -355,7 +377,8 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
       // clang-format off
         // Ops that perform destructive updates on operand(s) to produce
         // result(s).
-        .Case<LinalgOp,
+        .Case<scf::ForOp,
+              LinalgOp,
               InsertSliceOp,
               VectorTransferOpInterface>(
             [&](auto op) { return getInplaceableOpResult(op, opOperand); })
@@ -377,12 +400,15 @@ static Optional<OpResult> getAliasingOpResult(OpOperand &opOperand) {
   if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner()))
     return None;
   return TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
-      // ReturnOp has no result.
-      .Case([&](ReturnOp op) { return OpResult(); })
+      // These terminators legitimately have no result.
+      .Case<ReturnOp, linalg::YieldOp, scf::YieldOp>(
+          [&](auto op) { return OpResult(); })
       // ExtractSliceOp is 
diff erent: its result is not inplaceable on op.source
       // but when bufferized inplace, the result is an aliasing subregion of
       // op.source.
       .Case([&](ExtractSliceOp op) { return op->getResult(0); })
+      // All other ops, including scf::ForOp, return the result of
+      // `getInplaceableOpResult`.
       .Default(
           [&](Operation *op) { return getInplaceableOpResult(opOperand); });
 }
@@ -398,6 +424,10 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
   // may.
   if (isa<ExtractSliceOp>(opOperand.getOwner()))
     return false;
+  // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
+  // matching bbArg may.
+  if (isa<scf::ForOp>(opOperand.getOwner()))
+    return false;
   if (auto linalgOp = dyn_cast<LinalgOp>(opOperand.getOwner()))
     return linalgOp.isInputTensor(&opOperand) ||
            linalgOp.isInitTensor(&opOperand);
@@ -422,8 +452,8 @@ bufferizesToMemoryWrite(OpOperand &opOperand,
   // This does not bufferize to a write.
   if (!*maybeOpResult)
     return false;
-  // A ReturnOp is not a write.
-  if (isa<ReturnOp>(opOperand.getOwner()))
+  // These terminators are not writes.
+  if (isa<ReturnOp, linalg::YieldOp, scf::YieldOp>(opOperand.getOwner()))
     return false;
   // ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses
   // may.
@@ -472,10 +502,14 @@ class BufferizationAliasInfo {
   /// to some use that would bufferize to a write to a buffer.
   bool aliasesInPlaceWrite(ExtractSliceOp extractSliceOp) const;
 
+  /// Set the inPlace bufferization spec to true.
   /// Merge result's and operand's aliasing sets and iterate to a fixed point.
   void bufferizeInPlace(OpResult result, OpOperand &operand,
                         BufferRelation bufferRelation = BufferRelation::None);
 
+  /// Set the inPlace bufferization spec to false.
+  void bufferizeOutOfPlace(OpResult result);
+
   /// Return true if it is possible to find an inplace write W among the uses of
   /// aliasInfo[rootWrite], and a read R among the uses of aliasInfo[rootRead],
   /// such that W and R interfere.
@@ -496,7 +530,13 @@ class BufferizationAliasInfo {
   bool existsNonDominatingRead(OpOperand &opOperand,
                                const DominanceInfo &domInfo) const;
 
-  /// Return true if the source of a `insertSliceOp` bufferizes to an
+  /// Return true if `v1` and `v2` bufferize to equivalent buffers.
+  bool areEquivalentBufferizedValues(Value v1, Value v2) const {
+    return equivalentInfo.getLeaderValue(v1) ==
+           equivalentInfo.getLeaderValue(v2);
+  }
+
+  /// Return true if the source of an `insertSliceOp` bufferizes to an
   /// equivalent ExtractSliceOp.
   bool isSourceEquivalentToAMatchingExtractSliceOp(
       InsertSliceOp insertSliceOp) const;
@@ -601,14 +641,6 @@ class BufferizationAliasInfo {
 } // namespace
 
 BufferizationAliasInfo::BufferizationAliasInfo(FuncOp funcOp) {
-  for (auto bbArg : funcOp.getArguments()) {
-    if (!bbArg.getType().isa<TensorType>())
-      continue;
-    DenseSet<Value> selfSet;
-    selfSet.insert(bbArg);
-    aliasInfo.try_emplace(bbArg, selfSet);
-    equivalentInfo.insert(bbArg);
-  }
   funcOp.walk([&](Operation *op) {
     for (Value v : op->getResults()) {
       if (!v.getType().isa<TensorType>())
@@ -620,6 +652,18 @@ BufferizationAliasInfo::BufferizationAliasInfo(FuncOp funcOp) {
       aliasInfo.try_emplace(v, selfSet);
       equivalentInfo.insert(v);
     }
+    for (Region &r : op->getRegions()) {
+      for (Block &b : r.getBlocks()) {
+        for (auto bbArg : b.getArguments()) {
+          if (!bbArg.getType().isa<TensorType>())
+            continue;
+          DenseSet<Value> selfSet;
+          selfSet.insert(bbArg);
+          aliasInfo.try_emplace(bbArg, selfSet);
+          equivalentInfo.insert(bbArg);
+        }
+      }
+    }
   });
 }
 
@@ -634,13 +678,10 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
   for (Value v : getAliasInfoRef(operand.get())) {
     LDBG("-----------examine: " << v << '\n');
     if (auto bbArg = v.dyn_cast<BlockArgument>()) {
-      // Uses of function arguments that may be written-to can be skipped.
-      if (isa<FuncOp>(bbArg.getOwner()->getParentOp()) &&
-          getInPlace(bbArg) == InPlaceSpec::True) {
+      if (getInPlace(bbArg) == InPlaceSpec::True) {
         LDBG("-----------bbArg is writeable -> skip: " << bbArg << '\n');
         continue;
       }
-      // Conservatively dump any other block argument for now.
       LDBG("-----------notWriteable: " << v << '\n');
       return true;
     }
@@ -675,14 +716,23 @@ bool BufferizationAliasInfo::aliasesInPlaceWrite(
   return false;
 }
 
+/// Set the inPlace bufferization spec to true.
 /// Merge result's and operand's aliasing sets and iterates to a fixed point.
 void BufferizationAliasInfo::bufferizeInPlace(OpResult result,
                                               OpOperand &operand,
                                               BufferRelation bufferRelation) {
+  setInPlaceOpResult(result, InPlaceSpec::True);
   if (mergeAliases(result, operand.get()))
     mergeAliasesToFixedPoint();
   if (bufferRelation == BufferRelation::Equivalent)
     equivalentInfo.unionSets(result, operand.get());
+  // Dump the updated analysis.
+  LLVM_DEBUG(dump());
+}
+
+/// Set the inPlace bufferization spec to false.
+void BufferizationAliasInfo::bufferizeOutOfPlace(OpResult result) {
+  setInPlaceOpResult(result, InPlaceSpec::False);
 }
 
 /// Return true if merging the alias sets of `rootWrite` and `rootRead` would
@@ -1217,6 +1267,44 @@ static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp,
   return success();
 }
 
+static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
+                               BlockAndValueMapping &bvm,
+                               const BufferizationAliasInfo &aliasInfo) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+  Location loc = forOp.getLoc();
+
+  LLVM_DEBUG(DBGS() << "bufferize: " << *forOp << "\n");
+
+  // If inPlace, just forward the buffer.
+  // Otherwise alloc and copy.
+  b.setInsertionPoint(forOp);
+  for (OpResult opResult : forOp->getResults()) {
+    // TODO: Atm we bail on unranked TensorType because we don't know how to
+    // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
+    if (!opResult.getType().isa<RankedTensorType>())
+      return failure();
+    OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
+    Value operand = opOperand.get();
+    Value operandBuffer = lookup(bvm, operand);
+    Value resultBuffer = operandBuffer;
+    if (getInPlace(opResult) != InPlaceSpec::True) {
+      resultBuffer = createNewAllocDeallocPairForShapedValue(b, loc, operand);
+      // If the tensor comes from `linalg::InitTensorOp`, the value is
+      // unitialized and we do not need to copy.
+      // TODO: if the matching bbArg does not bufferize to a read is more
+      // general.
+      if (!operand.getDefiningOp<linalg::InitTensorOp>())
+        b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, resultBuffer);
+    }
+    BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
+    map(bvm, bbArg, resultBuffer);
+    map(bvm, opResult, resultBuffer);
+  }
+
+  return success();
+}
+
 /// FuncOp always creates TensorToMemRef ops.
 static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
                                BlockAndValueMapping &bvm,
@@ -1429,6 +1517,31 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
   return success();
 }
 
+static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
+                               BlockAndValueMapping &bvm,
+                               const BufferizationAliasInfo &aliasInfo) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(yieldOp);
+
+  scf::ForOp forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
+  assert(forOp && "only support scf::ForOp parent for scf::YieldOp");
+  for (OpOperand &operand : yieldOp->getOpOperands()) {
+    auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+    if (!tensorType)
+      continue;
+    OpOperand &forOperand = forOp.getOpOperandForResult(
+        forOp->getResult(operand.getOperandNumber()));
+    auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+    if (getInPlace(bbArg) == InPlaceSpec::True)
+      operand.set(bbArg);
+    else
+      operand.set(
+          b.create<memref::TensorLoadOp>(yieldOp.getLoc(), lookup(bvm, bbArg)));
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization analyses.
 //===----------------------------------------------------------------------===//
@@ -1447,11 +1560,12 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
 ///
 /// An analysis is required to ensure inplace bufferization would not result in
 /// RaW dependence violations.
-static void bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp,
-                                        BufferizationAliasInfo &aliasInfo,
-                                        const DominanceInfo &domInfo) {
+static LogicalResult
+bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp,
+                            BufferizationAliasInfo &aliasInfo,
+                            const DominanceInfo &domInfo) {
   LDBG('\n');
-  LDBG("Try to bufferize extract_slice inplace: " << *extractSliceOp << '\n');
+  LDBG("Inplace analysis for extract_slice: " << *extractSliceOp << '\n');
 
   // If `extractSliceOp` were to be bufferized inplace, it cannot end up
   // aliasing a write into a non-writeable buffer.
@@ -1461,35 +1575,38 @@ static void bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp,
 
   if (wouldCreateAliasingWriteToNonWriteableBuffer)
     LDBG("->the corresponding buffer is not writeable\n");
-  LDBG("->bufferizes to writeable inplace buffer\n");
+  else
+    LDBG("->bufferizes to writeable inplace buffer\n");
 
   // In any of extractSliceOp.result's aliases, can we find 2 such that we hit
   // an interfering write?
-  Value s = extractSliceOp.source(), r = extractSliceOp.result();
+  OpResult r = extractSliceOp->getResult(0);
+  OpOperand &s = extractSliceOp->getOpOperand(0);
   bool foundInterference = wouldCreateAliasingWriteToNonWriteableBuffer ||
                            // Do not consider (s, s) and (r, r) as all the
                            // aliasings already exist by construction; we are
                            // interested in new interfering aliases only.
                            aliasInfo.wouldCreateReadAfterWriteInterference(
-                               s, r, extractSliceOp, domInfo) ||
+                               s.get(), r, extractSliceOp, domInfo) ||
                            aliasInfo.wouldCreateReadAfterWriteInterference(
-                               r, s, extractSliceOp, domInfo);
-  if (foundInterference) {
-    setInPlaceOpResult(extractSliceOp->getResult(0), InPlaceSpec::False);
-  } else {
-    setInPlaceOpResult(extractSliceOp->getResult(0), InPlaceSpec::True);
-    aliasInfo.bufferizeInPlace(extractSliceOp->getResult(0),
-                               extractSliceOp->getOpOperand(0));
-  }
-  LDBG("Done bufferizing extract_slice\n");
+                               r, s.get(), extractSliceOp, domInfo);
+  if (foundInterference)
+    aliasInfo.bufferizeOutOfPlace(r);
+  else
+    aliasInfo.bufferizeInPlace(r, s);
+
+  LDBG("Done inplace analysis for extract_slice\n");
+
+  return success();
 }
 
 /// Analyze the (opOperand, result) pair to determine whether the result can
 /// be bufferized inPlace. If successful, InPlaceSpec::True is set for
 /// `result`. Otherwise, InPlaceSpec::False is set for `result`.
-static void bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
-                                        BufferizationAliasInfo &aliasInfo,
-                                        const DominanceInfo &domInfo) {
+static LogicalResult
+bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
+                            BufferizationAliasInfo &aliasInfo,
+                            const DominanceInfo &domInfo) {
   Operation *op = result.getDefiningOp();
   assert(result && !isa<ExtractSliceOp>(op) &&
          "expected OpResult not coming from a ExtractSliceOp");
@@ -1497,9 +1614,9 @@ static void bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
   int64_t resultNumber = result.getResultNumber();
   (void)resultNumber;
   LDBG('\n');
-  LDBG("Try to bufferize inplace result #"
-       << resultNumber << " (operand #" << operand.getOperandNumber() << ") in "
-       << result << '\n');
+  LDBG("Inplace analysis for result #" << resultNumber << " (operand #"
+                                       << operand.getOperandNumber() << ") in "
+                                       << result << '\n');
 
   // `result` must bufferize to a writeable buffer to be a candidate.
   // This means the use->def chain not backpropagate to a function that is
@@ -1508,7 +1625,8 @@ static void bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
       aliasInfo.aliasesNonWriteableBuffer(operand);
   if (wouldCreateAliasingWriteToNonWriteableBuffer)
     LDBG("->the corresponding buffer is not writeable\n");
-  LDBG("->bufferizes to writeable inplace buffer\n");
+  else
+    LDBG("->bufferizes to writeable inplace buffer\n");
 
   Value s = operand.get(), r = result;
   bool foundInterference =
@@ -1520,22 +1638,56 @@ static void bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
       aliasInfo.wouldCreateReadAfterWriteInterference(s, r, op, domInfo) ||
       aliasInfo.wouldCreateReadAfterWriteInterference(r, s, op, domInfo);
 
-  if (foundInterference) {
-    setInPlaceOpResult(result, InPlaceSpec::False);
-  } else {
-    setInPlaceOpResult(result, InPlaceSpec::True);
+  if (foundInterference)
+    aliasInfo.bufferizeOutOfPlace(result);
+  else
     // TODO: Atm, all inplace bufferizations yield equivalent tensors. Support
     // more cases on a per-need basis.
     aliasInfo.bufferizeInPlace(
         result, operand, BufferizationAliasInfo::BufferRelation::Equivalent);
+
+  LDBG("Done inplace analysis for result #" << resultNumber << '\n');
+
+  return success();
+}
+
+/// Return `failure()` if either
+/// scf::YieldOp are not explicitly bufferized and we need to perform a separate
+/// sanity check for now.
+static LogicalResult
+bufferizationSanityCheck(scf::YieldOp yieldOp,
+                         const BufferizationAliasInfo &aliasInfo) {
+  auto parentForOp = yieldOp->getParentOfType<scf::ForOp>();
+  if (!parentForOp)
+    return failure();
+
+  for (OpOperand &operand : yieldOp->getOpOperands()) {
+    OpResult matchingForOpResult =
+        parentForOp->getResult(operand.getOperandNumber());
+    // Nothing to do if operand bufferizes out of place.
+    if (getInPlace(matchingForOpResult) != InPlaceSpec::True)
+      continue;
+    OpOperand &machingForOpOperand =
+        parentForOp.getOpOperandForResult(matchingForOpResult);
+    BlockArgument matchingForOpIterArg =
+        parentForOp.getRegionIterArgForOpOperand(machingForOpOperand);
+    if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg,
+                                                 operand.get())) {
+      yieldOp->emitError()
+          << "Yield operand #" << operand.getOperandNumber()
+          << " does not bufferize to an equivalent buffer to the matching"
+          << " enclosing scf::for operand -> Fail the pass\n";
+      return failure();
+    }
   }
-  LDBG("Done bufferizing result #" << resultNumber << '\n');
+
+  return success();
 }
 
 /// Analyze the `funcOp` body to determine which OpResults are inplaceable.
-static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp,
-                                           BufferizationAliasInfo &aliasInfo,
-                                           const DominanceInfo &domInfo) {
+static LogicalResult
+inPlaceAnalysisFuncOpInternals(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+                               const DominanceInfo &domInfo) {
   LLVM_DEBUG(llvm::dbgs() << "\n\n");
   LDBG("Begin InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');
   assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() &&
@@ -1565,9 +1717,10 @@ static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp,
   // Walk InsertSliceOp in reverse for better interference behavior.
   for (InsertSliceOp insertSliceOp : reverse(insertSliceOps)) {
     OpOperand &destOpOperand = insertSliceOp->getOpOperand(1);
-    bufferizableInPlaceAnalysis(destOpOperand,
-                                getInplaceableOpResult(destOpOperand),
-                                aliasInfo, domInfo);
+    if (failed(bufferizableInPlaceAnalysis(
+            destOpOperand, getInplaceableOpResult(destOpOperand), aliasInfo,
+            domInfo)))
+      return failure();
   }
 
   // Bufferize all ops except ExtractSliceOp and InsertSliceOp which are handled
@@ -1576,15 +1729,25 @@ static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp,
   for (Operation *op : reverse(nonSliceOps))
     for (OpOperand &opOperand : op->getOpOperands())
       if (OpResult result = getInplaceableOpResult(opOperand))
-        bufferizableInPlaceAnalysis(opOperand, result, aliasInfo, domInfo);
+        if (failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo,
+                                               domInfo)))
+          return failure();
 
   // Finally, bufferize ExtractSliceOp.
   // Walk ExtractSliceOps in reverse for better clobbering behavior: it is
   // easier to detect clobbers of smaller slices before larger ones.
   for (ExtractSliceOp extractSliceOp : reverse(extractSliceOps))
-    bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo);
+    if (failed(bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo)))
+      return failure();
+
+  // Sanity checks.
+  auto walkResult = funcOp.walk([&](scf::YieldOp yieldOp) -> WalkResult {
+    return bufferizationSanityCheck(yieldOp, aliasInfo);
+  });
 
   LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');
+
+  return success(!walkResult.wasInterrupted());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1600,7 +1763,8 @@ bufferizeFuncOpInternals(FuncOp funcOp, BlockAndValueMapping &bvm,
   /// Start by bufferizing `funcOp` arguments.
   if (failed(bufferize(b, funcOp, bvm, aliasInfo)))
     return failure();
-  WalkResult result = funcOp.walk<WalkOrder::PostOrder>([&](Operation *op) {
+  // Walk in PreOrder to ensure ops with regions are handled before their body.
+  WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
     LogicalResult status =
         TypeSwitch<Operation *, LogicalResult>(op)
             // Skip BufferCast and TensorLoad ops.
@@ -1609,12 +1773,17 @@ bufferizeFuncOpInternals(FuncOp funcOp, BlockAndValueMapping &bvm,
                   memref::TensorLoadOp>(
                 [&](auto) { return success(); })
             .Case<memref::DimOp,
+                  scf::ForOp,
                   LinalgOp,
                   ReturnOp,
                   ExtractSliceOp,
                   InsertSliceOp,
-                  VectorTransferOpInterface>(
-                [&](auto op) { return bufferize(b, op, bvm, aliasInfo); })
+                  VectorTransferOpInterface,
+                  scf::YieldOp>(
+                [&](auto op) {
+                  LDBG("Begin buferize:\n" << op << '\n');
+                  return bufferize(b, op, bvm, aliasInfo);
+                })
             // clang-format on
             .Default([&](Operation *op) {
               auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -1652,7 +1821,12 @@ void LinalgComprehensiveFuncBufferize::runOnFunction() {
   // Analysis phase.
   DominanceInfo domInfo(funcOp);
   BufferizationAliasInfo aliasInfo(funcOp);
-  inPlaceAnalysisFuncOpInternals(funcOp, aliasInfo, domInfo);
+  // If the analysis fails, just return. This is expected to reset the IR and no
+  // single OpResult should be marked inPlace.
+  if (failed(inPlaceAnalysisFuncOpInternals(funcOp, aliasInfo, domInfo))) {
+    signalPassFailure();
+    return;
+  }
 
   if (testAnalysisOnly)
     return;

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis-invalid.mlir
new file mode 100644
index 0000000000000..41e698f97c873
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis-invalid.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -linalg-comprehensive-func-bufferize=test-analysis-only -split-input-file -verify-diagnostics
+
+// -----
+
+func @scf_for(%A : tensor<?xf32>,
+              %B : tensor<?xf32> {linalg.inplaceable = true},
+              %C : tensor<4xf32>,
+              %lb : index, %ub : index, %step : index)
+  -> (tensor<?xf32>, tensor<?xf32>)
+{
+  %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
+      -> (tensor<?xf32>, tensor<?xf32>)
+  {
+    %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor<?xf32>
+    %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor<?xf32>
+
+    // Throw a wrench in the system by swapping yielded values: this result in a
+    // ping-pong of values at each iteration on which we currently want to fail.
+
+    // expected-error @+1 {{Yield operand #1 does not bufferize to an equivalent buffer}}
+    scf.yield %ttB, %ttA : tensor<?xf32>, tensor<?xf32>
+  }
+
+  return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
+}
+

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis.mlir
index 5ee495eca8e8c..5234d85b0b5b1 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis.mlir
@@ -412,3 +412,63 @@ func @nested_extract_slice_and_insert(
   return %rA, %rB, %rC: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
 }
 
+//===----------------------------------------------------------------------===//
+// Simple loop cases
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: func @scf_for_yield_only
+func @scf_for_yield_only(%A : tensor<?xf32>,
+                         %B : tensor<?xf32> {linalg.inplaceable = true},
+                         %lb : index, %ub : index, %step : index)
+  -> (tensor<?xf32>, tensor<?xf32>)
+{
+  //      CHECK: scf.for
+  // CHECK-NEXT: scf.yield
+  // CHECK-NEXT: {__inplace_results_attr__ = ["false"]}
+  %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
+    scf.yield %t : tensor<?xf32>
+  }
+
+  //      CHECK: scf.for
+  // CHECK-NEXT: scf.yield
+  // CHECK-NEXT: {__inplace_results_attr__ = ["true"]}
+  %r1 = scf.for %i = %lb to %ub step %step iter_args(%t = %B) -> (tensor<?xf32>) {
+    scf.yield %t : tensor<?xf32>
+  }
+
+  return %r0, %r1: tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_for_with_tensor.insert_slice
+func @scf_for_with_tensor.insert_slice(%A : tensor<?xf32>,
+              %B : tensor<?xf32> {linalg.inplaceable = true},
+              %C : tensor<4xf32>,
+              %lb : index, %ub : index, %step : index)
+  -> (tensor<?xf32>, tensor<?xf32>)
+{
+  //      CHECK: scf.for
+  // scf.for bbArgs are always inplaceable seen from ops inside the body:
+  //   1. Either the matching tensor is not inplaceable and an alloc occurs
+  //      which makes bbArg inplaceable.
+  //   2. Or it is already inplaceable and so is bbArg.
+  // CHECK-NEXT:   tensor.insert_slice
+  // CHECK-SAME:     {__inplace_results_attr__ = ["true"]}
+  // CHECK-NEXT:   tensor.insert_slice
+  // CHECK-SAME:     {__inplace_results_attr__ = ["true"]}
+  // CHECK-NEXT:   scf.yield
+  // CHECK-NEXT: {__inplace_results_attr__ = ["false", "true"]}
+  %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
+      -> (tensor<?xf32>, tensor<?xf32>)
+  {
+    %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor<?xf32>
+    %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor<?xf32>
+    scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
+  }
+
+  return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
+}
+

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
index 0b16800dfa5e4..e217a7062a94f 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
@@ -273,3 +273,81 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
   //     CHECK: return %[[RES]]
   return %r0: tensor<4xf32>
 }
+
+//===----------------------------------------------------------------------===//
+// Simple loop cases
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: func @scf_for_yield_only
+func @scf_for_yield_only(%A : tensor<?xf32>,
+                         %B : tensor<?xf32> {linalg.inplaceable = true},
+                         %lb : index, %ub : index, %step : index)
+  -> (tensor<?xf32>, tensor<?xf32>)
+{
+  //     CHECK:   %[[ALLOC_FOR_A:.*]] = memref.alloc
+  //     CHECK:   %[[BUFFER_CAST_A:.*]] = memref.buffer_cast
+  //     CHECK:   %[[BUFFER_CAST_B:.*]] = memref.buffer_cast
+  //     CHECK:   linalg.copy(%[[BUFFER_CAST_A]], %[[ALLOC_FOR_A]])
+
+  // The first scf.for remains but just turns into dead code.
+  %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
+    scf.yield %t : tensor<?xf32>
+  }
+
+  // The second scf.for remains but just turns into dead code.
+  %r1 = scf.for %i = %lb to %ub step %step iter_args(%t = %B) -> (tensor<?xf32>) {
+    scf.yield %t : tensor<?xf32>
+  }
+
+  // Cross function call alloc/dealloc pattern must be hoist out.
+  //     CHECK:   memref.dealloc %[[ALLOC_FOR_A]] : memref<?xf32>
+  //     CHECK:   %[[rA:.*]] = memref.tensor_load %[[ALLOC_FOR_A]]
+  // Returning tensor_load of the buffer cast makes the %r1 loop dead.
+  //     CHECK:   %[[rB:.*]] = memref.tensor_load %[[BUFFER_CAST_B:.*]]
+  //     CHECK:   return %[[rA]], %[[rB]] : tensor<?xf32>, tensor<?xf32>
+  return %r0, %r1: tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_for_with_tensor.insert_slice
+func @scf_for_with_tensor.insert_slice(
+   %A : tensor<?xf32>,
+              %B : tensor<?xf32> {linalg.inplaceable = true},
+              %C : tensor<4xf32>,
+              %lb : index, %ub : index, %step : index)
+  -> (tensor<?xf32>, tensor<?xf32>)
+{
+  //     CHECK:   %[[ALLOC_FOR_A:.*]] = memref.alloc
+  //     CHECK:   %[[BUFFER_CAST_A:.*]] = memref.buffer_cast
+  //     CHECK:   %[[BUFFER_CAST_B:.*]] = memref.buffer_cast
+  //     CHECK:   %[[BUFFER_CAST_C:.*]] = memref.buffer_cast
+  //     CHECK:   linalg.copy(%[[BUFFER_CAST_A]], %[[ALLOC_FOR_A]])
+
+  //     CHECK:   scf.for {{.*}} iter_args(%[[bbA:.*]] = %{{.*}}, %[[bbB:.*]] = %{{.*}})
+  %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
+      -> (tensor<?xf32>, tensor<?xf32>)
+  {
+    //     CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1]
+    // %ttA bufferizes to direct copy of %BUFFER_CAST_C into %svA
+    //     CHECK: linalg.copy(%[[BUFFER_CAST_C]], %[[svA]])
+    %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor<?xf32>
+
+    // %ttB bufferizes to direct copy of %BUFFER_CAST_C into %BUFFER_CAST_B
+    //     CHECK:   %[[svB:.*]] = memref.subview %[[BUFFER_CAST_B]][0] [4] [1]
+    //     CHECK:   linalg.copy(%[[BUFFER_CAST_C]], %[[svB]])
+    %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor<?xf32>
+
+    // Yielding bbA and bbB will canonicalize away into oblivion.
+    //     CHECK:   scf.yield %[[bbA]], %[[bbB]] : tensor<?xf32>, tensor<?xf32>
+    scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
+  }
+
+  //     CHECK:  memref.dealloc %[[ALLOC_FOR_A]] : memref<?xf32>
+  //     CHECK:  %[[rA:.*]] = memref.tensor_load %[[ALLOC_FOR_A]] : memref<?xf32>
+  //     CHECK:  %[[rB:.*]] = memref.tensor_load %[[BUFFER_CAST_B]] : memref<?xf32, #map>
+  //     CHECK:  return %[[rA]], %[[rB]] : tensor<?xf32>, tensor<?xf32>
+  return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list