[Mlir-commits] [mlir] ce4f99e - [mlir][Linalg] Add comprehensive bufferization support for subtensor (5/n)

Nicolas Vasilache llvmlistbot at llvm.org
Thu May 27 05:51:10 PDT 2021


Author: Nicolas Vasilache
Date: 2021-05-27T12:48:08Z
New Revision: ce4f99e7f272d481d0689c551d9818019992c841

URL: https://github.com/llvm/llvm-project/commit/ce4f99e7f272d481d0689c551d9818019992c841
DIFF: https://github.com/llvm/llvm-project/commit/ce4f99e7f272d481d0689c551d9818019992c841.diff

LOG: [mlir][Linalg] Add comprehensive bufferization support for subtensor (5/n)

This revision refactors and simplifies the pattern detection logic: thanks to SSA value properties, we can actually look at all the uses of a given value and avoid having to pattern-match specific chains of operations.

A bufferization pattern for subtensor is added and specific inplaceability analysis is implemented for the simple case of subtensor. More advanced use cases will follow.

Differential revision: https://reviews.llvm.org/D102512

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.td
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp
    mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 2023f8820d0a4..d36d655638a2d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -34,6 +34,11 @@ def LinalgComprehensiveFuncBufferize :
     same buffers. The analysis is performed on SSA use-def chains starting from
     function operands that are annotated with the 'inplaceable' attribute
   }];
+  let options = [
+    Option<"testAnalysisOnly", "test-analysis-only", "bool",
+            /*default=*/"false",
+           "Only runs inplaceability analysis (for testing purposes only)">
+  ];
   let constructor = "mlir::createLinalgComprehensiveFuncBufferizePass()";
 }
 

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 60f20432d3c3b..0094fffeea966 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -32,6 +32,10 @@ class OffsetSizeAndStrideOpInterface;
 
 namespace detail {
 LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
+
+bool sameOffsetsSizesAndStrides(
+    OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
+    llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp);
 } // namespace detail
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index f94350fe50e31..e26a02f61966a 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -419,6 +419,23 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
         return $_op.getOperand(getIndexOfDynamicStride(idx));
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if all `other`'s offsets, sizes and strides are the same.
+        Takes a custom `cmp` comparison function on OpFoldResult to avoid taking
+        a dialect dependence.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isSameAs",
+      /*args=*/(ins "OffsetSizeAndStrideOpInterface":$other,
+                   "llvm::function_ref<bool(OpFoldResult, OpFoldResult)>":$cmp),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return detail::sameOffsetsSizesAndStrides(
+          ::mlir::cast<::mlir::OffsetSizeAndStrideOpInterface>(
+            $_op.getOperation()), other, cmp);
+      }]
+    >,
   ];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 4ad8095505025..d14281a6b19c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -88,6 +88,7 @@
 #include "mlir/Transforms/BufferUtils.h"
 
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 #define DEBUG_TYPE "comprehensive-func-bufferize"
@@ -152,12 +153,18 @@ OpResult getMatchingOpResult(SubTensorInsertOp op, OpOperand &opOperand) {
 /// analysis to determine which op results reuse the same buffer as some
 /// operand.
 OpResult getMatchingOpResult(OpOperand &opOperand) {
-  OpResult res =
-      llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
-          .Case<LinalgOp, SubTensorInsertOp, VectorTransferOpInterface>(
-              [&](auto op) { return getMatchingOpResult(op, opOperand); })
-          .Default([&](Operation *op) { return OpResult(); });
-  return res;
+  return llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
+      // clang-format off
+        // Ops that perform destructive updates on operand(s) to produce
+        // result(s).
+        .Case<LinalgOp,
+              SubTensorInsertOp,
+              VectorTransferOpInterface>(
+            [&](auto op) { return getMatchingOpResult(op, opOperand); })
+        // Other ops.
+        .Case<SubTensorOp>([&](auto op) { return OpResult(); })
+        .Default([&](Operation *op) { return OpResult(); });
+  // clang-format on
 }
 
 //===----------------------------------------------------------------------===//
@@ -290,70 +297,6 @@ static Value lookup(BlockAndValueMapping &bvm, Value key) {
   return bvm.lookup(key);
 }
 
-//===----------------------------------------------------------------------===//
-// Bufferization-specific support.
-//===----------------------------------------------------------------------===//
-
-/// Determine whether any subsequent read of the tensor `opOperand` may occur.
-/// For now, this assumes any use is a read. If any use of the tensor does not
-/// properly dominate `opOperand.getOwner()`, then the tensor cannot be
-/// bufferized inPlace.
-// TODO: For now, this assumes any use is a read. Refine this.
-bool hasInterferingTensorRead(OpOperand &opOperand,
-                              const DominanceInfo &domInfo) {
-  if (!opOperand.get().getType().isa<RankedTensorType>())
-    return false;
-  for (auto &use : opOperand.get().getUses()) {
-    Operation *user = use.getOwner();
-
-    // If properly dominate, there  is a clear sequence point and we can dismiss
-    // read.
-    if (domInfo.properlyDominates(user, opOperand.getOwner()))
-      continue;
-    // Otherwise, we need to analyze self-dependencies, for now just let it go.
-    // TODO: proper self-dependence analysis.
-    if (domInfo.dominates(user, opOperand.getOwner()))
-      continue;
-    if (user == opOperand.getOwner() &&
-        use.getOperandNumber() == opOperand.getOperandNumber())
-      continue;
-    LLVM_DEBUG(DBGS() << "found interfering read operand #"
-                      << opOperand.getOperandNumber()
-                      << " in op: " << *opOperand.getOwner() << "\n");
-    return true;
-  }
-  LLVM_DEBUG(DBGS() << "no interfering read\n");
-  return false;
-}
-
-/// Return false if either:
-/// 1. `opOperand` is produced by a constant op. For now this is assumed to be
-///    bufferized to a GlobalMemrefOp that cannot be written. Generalize in the
-///    future.
-/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be
-///    bufferizable inplace.
-/// 3.`opOperand` has an interfering tensor read.
-/// Return true otherwise.
-bool isBufferizableInPlace(OpOperand &opOperand, const DominanceInfo &domInfo) {
-  // Constant tensors are deemed not bufferizable for now.
-  if (auto constantOp =
-          dyn_cast_or_null<ConstantOp>(opOperand.get().getDefiningOp()))
-    return !constantOp.getResult().getType().isa<RankedTensorType>();
-  if (auto bbArg = opOperand.get().dyn_cast<BlockArgument>()) {
-    // Uses of function arguments that may not be written-to need to be copied.
-    // If the function argument itself is not inplaceable, early return false.
-    // If is is inplaceable, interfering tensor read need to be checked.
-    //
-    // TODO: better propagate the fact that we want a single clone inside the
-    // function. Atm every user that wants to write inplace will create its own
-    // alloc, irrespective of whether or not interfering reads occur.
-    if (isa<FuncOp>(bbArg.getOwner()->getParentOp()))
-      if (getInPlace(bbArg) != InPlaceSpec::True)
-        return false;
-  }
-  return !hasInterferingTensorRead(opOperand, domInfo);
-}
-
 //===----------------------------------------------------------------------===//
 // Bufferization-specific MemRefType support.
 //===----------------------------------------------------------------------===//
@@ -399,26 +342,6 @@ static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
                          stridedLayout, addressSpace);
 }
 
-//===----------------------------------------------------------------------===//
-// Bufferization-specific inPlace pattern matching support.
-//===----------------------------------------------------------------------===//
-
-/// First assign `op` if `slice.back()` isa `T`, then check condition.
-/// If anything fails just return failure. Otherwise update `sliceRef` by
-/// dropping `sliceRef.back()`, then return success().
-template <typename T>
-static LogicalResult
-matchAndDropBack(ArrayRef<Operation *> &sliceRef, T &op,
-                 llvm::function_ref<LogicalResult(T)> condition = nullptr) {
-  if (sliceRef.empty())
-    return failure();
-  op = dyn_cast<T>(sliceRef.back());
-  if (!op || (condition && failed(condition(op))))
-    return failure();
-  sliceRef = sliceRef.drop_back();
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Bufferization-specific scoped alloc/dealloc insertion support.
 //===----------------------------------------------------------------------===//
@@ -470,121 +393,6 @@ static Value createNewAllocDeallocPairForShapedValue(
   return casted;
 }
 
-//===----------------------------------------------------------------------===//
-// Bufferization-specific inPlace analysis support.
-//===----------------------------------------------------------------------===//
-
-/// Detect the simple terminator pattern:
-/// ```
-///    candidate -> ... -> inplaceable_op(candidate) -> term
-/// ```
-template <typename ContainerOp, typename TerminatorOp>
-static LogicalResult detectInplaceOpToTerminator(Operation *parentOp,
-                                                 BlockArgument candidate,
-                                                 ArrayRef<Operation *> slice) {
-  assert(parentOp && "Unexpected null parent op");
-  if (!isa<ContainerOp>(parentOp))
-    return failure();
-  TerminatorOp terminatorOp;
-  // Match returnOp and update slice.
-  if (failed(matchAndDropBack(slice, terminatorOp))) {
-    LLVM_DEBUG(DBGS() << "FAIL: inplaceOpToTerm pattern -> slice must end with "
-                         "a known terminator\n");
-    return failure();
-  }
-  return success();
-}
-
-/// The following uses internal knowledge of the position of tied operand /
-/// results.
-static void propagateInPlace(const SmallVector<OpOperand *> &initalWorklist,
-                             const DominanceInfo &domInfo) {
-  LLVM_DEBUG(DBGS() << "\n\n");
-  LLVM_DEBUG(DBGS() << "Start propagateInPlace from initial WL\n");
-  LLVM_DEBUG(for (OpOperand *operand
-                  : initalWorklist) DBGS()
-             << "WL item: " << operand->get() << " used by "
-             << *operand->getOwner() << "\n");
-  SmallVector<OpOperand *> worklist(initalWorklist);
-  for (unsigned idx = 0; idx < worklist.size(); ++idx) {
-    // TODO: bail on subtensor/subtensor_insert and vector.transfer_read/write
-    // that should have been already captured in destructive update patterns?
-    OpOperand &operand = *worklist[idx];
-    LLVM_DEBUG(DBGS() << "WL item: " << *operand.getOwner() << "\n");
-    // If the owner turns out to be a CallOp without
-    // `kWriteableFuncBufferArgsAttrName` this will be a noop.
-    if (isBufferizableInPlace(operand, domInfo)) {
-      LLVM_DEBUG(DBGS() << "bufferizable inplace\n");
-      setInPlaceOpResult(getMatchingOpResult(operand));
-    }
-    LLVM_DEBUG(DBGS() << "propagatedInPlace: " << *operand.getOwner() << "\n");
-    // use can have interfering reads that prevent it from being written inPlace
-    // but the values it produces are still themselves candidates for inPlace at
-    // their point of use.
-    for (Value v : operand.getOwner()->getResults()) {
-      LLVM_DEBUG(DBGS() << "propagate result: " << v << "\n");
-      for (auto &use : v.getUses()) {
-        LLVM_DEBUG(DBGS() << "add use to WL: " << use.get() << "\n");
-        worklist.push_back(&use);
-      }
-    }
-  }
-  LLVM_DEBUG(DBGS() << "\n\n");
-}
-
-static void propagateInPlace(BlockArgument &bbArg,
-                             const DominanceInfo &domInfo) {
-  SmallVector<OpOperand *> worklist;
-  for (auto &use : bbArg.getUses())
-    worklist.push_back(&use);
-  propagateInPlace(worklist, domInfo);
-}
-
-/// Iterate over bbArgs of `parentOp` and determine if they are the root of a
-/// known destructive update chain. Such a destructive update is related to
-/// traditional loop nest + memory analysis but provides a simpler SSA use-def
-/// chain-based abstraction.
-static void destructiveUpdateAnalysis(Block *block,
-                                      const DominanceInfo &domInfo) {
-  Operation *parentOp = block->getParentOp();
-  for (BlockArgument candidate : block->getArguments()) {
-    LLVM_DEBUG(llvm::dbgs() << "\n\n");
-    LLVM_DEBUG(DBGS() << "Destructive update analysis on candidate: "
-                      << candidate << "\nof:\n"
-                      << *parentOp << "\n");
-
-    if (!candidate.getType().isa<ShapedType>()) {
-      LLVM_DEBUG(DBGS() << "Not a tensor\n");
-      continue;
-    }
-
-    // FuncOp arguments must be inplaceable otherwise they cannot be the root of
-    // a destructive update chain.
-    if (isa<FuncOp>(parentOp) && getInPlace(candidate) != InPlaceSpec::True) {
-      LLVM_DEBUG(DBGS() << "Not inplace\n");
-      continue;
-    }
-
-    llvm::SetVector<Operation *> slice;
-    getForwardSlice(candidate, &slice,
-                    [&](Operation *op) { return op->getBlock() == block; });
-
-    LLVM_DEBUG(DBGS() << "Slice:\n");
-    LLVM_DEBUG(for (auto *op : slice) DBGS() << *op << "\n");
-
-    bool failedDetectingDestructiveUpdate =
-        // func / return inplace patterns.
-        failed(detectInplaceOpToTerminator<FuncOp, ReturnOp>(
-            parentOp, candidate, slice.getArrayRef()));
-    if (failedDetectingDestructiveUpdate) {
-      LLVM_DEBUG(DBGS() << "Failed to detect a destructive update pattern\n");
-      continue;
-    }
-
-    propagateInPlace(candidate, domInfo);
-  }
-}
-
 //===----------------------------------------------------------------------===//
 // Bufferization as simple BlockAndValueMapping rewrites.
 //===----------------------------------------------------------------------===//
@@ -748,6 +556,55 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
   return success();
 }
 
+/// Bufferize SubTensorOp to subview with optional alloc + copy depending on
+/// whether or not it is marked inplaceable.
+/// Note that `getMatchingOpResult` on a SubTensorOp always returns null.
+/// As consequence a SubTensorOp always alloc + copy when taken in isolation.
+static LogicalResult bufferize(OpBuilder &b, SubTensorOp subTensorOp,
+                               BlockAndValueMapping &bvm) {
+  LLVM_DEBUG(DBGS() << "bufferize: " << *subTensorOp << "\n");
+
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(subTensorOp);
+
+  Location loc = subTensorOp.getLoc();
+  // Bail if source was not bufferized.
+  Value srcMemref = lookup(bvm, subTensorOp.source());
+  if (!srcMemref)
+    return failure();
+  auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
+  auto dstTensorType = subTensorOp.result().getType().cast<RankedTensorType>();
+
+  // If not inplaceable, alloc.
+  Value alloc;
+  auto inPlace = getInPlace(subTensorOp->getResult(0));
+  if (inPlace != InPlaceSpec::True) {
+    alloc =
+        createNewAllocDeallocPairForShapedValue(b, loc, subTensorOp.result());
+    b.setInsertionPointAfter(alloc.getDefiningOp());
+  }
+
+  // Bufferize to subview.
+  auto subviewMemRefType =
+      memref::SubViewOp::inferRankReducedResultType(
+          dstTensorType.getRank(), srcMemrefType, subTensorOp.getMixedOffsets(),
+          subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides())
+          .cast<MemRefType>();
+  Value subView = b.create<memref::SubViewOp>(
+      loc, subviewMemRefType, srcMemref, subTensorOp.getMixedOffsets(),
+      subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides());
+
+  /// If not inplaceable, copy.
+  if (alloc) {
+    b.create<CopyOp>(subTensorOp.getLoc(), subView, alloc);
+    subView = alloc;
+  }
+
+  map(bvm, subTensorOp.result(), subView);
+  return success();
+}
+
 static LogicalResult bufferize(OpBuilder &b,
                                SubTensorInsertOp subTensorInsertOp,
                                BlockAndValueMapping &bvm) {
@@ -765,7 +622,7 @@ static LogicalResult bufferize(OpBuilder &b,
   if (inPlace != InPlaceSpec::True) {
     // Since subtensor_insert arise from tiling and introducing loops, this case
     // is generally a deal breaker. When used with loops, this ends up cloning
-    // the whole tensor on every single iteration and is a symtpom of a
+    // the whole tensor on every single iteration and is a symptom of a
     // catastrophically bad scheduling decision.
     // TODO: be very loud about it or even consider failing the pass.
     Value newDstMemref = createNewAllocDeallocPairForShapedValue(
@@ -865,13 +722,165 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Functions and calls bufferization support.
+//===----------------------------------------------------------------------===//
+
+/// Determine whether any subsequent read of the tensor `opOperand` may occur.
+/// For now, this assumes any use is a read. If any use of the tensor does not
+/// properly dominate `opOperand.getOwner()`, then the tensor cannot be
+/// bufferized inPlace.
+// TODO: For now, this assumes any use is a read. Refine this.
+bool hasInterferingTensorRead(OpOperand &opOperand,
+                              const DominanceInfo &domInfo) {
+  if (!opOperand.get().getType().isa<RankedTensorType>())
+    return false;
+  for (auto &use : opOperand.get().getUses()) {
+    Operation *user = use.getOwner();
+    // If properly dominate, there  is a clear sequence point and we can dismiss
+    // read.
+    if (domInfo.properlyDominates(user, opOperand.getOwner()))
+      continue;
+    // Otherwise, we need to analyze self-dependencies, for now just let it go.
+    // TODO: proper self-dependence analysis.
+    if (domInfo.dominates(user, opOperand.getOwner()))
+      continue;
+    if (user == opOperand.getOwner() &&
+        use.getOperandNumber() == opOperand.getOperandNumber())
+      continue;
+    LLVM_DEBUG(DBGS() << "found interfering read operand #"
+                      << opOperand.getOperandNumber()
+                      << " in op: " << *opOperand.getOwner() << "\n");
+    return true;
+  }
+  LLVM_DEBUG(DBGS() << "no interfering read\n");
+  return false;
+}
+
+/// Return false if either:
+/// 1. `opOperand` is produced by a constant op. For now this is assumed to be
+///    bufferized to a GlobalMemrefOp that cannot be written. Generalize in the
+///    future.
+/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be
+///    bufferizable inplace.
+/// Return true otherwise.
+static bool bufferizeToWriteable(OpOperand &opOperand) {
+  // Constant tensors are deemed not bufferizable for now.
+  if (auto constantOp =
+          dyn_cast_or_null<ConstantOp>(opOperand.get().getDefiningOp()))
+    return !constantOp.getResult().getType().isa<RankedTensorType>();
+  if (auto bbArg = opOperand.get().dyn_cast<BlockArgument>()) {
+    // Uses of function arguments that may not be written-to need to be copied.
+    // If the function argument itself is not inplaceable, early return false.
+    // If is is inplaceable, interfering tensor read need to be checked.
+    //
+    // TODO: better propagate the fact that we want a single clone inside the
+    // function. Atm every user that wants to write inplace will create its own
+    // alloc, irrespective of whether or not interfering reads occur.
+    if (isa<FuncOp>(bbArg.getOwner()->getParentOp())) {
+      if (getInPlace(bbArg) != InPlaceSpec::True)
+        return false;
+    } else {
+      // Conservatively dump any other block argument for now.
+      return false;
+    }
+  }
+  return true;
+}
+
+/// Return false if either:
+/// 1. `opOperand` is produced by a constant op. For now this is assumed to be
+///    bufferized to a GlobalMemrefOp that cannot be written. Generalize in the
+///    future.
+/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be
+///    bufferizable inplace.
+/// 3.`opOperand` has an interfering tensor read.
+/// Return true otherwise.
+static bool isBufferizableInPlace(OpOperand &opOperand,
+                                  const DominanceInfo &domInfo) {
+  return bufferizeToWriteable(opOperand) &&
+         !hasInterferingTensorRead(opOperand, domInfo);
+}
+
+/// Return true if `operand` bufferizes to a buffer that is known to never be
+/// written.
+static bool bufferizeToReadOnly(OpOperand &operand) {
+  return llvm::TypeSwitch<Operation *, bool>(operand.getOwner())
+      .Case([&](LinalgOp linalgOp) { return linalgOp.isInputTensor(&operand); })
+      .Default([&](Operation *op) { return false; });
+}
+
+/// Assume operand is a use of a `subTensorOp`.
+/// Return true if this use bufferizes to a buffer that is known to never be
+/// written.
+/// Note: This function takes into consideration uses of subTensorOp and whether
+/// the owner of those uses is inplaceable. This needs to be run in postorder to
+/// provide the most accurate analysis; otherwise it is conservative.
+static bool subTensorUseBufferizesToReadOnly(OpOperand &operand) {
+  assert(operand.get().getDefiningOp<SubTensorOp>() && "expected subtensor op");
+  if (auto subTensorInsertOp =
+          dyn_cast<SubTensorInsertOp>(operand.getOwner())) {
+    return operand.getOperandNumber() == 0 /* source of the subTensorInsert*/ &&
+           // If the subTensorInsertOp is not inplace, there is no possible
+           // internal aliasing with subTensorOp, which is inplaceable.
+           getInPlace(subTensorInsertOp->getResult(0)) != InPlaceSpec::True;
+  }
+  return bufferizeToReadOnly(operand);
+}
+
+/// Return true if `dominator.getOwner()` dominates all other uses of
+/// `dominator.get()`.
+static bool dominatesAllOtherUses(OpOperand &dominator,
+                                  const DominanceInfo &domInfo) {
+  for (OpOperand &use : dominator.get().getUses()) {
+    // Same use.
+    if (use.getOwner() == dominator.getOwner() &&
+        use.getOperandNumber() == dominator.getOperandNumber())
+      continue;
+    if (!domInfo.properlyDominates(dominator.getOwner(), use.getOwner()))
+      return false;
+  }
+  return true;
+}
+
+/// SubTensorOp introduces potential aliasing and a combination of things need
+/// to occur to determine whether it is inplaceable.
+static void analyzeInPlaceSubTensor(SubTensorOp subTensorOp,
+                                    const DominanceInfo &domInfo) {
+  // Case 1:
+  //   a. All uses are known to bufferize to readonly buffers.
+  //   b. The source has no use that is not dominated by subTensorOp.
+  // This can skip bufferizeToWriteable analysis / function boundary annotation.
+  if (llvm::all_of(subTensorOp.result().getUses(),
+                   subTensorUseBufferizesToReadOnly) &&
+      dominatesAllOtherUses(subTensorOp->getOpOperand(0), domInfo))
+    return setInPlaceOpResult(subTensorOp->getResult(0), InPlaceSpec::True);
+
+  // TODO: Implement more advanced use cases.There is a notion of transitivity
+  // and interference sets lurking.
+}
+
+/// Analyze the internals of a FuncOp to determine inplaceable ops.
 static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp,
                                            const DominanceInfo &domInfo) {
   assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() &&
          "expected a funcOp definition with a body");
 
-  // Start propagating from FuncOp bbArgs.
-  destructiveUpdateAnalysis(&funcOp.body().front(), domInfo);
+  funcOp.walk([&](Operation *op) {
+    // Skip SubTensorOp in a first pass.
+    if (auto subTensorOp = dyn_cast<SubTensorOp>(op))
+      return analyzeInPlaceSubTensor(subTensorOp, domInfo);
+
+    // All other ops are checked for `isBufferizableInPlace`.
+    for (OpOperand &opOperand : op->getOpOperands()) {
+      OpResult result = getMatchingOpResult(opOperand);
+      if (result && isBufferizableInPlace(opOperand, domInfo)) {
+        LLVM_DEBUG(DBGS() << "bufferizable inplace operand #"
+                          << opOperand.getOperandNumber() << " in " << *op);
+        setInPlaceOpResult(result);
+      }
+    }
+  });
 }
 
 static LogicalResult bufferizeFuncOpInternals(
@@ -881,15 +890,22 @@ static LogicalResult bufferizeFuncOpInternals(
   /// Start by bufferizing `funcOp` arguments.
   if (failed(bufferize(b, funcOp, bvm)))
     return failure();
-  WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
+  WalkResult result = funcOp.walk<WalkOrder::PostOrder>([&](Operation *op) {
     LogicalResult status =
         llvm::TypeSwitch<Operation *, LogicalResult>(op)
             // Skip BufferCast and TensorLoad ops.
-            .Case<memref::BufferCastOp, memref::TensorLoadOp>(
+            // clang-format off
+            .Case<memref::BufferCastOp,
+                  memref::TensorLoadOp>(
                 [&](auto) { return success(); })
-            .Case<memref::DimOp, LinalgOp, ReturnOp, SubTensorInsertOp,
+            .Case<memref::DimOp,
+                  LinalgOp,
+                  ReturnOp,
+                  SubTensorOp,
+                  SubTensorInsertOp,
                   VectorTransferOpInterface>(
                 [&](auto op) { return bufferize(b, op, bvm); })
+            // clang-format on
             .Default([&](Operation *op) {
               auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
               if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
@@ -925,8 +941,17 @@ void LinalgComprehensiveFuncBufferize::runOnFunction() {
   DominanceInfo domInfo(funcOp);
   BlockAndValueMapping bvm;
   DenseMap<FuncOp, SmallVector<int64_t>> tiedResultsMap;
+  LLVM_DEBUG(llvm::dbgs() << "\n\n");
+  LLVM_DEBUG(DBGS() << "Begin InPlaceAnalysisFuncOpInternals:\n"
+                    << funcOp << "\n");
   inPlaceAnalysisFuncOpInternals(funcOp, domInfo);
+  LLVM_DEBUG(DBGS() << "End InPlaceAnalysisFuncOpInternals:\n"
+                    << funcOp << "\n");
+
+  if (testAnalysisOnly)
+    return;
 
+  LLVM_DEBUG(llvm::dbgs() << "\n\n");
   LLVM_DEBUG(DBGS() << "Begin BufferizeFuncOpInternals:\n" << funcOp << "\n");
   auto guard = llvm::make_scope_exit([&] {
     funcOp.walk(

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index ad2b6b49feb7e..4a963a1d54fda 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -155,3 +155,24 @@ ParseResult mlir::parseOperandsOrIntegersSizesList(
   return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
                                                                integers);
 }
+
+bool mlir::detail::sameOffsetsSizesAndStrides(
+    OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
+    llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
+  if (a.static_offsets().size() != b.static_offsets().size())
+    return false;
+  if (a.static_sizes().size() != b.static_sizes().size())
+    return false;
+  if (a.static_strides().size() != b.static_strides().size())
+    return false;
+  for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
+    if (!cmp(std::get<0>(it), std::get<1>(it)))
+      return false;
+  for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
+    if (!cmp(std::get<0>(it), std::get<1>(it)))
+      return false;
+  for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
+    if (!cmp(std::get<0>(it), std::get<1>(it)))
+      return false;
+  return true;
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
index 19599b99866bf..674483b7acdc4 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -linalg-comprehensive-func-bufferize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-comprehensive-func-bufferize=test-analysis-only -split-input-file | FileCheck %s --check-prefix=ANALYSIS
 
 // CHECK-DAG: #[[$map_2d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
 
@@ -218,3 +219,140 @@ func @subtensor_insert_fun_not_inplace(%A : tensor<?xf32> {linalg.inplaceable =
   %r1 = linalg.fill(%A, %f0) : tensor<?xf32>, f32 -> tensor<?xf32>
   return %r0, %r1: tensor<?xf32>, tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @subtensor_fun
+func @subtensor_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
+  ->  tensor<4xf32>
+{
+  //      CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xf32>
+  // CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]][0] [4] [1]
+  // CHECK: linalg.copy(%[[SV]], %[[ALLOC]])
+  %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+  return %r0: tensor<4xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_readonly_use
+func @subtensor_readonly_use(
+    %A : tensor<?x?xf32> {linalg.inplaceable = true},
+    %B : tensor<4x4xf32>, %C : tensor<4x4xf32>) ->  tensor<4x4xf32>
+{
+  // subtensor is only used as a read.
+  //     ANALYSIS: subtensor {{.*}} {__inplace_results_attr__ = ["true"]}
+  %sA = subtensor %A[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
+  // matmul output operand is not inplaceable at the function boundary.
+  //     ANALYSIS: linalg.matmul {{.*}}
+  // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+  %D = linalg.matmul  ins(%sA, %B: tensor<4x4xf32>, tensor<4x4xf32>)
+                     outs(%B: tensor<4x4xf32>)
+    -> tensor<4x4xf32>
+  return %D: tensor<4x4xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_nonmatching_subtensor_insert_inplace
+func @subtensor_nonmatching_subtensor_insert_inplace(
+    %A : tensor<?xf32> {linalg.inplaceable = true}, %idx: index)
+  ->  tensor<?xf32>
+{
+  // subtensor has no matching subtensor_insert and is not just used by known
+  // readonly ops.
+  //     ANALYSIS: subtensor {{.*}}
+  // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+  %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+  // subtensor_insert can bufferize inplace fine.
+  //     ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]}
+  %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+  return %r1: tensor<?xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_nonmatching_subtensor_insert_non_inplace
+func @subtensor_nonmatching_subtensor_insert_non_inplace(
+    %A : tensor<?xf32> {linalg.inplaceable = false}, %idx: index)
+  ->  tensor<?xf32>
+{
+  // subtensor has no matching subtensor_insert and is not just used by known
+  // readonly ops.
+  //     ANALYSIS: subtensor {{.*}} {__inplace_results_attr__ = ["true"]}
+  %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+  // subtensor_insert cannot bufferize inplace.
+  //     ANALYSIS: subtensor_insert {{.*}}
+  // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+  %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+  return %r1: tensor<?xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_matching_subtensor_insert
+func @subtensor_matching_subtensor_insert(%A : tensor<?xf32> {linalg.inplaceable = true})
+  ->  tensor<?xf32>
+{
+  // subtensor has a matching subtensor_insert that bufferizes inplace.
+  // TODO: Atm subtensor is not inplaceable but can be.
+  // In the grander scheme, this will canonicalize away beforehand.
+  //     ANALYSIS: subtensor {{.*}}
+  // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+  %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+  // subtensor_insert can bufferize inplace fine.
+  //     ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]}
+  %r1 = subtensor_insert %r0 into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+  return %r1: tensor<?xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_matching_and_nonmatching_1
+func @subtensor_matching_and_nonmatching_1(%A : tensor<?xf32> {linalg.inplaceable = true}, %idx: index)
+  ->  (tensor<?xf32>, tensor<?xf32>)
+{
+  // %r1 is not inplaceable and %r2 is a matching subtensor_insert so %r0 could
+  // be inplaceable.
+  // In the grander scheme, %r2 will canonicalize away beforehand but %r0 will still
+  // not be inplaceable as the production of %r1 may involve a self-copy.
+  //     ANALYSIS: subtensor {{.*}}
+  // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+  %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+  //     ANALYSIS: subtensor_insert {{.*}}
+  // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+  %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+  //     ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]}
+  %r2 = subtensor_insert %r0 into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+  return %r1, %r2: tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_matching_and_nonmatching_2
+func @subtensor_matching_and_nonmatching_2(%A : tensor<?xf32> {linalg.inplaceable = true}, %idx: index)
+  ->  (tensor<?xf32>, tensor<?xf32>)
+{
+  // %r1 is not inplaceable and %r2 is a matching subtensor_insert so %r0 should
+  // be inplaceable.
+  // In the grander scheme, %r2 will canonicalize away beforehand and %r0 will become
+  // inplaceable by reducing to the `subtensor_nonmatching_subtensor_insert_non_inplace`
+  // case,
+  //     ANALYSIS: subtensor {{.*}}
+  // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+  %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+  //     ANALYSIS: subtensor_insert {{.*}}
+  // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+  %r2 = subtensor_insert %r0 into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+  //     ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]}
+  %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+
+  return %r1, %r2: tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// TODO: unknown ops, linalg chain success, linalg chain failure.
+


        


More information about the Mlir-commits mailing list