[Mlir-commits] [mlir] 73bea97 - [mlir][Linalg] Add support for CallOp bufferization (10/n)

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jul 1 04:13:23 PDT 2021


Author: Nicolas Vasilache
Date: 2021-07-01T10:33:12Z
New Revision: 73bea97a336ba2da276ef34fd21b2c5c676b0a97

URL: https://github.com/llvm/llvm-project/commit/73bea97a336ba2da276ef34fd21b2c5c676b0a97
DIFF: https://github.com/llvm/llvm-project/commit/73bea97a336ba2da276ef34fd21b2c5c676b0a97.diff

LOG: [mlir][Linalg] Add support for CallOp bufferization (10/n)

Cross function boundary bufferization support is added.
This is enabled by cross-function boundary alias analysis, for which the bufferization process is extended: it can now modify the BufferizationAliasInfo as new ops are introduced.

A number of simplifying assumptions are made:

1. by default we bufferize to the most dynamic strided memref type, further memref::CastOp canonicalizations are expected to clean up the IR.
2. in the current implementation, the stride information is always erased at function boundaries. A subsequent pass will be required to analyze the meet of all call ops to a function and decide whether more static buffer types can be used. This will potentially clone functions when it is deemed profitable to do so (e.g. when the stride-1 dimension may vary).
3. external function always bufferize to the most dynamic strided memref version. This may require special annotations for specifying that particular operands of top-level functions have contiguous buffer layout.

An alternative to point 3. would be to support tensor layout annotations, which is currently not supported in MLIR.

Differential revision: https://reviews.llvm.org/D104873

Added: 
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 14acc36fbf22..824092df292c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -114,7 +114,9 @@
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/BufferUtils.h"
+#include "mlir/Transforms/Passes.h"
 
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/EquivalenceClasses.h"
@@ -136,6 +138,8 @@ using namespace tensor;
 // Generic helpers.
 //===----------------------------------------------------------------------===//
 
+static bool isaTensor(Type t) { return t.isa<TensorType>(); }
+
 /// Return the FuncOp called by `callOp`.
 static FuncOp getCalledFunction(CallOpInterface callOp) {
   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
@@ -145,6 +149,20 @@ static FuncOp getCalledFunction(CallOpInterface callOp) {
       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
 }
 
+/// Return the unique ReturnOp that terminates `funcOp`.
+/// Return nullptr if there is no such unique ReturnOp.
+static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
+  ReturnOp returnOp;
+  for (Block &b : funcOp.body()) {
+    if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
+      if (returnOp)
+        return nullptr;
+      returnOp = candidateOp;
+    }
+  }
+  return returnOp;
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific BlockAndValueMapping support with debugging.
 //===----------------------------------------------------------------------===//
@@ -163,7 +181,7 @@ static void map(BlockAndValueMapping &bvm, Value key, Value value) {
 }
 
 /// Wrapper for better debugging.
-static Value lookup(BlockAndValueMapping &bvm, Value key) {
+static Value lookup(const BlockAndValueMapping &bvm, Value key) {
   // TODO: if key comes from bbArg, forward.
   assert(key.getType().isa<TensorType>());
   Value v = bvm.lookupOrNull(key);
@@ -347,10 +365,8 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
           VectorTransferOpInterface,
           scf::YieldOp>(op)
       // clang-format on
-      || (none_of(op->getResultTypes(),
-                  [](Type t) { return t.isa<TensorType>(); }) &&
-          none_of(op->getOperandTypes(),
-                  [](Type t) { return t.isa<TensorType>(); }));
+      || (none_of(op->getResultTypes(), isaTensor) &&
+          none_of(op->getOperandTypes(), isaTensor));
 }
 
 /// Return the OpResult that may bufferize into the same buffer as `opOperand`
@@ -577,14 +593,22 @@ class BufferizationAliasInfo {
   /// beginning the alias and equivalence sets only contain `v` itself.
   void createAliasInfoEntry(Value v);
 
+  /// Insert an info entry for `newValue` and merge its alias set with that of
+  /// `alias`.
+  void insertNewBufferAlias(Value newValue, Value alias);
+
+  /// Insert an info entry for `newValue` and merge its alias set with that of
+  /// `alias`. Additionally, merge their equivalence classes.
+  void insertNewBufferEquivalence(Value newValue, Value alias);
+
   /// Return true if the buffer to which `operand` would bufferize aliases a
   /// buffer that is known to not be writeable. This implies that the matching
   /// OpResult cannot be bufferized inplace.
   bool aliasesNonWriteableBuffer(OpOperand &operand) const;
 
   /// Return true if the buffer to which `operand` would bufferize is equivalent
-  /// to some use that would bufferize to a write to a buffer.
-  bool aliasesInPlaceWrite(ExtractSliceOp extractSliceOp) const;
+  /// to some buffer write.
+  bool aliasesInPlaceWrite(Value v) const;
 
   /// Set the inPlace bufferization spec to true.
   /// Merge result's and operand's aliasing sets and iterate to a fixed point.
@@ -619,6 +643,9 @@ class BufferizationAliasInfo {
   bool isSourceEquivalentToAMatchingExtractSliceOp(
       InsertSliceOp insertSliceOp) const;
 
+  /// Apply `fun` to all the members of the equivalence class of `v`.
+  void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
+
   /// Print to `os`.
   void print(raw_ostream &os) const;
 
@@ -626,8 +653,9 @@ class BufferizationAliasInfo {
   void dump() const { print(llvm::errs()); }
 
 private:
-  /// Check aliasInfo for `v` exists and return a reference to it.
+  /// Check that aliasInfo for `v` exists and return a reference to it.
   DenseSet<Value> &getAliasInfoRef(Value v);
+
   const DenseSet<Value> &getAliasInfoRef(Value v) const {
     return const_cast<BufferizationAliasInfo *>(this)->getAliasInfoRef(v);
   }
@@ -740,6 +768,23 @@ void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
   equivalentInfo.insert(v);
 }
 
+/// Insert an info entry for `newValue` and merge its alias set with that of
+/// `alias`.
+void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
+  assert(aliasInfo.find(alias) != aliasInfo.end() && "Missing alias entry");
+  createAliasInfoEntry(newValue);
+  mergeAliases(newValue, alias);
+  mergeAliasesToFixedPoint();
+}
+
+/// Insert an info entry for `newValue` and merge its alias set with that of
+/// `alias`. Additionally, merge their equivalence classes.
+void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
+                                                        Value alias) {
+  insertNewBufferAlias(newValue, alias);
+  equivalentInfo.unionSets(newValue, alias);
+}
+
 /// Return true if the buffer to which `operand` would bufferize aliases a
 /// buffer that is known to not be writeable. This implies that the matching
 /// OpResult cannot be bufferized inplace.
@@ -755,13 +800,13 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
         LDBG("-----------bbArg is writeable -> skip: " << bbArg << '\n');
         continue;
       }
-      LDBG("-----------notWriteable: " << v << '\n');
+      LDBG("-----------notWriteable\n");
       return true;
     }
 
     if (Operation *op = v.getDefiningOp()) {
       if (isa<ConstantOp>(op) || !hasKnownBufferizationAliasingBehavior(op)) {
-        LDBG("-----------notWriteable: " << v << '\n');
+        LDBG("-----------notWriteable\n");
         return true;
       }
     }
@@ -771,12 +816,11 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
 }
 
 /// Return true if the buffer to which `operand` would bufferize is equivalent
-/// to some use that would bufferize to a write to a buffer.
-bool BufferizationAliasInfo::aliasesInPlaceWrite(
-    ExtractSliceOp extractSliceOp) const {
+/// to some buffer write.
+bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const {
   LDBG("----Start aliasesInPlaceWrite\n");
-  LDBG("-------for op: " << *extractSliceOp.getOperation() << '\n');
-  for (Value v : getAliasInfoRef(extractSliceOp.result())) {
+  LDBG("-------for : " << value << '\n');
+  for (Value v : getAliasInfoRef(value)) {
     for (auto &use : v.getUses()) {
       if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) {
         LDBG("-----------wants to bufferize to inPlace write: "
@@ -785,7 +829,7 @@ bool BufferizationAliasInfo::aliasesInPlaceWrite(
       }
     }
   }
-  LDBG("----------->extract_slice does not alias an inplace write");
+  LDBG("----------->does not alias an inplace write\n");
   return false;
 }
 
@@ -920,6 +964,16 @@ bool BufferizationAliasInfo::isSourceEquivalentToAMatchingExtractSliceOp(
   return false;
 }
 
+/// Apply `fun` to all the members of the equivalence class of `v`.
+void BufferizationAliasInfo::applyOnEquivalenceClass(
+    Value v, function_ref<void(Value)> fun) const {
+  for (auto it = equivalentInfo.findLeader(v),
+            eit = equivalentInfo.member_end();
+       it != eit; ++it) {
+    fun(v);
+  }
+}
+
 void BufferizationAliasInfo::print(raw_ostream &os) const {
   os << "\n/========================== AliasInfo "
         "==========================\n";
@@ -1106,6 +1160,21 @@ bool BufferizationAliasInfo::isClobberedWriteBeforeRead(
   return existsInterleavedValueClobber(aliasingRead, aliasingWrite, domInfo);
 }
 
+//===----------------------------------------------------------------------===//
+// Forward declarations.
+//===----------------------------------------------------------------------===//
+
+/// Return the op with Allocate MemoryEffect if `v` is equivalent to an such
+/// an op. Return null otherwise.
+static Operation *getEquivalentAlloc(Value value,
+                                     const BufferizationAliasInfo &aliasInfo);
+
+/// Return the first argument of the enclosing FuncOp that is equivalent to `v`.
+/// Return null if no such bbArg can be found.
+static BlockArgument
+getEquivalentEnclosingFuncBBArg(Value v,
+                                const BufferizationAliasInfo &aliasInfo);
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific MemRefType support.
 //===----------------------------------------------------------------------===//
@@ -1152,6 +1221,47 @@ static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
                          stridedLayout, addressSpace);
 }
 
+/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
+/// tensor is replaced by the corresponding buffer type.
+/// In order for all the callers to agree, this *must* bufferize to the most
+/// dynamic buffer type supported.
+/// A later pass across all CallOps in the module can decide whether to simplify
+/// the types of to version according to some cost model.
+static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
+                                              TypeRange argumentTypes,
+                                              TypeRange resultTypes) {
+  auto rewrite = [](Type t) -> Type {
+    // TODO: non-zero address space.
+    // TODO: layout information if relevant.
+    if (auto rankedTensorType = t.dyn_cast<RankedTensorType>())
+      return getDynamicMemRefType(rankedTensorType);
+    if (auto tensorType = t.dyn_cast<TensorType>())
+      return getContiguousOrUnrankedMemRefType(tensorType);
+    return t;
+  };
+  auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite));
+  auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite));
+  return FunctionType::get(ctx, argTypes, retTypes);
+}
+
+/// If an entry for `funcOp` is available in `bufferizedFunctionTypes`, return
+/// it. Otherwise, construct a new entry based on `argumentTypes` and
+/// `resultTypes`.
+// TODO: improve the layering.
+static FunctionType getOrCreateBufferizedFunctionType(
+    FuncOp funcOp, TypeRange argumentTypes, TypeRange resultTypes,
+    DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+  auto it = bufferizedFunctionTypes.find(funcOp);
+  if (it != bufferizedFunctionTypes.end())
+    return it->second;
+
+  auto it2 = bufferizedFunctionTypes.try_emplace(
+      funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes,
+                                        resultTypes));
+  LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n");
+  return it2.first->second;
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific scoped alloc/dealloc insertion support.
 //===----------------------------------------------------------------------===//
@@ -1159,8 +1269,10 @@ static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
 /// Create an Allocop/DeAllocOp pair, where the AllocOp is after
 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
 /// bbArg) and the DeallocOp is at the end of the block.
-static Value createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
-                                                     Value shapedValue) {
+static Value
+createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
+                                        Value shapedValue,
+                                        BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -1189,9 +1301,12 @@ static Value createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
       dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
 
   Value allocated = b.create<memref::AllocOp>(loc, allocMemRefType, dynShape);
+  aliasInfo.createAliasInfoEntry(allocated);
   Value casted = allocated;
-  if (memRefType != allocMemRefType)
+  if (memRefType != allocMemRefType) {
     casted = b.create<memref::CastOp>(loc, memRefType, allocated);
+    aliasInfo.insertNewBufferEquivalence(casted, allocated);
+  }
   b.setInsertionPoint(allocated.getParentBlock()->getTerminator());
   b.create<memref::DeallocOp>(loc, allocated);
   return casted;
@@ -1212,7 +1327,8 @@ static Value createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
 static LogicalResult
 allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
                           SmallVectorImpl<Value> &resultBuffers,
-                          BlockAndValueMapping &bvm) {
+                          BlockAndValueMapping &bvm,
+                          BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -1236,7 +1352,8 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
 
     // Otherwise, `op` is not inplaceable and we need to allocate its result.
     Value dimTensor = bvm.lookupOrDefault(output);
-    Value alloc = createNewAllocDeallocPairForShapedValue(b, loc, dimTensor);
+    Value alloc =
+        createNewAllocDeallocPairForShapedValue(b, loc, dimTensor, aliasInfo);
     b.setInsertionPointAfter(alloc.getDefiningOp());
     resultBuffers.push_back(alloc);
 
@@ -1258,7 +1375,7 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
 /// Generic conversion for any LinalgOp on tensors.
 static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
                                BlockAndValueMapping &bvm,
-                               const BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -1267,8 +1384,6 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
   if (!op.hasTensorSemantics())
     return failure();
 
-  LDBG("bufferize: " << *op << '\n');
-
   b.setInsertionPoint(op);
   Location loc = op.getLoc();
   SmallVector<Value> newInputBuffers;
@@ -1284,7 +1399,8 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
   }
   SmallVector<Value> newOutputBuffers;
   // Try to allocate new buffers depending on op's inplace semantics.
-  if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm)))
+  if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
+                                       aliasInfo)))
     return failure();
 
   // Clone the newly bufferized op.
@@ -1301,11 +1417,153 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
   return success();
 }
 
+/// In a first approximation, all the function arguments of a FuncOp are marked
+/// inplaceable. For now, it is the responsibility of the `callOp` bufferization
+/// to allow FuncOp that are inplaceable to write inPlace.
+static LogicalResult
+bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
+          BufferizationAliasInfo &aliasInfo,
+          DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+  FuncOp funcOp = getCalledFunction(callOp);
+  assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
+         "expected Callop to a FuncOp");
+
+  // If nothing to do then we are done.
+  if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
+      !llvm::any_of(funcOp.getType().getResults(), isaTensor))
+    return success();
+
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(callOp);
+
+  // 1. Filter return types:
+  //    - if the callee is bodiless / external, we cannot inspect it and we
+  //      cannot assume anything. We can just assert that it does not return a
+  //      tensor as this would have to bufferize to "return a memref", whose
+  //      semantics is ill-defined.
+  //    - if the callee has a body, we perform inter-procedural equivalence
+  //      analysis. When successful, a result folds onto an operand. When
+  //      unsuccessful, additional work is needed to either:
+  //        * hoist a result into an inplaceable operand or
+  //        * devise a better representation to truly return a buffer.
+  SmallVector<Type> resultTypes;
+  SmallVector<Value> hoistedArguments;
+  if (funcOp.body().empty()) {
+    if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
+      return callOp->emitError()
+             << "cannot bufferize bodiless function that returns a tensor";
+  } else {
+    ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+    if (!returnOp)
+      return funcOp->emitError() << "cannot bufferize a FuncOp with tensors "
+                                    "and without a unique ReturnOp";
+
+    // For each FuncOp result, keep track of which inplace argument it reuses.
+    for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+      Type returnType = returnOperand.get().getType();
+      if (!isaTensor(returnType)) {
+        resultTypes.push_back(returnType);
+        continue;
+      }
+
+      // If return operand is equivalent to some bbArg, no need to return it.
+      Value returnVal = returnOperand.get();
+      if (BlockArgument bbArg =
+              getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) {
+        Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
+        int64_t idx = bbArg.getArgNumber();
+        Value buffer = bvm.lookupOrNull(callOp->getOperand(idx));
+        if (!buffer)
+          return callOp->emitError() << "operand #" << idx << " not bufferized";
+        // Add CallOp operand/result equivalence: this is interprocedural info.
+        aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
+        map(bvm, oldRes, buffer);
+        // Add a TensorLoadOp to kill all uses of the CallOp return.
+        // Replace all uses of the CallOp results so we can erase the CallOp.
+        // This TensorLoadOp must fold/DCE away or bufferization should be
+        // considered failed.
+        Value tensorLoad =
+            b.create<memref::TensorLoadOp>(callOp.getLoc(), buffer);
+        oldRes.replaceAllUsesWith(tensorLoad);
+        // Add new op equivalence info.
+        aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
+        map(bvm, tensorLoad, buffer);
+        continue;
+      }
+
+      // TODO: Need to hoist above function boundary and add to
+      // `hoistedArgumentTypes`.
+      if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo))
+        return allocOp->emitError()
+               << " needs hoist across function boundary\n";
+
+      // Other cases legitimately need to return a tensor, this is currently not
+      // supported. For instance, if hoisting across function boundary has
+      // failed, it may be due to e.g. data-dependent sizes. In such a case, we
+      // would we need a better type than memref.
+      resultTypes.push_back(returnType);
+
+      int64_t returnIdx = returnOperand.getOperandNumber();
+      return returnOp->emitError()
+             << " bufferize result #" << returnIdx << "\n";
+    }
+  }
+
+  // 2. Compute bufferized FunctionType.
+  SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
+  llvm::append_range(argumentTypes, ValueRange{hoistedArguments}.getTypes());
+  // Get the bufferized FunctionType for funcOp or construct it if not yet
+  // available.
+  FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
+      funcOp, argumentTypes, resultTypes, bufferizedFunctionTypes);
+
+  // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
+  SmallVector<Value> newOperands;
+  newOperands.reserve(callOp->getNumOperands());
+  for (OpOperand &opOperand : callOp->getOpOperands()) {
+    Value tensorOperand = opOperand.get();
+    // Non-tensor operands are just copied.
+    if (!tensorOperand.getType().isa<TensorType>()) {
+      newOperands.push_back(tensorOperand);
+      continue;
+    }
+
+    // Tensor operands are guaranteed to have been buferized.
+    int64_t idx = opOperand.getOperandNumber();
+    Value buffer = bvm.lookupOrNull(tensorOperand);
+    assert(buffer && " missing buffer for operand");
+
+    // Caller / callee type mistmatch is handled with a CastOp.
+    auto memRefType = bufferizedFuncType.getInput(idx);
+    // Since we don't yet have a clear layout story, buffer_cast may
+    // conservatively turn tensors into more dynamic memref than necessary.
+    // If the memref type of the callee fails, introduce an extra memref.cast
+    // that will either canonicalize away or fail compilation until we can do
+    // something better.
+    if (buffer.getType() != memRefType) {
+      Value castBuffer =
+          b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
+      // Add new op equivalence info.
+      aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
+      map(bvm, tensorOperand, castBuffer);
+      buffer = castBuffer;
+    }
+    newOperands.push_back(buffer);
+  }
+
+  // 4. Create the new CallOp.
+  Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
+                                          resultTypes, newOperands);
+  newCallOp->setAttrs(callOp->getAttrs());
+  return success();
+}
+
 /// DimOp tensor operand is modified inplace. This allows leaving dead
 /// tensors behind that will get DCE'd.
 static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
                                BlockAndValueMapping &bvm,
-                               const BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo) {
   if (dimOp.source().getType().isa<RankedTensorType>()) {
     Value v = lookup(bvm, dimOp.source());
     if (!v)
@@ -1317,13 +1575,11 @@ static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
 
 static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
                                BlockAndValueMapping &bvm,
-                               const BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   Location loc = forOp.getLoc();
 
-  LLVM_DEBUG(DBGS() << "bufferize: " << *forOp << "\n");
-
   // If inPlace, just forward the buffer.
   // Otherwise alloc and copy.
   b.setInsertionPoint(forOp);
@@ -1337,11 +1593,12 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
     Value operandBuffer = lookup(bvm, operand);
     Value resultBuffer = operandBuffer;
     if (getInPlace(opResult) != InPlaceSpec::True) {
-      resultBuffer = createNewAllocDeallocPairForShapedValue(b, loc, operand);
+      resultBuffer =
+          createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo);
       // If the tensor comes from `linalg::InitTensorOp`, the value is
       // unitialized and we do not need to copy.
-      // TODO: if the matching bbArg does not bufferize to a read is more
-      // general.
+      // TODO: "matching bbArg does not bufferize to a read" is a more general
+      // check.
       if (!operand.getDefiningOp<linalg::InitTensorOp>())
         b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, resultBuffer);
     }
@@ -1356,7 +1613,7 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
 /// FuncOp always creates TensorToMemRef ops.
 static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
                                BlockAndValueMapping &bvm,
-                               const BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPointToStart(&funcOp.body().front());
@@ -1370,9 +1627,10 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
     Type memRefType = rankedTensorType
                           ? getDynamicMemRefType(rankedTensorType)
                           : getContiguousOrUnrankedMemRefType(tensorType);
-    Value tensorToMemref =
+    Value bufferCast =
         b.create<memref::BufferCastOp>(funcOp.getLoc(), memRefType, bbArg);
-    map(bvm, bbArg, tensorToMemref);
+    aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
+    map(bvm, bbArg, bufferCast);
   }
   return success();
 }
@@ -1380,7 +1638,7 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
 /// ReturnOp always creates memref::TensorLoadOp.
 static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
                                BlockAndValueMapping &bvm,
-                               const BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(returnOp);
@@ -1394,7 +1652,10 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
     Value v = lookup(bvm, operand.get());
     if (!v)
       return failure();
-    operand.set(b.create<memref::TensorLoadOp>(returnOp.getLoc(), v));
+    Value returnTensor = b.create<memref::TensorLoadOp>(returnOp.getLoc(), v);
+    operand.set(returnTensor);
+    aliasInfo.insertNewBufferEquivalence(returnTensor, v);
+    map(bvm, returnTensor, v);
   }
   return success();
 }
@@ -1406,7 +1667,7 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
 /// isolation.
 static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
                                BlockAndValueMapping &bvm,
-                               const BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo) {
   LDBG("bufferize: " << *extractSliceOp << '\n');
 
   // Take a guard before anything else.
@@ -1426,8 +1687,8 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
   Value alloc;
   auto inPlace = getInPlace(extractSliceOp->getResult(0));
   if (inPlace != InPlaceSpec::True) {
-    alloc = createNewAllocDeallocPairForShapedValue(b, loc,
-                                                    extractSliceOp.result());
+    alloc = createNewAllocDeallocPairForShapedValue(
+        b, loc, extractSliceOp.result(), aliasInfo);
     b.setInsertionPointAfter(alloc.getDefiningOp());
   }
 
@@ -1441,6 +1702,8 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
   Value subView = b.create<memref::SubViewOp>(
       loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
       extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
+  // Insert new alias.
+  aliasInfo.insertNewBufferAlias(subView, srcMemref);
 
   /// If not inplaceable, copy.
   if (alloc) {
@@ -1454,7 +1717,7 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
 
 static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
                                BlockAndValueMapping &bvm,
-                               const BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo) {
   LDBG("bufferize: " << *insertSliceOp << '\n');
 
   // Take a guard before anything else.
@@ -1472,8 +1735,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
     // cloning the whole tensor on every single iteration and is a symptom
     // of a catastrophically bad scheduling decision.
     // TODO: be very loud about it or even consider failing the pass.
-    Value newDstMemref =
-        createNewAllocDeallocPairForShapedValue(b, loc, insertSliceOp.result());
+    Value newDstMemref = createNewAllocDeallocPairForShapedValue(
+        b, loc, insertSliceOp.result(), aliasInfo);
     b.setInsertionPointAfter(newDstMemref.getDefiningOp());
     b.create<CopyOp>(insertSliceOp.getLoc(), dstMemref, newDstMemref);
     dstMemref = newDstMemref;
@@ -1503,6 +1766,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
     Value subView = b.create<memref::SubViewOp>(
         loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
         insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+    // Insert new alias.
+    aliasInfo.insertNewBufferAlias(subView, dstMemref);
     b.create<CopyOp>(insertSliceOp.getLoc(), srcMemref, subView);
   }
 
@@ -1513,7 +1778,7 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
 
 static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
                                BlockAndValueMapping &bvm,
-                               const BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
@@ -1522,8 +1787,6 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
   if (op.getShapedType().isa<MemRefType>())
     return failure();
 
-  LDBG("bufferize: " << *op << '\n');
-
   /// transfer_read from buffer always reads from the bufferized
   /// op.source().
   if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
@@ -1540,8 +1803,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
   // If transfer_write is not inPlace, allocate a new buffer.
   Value newInputBuffer;
   if (inPlace != InPlaceSpec::True) {
-    newInputBuffer =
-        createNewAllocDeallocPairForShapedValue(b, loc, writeOp.result());
+    newInputBuffer = createNewAllocDeallocPairForShapedValue(
+        b, loc, writeOp.result(), aliasInfo);
     b.setInsertionPointAfter(newInputBuffer.getDefiningOp());
     map(bvm, writeOp.result(), newInputBuffer);
   } else {
@@ -1567,7 +1830,7 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
 
 static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
                                BlockAndValueMapping &bvm,
-                               const BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(yieldOp);
@@ -1618,7 +1881,7 @@ bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp,
   // If `extractSliceOp` were to be bufferized inplace, it cannot end up
   // aliasing a write into a non-writeable buffer.
   bool wouldCreateAliasingWriteToNonWriteableBuffer =
-      aliasInfo.aliasesInPlaceWrite(extractSliceOp) &&
+      aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) &&
       aliasInfo.aliasesNonWriteableBuffer(extractSliceOp->getOpOperand(0));
 
   if (wouldCreateAliasingWriteToNonWriteableBuffer)
@@ -1743,7 +2006,6 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
       return extractSliceOps.push_back(extractSliceOp);
     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(op))
       return insertSliceOps.push_back(insertSliceOp);
-    auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
     // No tensors => no buffers.
     if (none_of(op->getOperandTypes(), isaTensor) &&
         none_of(op->getResultTypes(), isaTensor))
@@ -1792,12 +2054,12 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
 }
 
 //===----------------------------------------------------------------------===//
-// Bufferization entry-point.
+// Bufferization entry-point for functions.
 //===----------------------------------------------------------------------===//
 
-static LogicalResult
-bufferizeFuncOpInternals(FuncOp funcOp, BlockAndValueMapping &bvm,
-                         const BufferizationAliasInfo &aliasInfo) {
+static LogicalResult bufferizeFuncOpInternals(
+    FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
+    DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
   LLVM_DEBUG(llvm::dbgs() << "\n\n");
   LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
   OpBuilder b(funcOp->getContext());
@@ -1805,42 +2067,54 @@ bufferizeFuncOpInternals(FuncOp funcOp, BlockAndValueMapping &bvm,
   if (failed(bufferize(b, funcOp, bvm, aliasInfo)))
     return failure();
   // Walk in PreOrder to ensure ops with regions are handled before their body.
-  WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
-    LogicalResult status =
-        TypeSwitch<Operation *, LogicalResult>(op)
-            // Skip BufferCast and TensorLoad ops.
-            // clang-format off
-            .Case<memref::BufferCastOp,
-                  memref::TensorLoadOp>(
-                [&](auto) { return success(); })
-            .Case<scf::ForOp,
-                  tensor::DimOp,
-                  LinalgOp,
-                  ReturnOp,
-                  ExtractSliceOp,
-                  InsertSliceOp,
-                  VectorTransferOpInterface,
-                  scf::YieldOp>(
-                [&](auto op) {
-                  LDBG("Begin buferize:\n" << op << '\n');
-                  return bufferize(b, op, bvm, aliasInfo);
-                })
-            // clang-format on
-            .Default([&](Operation *op) {
-              auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
-              if (any_of(op->getOperandTypes(), isaTensor) ||
-                  any_of(op->getResultTypes(), isaTensor))
-                return failure();
-              return success();
-            });
-    if (failed(status)) {
-      op->emitError("Failed bufferization");
-      return WalkResult::interrupt();
-    }
-    return WalkResult::advance();
+  // Since walk has to be PreOrder, we need to erase ops that require it
+  // separately: this is the case for CallOp
+  SmallVector<Operation *> toErase;
+  WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op)
+                                                           -> WalkResult {
+    // clang-format off
+    WalkResult result =
+      TypeSwitch<Operation *, LogicalResult>(op)
+      // Skip BufferCast and TensorLoad ops.
+      .Case<memref::BufferCastOp,
+            memref::TensorLoadOp>([&](auto) { return success(); })
+      .Case<tensor::DimOp,
+            scf::ForOp,
+            LinalgOp,
+            ReturnOp,
+            ExtractSliceOp,
+            InsertSliceOp,
+            VectorTransferOpInterface,
+            scf::YieldOp>([&](auto op) {
+        LDBG("Begin bufferize:\n" << op << '\n');
+        return bufferize(b, op, bvm, aliasInfo);
+      })
+      .Case([&](CallOpInterface op) {
+        LDBG("Begin bufferize:\n" << op << '\n');
+        return bufferize(b, op, bvm, aliasInfo, bufferizedFunctionTypes);
+      })
+      .Default([&](Operation *op) {
+        auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
+        if (any_of(op->getOperandTypes(), isaTensor) ||
+            any_of(op->getResultTypes(), isaTensor))
+          return failure();
+        return success();
+      });
+    // clang-format on
+
+    // Register post-walk erasure, if necessary.
+    if (isa<CallOpInterface>(op))
+      if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
+          llvm::any_of(op->getResultTypes(), isaTensor))
+        toErase.push_back(op);
+
+    return result;
   });
   LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n');
 
+  for (Operation *op : toErase)
+    op->erase();
+
   return failure(result.wasInterrupted());
 }
 
@@ -1874,7 +2148,9 @@ void LinalgComprehensiveFuncBufferize::runOnFunction() {
 
   // Bufferization phase.
   BlockAndValueMapping bvm;
-  if (failed(bufferizeFuncOpInternals(funcOp, bvm, aliasInfo)))
+  DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
+  if (failed(bufferizeFuncOpInternals(funcOp, bvm, aliasInfo,
+                                      bufferizedFunctionTypes)))
     signalPassFailure();
 
   // Post-pass cleanup of inplaceable attributes.
@@ -1889,6 +2165,168 @@ std::unique_ptr<Pass> mlir::createLinalgComprehensiveFuncBufferizePass() {
 // Bufferization entry-point for modules.
 //===----------------------------------------------------------------------===//
 
+/// Return the op with Allocate MemoryEffect if `v` is equivalent to an such
+/// an op. Return null otherwise.
+static Operation *getEquivalentAlloc(Value value,
+                                     const BufferizationAliasInfo &aliasInfo) {
+  Operation *res;
+  aliasInfo.applyOnEquivalenceClass(value, [&](Value v) {
+    if (!res)
+      if (auto interface =
+              dyn_cast_or_null<MemoryEffectOpInterface>(v.getDefiningOp()))
+        if (auto effect =
+                interface.getEffectOnValue<MemoryEffects::Allocate>(value))
+          res = v.getDefiningOp();
+  });
+  return res;
+}
+
+/// Return the first argument of the enclosing FuncOp that is equivalent to `v`.
+/// Return null if no such bbArg can be found.
+static BlockArgument
+getEquivalentEnclosingFuncBBArg(Value v,
+                                const BufferizationAliasInfo &aliasInfo) {
+  Operation *op = v.getParentBlock()->getParentOp();
+  FuncOp funcOp = dyn_cast<FuncOp>(op);
+  if (!funcOp)
+    funcOp = op->getParentOfType<FuncOp>();
+  assert(funcOp && "expected non-null FuncOp");
+  for (BlockArgument bbArg : funcOp.getArguments())
+    if (aliasInfo.areEquivalentBufferizedValues(v, bbArg))
+      return bbArg;
+  return nullptr;
+}
+
+/// Rewrite the `funcOp` arguments analysis return values and terminator into
+/// buffer form (using the canonical memref layout for now), according to the
+/// inPlace-bufferizable information of the function arguments.
+/// This relies on a buffer equivalence analysis of each return operand. When a
+/// result buffer is equivalent to:
+///   1. a BlockArgument of `funcOp`, it can be dropped from the return values
+///      and becomes inplaceable at all callers. This assumes all CallOp perform
+///      the necessary work to clone operands so as to make them inplaceable.
+//       Reliance on this logic will need to be relaxed in thefuture.
+///   2. an op with an Alloc effect, this currently fails bufferization but is a
+///      candidate for hoisting and creating a new inplace operand at all caller
+///      sites.
+///   3. if such a hoisting for 2. is not possible (e.g. data-dependent that
+///      prevents hoisting), this is currently unsupported and will require a
+///      refcounted buffer type.
+static LogicalResult bufferizeFuncOpBoundary(
+    FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+    DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+  LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
+
+  // If nothing to do then we are done.
+  if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
+      !llvm::any_of(funcOp.getType().getResults(), isaTensor))
+    return success();
+
+  // Get the bufferized FunctionType for funcOp or construct it if not yet
+  // available.
+  // TODO: Atm we have 3 cases:
+  // 1. if a function is called from within the Module, it must have bufferized
+  //    to inplaceable tensor results.
+  // 2. if it is bodiless, it must have bufferized and is not allowed to have
+  //    result tensors.
+  // 3. if it is not called internally, it still must bufferize to inplaceable
+  //    tensor results and we construct it now (e.g. top-level function called
+  //    externally).
+  // -> Figure out a better layering.
+  TypeRange resultTypes;
+  FunctionType bufferizedFuncType =
+      getOrCreateBufferizedFunctionType(funcOp, funcOp.getType().getInputs(),
+                                        resultTypes, bufferizedFunctionTypes);
+
+  // Corner case: Bodiless FuncOp
+  // ============================
+  // The body of such functions is assumed opaque and we can't know the
+  // bufferization contract they want to enforce atm.
+  // As a consequence, only support functions that don't return any tensor atm.
+  if (funcOp.getBody().empty()) {
+    if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
+      return funcOp->emitError() << "cannot bufferize bodiless function that "
+                                 << "returns a tensor";
+    funcOp.setType(bufferizedFuncType);
+    LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
+    return success();
+  }
+
+  // Support only single return-terminated block in the function.
+  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+  if (!returnOp)
+    return funcOp->emitError() << "cannot bufferize a FuncOp with tensors and "
+                                  "without a unique ReturnOp";
+
+  // 1. For each FuncOp result, keep track of which inplace argument it reuses.
+  SmallVector<Value> returnValues;
+  for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+    // If return operand is equivalent to some bbArg, no need to return it.
+    Value returnVal = returnOperand.get();
+    if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo))
+      continue;
+    // TODO: Need to hoist above function boundary. If this is not possible due
+    // to data-depedent sizes, we need a better type than memref.
+    if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo))
+      return allocOp->emitError() << " needs hoist across function boundary\n";
+    int64_t returnIdx = returnOperand.getOperandNumber();
+    return returnOp->emitError() << " bufferize result #" << returnIdx << "\n";
+  }
+
+  // 2. Rewrite the terminator without the inPlace bufferizable values.
+  OpBuilder(returnOp).create<ReturnOp>(returnOp.getLoc(), returnValues);
+  returnOp->erase();
+
+  // 3. Rewrite the bbArgs.
+  // Iterate on the original `numArgs` and replace them in order.
+  // This guarantees the argument order still matches after the rewrite.
+  Block &frontBlock = funcOp.body().front();
+  unsigned numArgs = frontBlock.getNumArguments();
+  for (unsigned idx = 0; idx < numArgs; ++idx) {
+    auto bbArg = frontBlock.getArgument(0);
+    auto tensorType = bbArg.getType().dyn_cast<TensorType>();
+    // Non-tensor types are just forwarded.
+    if (!tensorType) {
+      frontBlock.addArgument(bbArg.getType());
+      bbArg.replaceAllUsesWith(frontBlock.getArguments().back());
+      frontBlock.eraseArgument(0);
+      continue;
+    }
+
+    // Get the buffer type from the bufferized function type.
+    Type memrefType = bufferizedFuncType.getInput(idx);
+    Value memref = frontBlock.addArgument(memrefType);
+    OpBuilder b(funcOp->getContext());
+    b.setInsertionPointToStart(&frontBlock);
+    // Replace all uses of bbArg through a BufferCastOp by a memref::CastOp.
+    for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
+      if (auto bufferCastOp = dyn_cast<memref::BufferCastOp>(use.getOwner())) {
+        auto castOp = b.create<memref::CastOp>(
+            funcOp.getLoc(), bufferCastOp.memref().getType(), memref);
+        bufferCastOp.memref().replaceAllUsesWith(castOp);
+        aliasInfo.insertNewBufferEquivalence(castOp.dest(),
+                                             bufferCastOp.memref());
+      }
+    }
+    // Replace all remaining uses by a tensor_load.
+    if (!bbArg.use_empty()) {
+      auto tensorLoadOp =
+          b.create<memref::TensorLoadOp>(funcOp.getLoc(), memref);
+      aliasInfo.insertNewBufferEquivalence(tensorLoadOp, bbArg);
+      bbArg.replaceAllUsesWith(tensorLoadOp);
+    }
+    frontBlock.eraseArgument(0);
+    // TODO: add support to erase aliasInfo entries if deemed necessary.
+  }
+
+  // 4. Rewrite the FuncOp type to buffer form.
+  funcOp.setType(bufferizedFuncType);
+
+  LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp);
+
+  return success();
+}
+
 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
 /// callee-caller order (i.e. callees without callers first).
 /// Store the map of FuncOp to all its callers in `callerMap`.
@@ -1905,10 +2343,12 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
   DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
   WalkResult res = moduleOp.walk([&](FuncOp funcOp) {
     numberCallOpsContainedInFuncOp[funcOp] = 0;
-    return funcOp.walk([&](CallOpInterface callOp) {
+    return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
+      // Only support CallOp for now.
+      if (!isa<CallOp>(callOp.getOperation()))
+        return callOp->emitError() << "expected a CallOp";
       FuncOp calledFunction = getCalledFunction(callOp);
-      if (!calledFunction)
-        return WalkResult::interrupt();
+      assert(calledFunction && "could not retrieved called FuncOp");
       auto it = callerMap.try_emplace(calledFunction, DenseSet<Operation *>{});
       it.first->getSecond().insert(callOp);
       if (calledBy[calledFunction].count(funcOp) == 0) {
@@ -1954,6 +2394,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
 
   SmallVector<FuncOp> orderedFuncOps;
   DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
+  DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
     return signalPassFailure();
 
@@ -1985,12 +2426,30 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
       return;
     }
 
-    // TODO: Bufferization phase.
+    // Bufferization phase.
+    if (!testAnalysisOnly) {
+      BlockAndValueMapping tensorToBufferMap;
+      if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
+                                          bufferizedFunctionTypes))) {
+        signalPassFailure();
+        return;
+      }
+    }
   }
   // Don't drop the attributes if we only want to report the analysis.
   if (testAnalysisOnly)
     return;
 
+  for (FuncOp funcOp : orderedFuncOps) {
+    // Note: It would be good to apply cleanups here but we cannot as aliasInfo
+    // would be invalidated.
+    if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo,
+                                       bufferizedFunctionTypes))) {
+      signalPassFailure();
+      return;
+    }
+  }
+
   // Post-pass cleanup of inplaceable attributes.
   moduleOp.walk(
       [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); });
@@ -1998,6 +2457,12 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     for (BlockArgument bbArg : op.getArguments())
       removeInPlaceFuncArgument(bbArg);
   });
+
+  OpPassManager cleanupPipeline(OpPassManager("module"));
+  cleanupPipeline.addPass(createCanonicalizerPass());
+  cleanupPipeline.addPass(createCSEPass());
+  cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
+  (void)runPipeline(cleanupPipeline, moduleOp);
 }
 
 std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 0e378a89ef58..d6a6d7c67f6c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -1,5 +1,36 @@
 // RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file -verify-diagnostics
 
+func private @foo() -> tensor<?xf32>
+
+func @bar() -> tensor<?xf32> {
+  %foo = constant @foo : () -> (tensor<?xf32>)
+// expected-error @+1 {{expected a CallOp}}
+  %res = call_indirect %foo() : () -> (tensor<?xf32>)
+  return %res : tensor<?xf32>
+}
+
+// -----
+
+// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
+func private @foo() -> tensor<?xf32>
+
+// -----
+
+// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
+func @switch(%flag : i32, %caseOperand : i32, %t1 : tensor<f32>, %t2 : tensor<f32>)
+    -> (tensor<f32>) 
+{
+  switch %flag : i32, [
+    default: ^bb1(%caseOperand : i32),
+    42: ^bb2(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return %t1 : tensor<f32>
+  ^bb2(%bb2arg : i32):
+    return %t2 : tensor<f32>
+}
+
 // -----
 
 // expected-error @-3 {{expected callgraph to be free of circular dependencies}}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
new file mode 100644
index 000000000000..7756587560ea
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file | FileCheck %s
+
+//      CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+//      CHECK:  func private @some_external_func(memref<?xf32, #[[$DYN_1D_MAP]]>)
+func private @some_external_func(tensor<?xf32>)
+
+//      CHECK:  func @scf_for_with_tensor_insert_slice(
+// CHECK-SAME:    %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME:    %[[B:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME:    %[[C:[a-zA-Z0-9]*]]: memref<4xf32, #[[$DYN_1D_MAP]]>
+func @scf_for_with_tensor_insert_slice(
+    %A : tensor<?xf32>, %B : tensor<?xf32>, %C : tensor<4xf32>,
+    %lb : index, %ub : index, %step : index)
+  -> (tensor<?xf32>, tensor<?xf32>)
+{
+  // CHECK-NEXT: scf.for
+  %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
+      -> (tensor<?xf32>, tensor<?xf32>)
+  {
+    // CHECK-NEXT:   %[[SVA:.*]] = memref.subview %[[A]]
+    // CHECK-NEXT:   linalg.copy(%[[C]], %[[SVA]]) : memref<4xf32, #[[$DYN_1D_MAP]]>, memref<4xf32, #[[$DYN_1D_MAP]]>
+    %ttA = tensor.insert_slice %C into %tA[%i][4][1] : tensor<4xf32> into tensor<?xf32>
+
+    // CHECK-NEXT:   %[[SVB:.*]] = memref.subview %[[B]]
+    // CHECK-NEXT:   linalg.copy(%[[C]], %[[SVB]]) : memref<4xf32, #[[$DYN_1D_MAP]]>, memref<4xf32, #[[$DYN_1D_MAP]]>
+    %ttB = tensor.insert_slice %C into %tB[%i][4][1] : tensor<4xf32> into tensor<?xf32>
+
+    // scf.yield is empty and is elided
+    //  CHECK-NOT:   scf.yield
+    scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
+  }
+
+  // Swaparoo requires bufferizing the whole function to figure out who's who.
+  return %r0#1, %r0#0: tensor<?xf32>, tensor<?xf32>
+}
+
+//      CHECK:  func @bar(
+// CHECK-SAME:    %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME:    %[[B:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME:    %[[C:[a-zA-Z0-9]*]]: memref<4xf32, #[[$DYN_1D_MAP]]>
+func @bar(
+    %A : tensor<?xf32> {linalg.inplaceable = true},
+    %B : tensor<?xf32> {linalg.inplaceable = true},
+    %C : tensor<4xf32> {linalg.inplaceable = true},
+    %lb : index, %ub : index, %step : index)
+  -> (tensor<?xf32>, tensor<?xf32>)
+{
+// CHECK-NEXT:   call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]]
+  %r0:2 = call @scf_for_with_tensor_insert_slice(%A, %B, %C, %lb, %ub, %step) :
+      (tensor<?xf32>, tensor<?xf32>, tensor<4xf32>, index, index, index)
+        -> (tensor<?xf32>, tensor<?xf32>)
+
+  // %r0#0 is actually %B after inplaceable results are swapped in the callee.
+// CHECK-NEXT:   call @some_external_func(%[[B]]) : (memref<?xf32, #[[$DYN_1D_MAP]]>) -> ()
+  call @some_external_func(%r0#0) : (tensor<?xf32>) -> ()
+
+// CHECK-NEXT:   return
+  return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list