[Mlir-commits] [mlir] 75d6529 - [mlir][linalg][bufferize][NFC] Clean up comments and minor code refactorings

Matthias Springer llvmlistbot at llvm.org
Thu Jan 6 13:26:17 PST 2022


Author: Matthias Springer
Date: 2022-01-07T06:23:01+09:00
New Revision: 75d65293ca83a7ff24e4c6634e46e63e8ae8c24c

URL: https://github.com/llvm/llvm-project/commit/75d65293ca83a7ff24e4c6634e46e63e8ae8c24c
DIFF: https://github.com/llvm/llvm-project/commit/75d65293ca83a7ff24e4c6634e46e63e8ae8c24c.diff

LOG: [mlir][linalg][bufferize][NFC] Clean up comments and minor code refactorings

Differential Revision: https://reviews.llvm.org/D116451

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 0bd42e34f047..921353a23ea7 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -64,14 +64,14 @@ struct AllocationCallbacks {
 std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
 
 /// PostAnalysisSteps can be registered with `BufferizationOptions` and are
-/// executed after the analysis, but before bufferization. They can be used
+/// executed after the analysis, but before bufferization. They can be used to
 /// implement custom dialect-specific optimizations.
 struct PostAnalysisStep {
   virtual ~PostAnalysisStep() {}
 
   /// Run the post analysis step. This function may modify the IR, but must keep
-  /// `aliasInfo` (inside `state`) consistent. Newly created operations and
-  /// operations that should be re-analyzed must be stored in `newOps`.
+  /// `aliasInfo` consistent. Newly created operations and operations that
+  /// should be re-analyzed must be added to `newOps`.
   virtual LogicalResult run(Operation *op, BufferizationState &state,
                             BufferizationAliasInfo &aliasInfo,
                             SmallVector<Operation *> &newOps) = 0;
@@ -102,7 +102,8 @@ struct BufferizationOptions {
   }
 
   /// Allow-list the given dialects in the dialect filter. Only ops from
-  /// allow-listed dialects will be bufferized.
+  /// allow-listed dialects will be bufferized. If no dialect is added, ops from
+  /// any dialect will be bufferized.
   template <typename... DialectTs>
   void addToDialectFilter() {
     // The following expands a call to addToDialectFilterImpl for each dialect
@@ -288,17 +289,7 @@ struct DialectBufferizationState {
 };
 
 /// BufferizationState provides a variety of helper functions for dealing with
-/// tensor values and memref buffers. In particular,
-/// `BufferizableOpInterface::bufferize` implementation should utilize the
-/// following helper functions.
-///
-/// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops
-///   that allocate and/or deallocate memref buffers.
-/// * `lookupBuffer` returns the memref buffer of a given tensor value.
-/// * `getResultBuffer` returns the memref buffer for a given tensor OpResult.
-///   Based on inplace bufferization decisions of the analysis, it may either
-///   directly return a mapped buffer or allocate a new brand new buffer.
-/// * `replaceOp` replaces an op with new values.
+/// tensor values and memref buffers.
 class BufferizationState {
 public:
   BufferizationState(Operation *op, const BufferizationOptions &options);
@@ -396,7 +387,8 @@ class BufferizationState {
   /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
   /// a new buffer and copy over data from the existing buffer if out-of-place
   /// bufferization is necessary.
-  Value getResultBuffer(RewriterBase &rewriter, OpResult result) const;
+  FailureOr<Value> getResultBuffer(RewriterBase &rewriter,
+                                   OpResult result) const;
 
   /// Return dialect-specific bufferization state.
   template <typename StateT>
@@ -455,12 +447,9 @@ MemRefType getContiguousMemRefType(ShapedType shapedType,
                                    MemRefLayoutAttrInterface layout = {},
                                    Attribute memorySpace = {});
 
-/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
-/// with the same shape as `shapedType` and specified `layout` and
-/// `addressSpace` or an UnrankedMemRefType otherwise.
-Type getContiguousOrUnrankedMemRefType(Type type,
-                                       MemRefLayoutAttrInterface layout = {},
-                                       Attribute memorySpace = {});
+/// Return an UnrankedMemRefType with the given element type and memory space.
+UnrankedMemRefType getUnrankedMemRefType(Type elementType,
+                                         Attribute memorySpace = {});
 
 /// Return a MemRefType to which the `tensorType` can be bufferized in a
 /// composable fashion. The layout must be the most dynamic possible and
@@ -493,7 +482,7 @@ struct AllocationHoistingBarrierOnly
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const BufferizationState &state) const {
-    return false;
+    return true;
   }
 
   SmallVector<OpOperand *>

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
index 1da5903f1048..e56371617b97 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -23,6 +23,12 @@ class BufferizationAliasInfo;
 namespace linalg_ext {
 
 struct InitTensorEliminationStep : public PostAnalysisStep {
+  /// A function that matches anchor OpOperands for InitTensorOp elimination.
+  using AnchorMatchFn = std::function<bool(OpOperand &)>;
+
+  /// A function that rewrites matched anchors.
+  using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
+
   /// Try to eliminate InitTensorOps inside `op`.
   ///
   /// * `rewriteFunc` generates the replacement for the InitTensorOp.
@@ -33,12 +39,11 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
   ///   InitTensorOp.
   /// * The result of `rewriteFunc` must usually be analyzed for inplacability.
   ///   This analysis can be skipped with `skipAnalysis`.
-  LogicalResult eliminateInitTensors(
-      Operation *op, BufferizationState &state,
-      BufferizationAliasInfo &aliasInfo,
-      std::function<bool(OpOperand &)> anchorMatchFunc,
-      std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
-      SmallVector<Operation *> &newOps);
+  LogicalResult eliminateInitTensors(Operation *op, BufferizationState &state,
+                                     BufferizationAliasInfo &aliasInfo,
+                                     AnchorMatchFn anchorMatchFunc,
+                                     RewriteFn rewriteFunc,
+                                     SmallVector<Operation *> &newOps);
 };
 
 /// Try to eliminate InitTensorOps inside `op` that are anchored on an

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
index 6dcc7c5fca92..bed69f19582f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
@@ -13,6 +13,8 @@
 
 void mlir::linalg::comprehensive_bufferize::affine_ext::
     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
+  // AffineParallelOp bufferization not implemented yet. However, never hoist
+  // memref allocations across AffineParallelOp boundaries.
   registry.addOpInterface<AffineParallelOp,
                           AllocationHoistingBarrierOnly<AffineParallelOp>>();
 }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index 40d54445c5e8..3c0926e3fae6 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -20,23 +20,30 @@ namespace linalg {
 namespace comprehensive_bufferize {
 namespace arith_ext {
 
+/// Bufferization of arith.constant. Replace with memref.get_global.
 struct ConstantOpInterface
     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
                                                     arith::ConstantOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationState &state) const {
     auto constantOp = cast<arith::ConstantOp>(op);
-    assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
-           "not a constant ranked tensor");
+
+    // Only ranked tensors are supported.
+    if (!constantOp.getType().isa<RankedTensorType>())
+      return failure();
+
+    // Only constants inside a module are supported.
     auto moduleOp = constantOp->getParentOfType<ModuleOp>();
     if (!moduleOp)
-      return constantOp.emitError(
-          "cannot bufferize constants not within builtin.module op");
+      return failure();
 
+    // Create global memory segment and replace tensor with memref pointing to
+    // that memory segment.
     GlobalCreator globalCreator(moduleOp);
     auto globalMemref = globalCreator.getGlobalFor(constantOp);
     replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
         rewriter, op, globalMemref.type(), globalMemref.getName());
+
     return success();
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 84dba99b0840..118e25a23148 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -74,6 +74,21 @@ mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
 BufferizationOptions::BufferizationOptions()
     : allocationFns(defaultAllocationCallbacks()) {}
 
+BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
+    BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
+  if (isOpAllowed(op))
+    return dyn_cast<BufferizableOpInterface>(op);
+  return nullptr;
+}
+
+BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
+    BufferizationOptions::dynCastBufferizableOp(Value value) const {
+  if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
+    if (isOpAllowed(bufferizableOp.getOperation()))
+      return bufferizableOp;
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizationAliasInfo
 //===----------------------------------------------------------------------===//
@@ -180,21 +195,6 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
   }
 }
 
-BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
-    BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
-  if (isOpAllowed(op))
-    return dyn_cast<BufferizableOpInterface>(op);
-  return nullptr;
-}
-
-BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
-    BufferizationOptions::dynCastBufferizableOp(Value value) const {
-  if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
-    if (isOpAllowed(bufferizableOp.getOperation()))
-      return bufferizableOp;
-  return nullptr;
-}
-
 /// Determine which OpOperand* will alias with `result` if the op is bufferized
 /// in place. Return an empty vector if the op is not bufferizable.
 SmallVector<OpOperand *>
@@ -358,8 +358,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
 /// a new buffer and copy over data from the existing buffer if out-of-place
 /// bufferization is necessary.
-Value mlir::linalg::comprehensive_bufferize::BufferizationState::
-    getResultBuffer(RewriterBase &rewriter, OpResult result) const {
+FailureOr<Value>
+mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
+    RewriterBase &rewriter, OpResult result) const {
   OpBuilder::InsertionGuard guard(rewriter);
   Operation *op = result.getOwner();
   SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
@@ -375,10 +376,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
   if (aliasingOperands.size() > 1 &&
       !llvm::all_of(aliasingOperands, [&](OpOperand *o) {
         return lookupBuffer(rewriter, o->get()) == operandBuffer;
-      })) {
-    op->emitError("result buffer is ambiguous");
-    return Value();
-  }
+      }))
+    return FailureOr<Value>(op->emitError("result buffer is ambiguous"));
 
   // If bufferizing out-of-place, allocate a new buffer.
   if (!aliasInfo.isInPlace(result)) {
@@ -610,10 +609,13 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
   // Insert to_memref op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, tensor);
-  Type memrefType =
-      tensor.getType().isa<RankedTensorType>()
-          ? getDynamicMemRefType(tensor.getType().cast<RankedTensorType>())
-          : getContiguousOrUnrankedMemRefType(tensor.getType());
+  Type memrefType;
+  if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
+    memrefType = getDynamicMemRefType(rankedTensorType);
+  } else {
+    memrefType = getUnrankedMemRefType(
+        tensor.getType().cast<TensorType>().getElementType());
+  }
   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
                                                     tensor);
 }
@@ -630,13 +632,9 @@ MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
                          layout, memorySpace);
 }
 
-Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType(
-    Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
-  if (type.isa<RankedTensorType, MemRefType>())
-    return getContiguousMemRefType(type.cast<ShapedType>(), layout,
-                                   memorySpace);
-  assert(!layout && "expected empty layout with UnrankedMemRefType");
-  return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
+UnrankedMemRefType mlir::linalg::comprehensive_bufferize::getUnrankedMemRefType(
+    Type elementType, Attribute memorySpace) {
+  return UnrankedMemRefType::get(elementType, memorySpace);
 }
 
 MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType(

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index 5051d43bb584..aaa304b2c91f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -25,6 +25,9 @@ namespace bufferization_ext {
 // TODO: These ops should implement BufferizableOpInterface directly when moved
 // to the Bufferization dialect.
 
+/// Bufferization of bufferization.to_memref. to_memref(to_tensor(x)) is folded
+/// to x. Other to_memref ops are ignored during bufferization.
+///
 /// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory
 /// location of the incoming tensor once it will be bufferized. In the anlysis,
 /// the incoming tensor is assumed to bufferize to a memory read and to an
@@ -41,7 +44,7 @@ struct ToMemrefOpInterface
                                                     bufferization::ToMemrefOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const BufferizationState &state) const {
-    // It is unknown whether the resulting MemRef will be read or not.
+    // It is unknown whether the resulting memref will be read or not.
     return true;
   }
 
@@ -58,10 +61,13 @@ struct ToMemrefOpInterface
     if (auto toTensorOp =
             toMemrefOp.tensor().getDefiningOp<bufferization::ToTensorOp>()) {
       Value buffer = toTensorOp.memref();
+
+      // Insert cast in case to_memref(to_tensor(x))'s type is 
diff erent from
+      // x's type.
       if (toTensorOp.memref().getType() != toMemrefOp.getType())
         buffer = rewriter.create<memref::CastOp>(toMemrefOp.getLoc(), buffer,
                                                  toMemrefOp.getType());
-      rewriter.replaceOp(toMemrefOp, buffer);
+      replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer);
       return success();
     }
 
@@ -69,16 +75,19 @@ struct ToMemrefOpInterface
   }
 };
 
-/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do
-/// not lower any further, and they should have disappeared by the time the
-/// input is fully bufferized.
+/// Bufferization of bufferization.to_tensor. Such ops cannot be bufferized.
+/// However, other ops that are using to_tensor's result will eventually be
+/// bufferized. At that point, they will start using to_tensor's memref operand.
+/// Once all users of to_tensor are bufferized, the op will not have any users
+/// anymore and DCE away.
 ///
-/// The analysis has no information about the memref that is loaded from by the
-/// ToTensorOp. We have to assume that the loaded tensor may after bufferization
-/// potentially alias with any other bufferized tensor. Since ToTensorOp and
-/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded
-/// directly in the analysis. However, declaring ToTensorOp results as not
-/// writable also enforces a buffer copy and has the same effect.
+/// ToTensorOp conceptually loads a tensor from a memory location. The analysis
+/// has no information about the memref that is loaded from by ToTensorOp. We
+/// have to assume that the loaded tensor may after bufferization potentially
+/// alias with any other bufferized tensor. Since ToTensorOp and ToMemrefOp have
+/// no aliasing OpOperand/OpResult pairs, this cannot be encoded directly in the
+/// analysis. However, declaring ToTensorOp results as not writable enforces a
+/// buffer copy and has the same effect.
 struct ToTensorOpInterface
     : public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
                                                     bufferization::ToTensorOp> {
@@ -89,7 +98,7 @@ struct ToTensorOpInterface
 
   bool isWritable(Operation *op, Value value,
                   const BufferizationState &state) const {
-    // It is unknown whether the MemRef operand is writable or not.
+    // It is unknown whether the memref operand is writable or not.
     return false;
   }
 };

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 88a8f861c543..60ca3623fb95 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -6,98 +6,37 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Perform inplace bufferization within function boundaries.
-// This is a specialized pass that supports inplace analysis for a fixed subset
-// of ops that have well-defined inplace semantics.
-// This pass caters to high-performance codegen where buffer reuse is deemed
-// critical: the pass should fail if the bufferized form of the function needs
-// to return any buffer.
-// Generic control-flow and branching are unsupported.
-// Composability with extensible set of ops is not a first-class concern.
-//
-// Bufferization occurs by:
-//  a. performing an inPlace analysis `inPlaceAnalysis` which marks each
-//     operation within the op with the `kInPlaceResultsAttrName` attribute.
-//  b. traversing each operation in the op and rewriting it in
-//     buffer form and keeping a BlockAndValueMapping mapping of the
-//     rewrites. New allocations are introduced during this step.
-//     TODO: Allocation + depending op hoisting to outermost enclosing
-//     sequential scope.
-//  c. at the end of this bufferization, 3 cases may occur:
-//     i. inplaceable function arguments may be reused in place after the
-//        function itself has been bufferized. This is encoded by IR resembling:
-//
-//        ```
-//          #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-//           func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
-//              -> tensor<?xf32> {
-//            %0 = bufferization.to_memref %A : memref<?xf32, #map>
-//            // ... uses of %0
-//            %res = bufferization.to_tensor %0 : memref<?xf32, #map>
-//            return %res : tensor<?xf32>
-//          }
-//        ```
+// Comprehensive Bufferize bufferizes function bodies. Function boundaries
+// (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
+// ModuleBufferization.cpp is an extension of Comprehensive Bufferize for simple
+// call graphs.
 //
-//        this is the cue for the bufferization of the function foo (and calls
-//        to it) may bufferize to `func @foo(%A: memref<?xf32, some_layout>)`.
-//        To fully achieve bufferization, an additional analysis is needed to
-//        determine whether function argument/operand pairs bufferize to a
-//        single inplace buffer argument (i.e. functions may return tensors in
-//        arbitrary order that may not match argument numbers).
+// Comprehensive Bufferize consists of two phases.
 //
-//    ii. results that don't map to an inplaceable function argument are
-//        generally allocated. Since memref semantics wrt ownership of the
-//        underlying memory region are not well-defined, comprehensive
-//        bufferization chooses to perform allocations in a scoped fashion:
-//        returning memrefs is always considered illegal.
-//        Such scenarios are encoded by IR resembling:
+// 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
+//    inserting buffer copies. The analysis queries op bufferization semantics
+//    via `BufferizableOpInterface`.
+// 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This
+//    function does not generate buffer copies for OpResults that were decided
+//    to bufferize inplace during the analysis phase.
 //
-//        ```
-//          #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-//          func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
-//              -> tensor<?xf32> {
-//            %0 = bufferization.to_memref %A : memref<?xf32, #map>
-//            %1 = memref.dim %0, %c0 : memref<?xf32, #map>
-//            %2 = memref.alloc(%1) : memref<?xf32>
-//            %3 = memref.cast %2 : memref<?xf32> to memref<?xf32, #map>
-//            // ... uses of %3
-//            memref.dealloc %2 : memref<?xf32, #map>
-//            %res = bufferization.to_tensor %3 : memref<?xf32, #map>
-//            return %res : tensor<?xf32>
-//          }
-//       ```
+// Inplace bufferization decisions are passed from the analysis to the
+// bufferization phase via `BufferizationState` and `BufferizationAliasInfo`.
+// They can be printed for debugging purposes with `testAnalysisOnly`.
 //
-//        this is the cue for the bufferization of the function foo (and calls
-//        to it) that it must bufferize to `func @foo(%A: memref<?xf32,
-//        some_layout>,
-//                   %B: memref<?xf32, some_layout>)` (i.e. make a cloned
-//        allocation of the result tensor)
-//        To fully achieve bufferization, the alloc/dealloc pair must be lifted
-//        out of the function at each call site.
+// Ops that do not implement `BufferizableOpInterface` can be analyzed but are
+// treated conservatively. E.g., the analysis has to assume that their
+// OpOperands bufferize to memory writes. While such ops can be analyzed, they
+// are not bufferized and remain in the IR. to_tensor and to_memref ops are
+// inserted at the bufferization boundary.
 //
-//   iii. as an optimization over ii., it may be possible to reuse an argument
-//        and only want to return a slice.
-//        This may forego allocation by letting *all* callers decide whether to
-//        pass a new *aliasing* memref function argument (i.e. a subview).
-//        Without loss of generality, callers may agree to allocate a new buffer
-//        to avoid this aliasing. Such scenarios are encoded by IR resembling:
+// Note: If `allowUnknownOps` is set to false, bufferization fails when an
+// unknown op (that does not implement `BufferizableOpInterface`) is found. No
+// to_tensor/to_memref ops are inserted.
 //
-//        ```
-//          #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-//          func @foo(%arg0: tensor<?xf32> {linalg.inplaceable = true})
-//              -> tensor<4xf32> {
-//            %0 = bufferization.to_memref %arg0 : memref<?xf32, #map>
-//            %1 = memref.subview %0[0] [4] [1] : memref<?xf32, #map> to
-//                                                memref<4xf32, #map>
-//            // ... inplace computes into %1
-//            %3 = bufferization.to_tensor %1 : memref<4xf32, #map>
-//            return %3 : tensor<4xf32>
-//          }
-//        ```
-//
-//  Note: In the future, it may be worthwhile to design special bufferization
-//  ops to encode the desired semantics at function boundaries for i., ii. and
-//  iii.
+// This pass caters to high-performance codegen where buffer reuse is deemed
+// critical: the pass should fail if the bufferized form of the function needs
+// to return any buffer, unless `allowReturnMemref` is enabled.
 //
 //  Lastly, note that layout map chosen to bufferize is the most dynamic
 //  canonical strided layout of the proper rank. This ensures compatibility with

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 5c5462527e9b..c4f42afb9828 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -38,6 +38,7 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
   if (!op.hasTensorSemantics())
     return op->emitError() << "op does not have tensor semantics";
 
+  // New input operands for the cloned op.
   SmallVector<Value> newInputBuffers;
   newInputBuffers.reserve(op.getNumInputs());
   for (OpOperand *opOperand : op.getInputOperands()) {
@@ -48,22 +49,23 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
     newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get()));
   }
 
+  // New output operands for the cloned op.
   SmallVector<Value> newOutputBuffers;
   for (OpOperand *opOperand : op.getOutputOperands()) {
     OpResult opResult = op.getTiedOpResult(opOperand);
     assert(opResult && "could not find correspond OpResult");
-    Value resultBuffer = state.getResultBuffer(rewriter, opResult);
-    if (!resultBuffer)
-      return failure();
-    newOutputBuffers.push_back(resultBuffer);
+    FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult);
+    newOutputBuffers.push_back(*resultBuffer);
   }
 
-  // Clone the newly bufferized op.
+  // Merge input/output operands.
   SmallVector<Value> newOperands = newInputBuffers;
   newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
 
   // Set insertion point now that potential alloc/dealloc are introduced.
   rewriter.setInsertionPoint(op);
+  // Clone the op, but use the new operands. Since the new op does not have any
+  // tensor results, it does not return anything.
   op.clone(rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands);
 
   // Replace the results of the old op with the new output buffers.
@@ -135,18 +137,23 @@ static DenseMap<OpOperand *, OpResult> computeAliasingPairs(LinalgOp op) {
   return mapping;
 }
 
+/// Bufferization of linalg.generic. Replace with a new linalg.generic that
+/// operates entirely on memrefs.
 template <typename OpTy>
 struct LinalgOpInterface
     : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
                                                     OpTy> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const BufferizationState &state) const {
+    // Operand is read if it is used in the computation.
     auto genericOp = cast<linalg::LinalgOp>(op);
     return genericOp.payloadUsesValueFromOperand(&opOperand);
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const BufferizationState &state) const {
+    // Operand is written to if it has an aliasing OpResult. For more details,
+    // see `computeAliasingPairs`.
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
     return static_cast<bool>(
         bufferizableOp.getAliasingOpResult(opOperand, state));
@@ -156,6 +163,8 @@ struct LinalgOpInterface
   getAliasingOpOperand(Operation *op, OpResult opResult,
                        const BufferizationState &state) const {
     auto genericOp = cast<linalg::LinalgOp>(op);
+
+    // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`.
     DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands())
       if (pairs[opOperand] == opResult)
@@ -166,6 +175,8 @@ struct LinalgOpInterface
   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
                                const BufferizationState &state) const {
     auto genericOp = cast<linalg::LinalgOp>(op);
+
+    // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`.
     DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
     return pairs[&opOperand];
   }
@@ -207,22 +218,26 @@ struct InitTensorOpInterface
   }
 };
 
+/// Bufferization of linalg.tiled_loop. Replace with a new linalg.tiled_loop
+/// that operates entirely on memrefs.
 struct TiledLoopOpInterface
     : public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface,
                                                     linalg::TiledLoopOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const BufferizationState &state) const {
-    // TiledLoop alone doesn't bufferize to a memory read, one of the uses of
-    // its matching bbArg may.
     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
+
+    // linalg.tiled_loop operands alone do not bufferize to a memory read, but
+    // one of the uses of their matching bbArgs may.
     return state.isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand));
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const BufferizationState &state) const {
-    // TiledLoop alone doesn't bufferize to a memory write, one of the uses of
-    // its matching bbArg may.
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
+
+    // Only operands with an aliasing OpResult (i.e., output operands) bufferize
+    // to a memory write.
     return static_cast<bool>(
         bufferizableOp.getAliasingOpResult(opOperand, state));
   }
@@ -230,6 +245,8 @@ struct TiledLoopOpInterface
   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
                                const BufferizationState &state) const {
     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
+
+    // Output operands are tied to their corresponding OpResults.
     return tiledLoopOp.getTiedOpResult(opOperand);
   }
 
@@ -241,8 +258,8 @@ struct TiledLoopOpInterface
 
   bool isWritable(Operation *op, Value value,
                   const BufferizationState &state) const {
-    // Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed
-    // inplace from the perspective of ops nested under:
+    // Interestingly, linalg::TiledLoopOp's bbArgs can **always** be viewed
+    // inplace from the perspective of nested ops:
     //   1. Either the matching iter operand is not bufferized inplace and an
     //      alloc + optional copy makes the bbArg itself inplaceable.
     //   2. Or the matching iter operand is bufferized inplace and bbArg just
@@ -268,10 +285,10 @@ struct TiledLoopOpInterface
     int nextResultNum = 0;
     for (Value value : tiledLoopOp.outputs()) {
       if (value.getType().isa<TensorType>()) {
-        Value buffer = state.getResultBuffer(
+        FailureOr<Value> buffer = state.getResultBuffer(
             rewriter, tiledLoopOp->getResult(nextResultNum++));
-        newOutputs.push_back(buffer);
-        newResults.push_back(buffer);
+        newOutputs.push_back(*buffer);
+        newResults.push_back(*buffer);
       } else {
         newOutputs.push_back(value);
       }
@@ -349,6 +366,8 @@ struct TiledLoopOpInterface
   }
 };
 
+/// Bufferization of linalg.yield. Bufferized as part of linalg.tiled_loop's
+/// bufferization.
 struct YieldOpInterface
     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
                                                     linalg::YieldOp> {
@@ -407,13 +426,12 @@ struct LinalgOpInterfaceHelper<> {
 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
 /// chain, starting from the OpOperand and always following the aliasing
 /// OpOperand, that eventually ends at a single InitTensorOp.
-LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
-    InitTensorEliminationStep::eliminateInitTensors(
-        Operation *op, BufferizationState &state,
-        BufferizationAliasInfo &aliasInfo,
-        std::function<bool(OpOperand &)> anchorMatchFunc,
-        std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
-        SmallVector<Operation *> &newOps) {
+LogicalResult
+mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
+    eliminateInitTensors(Operation *op, BufferizationState &state,
+                         BufferizationAliasInfo &aliasInfo,
+                         AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
+                         SmallVector<Operation *> &newOps) {
   OpBuilder b(op->getContext());
 
   WalkResult status = op->walk([&](Operation *op) {
@@ -506,6 +524,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
         BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
   return eliminateInitTensors(
       op, state, aliasInfo,
+      /*anchorMatchFunc=*/
       [&](OpOperand &operand) {
         auto insertSliceOp =
             dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
@@ -516,6 +535,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
           return false;
         return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
       },
+      /*rewriteFunc=*/
       [](OpBuilder &b, Location loc, OpOperand &operand) {
         auto insertSliceOp = cast<tensor::InsertSliceOp>(operand.getOwner());
         auto extractOp = b.create<tensor::ExtractSliceOp>(

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 7d9a5648b128..5cac342296f8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -5,6 +5,88 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// Module bufferization is an extension of Comprehensive Bufferize that
+// bufferizes function boundaries. It provides `BufferizableOpInterface`
+// implementations for FuncOp, CallOp and ReturnOp, along with a few helper
+// functions that control the order in which functions are bufferized.
+//
+// Three cases can occur during bufferization of FuncOps.
+//
+//     i. inplaceable function arguments may be reused in place after the
+//        function itself has been bufferized. This is encoded by IR resembling:
+//
+//        ```
+//          #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+//           func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
+//              -> tensor<?xf32> {
+//            %0 = bufferization.to_memref %A : memref<?xf32, #map>
+//            // ... uses of %0
+//            %res = bufferization.to_tensor %0 : memref<?xf32, #map>
+//            return %res : tensor<?xf32>
+//          }
+//        ```
+//
+//        this is the cue for the bufferization of the function foo (and calls
+//        to it) may bufferize to `func @foo(%A: memref<?xf32, some_layout>)`.
+//        To fully achieve bufferization, an additional analysis is needed to
+//        determine whether function argument/operand pairs bufferize to a
+//        single inplace buffer argument (i.e. functions may return tensors in
+//        arbitrary order that may not match argument numbers).
+//
+//    ii. results that don't map to an inplaceable function argument are
+//        generally allocated. Since memref semantics wrt ownership of the
+//        underlying memory region are not well-defined, comprehensive
+//        bufferization chooses to perform allocations in a scoped fashion:
+//        returning memrefs is always considered illegal.
+//        Such scenarios are encoded by IR resembling:
+//
+//        ```
+//          #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+//          func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
+//              -> tensor<?xf32> {
+//            %0 = bufferization.to_memref %A : memref<?xf32, #map>
+//            %1 = memref.dim %0, %c0 : memref<?xf32, #map>
+//            %2 = memref.alloc(%1) : memref<?xf32>
+//            %3 = memref.cast %2 : memref<?xf32> to memref<?xf32, #map>
+//            // ... uses of %3
+//            memref.dealloc %2 : memref<?xf32, #map>
+//            %res = bufferization.to_tensor %3 : memref<?xf32, #map>
+//            return %res : tensor<?xf32>
+//          }
+//       ```
+//
+//        this is the cue for the bufferization of the function foo (and calls
+//        to it) that it must bufferize to `func @foo(%A: memref<?xf32,
+//        some_layout>,
+//                   %B: memref<?xf32, some_layout>)` (i.e. make a cloned
+//        allocation of the result tensor)
+//        To fully achieve bufferization, the alloc/dealloc pair must be lifted
+//        out of the function at each call site.
+//
+//   iii. as an optimization over ii., it may be possible to reuse an argument
+//        and only want to return a slice.
+//        This may forego allocation by letting *all* callers decide whether to
+//        pass a new *aliasing* memref function argument (i.e. a subview).
+//        Without loss of generality, callers may agree to allocate a new buffer
+//        to avoid this aliasing. Such scenarios are encoded by IR resembling:
+//
+//        ```
+//          #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+//          func @foo(%arg0: tensor<?xf32> {linalg.inplaceable = true})
+//              -> tensor<4xf32> {
+//            %0 = bufferization.to_memref %arg0 : memref<?xf32, #map>
+//            %1 = memref.subview %0[0] [4] [1] : memref<?xf32, #map> to
+//                                                memref<4xf32, #map>
+//            // ... inplace computes into %1
+//            %3 = bufferization.to_tensor %1 : memref<4xf32, #map>
+//            return %3 : tensor<4xf32>
+//          }
+//        ```
+//
+//  Note: In the future, it may be worthwhile to design special bufferization
+//  ops to encode the desired semantics at function boundaries for i., ii. and
+//  iii.
 
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
 
@@ -161,7 +243,7 @@ static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
     if (auto rankedTensorType = t.dyn_cast<RankedTensorType>())
       return getDynamicMemRefType(rankedTensorType);
     if (auto tensorType = t.dyn_cast<TensorType>())
-      return getContiguousOrUnrankedMemRefType(tensorType);
+      return getUnrankedMemRefType(tensorType.getElementType());
     return t;
   };
   auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite));

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 1c4185c8ffef..5983d421aaed 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -19,6 +19,8 @@ namespace linalg {
 namespace comprehensive_bufferize {
 namespace scf_ext {
 
+/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
+/// fully implemented at the moment.
 struct ExecuteRegionOpInterface
     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
                                                     scf::ExecuteRegionOp> {
@@ -79,6 +81,7 @@ struct ExecuteRegionOpInterface
   }
 };
 
+/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
 struct IfOpInterface
     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
   SmallVector<OpOperand *>
@@ -212,6 +215,8 @@ struct IfOpInterface
   }
 };
 
+/// Bufferization of scf.for. Replace with a new scf.for that operates on
+/// memrefs.
 struct ForOpInterface
     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
                                                     scf::ForOp> {
@@ -292,7 +297,7 @@ struct ForOpInterface
     // Construct a new scf.for op with memref instead of tensor values.
     SmallVector<Value> initArgs =
         convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
-          return state.getResultBuffer(rewriter, forOp->getOpResult(index));
+          return *state.getResultBuffer(rewriter, forOp->getOpResult(index));
         });
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
@@ -399,6 +404,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
   return status;
 }
 
+/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
+/// this is for analysis only.
 struct YieldOpInterface
     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
                                                     scf::YieldOp> {

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index e1cd933de3b8..6b8b8983972a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -51,29 +51,38 @@ struct CastOpInterface
                           const BufferizationState &state) const {
     auto castOp = cast<tensor::CastOp>(op);
 
-    Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0));
-    if (!resultBuffer)
-      return failure();
-    Type sourceType = resultBuffer.getType();
-    auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
-    auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
-    assert(rankedMemRefType || unrankedMemRefType);
-    Attribute memorySpace = rankedMemRefType
-                                ? rankedMemRefType.getMemorySpace()
-                                : unrankedMemRefType.getMemorySpace();
-    TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
-    MemRefLayoutAttrInterface layout =
-        rankedMemRefType && tensorType.isa<RankedTensorType>()
-            ? rankedMemRefType.getLayout()
-            : MemRefLayoutAttrInterface();
-    Type memRefType = getContiguousOrUnrankedMemRefType(
-        castOp.getResult().getType(), layout, memorySpace);
-    replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, memRefType,
-                                                 resultBuffer);
+    // The result buffer still has the old (pre-cast) type.
+    FailureOr<Value> resultBuffer =
+        state.getResultBuffer(rewriter, castOp->getResult(0));
+    auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
+    Attribute memorySpace = sourceMemRefType.getMemorySpace();
+    TensorType resultTensorType =
+        castOp.getResult().getType().cast<TensorType>();
+    MemRefLayoutAttrInterface layout;
+
+    if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
+      if (resultTensorType.isa<RankedTensorType>())
+        layout = rankedMemRefType.getLayout();
+
+    // Compute the new memref type.
+    Type resultMemRefType;
+    if (auto rankedTensorType = resultTensorType.isa<RankedTensorType>()) {
+      resultMemRefType =
+          getContiguousMemRefType(resultTensorType, layout, memorySpace);
+    } else {
+      resultMemRefType =
+          getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace);
+    }
+
+    // Replace the op with a memref.cast.
+    replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
+                                                 *resultBuffer);
+
     return success();
   }
 };
 
+/// Bufferization of tensor.dim. Replace with memref.dim.
 struct DimOpInterface
     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
                                                     tensor::DimOp> {
@@ -95,14 +104,13 @@ struct DimOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationState &state) const {
     auto dimOp = cast<tensor::DimOp>(op);
-    if (!dimOp.source().getType().isa<RankedTensorType>())
-      return dimOp.emitError("unranked tensor not supported");
     Value v = state.lookupBuffer(rewriter, dimOp.source());
     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
     return success();
   }
 };
 
+/// Bufferization of tensor.extract_slice. Replace with memref.subview.
 struct ExtractSliceOpInterface
     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
                                                     tensor::ExtractSliceOp> {
@@ -156,7 +164,7 @@ struct ExtractSliceOpInterface
         loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
         extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
 
-    /// If not inplaceable, copy.
+    // If not inplaceable, copy.
     if (!inplace) {
       // Do not copy if the copied data is never read.
       if (state.isValueRead(extractSliceOp.result()))
@@ -169,6 +177,7 @@ struct ExtractSliceOpInterface
   }
 };
 
+/// Bufferization of tensor.extract. Replace with memref.load.
 struct ExtractOpInterface
     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
                                                     tensor::ExtractOp> {
@@ -197,6 +206,7 @@ struct ExtractOpInterface
   }
 };
 
+/// Bufferization of tensor.insert. Replace with memref.store.
 struct InsertOpInterface
     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
                                                     tensor::InsertOp> {
@@ -226,12 +236,11 @@ struct InsertOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationState &state) const {
     auto insertOp = cast<tensor::InsertOp>(op);
-    Location loc = insertOp.getLoc();
-    Value destMemref =
+    FailureOr<Value> destMemref =
         state.getResultBuffer(rewriter, insertOp->getOpResult(0));
-    rewriter.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
-                                     insertOp.indices());
-    replaceOpWithBufferizedValues(rewriter, op, destMemref);
+    rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
+                                     *destMemref, insertOp.indices());
+    replaceOpWithBufferizedValues(rewriter, op, *destMemref);
     return success();
   }
 
@@ -276,6 +285,8 @@ static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
                       condition);
 }
 
+/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
+/// certain circumstances, this op can also be a no-op.
 struct InsertSliceOpInterface
     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
                                                     tensor::InsertSliceOp> {
@@ -391,13 +402,11 @@ struct InsertSliceOpInterface
     Location loc = insertSliceOp.getLoc();
 
     // When bufferizing out-of-place, `getResultBuffer` allocates.
-    Value dstMemref =
+    FailureOr<Value> dstMemref =
         state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
-    if (!dstMemref)
-      return failure();
 
     // Take a subview of the dst.
-    auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+    auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
     auto subviewMemRefType =
         memref::SubViewOp::inferRankReducedResultType(
             insertSliceOp.getSourceType().getRank(), dstMemrefType,
@@ -405,15 +414,15 @@ struct InsertSliceOpInterface
             insertSliceOp.getMixedStrides())
             .cast<MemRefType>();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
+        loc, subviewMemRefType, *dstMemref, insertSliceOp.getMixedOffsets(),
         insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
 
     // Copy tensor. If this tensor.insert_slice has a matching
     // tensor.extract_slice, the copy operation will eventually fold away.
     Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
-    state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
+    state.createMemCpy(rewriter, loc, srcMemref, subView);
 
-    replaceOpWithBufferizedValues(rewriter, op, dstMemref);
+    replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index d4c57617b004..3c8d6a9c96e5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -17,6 +17,8 @@ namespace linalg {
 namespace comprehensive_bufferize {
 namespace vector_ext {
 
+/// Bufferization of vector.transfer_read. Replaced with a new
+/// vector.transfer_read that operates on a memref.
 struct TransferReadOpInterface
     : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
                                                     vector::TransferReadOp> {
@@ -55,6 +57,8 @@ struct TransferReadOpInterface
   }
 };
 
+/// Bufferization of vector.transfer_write. Replace with a new
+/// vector.transfer_write that operates on a memref.
 struct TransferWriteOpInterface
     : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
                                                     vector::TransferWriteOp> {
@@ -94,13 +98,12 @@ struct TransferWriteOpInterface
     // Create a new transfer_write on buffer that doesn't have a return value.
     // Leave the previous transfer_write to dead code as it still has uses at
     // this point.
-    Value resultBuffer = state.getResultBuffer(rewriter, op->getResult(0));
-    if (!resultBuffer)
-      return failure();
+    FailureOr<Value> resultBuffer =
+        state.getResultBuffer(rewriter, op->getResult(0));
     rewriter.create<vector::TransferWriteOp>(
-        writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
+        writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
         writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
-    replaceOpWithBufferizedValues(rewriter, op, resultBuffer);
+    replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
 
     return success();
   }


        


More information about the Mlir-commits mailing list