[Mlir-commits] [mlir] [mlir][bufferization] BufferDeallocationOpInterface: support custom ownership update logic (PR #66350)

Martin Erhart llvmlistbot at llvm.org
Thu Sep 14 05:01:35 PDT 2023


https://github.com/maerhart updated https://github.com/llvm/llvm-project/pull/66350:

>From c03efd8768c97c04dff41b1590a6048b8c3f0f4a Mon Sep 17 00:00:00 2001
From: Martin Erhart <merhart at google.com>
Date: Tue, 12 Sep 2023 15:08:55 +0000
Subject: [PATCH 1/2] [mlir][bufferization][NFC] Introduce
 BufferDeallocationOpInterface

This new interface allows operations to implement custom handling of ownership
values and insertion of dealloc operations which is useful when an op cannot
implement the interfaces supported by default by the buffer deallocation pass
(e.g., because they are not exactly compatible or because there are some
additional semantics to it that would render the default implementations in
buffer deallocation invalid, or because no interfaces exist for this kind of
behavior and it's not worth introducing one plus a default implementation in
buffer deallocation). Additionally, it can also be used to provide more
efficient handling for a specific op than the interface based default
implementations can.
---
 .../IR/BufferDeallocationOpInterface.h        | 217 ++++++++
 .../IR/BufferDeallocationOpInterface.td       |  46 ++
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../Bufferization/Transforms/BufferUtils.h    |   8 -
 .../BufferDeallocationOpInterfaceImpl.h       |  22 +
 mlir/include/mlir/InitAllDialects.h           |   2 +
 .../IR/BufferDeallocationOpInterface.cpp      | 274 ++++++++++
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../Bufferization/Transforms/BufferUtils.cpp  |  59 --
 .../Bufferization/Transforms/CMakeLists.txt   |   1 -
 .../OwnershipBasedBufferDeallocation.cpp      | 515 +++---------------
 .../BufferDeallocationOpInterfaceImpl.cpp     | 163 ++++++
 .../ControlFlow/Transforms/CMakeLists.txt     |   3 +-
 .../dealloc-region-branchop-interface.mlir    |   4 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |   4 +
 15 files changed, 798 insertions(+), 522 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
 create mode 100644 mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
 create mode 100644 mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
new file mode 100644
index 000000000000000..b88270f1c150a27
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
@@ -0,0 +1,217 @@
+//===- BufferDeallocationOpInterface.h --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_
+
+#include "mlir/Analysis/Liveness.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+namespace bufferization {
+
+/// Compare two SSA values in a deterministic manner. Two block arguments are
+/// ordered by argument number, block arguments are always less than operation
+/// results, and operation results are ordered by the `isBeforeInBlock` order of
+/// their defining operation.
+struct ValueComparator {
+  bool operator()(const Value &lhs, const Value &rhs) const;
+};
+
+/// This class is used to track the ownership of values. The ownership can
+/// either be not initialized yet ('Uninitialized' state), set to a unique SSA
+/// value which indicates the ownership at runtime (or statically if it is a
+/// constant value) ('Unique' state), or it cannot be represented in a single
+/// SSA value ('Unknown' state). An artificial example of a case where ownership
+/// cannot be represented in a single i1 SSA value could be the following:
+/// `%0 = test.non_deterministic_select %arg0, %arg1 : i32`
+/// Since the operation does not provide us a separate boolean indicator on
+/// which of the two operands was selected, we would need to either insert an
+/// alias check at runtime to determine if `%0` aliases with `%arg0` or `%arg1`,
+/// or insert a `bufferization.clone` operation to get a fresh buffer which we
+/// could assign ownership to.
+///
+/// The three states this class can represent form a lattice on a partial order:
+/// forall X in SSA values. uninitialized < unique(X) < unknown
+/// forall X, Y in SSA values.
+///   unique(X) == unique(Y) iff X and Y always evaluate to the same value
+///   unique(X) != unique(Y) otherwise
+class Ownership {
+public:
+  /// Constructor that creates an 'Uninitialized' ownership. This is needed for
+  /// default-construction when used in DenseMap.
+  Ownership() = default;
+
+  /// Constructor that creates an 'Unique' ownership. This is a non-explicit
+  /// constructor to allow implicit conversion from 'Value'.
+  Ownership(Value indicator);
+
+  /// Get an ownership value in 'Unknown' state.
+  static Ownership getUnknown();
+  /// Get an ownership value in 'Unique' state with 'indicator' as parameter.
+  static Ownership getUnique(Value indicator);
+  /// Get an ownership value in 'Uninitialized' state.
+  static Ownership getUninitialized();
+
+  /// Check if this ownership value is in the 'Uninitialized' state.
+  bool isUninitialized() const;
+  /// Check if this ownership value is in the 'Unique' state.
+  bool isUnique() const;
+  /// Check if this ownership value is in the 'Unknown' state.
+  bool isUnknown() const;
+
+  /// If this ownership value is in 'Unique' state, this function can be used to
+  /// get the indicator parameter. Using this function in any other state is UB.
+  Value getIndicator() const;
+
+  /// Get the join of the two-element subset {this,other}. Does not modify
+  /// 'this'.
+  Ownership getCombined(Ownership other) const;
+
+  /// Modify 'this' ownership to be the join of the current 'this' and 'other'.
+  void combine(Ownership other);
+
+private:
+  enum class State {
+    Uninitialized,
+    Unique,
+    Unknown,
+  };
+
+  // The indicator value is only relevant in the 'Unique' state.
+  Value indicator;
+  State state = State::Uninitialized;
+};
+
+/// Options for BufferDeallocationOpInterface-based buffer deallocation.
+struct DeallocationOptions {
+  // A pass option indicating whether private functions should be modified to
+  // pass the ownership of MemRef values instead of adhering to the function
+  // boundary ABI.
+  bool privateFuncDynamicOwnership = false;
+};
+
+/// This class collects all the state that we need to perform the buffer
+/// deallocation pass with associated helper functions such that we have easy
+/// access to it in the BufferDeallocationOpInterface implementations and the
+/// BufferDeallocation pass.
+class DeallocationState {
+public:
+  DeallocationState(Operation *op);
+
+  // The state should always be passed by reference.
+  DeallocationState(const DeallocationState &) = delete;
+
+  /// Small helper function to update the ownership map by taking the current
+  /// ownership ('Uninitialized' state if not yet present), computing the join
+  /// with the passed ownership and storing this new value in the map. By
+  /// default, it will be performed for the block where 'owned' is defined. If
+  /// the ownership of the given value should be updated for another block, the
+  /// 'block' argument can be explicitly passed.
+  void updateOwnership(Value memref, Ownership ownership,
+                       Block *block = nullptr);
+
+  /// Removes ownerships associated with all values in the passed range for
+  /// 'block'.
+  void resetOwnerships(ValueRange memrefs, Block *block);
+
+  /// Returns the ownership of 'memref' for the given basic block.
+  Ownership getOwnership(Value memref, Block *block) const;
+
+  /// Remember the given 'memref' to deallocate it at the end of the 'block'.
+  void addMemrefToDeallocate(Value memref, Block *block);
+
+  /// Forget about a MemRef that we originally wanted to deallocate at the end
+  /// of 'block', possibly because it already gets deallocated before the end of
+  /// the block.
+  void dropMemrefToDeallocate(Value memref, Block *block);
+
+  /// Return a sorted list of MemRef values which are live at the start of the
+  /// given block.
+  void getLiveMemrefsIn(Block *block, SmallVectorImpl<Value> &memrefs);
+
+  /// Given an SSA value of MemRef type, this function queries the ownership and
+  /// if it is not already in the 'Unique' state, potentially inserts IR to get
+  /// a new SSA value, returned as the first element of the pair, which has
+  /// 'Unique' ownership and can be used instead of the passed Value with the
+  /// the ownership indicator returned as the second element of the pair.
+  std::pair<Value, Value> getMemrefWithUniqueOwnership(OpBuilder &builder,
+                                                       Value memref);
+
+  /// Given two basic blocks and the values passed via block arguments to the
+  /// destination block, compute the list of MemRefs that have to be retained in
+  /// the 'fromBlock' to not run into a use-after-free situation.
+  /// This list consists of the MemRefs in the successor operand list of the
+  /// terminator and the MemRefs in the 'out' set of the liveness analysis
+  /// intersected with the 'in' set of the destination block.
+  ///
+  /// toRetain = filter(successorOperands + (liveOut(fromBlock) insersect
+  ///   liveIn(toBlock)), isMemRef)
+  void getMemrefsToRetain(Block *fromBlock, Block *toBlock,
+                          ValueRange destOperands,
+                          SmallVectorImpl<Value> &toRetain) const;
+
+  /// For a given block, computes the list of MemRefs that potentially need to
+  /// be deallocated at the end of that block. This list also contains values
+  /// that have to be retained (and are thus part of the list returned by
+  /// `getMemrefsToRetain`) and is computed by taking the MemRefs in the 'in'
+  /// set of the liveness analysis of 'block'  appended by the set of MemRefs
+  /// allocated in 'block' itself and subtracted by the set of MemRefs
+  /// deallocated in 'block'.
+  /// Note that we don't have to take the intersection of the liveness 'in' set
+  /// with the 'out' set of the predecessor block because a value that is in the
+  /// 'in' set must be defined in an ancestor block that dominates all direct
+  /// predecessors and thus the 'in' set of this block is a subset of the 'out'
+  /// sets of each predecessor.
+  ///
+  /// memrefs = filter((liveIn(block) U
+  ///   allocated(block) U arguments(block)) \ deallocated(block), isMemRef)
+  ///
+  /// The list of conditions is then populated by querying the internal
+  /// datastructures for the ownership value of that MemRef.
+  LogicalResult
+  getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc,
+                                      Block *block,
+                                      SmallVectorImpl<Value> &memrefs,
+                                      SmallVectorImpl<Value> &conditions) const;
+
+  /// Returns the symbol cache to lookup functions from call operations to check
+  /// attributes on the function operation.
+  SymbolTableCollection *getSymbolTable() { return &symbolTable; }
+
+private:
+  // Symbol cache to lookup functions from call operations to check attributes
+  // on the function operation.
+  SymbolTableCollection symbolTable;
+
+  // Mapping from each SSA value with MemRef type to the associated ownership in
+  // each block.
+  DenseMap<std::pair<Value, Block *>, Ownership> ownershipMap;
+
+  // Collects the list of MemRef values that potentially need to be deallocated
+  // per block. It is also fine (albeit not efficient) to add MemRef values that
+  // don't have to be deallocated, but only when the ownership is not 'Unknown'.
+  DenseMap<Block *, SmallVector<Value>> memrefsToDeallocatePerBlock;
+
+  // The underlying liveness analysis to compute fine grained information about
+  // alloc and dealloc positions.
+  Liveness liveness;
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Buffer Deallocation Interface
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
new file mode 100644
index 000000000000000..c35fe417184ffd4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
@@ -0,0 +1,46 @@
+//===-- BufferDeallocationOpInterface.td -------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUFFER_DEALLOCATION_OP_INTERFACE
+#define BUFFER_DEALLOCATION_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def BufferDeallocationOpInterface :
+    OpInterface<"BufferDeallocationOpInterface"> {
+  let description = [{
+    An op interface for Buffer Deallocation. Ops that implement this interface
+    can provide custom logic for computing the ownership of OpResults, modify
+    the operation to properly pass the ownership values around, and insert
+    `bufferization.dealloc` operations when necessary.
+  }];
+  let cppNamespace = "::mlir::bufferization";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          This method takes the current deallocation state and transformation
+          options and updates the deallocation state as necessary for the
+          operation implementing this interface. It may also insert
+          `bufferization.dealloc` operations and rebuild itself with different
+          result types. For operations implementing this interface all other
+          interface handlers (e.g., default handlers for interfaces like
+          RegionBranchOpInterface, CallOpInterface, etc.) are skipped by the
+          deallocation pass. On success, either the current operation or one of
+          the newly inserted operations is returned from which on the driver
+          should continue the processing. On failure, the deallocation pass
+          will terminate. It is recommended to emit a useful error message in
+          that case.
+        }],
+        /*retType=*/"FailureOr<Operation *>",
+        /*methodName=*/"process",
+        /*args=*/(ins "DeallocationState &":$state,
+                      "const DeallocationOptions &":$options)>
+  ];
+}
+
+#endif  // BUFFER_DEALLOCATION_OP_INTERFACE
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 440125031b1acc5..38057d4910d2958 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect(BufferizationOps bufferization)
 add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
 add_mlir_interface(AllocationOpInterface)
+add_mlir_interface(BufferDeallocationOpInterface)
 add_mlir_interface(BufferizableOpInterface)
 add_mlir_interface(SubsetInsertionOpInterface)
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index 83e55fd70de6bb8..85e9c47ad5302cb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -121,14 +121,6 @@ class BufferPlacementTransformationBase {
   Liveness liveness;
 };
 
-/// Compare two SSA values in a deterministic manner. Two block arguments are
-/// ordered by argument number, block arguments are always less than operation
-/// results, and operation results are ordered by the `isBeforeInBlock` order of
-/// their defining operation.
-struct ValueComparator {
-  bool operator()(const Value &lhs, const Value &rhs) const;
-};
-
 // Create a global op for the given tensor-valued constant in the program.
 // Globals are created lazily at the top of the enclosing ModuleOp with pretty
 // names. Duplicates are avoided.
diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h
new file mode 100644
index 000000000000000..c34ebd0494fec89
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h
@@ -0,0 +1,22 @@
+//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_CONTROLFLOW_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
+#define MLIR_DIALECT_CONTROLFLOW_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace cf {
+void registerBufferDeallocationOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace cf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_CONTROLFLOW_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 6eaa0cc0d46aadb..ee91bfa57d12a39 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -29,6 +29,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
 #include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
@@ -138,6 +139,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
       registry);
   builtin::registerCastOpInterfaceExternalModels(registry);
   cf::registerBufferizableOpInterfaceExternalModels(registry);
+  cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
   linalg::registerBufferizableOpInterfaceExternalModels(registry);
   linalg::registerTilingInterfaceExternalModels(registry);
   linalg::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
new file mode 100644
index 000000000000000..2314cee2ff2c158
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -0,0 +1,274 @@
+//===- BufferDeallocationOpInterface.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/SetOperations.h"
+
+//===----------------------------------------------------------------------===//
+// BufferDeallocationOpInterface
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace bufferization {
+
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
+
+} // namespace bufferization
+} // namespace mlir
+
+using namespace mlir;
+using namespace bufferization;
+
+//===----------------------------------------------------------------------===//
+// Helpers
+//===----------------------------------------------------------------------===//
+
+static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
+  return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
+}
+
+static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
+
+//===----------------------------------------------------------------------===//
+// Ownership
+//===----------------------------------------------------------------------===//
+
+Ownership::Ownership(Value indicator)
+    : indicator(indicator), state(State::Unique) {}
+
+Ownership Ownership::getUnknown() {
+  Ownership unknown;
+  unknown.indicator = Value();
+  unknown.state = State::Unknown;
+  return unknown;
+}
+Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); }
+Ownership Ownership::getUninitialized() { return Ownership(); }
+
+bool Ownership::isUninitialized() const {
+  return state == State::Uninitialized;
+}
+bool Ownership::isUnique() const { return state == State::Unique; }
+bool Ownership::isUnknown() const { return state == State::Unknown; }
+
+Value Ownership::getIndicator() const {
+  assert(isUnique() && "must have unique ownership to get the indicator");
+  return indicator;
+}
+
+Ownership Ownership::getCombined(Ownership other) const {
+  if (other.isUninitialized())
+    return *this;
+  if (isUninitialized())
+    return other;
+
+  if (!isUnique() || !other.isUnique())
+    return getUnknown();
+
+  // Since we create a new constant i1 value for (almost) each use-site, we
+  // should compare the actual value rather than just the SSA Value to avoid
+  // unnecessary invalidations.
+  if (isEqualConstantIntOrValue(indicator, other.indicator))
+    return *this;
+
+  // Return the join of the lattice if the indicator of both ownerships cannot
+  // be merged.
+  return getUnknown();
+}
+
+void Ownership::combine(Ownership other) { *this = getCombined(other); }
+
+//===----------------------------------------------------------------------===//
+// DeallocationState
+//===----------------------------------------------------------------------===//
+
+DeallocationState::DeallocationState(Operation *op) : liveness(op) {}
+
+void DeallocationState::updateOwnership(Value memref, Ownership ownership,
+                                        Block *block) {
+  // In most cases we care about the block where the value is defined.
+  if (block == nullptr)
+    block = memref.getParentBlock();
+
+  // Update ownership of current memref itself.
+  ownershipMap[{memref, block}].combine(ownership);
+}
+
+void DeallocationState::resetOwnerships(ValueRange memrefs, Block *block) {
+  for (Value val : memrefs)
+    ownershipMap[{val, block}] = Ownership::getUninitialized();
+}
+
+Ownership DeallocationState::getOwnership(Value memref, Block *block) const {
+  return ownershipMap.lookup({memref, block});
+}
+
+void DeallocationState::addMemrefToDeallocate(Value memref, Block *block) {
+  memrefsToDeallocatePerBlock[block].push_back(memref);
+}
+
+void DeallocationState::dropMemrefToDeallocate(Value memref, Block *block) {
+  llvm::erase_if(memrefsToDeallocatePerBlock[block],
+                 [&](const auto &mr) { return mr == memref; });
+}
+
+void DeallocationState::getLiveMemrefsIn(Block *block,
+                                         SmallVectorImpl<Value> &memrefs) {
+  SmallVector<Value> liveMemrefs(
+      llvm::make_filter_range(liveness.getLiveIn(block), isMemref));
+  llvm::sort(liveMemrefs, ValueComparator());
+  memrefs.append(liveMemrefs);
+}
+
+std::pair<Value, Value>
+DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
+                                                Value memref) {
+  auto iter = ownershipMap.find({memref, memref.getParentBlock()});
+  assert(iter != ownershipMap.end() &&
+         "Value must already have been registered in the ownership map");
+
+  Ownership ownership = iter->second;
+  if (ownership.isUnique())
+    return {memref, ownership.getIndicator()};
+
+  // Instead of inserting a clone operation we could also insert a dealloc
+  // operation earlier in the block and use the updated ownerships returned by
+  // the op for the retained values. Alternatively, we could insert code to
+  // check aliasing at runtime and use this information to combine two unique
+  // ownerships more intelligently to not end up with an 'Unknown' ownership in
+  // the first place.
+  auto cloneOp =
+      builder.create<bufferization::CloneOp>(memref.getLoc(), memref);
+  Value condition = buildBoolValue(builder, memref.getLoc(), true);
+  Value newMemref = cloneOp.getResult();
+  updateOwnership(newMemref, condition);
+  memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
+  return {newMemref, condition};
+}
+
+void DeallocationState::getMemrefsToRetain(
+    Block *fromBlock, Block *toBlock, ValueRange destOperands,
+    SmallVectorImpl<Value> &toRetain) const {
+  for (Value operand : destOperands) {
+    if (!isMemref(operand))
+      continue;
+    toRetain.push_back(operand);
+  }
+
+  SmallPtrSet<Value, 16> liveOut;
+  for (auto val : liveness.getLiveOut(fromBlock))
+    if (isMemref(val))
+      liveOut.insert(val);
+
+  if (toBlock)
+    llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock));
+
+  // liveOut has non-deterministic order because it was constructed by iterating
+  // over a hash-set.
+  SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
+  std::sort(retainedByLiveness.begin(), retainedByLiveness.end(),
+            ValueComparator());
+  toRetain.append(retainedByLiveness);
+}
+
+LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
+    OpBuilder &builder, Location loc, Block *block,
+    SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
+
+  for (auto [i, memref] :
+       llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {
+    Ownership ownership = ownershipMap.lookup({memref, block});
+    if (!ownership.isUnique())
+      return emitError(memref.getLoc(),
+                       "MemRef value does not have valid ownership");
+
+    // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
+    // that we can call extract_strided_metadata on it.
+    if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
+      memref = builder.create<memref::ReinterpretCastOp>(
+          loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
+          0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+
+    // Use the `memref.extract_strided_metadata` operation to get the base
+    // memref. This is needed because the same MemRef that was produced by the
+    // alloc operation has to be passed to the dealloc operation. Passing
+    // subviews, etc. to a dealloc operation is not allowed.
+    memrefs.push_back(
+        builder.create<memref::ExtractStridedMetadataOp>(loc, memref)
+            .getResult(0));
+    conditions.push_back(ownership.getIndicator());
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ValueComparator
+//===----------------------------------------------------------------------===//
+
+bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
+  if (lhs == rhs)
+    return false;
+
+  // Block arguments are less than results.
+  bool lhsIsBBArg = lhs.isa<BlockArgument>();
+  if (lhsIsBBArg != rhs.isa<BlockArgument>()) {
+    return lhsIsBBArg;
+  }
+
+  Region *lhsRegion;
+  Region *rhsRegion;
+  if (lhsIsBBArg) {
+    auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
+    auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
+    if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
+      return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
+    }
+    lhsRegion = lhsBBArg.getParentRegion();
+    rhsRegion = rhsBBArg.getParentRegion();
+    assert(lhsRegion != rhsRegion &&
+           "lhsRegion == rhsRegion implies lhs == rhs");
+  } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
+    return llvm::cast<OpResult>(lhs).getResultNumber() <
+           llvm::cast<OpResult>(rhs).getResultNumber();
+  } else {
+    lhsRegion = lhs.getDefiningOp()->getParentRegion();
+    rhsRegion = rhs.getDefiningOp()->getParentRegion();
+    if (lhsRegion == rhsRegion) {
+      return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
+    }
+  }
+
+  // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
+  // - have different heights
+  // - or there's a spot where their region numbers differ
+  // - or their parent regions are the same and their parent ops are
+  //   different.
+  while (lhsRegion && rhsRegion) {
+    if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
+      return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
+    }
+    if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
+      return lhsRegion->getParentOp()->isBeforeInBlock(
+          rhsRegion->getParentOp());
+    }
+    lhsRegion = lhsRegion->getParentRegion();
+    rhsRegion = rhsRegion->getParentRegion();
+  }
+  if (rhsRegion)
+    return true;
+  assert(lhsRegion && "this should only happen if lhs == rhs");
+  return false;
+}
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 3fd9221624d0f88..b1940e40ba34114 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRBufferizationDialect
   AllocationOpInterface.cpp
   BufferizableOpInterface.cpp
+  BufferDeallocationOpInterface.cpp
   BufferizationOps.cpp
   BufferizationDialect.cpp
   SubsetInsertionOpInterface.cpp
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index b8fd99a5541242f..119801f9cc92f32 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -202,62 +202,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
   global->moveBefore(&moduleOp.front());
   return global;
 }
-
-//===----------------------------------------------------------------------===//
-// ValueComparator
-//===----------------------------------------------------------------------===//
-
-bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
-  if (lhs == rhs)
-    return false;
-
-  // Block arguments are less than results.
-  bool lhsIsBBArg = lhs.isa<BlockArgument>();
-  if (lhsIsBBArg != rhs.isa<BlockArgument>()) {
-    return lhsIsBBArg;
-  }
-
-  Region *lhsRegion;
-  Region *rhsRegion;
-  if (lhsIsBBArg) {
-    auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
-    auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
-    if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
-      return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
-    }
-    lhsRegion = lhsBBArg.getParentRegion();
-    rhsRegion = rhsBBArg.getParentRegion();
-    assert(lhsRegion != rhsRegion &&
-           "lhsRegion == rhsRegion implies lhs == rhs");
-  } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
-    return llvm::cast<OpResult>(lhs).getResultNumber() <
-           llvm::cast<OpResult>(rhs).getResultNumber();
-  } else {
-    lhsRegion = lhs.getDefiningOp()->getParentRegion();
-    rhsRegion = rhs.getDefiningOp()->getParentRegion();
-    if (lhsRegion == rhsRegion) {
-      return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
-    }
-  }
-
-  // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
-  // - have different heights
-  // - or there's a spot where their region numbers differ
-  // - or their parent regions are the same and their parent ops are
-  //   different.
-  while (lhsRegion && rhsRegion) {
-    if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
-      return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
-    }
-    if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
-      return lhsRegion->getParentOp()->isBeforeInBlock(
-          rhsRegion->getParentOp());
-    }
-    lhsRegion = lhsRegion->getParentRegion();
-    rhsRegion = rhsRegion->getParentRegion();
-  }
-  if (rhsRegion)
-    return true;
-  assert(lhsRegion && "this should only happen if lhs == rhs");
-  return false;
-}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index cbbfe7a81205857..ed8dbd57bf40ba1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -35,7 +35,6 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   MLIRPass
   MLIRTensorDialect
   MLIRSCFDialect
-  MLIRControlFlowDialect
   MLIRSideEffectInterfaces
   MLIRTransforms
   MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index eaced7202f4e606..d4b8e0dff67bae4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -18,16 +18,14 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/Iterators.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "llvm/ADT/SetOperations.h"
 
 namespace mlir {
 namespace bufferization {
@@ -137,103 +135,13 @@ class Backedges {
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// This class is used to track the ownership of values. The ownership can
-/// either be not initialized yet ('Uninitialized' state), set to a unique SSA
-/// value which indicates the ownership at runtime (or statically if it is a
-/// constant value) ('Unique' state), or it cannot be represented in a single
-/// SSA value ('Unknown' state). An artificial example of a case where ownership
-/// cannot be represented in a single i1 SSA value could be the following:
-/// `%0 = test.non_deterministic_select %arg0, %arg1 : i32`
-/// Since the operation does not provide us a separate boolean indicator on
-/// which of the two operands was selected, we would need to either insert an
-/// alias check at runtime to determine if `%0` aliases with `%arg0` or `%arg1`,
-/// or insert a `bufferization.clone` operation to get a fresh buffer which we
-/// could assign ownership to.
-///
-/// The three states this class can represent form a lattice on a partial order:
-/// forall X in SSA values. uninitialized < unique(X) < unknown
-/// forall X, Y in SSA values.
-///   unique(X) == unique(Y) iff X and Y always evaluate to the same value
-///   unique(X) != unique(Y) otherwise
-class Ownership {
-public:
-  /// Constructor that creates an 'Uninitialized' ownership. This is needed for
-  /// default-construction when used in DenseMap.
-  Ownership() = default;
-
-  /// Constructor that creates an 'Unique' ownership. This is a non-explicit
-  /// constructor to allow implicit conversion from 'Value'.
-  Ownership(Value indicator) : indicator(indicator), state(State::Unique) {}
-
-  /// Get an ownership value in 'Unknown' state.
-  static Ownership getUnknown() {
-    Ownership unknown;
-    unknown.indicator = Value();
-    unknown.state = State::Unknown;
-    return unknown;
-  }
-  /// Get an ownership value in 'Unique' state with 'indicator' as parameter.
-  static Ownership getUnique(Value indicator) { return Ownership(indicator); }
-  /// Get an ownership value in 'Uninitialized' state.
-  static Ownership getUninitialized() { return Ownership(); }
-
-  /// Check if this ownership value is in the 'Uninitialized' state.
-  bool isUninitialized() const { return state == State::Uninitialized; }
-  /// Check if this ownership value is in the 'Unique' state.
-  bool isUnique() const { return state == State::Unique; }
-  /// Check if this ownership value is in the 'Unknown' state.
-  bool isUnknown() const { return state == State::Unknown; }
-
-  /// If this ownership value is in 'Unique' state, this function can be used to
-  /// get the indicator parameter. Using this function in any other state is UB.
-  Value getIndicator() const {
-    assert(isUnique() && "must have unique ownership to get the indicator");
-    return indicator;
-  }
-
-  /// Get the join of the two-element subset {this,other}. Does not modify
-  /// 'this'.
-  Ownership getCombined(Ownership other) const {
-    if (other.isUninitialized())
-      return *this;
-    if (isUninitialized())
-      return other;
-
-    if (!isUnique() || !other.isUnique())
-      return getUnknown();
-
-    // Since we create a new constant i1 value for (almost) each use-site, we
-    // should compare the actual value rather than just the SSA Value to avoid
-    // unnecessary invalidations.
-    if (isEqualConstantIntOrValue(indicator, other.indicator))
-      return *this;
-
-    // Return the join of the lattice if the indicator of both ownerships cannot
-    // be merged.
-    return getUnknown();
-  }
-
-  /// Modify 'this' ownership to be the join of the current 'this' and 'other'.
-  void combine(Ownership other) { *this = getCombined(other); }
-
-private:
-  enum class State {
-    Uninitialized,
-    Unique,
-    Unknown,
-  };
-
-  // The indicator value is only relevant in the 'Unique' state.
-  Value indicator;
-  State state = State::Uninitialized;
-};
-
 /// The buffer deallocation transformation which ensures that all allocs in the
 /// program have a corresponding de-allocation.
 class BufferDeallocation {
 public:
   BufferDeallocation(Operation *op, bool privateFuncDynamicOwnership)
-      : liveness(op), privateFuncDynamicOwnership(privateFuncDynamicOwnership) {
+      : state(op) {
+    options.privateFuncDynamicOwnership = privateFuncDynamicOwnership;
   }
 
   /// Performs the actual placement/creation of all dealloc operations.
@@ -291,57 +199,17 @@ class BufferDeallocation {
 
   /// Apply all supported interface handlers to the given op.
   FailureOr<Operation *> handleAllInterfaces(Operation *op) {
+    if (auto deallocOpInterface = dyn_cast<BufferDeallocationOpInterface>(op))
+      return deallocOpInterface.process(state, options);
+
     if (failed(verifyOperationPreconditions(op)))
       return failure();
 
     return handleOp<MemoryEffectOpInterface, RegionBranchOpInterface,
-                    CallOpInterface, BranchOpInterface, cf::CondBranchOp,
+                    CallOpInterface, BranchOpInterface,
                     RegionBranchTerminatorOpInterface>(op);
   }
 
-  /// While CondBranchOp also implements the BranchOpInterface, we add a
-  /// special-case implementation here because the BranchOpInterface does not
-  /// offer all of the functionality we need to insert dealloc operations in an
-  /// efficient way. More precisely, there is no way to extract the branch
-  /// condition without casting to CondBranchOp specifically. It would still be
-  /// possible to implement deallocation for cases where we don't know to which
-  /// successor the terminator branches before the actual branch happens by
-  /// inserting auxiliary blocks and putting the dealloc op there, however, this
-  /// can lead to less efficient code.
-  /// This function inserts two dealloc operations (one for each successor) and
-  /// adjusts the dealloc conditions according to the branch condition, then the
-  /// ownerships of the retained MemRefs are updated by combining the result
-  /// values of the two dealloc operations.
-  ///
-  /// Example:
-  /// ```
-  /// ^bb1:
-  ///   <more ops...>
-  ///   cf.cond_br cond, ^bb2(<forward-to-bb2>), ^bb3(<forward-to-bb2>)
-  /// ```
-  /// becomes
-  /// ```
-  /// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1)
-  /// // let r0 = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>)
-  /// // let r1 = getMemrefsToRetain(bb1, bb3, <forward-to-bb3>)
-  /// ^bb1:
-  ///   <more ops...>
-  ///   let thenCond = map(c, (c) -> arith.andi cond, c)
-  ///   let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c)
-  ///   o0 = bufferization.dealloc m if thenCond retain r0
-  ///   o1 = bufferization.dealloc m if elseCond retain r1
-  ///   // replace ownership(r0) with o0 element-wise
-  ///   // replace ownership(r1) with o1 element-wise
-  ///   // let ownership0 := (r) -> o in o0 corresponding to r
-  ///   // let ownership1 := (r) -> o in o1 corresponding to r
-  ///   // let cmn := intersection(r0, r1)
-  ///   foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)):
-  ///     forall r in r0: replace ownership0(r) with arith.select cond, a, b)
-  ///     forall r in r1: replace ownership1(r) with arith.select cond, a, b)
-  ///   cf.cond_br cond, ^bb2(<forward-to-bb2>, o0), ^bb3(<forward-to-bb3>, o1)
-  /// ```
-  FailureOr<Operation *> handleInterface(cf::CondBranchOp op);
-
   /// Make sure that for each forwarded MemRef value, an ownership indicator
   /// `i1` value is forwarded as well such that the successor block knows
   /// whether the MemRef has to be deallocated.
@@ -492,18 +360,6 @@ class BufferDeallocation {
   /// this function has to be called on blocks in a region in dominance order.
   LogicalResult deallocate(Block *block);
 
-  /// Small helper function to update the ownership map by taking the current
-  /// ownership ('Uninitialized' state if not yet present), computing the join
-  /// with the passed ownership and storing this new value in the map. By
-  /// default, it will be performed for the block where 'owned' is defined. If
-  /// the ownership of the given value should be updated for another block, the
-  /// 'block' argument can be explicitly passed.
-  void joinOwnership(Value owned, Ownership ownership, Block *block = nullptr);
-
-  /// Removes ownerships associated with all values in the passed range for
-  /// 'block'.
-  void clearOwnershipOf(ValueRange values, Block *block);
-
   /// After all relevant interfaces of an operation have been processed by the
   /// 'handleInterface' functions, this function sets the ownership of operation
   /// results that have not been set yet by the 'handleInterface' functions. It
@@ -517,51 +373,6 @@ class BufferDeallocation {
   /// operations, etc.).
   void populateRemainingOwnerships(Operation *op);
 
-  /// Given two basic blocks and the values passed via block arguments to the
-  /// destination block, compute the list of MemRefs that have to be retained in
-  /// the 'fromBlock' to not run into a use-after-free situation.
-  /// This list consists of the MemRefs in the successor operand list of the
-  /// terminator and the MemRefs in the 'out' set of the liveness analysis
-  /// intersected with the 'in' set of the destination block.
-  ///
-  /// toRetain = filter(successorOperands + (liveOut(fromBlock) insersect
-  ///   liveIn(toBlock)), isMemRef)
-  void getMemrefsToRetain(Block *fromBlock, Block *toBlock,
-                          ValueRange destOperands,
-                          SmallVectorImpl<Value> &toRetain) const;
-
-  /// For a given block, computes the list of MemRefs that potentially need to
-  /// be deallocated at the end of that block. This list also contains values
-  /// that have to be retained (and are thus part of the list returned by
-  /// `getMemrefsToRetain`) and is computed by taking the MemRefs in the 'in'
-  /// set of the liveness analysis of 'block'  appended by the set of MemRefs
-  /// allocated in 'block' itself and subtracted by the set of MemRefs
-  /// deallocated in 'block'.
-  /// Note that we don't have to take the intersection of the liveness 'in' set
-  /// with the 'out' set of the predecessor block because a value that is in the
-  /// 'in' set must be defined in an ancestor block that dominates all direct
-  /// predecessors and thus the 'in' set of this block is a subset of the 'out'
-  /// sets of each predecessor.
-  ///
-  /// memrefs = filter((liveIn(block) U
-  ///   allocated(block) U arguments(block)) \ deallocated(block), isMemRef)
-  ///
-  /// The list of conditions is then populated by querying the internal
-  /// datastructures for the ownership value of that MemRef.
-  LogicalResult
-  getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc,
-                                      Block *block,
-                                      SmallVectorImpl<Value> &memrefs,
-                                      SmallVectorImpl<Value> &conditions) const;
-
-  /// Given an SSA value of MemRef type, this function queries the ownership and
-  /// if it is not already in the 'Unique' state, potentially inserts IR to get
-  /// a new SSA value, returned as the first element of the pair, which has
-  /// 'Unique' ownership and can be used instead of the passed Value with the
-  /// the ownership indicator returned as the second element of the pair.
-  std::pair<Value, Value> getMemrefWithUniqueOwnership(OpBuilder &builder,
-                                                       Value memref);
-
   /// Given an SSA value of MemRef type, returns the same of a new SSA value
   /// which has 'Unique' ownership where the ownership indicator is guaranteed
   /// to be always 'true'.
@@ -602,27 +413,13 @@ class BufferDeallocation {
   static LogicalResult updateFunctionSignature(FunctionOpInterface op);
 
 private:
-  // Mapping from each SSA value with MemRef type to the associated ownership in
-  // each block.
-  DenseMap<std::pair<Value, Block *>, Ownership> ownershipMap;
-
-  // Collects the list of MemRef values that potentially need to be deallocated
-  // per block. It is also fine (albeit not efficient) to add MemRef values that
-  // don't have to be deallocated, but only when the ownership is not 'Unknown'.
-  DenseMap<Block *, SmallVector<Value>> memrefsToDeallocatePerBlock;
-
-  // Symbol cache to lookup functions from call operations to check attributes
-  // on the function operation.
-  SymbolTableCollection symbolTable;
-
-  // The underlying liveness analysis to compute fine grained information about
-  // alloc and dealloc positions.
-  Liveness liveness;
-
-  // A pass option indicating whether private functions should be modified to
-  // pass the ownership of MemRef values instead of adhering to the function
-  // boundary ABI.
-  bool privateFuncDynamicOwnership;
+  ///  Collects all analysis state and including liveness, caches, ownerships of
+  ///  already processed values and operations, and the MemRefs that have to be
+  ///  deallocated at the end of each block.
+  DeallocationState state;
+
+  /// Collects all pass options in a single place.
+  DeallocationOptions options;
 };
 
 } // namespace
@@ -631,22 +428,6 @@ class BufferDeallocation {
 // BufferDeallocation Implementation
 //===----------------------------------------------------------------------===//
 
-void BufferDeallocation::joinOwnership(Value owned, Ownership ownership,
-                                       Block *block) {
-  // In most cases we care about the block where the value is defined.
-  if (block == nullptr)
-    block = owned.getParentBlock();
-
-  // Update ownership of current memref itself.
-  ownershipMap[{owned, block}].combine(ownership);
-}
-
-void BufferDeallocation::clearOwnershipOf(ValueRange values, Block *block) {
-  for (Value val : values) {
-    ownershipMap[{val, block}] = Ownership::getUninitialized();
-  }
-}
-
 static bool regionOperatesOnMemrefValues(Region &region) {
   WalkResult result = region.walk([](Block *block) {
     if (llvm::any_of(block->getArguments(), isMemref))
@@ -717,10 +498,10 @@ LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) {
 
     // We only support terminators with 0 or 1 successors for now and
     // special-case the conditional branch op.
-    if (op->getSuccessors().size() > 1 && !isa<cf::CondBranchOp>(op))
+    if (op->getSuccessors().size() > 1)
 
       return op->emitError("Terminators with more than one successor "
-                           "are not supported (except cf.cond_br)!");
+                           "are not supported!");
   }
 
   return success();
@@ -776,80 +557,26 @@ LogicalResult BufferDeallocation::deallocate(FunctionOpInterface op) {
   return updateFunctionSignature(op);
 }
 
-void BufferDeallocation::getMemrefsToRetain(
-    Block *fromBlock, Block *toBlock, ValueRange destOperands,
-    SmallVectorImpl<Value> &toRetain) const {
-  for (Value operand : destOperands) {
-    if (!isMemref(operand))
-      continue;
-    toRetain.push_back(operand);
-  }
-
-  SmallPtrSet<Value, 16> liveOut;
-  for (auto val : liveness.getLiveOut(fromBlock))
-    if (isMemref(val))
-      liveOut.insert(val);
-
-  if (toBlock)
-    llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock));
-
-  // liveOut has non-deterministic order because it was constructed by iterating
-  // over a hash-set.
-  SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
-  std::sort(retainedByLiveness.begin(), retainedByLiveness.end(),
-            ValueComparator());
-  toRetain.append(retainedByLiveness);
-}
-
-LogicalResult BufferDeallocation::getMemrefsAndConditionsToDeallocate(
-    OpBuilder &builder, Location loc, Block *block,
-    SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
-
-  for (auto [i, memref] :
-       llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {
-    Ownership ownership = ownershipMap.lookup({memref, block});
-    assert(ownership.isUnique() && "MemRef value must have valid ownership");
-
-    // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
-    // that we can call extract_strided_metadata on it.
-    if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
-      memref = builder.create<memref::ReinterpretCastOp>(
-          loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
-          0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
-
-    // Use the `memref.extract_strided_metadata` operation to get the base
-    // memref. This is needed because the same MemRef that was produced by the
-    // alloc operation has to be passed to the dealloc operation. Passing
-    // subviews, etc. to a dealloc operation is not allowed.
-    memrefs.push_back(
-        builder.create<memref::ExtractStridedMetadataOp>(loc, memref)
-            .getResult(0));
-    conditions.push_back(ownership.getIndicator());
-  }
-
-  return success();
-}
-
 LogicalResult BufferDeallocation::deallocate(Block *block) {
   OpBuilder builder = OpBuilder::atBlockBegin(block);
 
   // Compute liveness transfers of ownership to this block.
-  for (auto li : liveness.getLiveIn(block)) {
-    if (!isMemref(li))
-      continue;
-
+  SmallVector<Value> liveMemrefs;
+  state.getLiveMemrefsIn(block, liveMemrefs);
+  for (auto li : liveMemrefs) {
     // Ownership of implicitly captured memrefs from other regions is never
     // taken, but ownership of memrefs in the same region (but different block)
     // is taken.
     if (li.getParentRegion() == block->getParent()) {
-      joinOwnership(li, ownershipMap[{li, li.getParentBlock()}], block);
-      memrefsToDeallocatePerBlock[block].push_back(li);
+      state.updateOwnership(li, state.getOwnership(li, li.getParentBlock()),
+                            block);
+      state.addMemrefToDeallocate(li, block);
       continue;
     }
 
     if (li.getParentRegion()->isProperAncestor(block->getParent())) {
       Value falseVal = buildBoolValue(builder, li.getLoc(), false);
-      joinOwnership(li, falseVal, block);
+      state.updateOwnership(li, falseVal, block);
     }
   }
 
@@ -863,14 +590,15 @@ LogicalResult BufferDeallocation::deallocate(Block *block) {
     if (isFunctionWithoutDynamicOwnership(block->getParentOp()) &&
         block->isEntryBlock()) {
       Value newArg = buildBoolValue(builder, arg.getLoc(), false);
-      joinOwnership(arg, newArg);
+      state.updateOwnership(arg, newArg);
+      state.addMemrefToDeallocate(arg, block);
       continue;
     }
 
     // Pass MemRef ownerships along via `i1` values.
     Value newArg = block->addArgument(builder.getI1Type(), arg.getLoc());
-    joinOwnership(arg, newArg);
-    memrefsToDeallocatePerBlock[block].push_back(arg);
+    state.updateOwnership(arg, newArg);
+    state.addMemrefToDeallocate(arg, block);
   }
 
   // For each operation in the block, handle the interfaces that affect aliasing
@@ -906,97 +634,6 @@ Operation *BufferDeallocation::appendOpResults(Operation *op,
   return newOp;
 }
 
-FailureOr<Operation *>
-BufferDeallocation::handleInterface(cf::CondBranchOp op) {
-  OpBuilder builder(op);
-
-  // The list of memrefs to pass to the `bufferization.dealloc` op as "memrefs
-  // to deallocate" in this block is independent of which branch is taken.
-  SmallVector<Value> memrefs, ownerships;
-  if (failed(getMemrefsAndConditionsToDeallocate(
-          builder, op.getLoc(), op->getBlock(), memrefs, ownerships)))
-    return failure();
-
-  // Helper lambda to factor out common logic for inserting the dealloc
-  // operations for each successor.
-  auto insertDeallocForBranch =
-      [&](Block *target, MutableOperandRange destOperands,
-          ArrayRef<Value> conditions,
-          DenseMap<Value, Value> &ownershipMapping) -> DeallocOp {
-    SmallVector<Value> toRetain;
-    getMemrefsToRetain(op->getBlock(), target, OperandRange(destOperands),
-                       toRetain);
-    auto deallocOp = builder.create<bufferization::DeallocOp>(
-        op.getLoc(), memrefs, conditions, toRetain);
-    clearOwnershipOf(deallocOp.getRetained(), op->getBlock());
-    for (auto [retained, ownership] :
-         llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
-      joinOwnership(retained, ownership, op->getBlock());
-      ownershipMapping[retained] = ownership;
-    }
-    SmallVector<Value> replacements, ownerships;
-    for (Value operand : destOperands) {
-      replacements.push_back(operand);
-      if (isMemref(operand)) {
-        assert(ownershipMapping.contains(operand) &&
-               "Should be contained at this point");
-        ownerships.push_back(ownershipMapping[operand]);
-      }
-    }
-    replacements.append(ownerships);
-    destOperands.assign(replacements);
-    return deallocOp;
-  };
-
-  // Call the helper lambda and make sure the dealloc conditions are properly
-  // modified to reflect the branch condition as well.
-  DenseMap<Value, Value> thenOwnershipMap, elseOwnershipMap;
-
-  // Retain `trueDestOperands` if "true" branch is taken.
-  SmallVector<Value> thenOwnerships(
-      llvm::map_range(ownerships, [&](Value cond) {
-        return builder.create<arith::AndIOp>(op.getLoc(), cond,
-                                             op.getCondition());
-      }));
-  DeallocOp thenTakenDeallocOp =
-      insertDeallocForBranch(op.getTrueDest(), op.getTrueDestOperandsMutable(),
-                             thenOwnerships, thenOwnershipMap);
-
-  // Retain `elseDestOperands` if "false" branch is taken.
-  SmallVector<Value> elseOwnerships(
-      llvm::map_range(ownerships, [&](Value cond) {
-        Value trueVal = builder.create<arith::ConstantOp>(
-            op.getLoc(), builder.getBoolAttr(true));
-        Value negation = builder.create<arith::XOrIOp>(op.getLoc(), trueVal,
-                                                       op.getCondition());
-        return builder.create<arith::AndIOp>(op.getLoc(), cond, negation);
-      }));
-  DeallocOp elseTakenDeallocOp = insertDeallocForBranch(
-      op.getFalseDest(), op.getFalseDestOperandsMutable(), elseOwnerships,
-      elseOwnershipMap);
-
-  // We specifically need to update the ownerships of values that are retained
-  // in both dealloc operations again to get a combined 'Unique' ownership
-  // instead of an 'Unknown' ownership.
-  SmallPtrSet<Value, 16> thenValues(thenTakenDeallocOp.getRetained().begin(),
-                                    thenTakenDeallocOp.getRetained().end());
-  SetVector<Value> commonValues;
-  for (Value val : elseTakenDeallocOp.getRetained()) {
-    if (thenValues.contains(val))
-      commonValues.insert(val);
-  }
-
-  for (Value retained : commonValues) {
-    clearOwnershipOf(retained, op->getBlock());
-    Value combinedOwnership = builder.create<arith::SelectOp>(
-        op.getLoc(), op.getCondition(), thenOwnershipMap[retained],
-        elseOwnershipMap[retained]);
-    joinOwnership(retained, combinedOwnership, op->getBlock());
-  }
-
-  return op.getOperation();
-}
-
 FailureOr<Operation *>
 BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
   OpBuilder builder = OpBuilder::atBlockBegin(op->getBlock());
@@ -1033,44 +670,18 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
   RegionBranchOpInterface newOp = appendOpResults(op, ownershipResults);
 
   for (auto result : llvm::make_filter_range(newOp->getResults(), isMemref)) {
-    joinOwnership(result, newOp->getResult(counter++));
-    memrefsToDeallocatePerBlock[newOp->getBlock()].push_back(result);
+    state.updateOwnership(result, newOp->getResult(counter++));
+    state.addMemrefToDeallocate(result, newOp->getBlock());
   }
 
   return newOp.getOperation();
 }
 
-std::pair<Value, Value>
-BufferDeallocation::getMemrefWithUniqueOwnership(OpBuilder &builder,
-                                                 Value memref) {
-  auto iter = ownershipMap.find({memref, memref.getParentBlock()});
-  assert(iter != ownershipMap.end() &&
-         "Value must already have been registered in the ownership map");
-
-  Ownership ownership = iter->second;
-  if (ownership.isUnique())
-    return {memref, ownership.getIndicator()};
-
-  // Instead of inserting a clone operation we could also insert a dealloc
-  // operation earlier in the block and use the updated ownerships returned by
-  // the op for the retained values. Alternatively, we could insert code to
-  // check aliasing at runtime and use this information to combine two unique
-  // ownerships more intelligently to not end up with an 'Unknown' ownership in
-  // the first place.
-  auto cloneOp =
-      builder.create<bufferization::CloneOp>(memref.getLoc(), memref);
-  Value condition = buildBoolValue(builder, memref.getLoc(), true);
-  Value newMemref = cloneOp.getResult();
-  joinOwnership(newMemref, condition);
-  memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
-  return {newMemref, condition};
-}
-
 Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder,
                                                            Value memref) {
   // First, make sure we at least have 'Unique' ownership already.
   std::pair<Value, Value> newMemrefAndOnwership =
-      getMemrefWithUniqueOwnership(builder, memref);
+      state.getMemrefWithUniqueOwnership(builder, memref);
   Value newMemref = newMemrefAndOnwership.first;
   Value condition = newMemrefAndOnwership.second;
 
@@ -1096,17 +707,16 @@ Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder,
               })
           .getResult(0);
   Value trueVal = buildBoolValue(builder, memref.getLoc(), true);
-  joinOwnership(maybeClone, trueVal);
-  memrefsToDeallocatePerBlock[maybeClone.getParentBlock()].push_back(
-      maybeClone);
+  state.updateOwnership(maybeClone, trueVal);
+  state.addMemrefToDeallocate(maybeClone, maybeClone.getParentBlock());
   return maybeClone;
 }
 
 FailureOr<Operation *>
 BufferDeallocation::handleInterface(BranchOpInterface op) {
-  // Skip conditional branches since we special case them for now.
-  if (isa<cf::CondBranchOp>(op.getOperation()))
-    return op.getOperation();
+  if (op->getNumSuccessors() > 1)
+    return op->emitError("BranchOpInterface operations with multiple "
+                         "successors are not supported yet");
 
   if (op->getNumSuccessors() != 1)
     return emitError(op.getLoc(),
@@ -1121,23 +731,24 @@ BufferDeallocation::handleInterface(BranchOpInterface op) {
   Block *block = op->getBlock();
   OpBuilder builder(op);
   SmallVector<Value> memrefs, conditions, toRetain;
-  if (failed(getMemrefsAndConditionsToDeallocate(builder, op.getLoc(), block,
-                                                 memrefs, conditions)))
+  if (failed(state.getMemrefsAndConditionsToDeallocate(
+          builder, op.getLoc(), block, memrefs, conditions)))
     return failure();
 
   OperandRange forwardedOperands =
       op.getSuccessorOperands(0).getForwardedOperands();
-  getMemrefsToRetain(block, op->getSuccessor(0), forwardedOperands, toRetain);
+  state.getMemrefsToRetain(block, op->getSuccessor(0), forwardedOperands,
+                           toRetain);
 
   auto deallocOp = builder.create<bufferization::DeallocOp>(
       op.getLoc(), memrefs, conditions, toRetain);
 
   // We want to replace the current ownership of the retained values with the
   // result values of the dealloc operation as they are always unique.
-  clearOwnershipOf(deallocOp.getRetained(), block);
+  state.resetOwnerships(deallocOp.getRetained(), block);
   for (auto [retained, ownership] :
        llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
-    joinOwnership(retained, ownership, block);
+    state.updateOwnership(retained, ownership, block);
   }
 
   unsigned numAdditionalReturns = llvm::count_if(forwardedOperands, isMemref);
@@ -1156,7 +767,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
   // Lookup the function operation and check if it has private visibility. If
   // the function is referenced by SSA value instead of a Symbol, it's assumed
   // to be always private.
-  Operation *funcOp = op.resolveCallable(&symbolTable);
+  Operation *funcOp = op.resolveCallable(state.getSymbolTable());
   bool isPrivate = true;
   if (auto symbol = dyn_cast<SymbolOpInterface>(funcOp))
     isPrivate &= (symbol.getVisibility() == SymbolTable::Visibility::Private);
@@ -1166,14 +777,15 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
   // argument/result for each MemRef argument/result to dynamically pass the
   // current ownership indicator rather than adhering to the function boundary
   // ABI.
-  if (privateFuncDynamicOwnership && isPrivate) {
+  if (options.privateFuncDynamicOwnership && isPrivate) {
     SmallVector<Value> newOperands, ownershipIndicatorsToAdd;
     for (Value operand : op.getArgOperands()) {
       if (!isMemref(operand)) {
         newOperands.push_back(operand);
         continue;
       }
-      auto [memref, condition] = getMemrefWithUniqueOwnership(builder, operand);
+      auto [memref, condition] =
+          state.getMemrefWithUniqueOwnership(builder, operand);
       newOperands.push_back(memref);
       ownershipIndicatorsToAdd.push_back(condition);
     }
@@ -1187,8 +799,8 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
     op = appendOpResults(op, ownershipTypesToAppend);
 
     for (auto result : llvm::make_filter_range(op->getResults(), isMemref)) {
-      joinOwnership(result, op->getResult(ownershipCounter++));
-      memrefsToDeallocatePerBlock[result.getParentBlock()].push_back(result);
+      state.updateOwnership(result, op->getResult(ownershipCounter++));
+      state.addMemrefToDeallocate(result, result.getParentBlock());
     }
 
     return op.getOperation();
@@ -1199,8 +811,8 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
   // 'true' and remember to deallocate it.
   Value trueVal = buildBoolValue(builder, op.getLoc(), true);
   for (auto result : llvm::make_filter_range(op->getResults(), isMemref)) {
-    joinOwnership(result, trueVal);
-    memrefsToDeallocatePerBlock[result.getParentBlock()].push_back(result);
+    state.updateOwnership(result, trueVal);
+    state.addMemrefToDeallocate(result, result.getParentBlock());
   }
 
   return op.getOperation();
@@ -1228,13 +840,13 @@ BufferDeallocation::handleInterface(MemoryEffectOpInterface op) {
         // `memref.alloc`. If we wouldn't set the ownership of the result here,
         // the default ownership population in `populateRemainingOwnerships`
         // would assume aliasing with the MemRef operand.
-        clearOwnershipOf(res, block);
-        joinOwnership(res, buildBoolValue(builder, op.getLoc(), false));
+        state.resetOwnerships(res, block);
+        state.updateOwnership(res, buildBoolValue(builder, op.getLoc(), false));
         continue;
       }
 
-      joinOwnership(res, buildBoolValue(builder, op.getLoc(), true));
-      memrefsToDeallocatePerBlock[block].push_back(res);
+      state.updateOwnership(res, buildBoolValue(builder, op.getLoc(), true));
+      state.addMemrefToDeallocate(res, block);
     }
   }
 
@@ -1271,11 +883,11 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
   // dealloc operation.
   Block *block = op->getBlock();
   SmallVector<Value> memrefs, conditions, toRetain;
-  if (failed(getMemrefsAndConditionsToDeallocate(builder, op.getLoc(), block,
-                                                 memrefs, conditions)))
+  if (failed(state.getMemrefsAndConditionsToDeallocate(
+          builder, op.getLoc(), block, memrefs, conditions)))
     return failure();
 
-  getMemrefsToRetain(block, nullptr, OperandRange(operands), toRetain);
+  state.getMemrefsToRetain(block, nullptr, OperandRange(operands), toRetain);
   if (memrefs.empty() && toRetain.empty())
     return op.getOperation();
 
@@ -1284,10 +896,10 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
 
   // We want to replace the current ownership of the retained values with the
   // result values of the dealloc operation as they are always unique.
-  clearOwnershipOf(deallocOp.getRetained(), block);
+  state.resetOwnerships(deallocOp.getRetained(), block);
   for (auto [retained, ownership] :
        llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
-    joinOwnership(retained, ownership, block);
+    state.updateOwnership(retained, ownership, block);
 
   // Add an additional operand for every MemRef for the ownership indicator.
   if (!funcWithoutDynamicOwnership) {
@@ -1304,7 +916,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
 
 bool BufferDeallocation::isFunctionWithoutDynamicOwnership(Operation *op) {
   auto funcOp = dyn_cast<FunctionOpInterface>(op);
-  return funcOp && (!privateFuncDynamicOwnership ||
+  return funcOp && (!options.privateFuncDynamicOwnership ||
                     funcOp.getVisibility() != SymbolTable::Visibility::Private);
 }
 
@@ -1312,14 +924,14 @@ void BufferDeallocation::populateRemainingOwnerships(Operation *op) {
   for (auto res : op->getResults()) {
     if (!isMemref(res))
       continue;
-    if (ownershipMap.count({res, op->getBlock()}))
+    if (!state.getOwnership(res, op->getBlock()).isUninitialized())
       continue;
 
     // Don't take ownership of a returned memref if no allocate side-effect is
     // present, relevant for memref.get_global, for example.
     if (op->getNumOperands() == 0) {
       OpBuilder builder(op);
-      joinOwnership(res, buildBoolValue(builder, op->getLoc(), false));
+      state.updateOwnership(res, buildBoolValue(builder, op->getLoc(), false));
       continue;
     }
 
@@ -1329,8 +941,9 @@ void BufferDeallocation::populateRemainingOwnerships(Operation *op) {
       if (!isMemref(operand))
         continue;
 
-      ownershipMap[{res, op->getBlock()}].combine(
-          ownershipMap[{operand, operand.getParentBlock()}]);
+      state.updateOwnership(
+          res, state.getOwnership(operand, operand.getParentBlock()),
+          op->getBlock());
     }
   }
 }
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..e847e946eef1b5d
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -0,0 +1,163 @@
+//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
+
+namespace {
+/// While CondBranchOp also implement the BranchOpInterface, we add a
+/// special-case implementation here because the BranchOpInterface does not
+/// offer all of the functionallity we need to insert dealloc oeprations in an
+/// efficient way. More precisely, there is no way to extract the branch
+/// condition without casting to CondBranchOp specifically. It is still
+/// possible to implement deallocation for cases where we don't know to which
+/// successor the terminator branches before the actual branch happens by
+/// inserting auxiliary blocks and putting the dealloc op there, however, this
+/// can lead to less efficient code.
+/// This function inserts two dealloc operations (one for each successor) and
+/// adjusts the dealloc conditions according to the branch condition, then the
+/// ownerships of the retained MemRefs are updated by combining the result
+/// values of the two dealloc operations.
+///
+/// Example:
+/// ```
+/// ^bb1:
+///   <more ops...>
+///   cf.cond_br cond, ^bb2(<forward-to-bb2>), ^bb3(<forward-to-bb2>)
+/// ```
+/// becomes
+/// ```
+/// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1)
+/// // let r0 = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>)
+/// // let r1 = getMemrefsToRetain(bb1, bb3, <forward-to-bb3>)
+/// ^bb1:
+///   <more ops...>
+///   let thenCond = map(c, (c) -> arith.andi cond, c)
+///   let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c)
+///   o0 = bufferization.dealloc m if thenCond retain r0
+///   o1 = bufferization.dealloc m if elseCond retain r1
+///   // replace ownership(r0) with o0 element-wise
+///   // replace ownership(r1) with o1 element-wise
+///   // let ownership0 := (r) -> o in o0 corresponding to r
+///   // let ownership1 := (r) -> o in o1 corresponding to r
+///   // let cmn := intersection(r0, r1)
+///   foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)):
+///     forall r in r0: replace ownership0(r) with arith.select cond, a, b)
+///     forall r in r1: replace ownership1(r) with arith.select cond, a, b)
+///   cf.cond_br cond, ^bb2(<forward-to-bb2>, o0), ^bb3(<forward-to-bb3>, o1)
+/// ```
+struct CondBranchOpInterface
+    : public BufferDeallocationOpInterface::ExternalModel<CondBranchOpInterface,
+                                                          cf::CondBranchOp> {
+  FailureOr<Operation *> process(Operation *op, DeallocationState &state,
+                                 const DeallocationOptions &options) const {
+    OpBuilder builder(op);
+    auto condBr = cast<cf::CondBranchOp>(op);
+
+    // The list of memrefs to deallocate in this block is independent of which
+    // branch is taken.
+    SmallVector<Value> memrefs, conditions;
+    if (failed(state.getMemrefsAndConditionsToDeallocate(
+            builder, condBr.getLoc(), condBr->getBlock(), memrefs, conditions)))
+      return failure();
+
+    // Helper lambda to factor out common logic for inserting the dealloc
+    // operations for each successor.
+    auto insertDeallocForBranch =
+        [&](Block *target, MutableOperandRange destOperands,
+            const std::function<Value(Value)> &conditionModifier,
+            DenseMap<Value, Value> &mapping) -> DeallocOp {
+      SmallVector<Value> toRetain;
+      state.getMemrefsToRetain(condBr->getBlock(), target,
+                               OperandRange(destOperands), toRetain);
+      SmallVector<Value> adaptedConditions(
+          llvm::map_range(conditions, conditionModifier));
+      auto deallocOp = builder.create<bufferization::DeallocOp>(
+          condBr.getLoc(), memrefs, adaptedConditions, toRetain);
+      state.resetOwnerships(deallocOp.getRetained(), condBr->getBlock());
+      for (auto [retained, ownership] : llvm::zip(
+               deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
+        state.updateOwnership(retained, ownership, condBr->getBlock());
+        mapping[retained] = ownership;
+      }
+      SmallVector<Value> replacements, ownerships;
+      for (Value operand : destOperands) {
+        replacements.push_back(operand);
+        if (isMemref(operand)) {
+          assert(mapping.contains(operand) &&
+                 "Should be contained at this point");
+          ownerships.push_back(mapping[operand]);
+        }
+      }
+      replacements.append(ownerships);
+      destOperands.assign(replacements);
+      return deallocOp;
+    };
+
+    // Call the helper lambda and make sure the dealloc conditions are properly
+    // modified to reflect the branch condition as well.
+    DenseMap<Value, Value> thenMapping, elseMapping;
+    DeallocOp thenTakenDeallocOp = insertDeallocForBranch(
+        condBr.getTrueDest(), condBr.getTrueDestOperandsMutable(),
+        [&](Value cond) {
+          return builder.create<arith::AndIOp>(condBr.getLoc(), cond,
+                                               condBr.getCondition());
+        },
+        thenMapping);
+    DeallocOp elseTakenDeallocOp = insertDeallocForBranch(
+        condBr.getFalseDest(), condBr.getFalseDestOperandsMutable(),
+        [&](Value cond) {
+          Value trueVal = builder.create<arith::ConstantOp>(
+              condBr.getLoc(), builder.getBoolAttr(true));
+          Value negation = builder.create<arith::XOrIOp>(
+              condBr.getLoc(), trueVal, condBr.getCondition());
+          return builder.create<arith::AndIOp>(condBr.getLoc(), cond, negation);
+        },
+        elseMapping);
+
+    // We specifically need to update the ownerships of values that are retained
+    // in both dealloc operations again to get a combined 'Unique' ownership
+    // instead of an 'Unknown' ownership.
+    SmallPtrSet<Value, 16> thenValues(thenTakenDeallocOp.getRetained().begin(),
+                                      thenTakenDeallocOp.getRetained().end());
+    SetVector<Value> commonValues;
+    for (Value val : elseTakenDeallocOp.getRetained()) {
+      if (thenValues.contains(val))
+        commonValues.insert(val);
+    }
+
+    for (Value retained : commonValues) {
+      state.resetOwnerships(retained, condBr->getBlock());
+      Value combinedOwnership = builder.create<arith::SelectOp>(
+          condBr.getLoc(), condBr.getCondition(), thenMapping[retained],
+          elseMapping[retained]);
+      state.updateOwnership(retained, combinedOwnership, condBr->getBlock());
+    }
+
+    return condBr.getOperation();
+  }
+};
+
+} // namespace
+
+void mlir::cf::registerBufferDeallocationOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, ControlFlowDialect *dialect) {
+    CondBranchOp::attachInterface<CondBranchOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index b2ef59887515e74..37b4cfc893879b3 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRControlFlowTransforms
+  BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -10,4 +11,4 @@ add_mlir_dialect_library(MLIRControlFlowTransforms
   MLIRControlFlowDialect
   MLIRMemRefDialect
   MLIRIR
-  )
+)
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
index d8090591c70513f..66449aa2ffdb60e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
@@ -500,10 +500,10 @@ func.func @assumingOp(
 //       CHECK: test.copy
 //       CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
 //       CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V1]]#0
-//       CHECK: bufferization.dealloc ([[BASE0]] :{{.*}}) if ([[V0]]#1)
-//   CHECK-NOT: retain
 //       CHECK: bufferization.dealloc ([[BASE1]] :{{.*}}) if ([[V1]]#1)
 //   CHECK-NOT: retain
+//       CHECK: bufferization.dealloc ([[BASE0]] :{{.*}}) if ([[V0]]#1)
+//   CHECK-NOT: retain
 //       CHECK: return
 
 // -----
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index d4390e7651be0f0..2447f63bab29afb 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -12066,6 +12066,7 @@ gentbl_cc_library(
 cc_library(
     name = "BufferizationDialect",
     srcs = [
+        "lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
@@ -12073,6 +12074,7 @@ cc_library(
         "lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp",
     ],
     hdrs = [
+        "include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/Bufferization.h",
         "include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h",
@@ -12083,7 +12085,9 @@ cc_library(
     deps = [
         ":AffineDialect",
         ":AllocationOpInterface",
+        ":Analysis",
         ":ArithDialect",
+        ":BufferDeallocationOpInterfaceIncGen",
         ":BufferizableOpInterfaceIncGen",
         ":BufferizationBaseIncGen",
         ":BufferizationEnumsIncGen",

>From 2d43cfed65bb076440b4babe9e8dbbeefda765c9 Mon Sep 17 00:00:00 2001
From: Martin Erhart <merhart at google.com>
Date: Tue, 12 Sep 2023 15:20:05 +0000
Subject: [PATCH 2/2] [mlir][bufferization] BufferDeallocationOpInterface:
 support custom ownership update logic

Add a method to the BufferDeallocationOpInterface that allows operations to
implement the interface and provide custom logic to compute the ownership
indicators of values it defines. As a demonstrating example, this new method is
implemented by the `arith.select` operation.
---
 .../BufferDeallocationOpInterfaceImpl.h       | 22 +++++
 .../IR/BufferDeallocationOpInterface.h        |  4 +-
 .../IR/BufferDeallocationOpInterface.td       | 29 ++++++-
 mlir/include/mlir/InitAllDialects.h           |  2 +
 .../BufferDeallocationOpInterfaceImpl.cpp     | 85 +++++++++++++++++++
 .../Dialect/Arith/Transforms/CMakeLists.txt   |  1 +
 .../IR/BufferDeallocationOpInterface.cpp      |  4 +-
 .../OwnershipBasedBufferDeallocation.cpp      | 46 ++++++++--
 .../dealloc-callop-interface.mlir             | 10 +--
 9 files changed, 187 insertions(+), 16 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h
new file mode 100644
index 000000000000000..16cec1a82b5c86c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h
@@ -0,0 +1,22 @@
+//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace arith {
+void registerBufferDeallocationOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
index b88270f1c150a27..7ac4592de7875fb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
@@ -142,8 +142,8 @@ class DeallocationState {
   /// a new SSA value, returned as the first element of the pair, which has
   /// 'Unique' ownership and can be used instead of the passed Value with the
   /// the ownership indicator returned as the second element of the pair.
-  std::pair<Value, Value> getMemrefWithUniqueOwnership(OpBuilder &builder,
-                                                       Value memref);
+  std::pair<Value, Value>
+  getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block);
 
   /// Given two basic blocks and the values passed via block arguments to the
   /// destination block, compute the list of MemRefs that have to be retained in
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
index c35fe417184ffd4..3e11432c65c5f08 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
@@ -39,7 +39,34 @@ def BufferDeallocationOpInterface :
         /*retType=*/"FailureOr<Operation *>",
         /*methodName=*/"process",
         /*args=*/(ins "DeallocationState &":$state,
-                      "const DeallocationOptions &":$options)>
+                      "const DeallocationOptions &":$options)>,
+      InterfaceMethod<
+        /*desc=*/[{
+          This method allows the implementing operation to specify custom logic
+          to materialize an ownership indicator value for the given MemRef typed
+          value it defines (including block arguments of nested regions). Since
+          the operation itself has more information about its semantics the
+          materialized IR can be more efficient compared to the default
+          implementation and avoid cloning MemRefs and/or doing alias checking
+          at runtime.
+          Note that the same logic could also be implemented in the 'process'
+          method above, however, the IR is always materialized then. If
+          it's desirable to only materialize the IR to compute an updated
+          ownership indicator when needed, it should be implemented using this
+          method (which is especially important if operations are created that
+          cannot be easily canonicalized away anymore).
+        }],
+        /*retType=*/"std::pair<Value, Value>",
+        /*methodName=*/"materializeUniqueOwnershipForMemref",
+        /*args=*/(ins "DeallocationState &":$state,
+                      "const DeallocationOptions &":$options,
+                      "OpBuilder &":$builder,
+                      "Value":$memref),
+        /*methodBody=*/[{}],
+        /*defaultImplementation=*/[{
+          return state.getMemrefWithUniqueOwnership(
+            builder, memref, memref.getParentBlock());
+        }]>,
   ];
 }
 
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index ee91bfa57d12a39..0182ab93929cb8c 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -133,6 +134,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
 
   // Register all external models.
   affine::registerValueBoundsOpInterfaceExternalModels(registry);
+  arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
   arith::registerBufferizableOpInterfaceExternalModels(registry);
   arith::registerValueBoundsOpInterfaceExternalModels(registry);
   bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..f2e7732e8ea4aa3
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -0,0 +1,85 @@
+//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+namespace {
+/// Provides custom logic to materialize ownership indicator values for the
+/// result value of 'arith.select'. Instead of cloning or runtime alias
+/// checking, this implementation inserts another `arith.select` to choose the
+/// ownership indicator of the operand in the same way the original
+/// `arith.select` chooses the MemRef operand. If at least one of the operand's
+/// ownerships is 'Unknown', fall back to the default implementation.
+///
+/// Example:
+/// ```mlir
+/// // let ownership(%m0) := %o0
+/// // let ownership(%m1) := %o1
+/// %res = arith.select %cond, %m0, %m1
+/// ```
+/// The default implementation would insert a clone and replace all uses of the
+/// result of `arith.select` with that clone:
+/// ```mlir
+/// %res = arith.select %cond, %m0, %m1
+/// %clone = bufferization.clone %res
+/// // let ownership(%res) := 'Unknown'
+/// // let ownership(%clone) := %true
+/// // replace all uses of %res with %clone
+/// ```
+/// This implementation, on the other hand, materializes the following:
+/// ```mlir
+/// %res = arith.select %cond, %m0, %m1
+/// %res_ownership = arith.select %cond, %o0, %o1
+/// // let ownership(%res) := %res_ownership
+/// ```
+struct SelectOpInterface
+    : public BufferDeallocationOpInterface::ExternalModel<SelectOpInterface,
+                                                          arith::SelectOp> {
+  FailureOr<Operation *> process(Operation *op, DeallocationState &state,
+                                 const DeallocationOptions &options) const {
+    return op; // nothing to do
+  }
+
+  std::pair<Value, Value>
+  materializeUniqueOwnershipForMemref(Operation *op, DeallocationState &state,
+                                      const DeallocationOptions &options,
+                                      OpBuilder &builder, Value value) const {
+    auto selectOp = cast<arith::SelectOp>(op);
+    assert(value == selectOp.getResult() &&
+           "Value not defined by this operation");
+
+    Block *block = value.getParentBlock();
+    if (!state.getOwnership(selectOp.getTrueValue(), block).isUnique() ||
+        !state.getOwnership(selectOp.getFalseValue(), block).isUnique())
+      return state.getMemrefWithUniqueOwnership(builder, value,
+                                                value.getParentBlock());
+
+    Value ownership = builder.create<arith::SelectOp>(
+        op->getLoc(), selectOp.getCondition(),
+        state.getOwnership(selectOp.getTrueValue(), block).getIndicator(),
+        state.getOwnership(selectOp.getFalseValue(), block).getIndicator());
+    return {selectOp.getResult(), ownership};
+  }
+};
+
+} // namespace
+
+void mlir::arith::registerBufferDeallocationOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
+    SelectOp::attachInterface<SelectOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index a9b86b4d99256c0..02240601bcd35a1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRArithTransforms
+  BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   EmulateUnsupportedFloats.cpp
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index 2314cee2ff2c158..407d75e2426e9f9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -134,8 +134,8 @@ void DeallocationState::getLiveMemrefsIn(Block *block,
 
 std::pair<Value, Value>
 DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
-                                                Value memref) {
-  auto iter = ownershipMap.find({memref, memref.getParentBlock()});
+                                                Value memref, Block *block) {
+  auto iter = ownershipMap.find({memref, block});
   assert(iter != ownershipMap.end() &&
          "Value must already have been registered in the ownership map");
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index d4b8e0dff67bae4..02fb4d3c42fa521 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -376,13 +376,24 @@ class BufferDeallocation {
   /// Given an SSA value of MemRef type, returns the same of a new SSA value
   /// which has 'Unique' ownership where the ownership indicator is guaranteed
   /// to be always 'true'.
-  Value getMemrefWithGuaranteedOwnership(OpBuilder &builder, Value memref);
+  Value materializeMemrefWithGuaranteedOwnership(OpBuilder &builder,
+                                                 Value memref, Block *block);
 
   /// Returns whether the given operation implements FunctionOpInterface, has
   /// private visibility, and the private-function-dynamic-ownership pass option
   /// is enabled.
   bool isFunctionWithoutDynamicOwnership(Operation *op);
 
+  /// Given an SSA value of MemRef type, this function queries the
+  /// BufferDeallocationOpInterface of the defining operation of 'memref' for a
+  /// materialized ownership indicator for 'memref'.  If the op does not
+  /// implement the interface or if the block for which the materialized value
+  /// is requested does not match the block in which 'memref' is defined, the
+  /// default implementation in
+  /// `DeallocationState::getMemrefWithUniqueOwnership` is queried instead.
+  std::pair<Value, Value>
+  materializeUniqueOwnership(OpBuilder &builder, Value memref, Block *block);
+
   /// Checks all the preconditions for operations implementing the
   /// FunctionOpInterface that have to hold for the deallocation to be
   /// applicable:
@@ -428,6 +439,28 @@ class BufferDeallocation {
 // BufferDeallocation Implementation
 //===----------------------------------------------------------------------===//
 
+std::pair<Value, Value>
+BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
+                                               Block *block) {
+  // The interface can only materialize ownership indicators in the same block
+  // as the defining op.
+  if (memref.getParentBlock() != block)
+    return state.getMemrefWithUniqueOwnership(builder, memref, block);
+
+  Operation *owner = memref.getDefiningOp();
+  if (!owner)
+    owner = memref.getParentBlock()->getParentOp();
+
+  // If the op implements the interface, query it for a materialized ownership
+  // value.
+  if (auto deallocOpInterface = dyn_cast<BufferDeallocationOpInterface>(owner))
+    return deallocOpInterface.materializeUniqueOwnershipForMemref(
+        state, options, builder, memref);
+
+  // Otherwise use the default implementation.
+  return state.getMemrefWithUniqueOwnership(builder, memref, block);
+}
+
 static bool regionOperatesOnMemrefValues(Region &region) {
   WalkResult result = region.walk([](Block *block) {
     if (llvm::any_of(block->getArguments(), isMemref))
@@ -677,11 +710,11 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
   return newOp.getOperation();
 }
 
-Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder,
-                                                           Value memref) {
+Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership(
+    OpBuilder &builder, Value memref, Block *block) {
   // First, make sure we at least have 'Unique' ownership already.
   std::pair<Value, Value> newMemrefAndOnwership =
-      state.getMemrefWithUniqueOwnership(builder, memref);
+      materializeUniqueOwnership(builder, memref, block);
   Value newMemref = newMemrefAndOnwership.first;
   Value condition = newMemrefAndOnwership.second;
 
@@ -785,7 +818,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
         continue;
       }
       auto [memref, condition] =
-          state.getMemrefWithUniqueOwnership(builder, operand);
+          materializeUniqueOwnership(builder, operand, op->getBlock());
       newOperands.push_back(memref);
       ownershipIndicatorsToAdd.push_back(condition);
     }
@@ -868,7 +901,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
       if (!isMemref(val.get()))
         continue;
 
-      val.set(getMemrefWithGuaranteedOwnership(builder, val.get()));
+      val.set(materializeMemrefWithGuaranteedOwnership(builder, val.get(),
+                                                       op->getBlock()));
     }
   }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
index 67128fee3dfe0ab..bff06d4499938df 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
@@ -95,15 +95,15 @@ func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
 //  CHECK-NEXT:   return
 
 // CHECK-DYNAMIC-LABEL: func @function_call_requries_merged_ownership_mid_block
+//  CHECK-DYNAMIC-SAME: ([[ARG0:%.+]]: i1)
 //       CHECK-DYNAMIC:   [[ALLOC0:%.+]] = memref.alloc(
 //  CHECK-DYNAMIC-NEXT:   [[ALLOC1:%.+]] = memref.alloca(
-//  CHECK-DYNAMIC-NEXT:   [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
-//  CHECK-DYNAMIC-NEXT:   [[CLONE:%.+]] = bufferization.clone [[SELECT]]
-//  CHECK-DYNAMIC-NEXT:   [[RET:%.+]]:2 = call @f([[CLONE]], %true{{[0-9_]*}})
+//  CHECK-DYNAMIC-NEXT:   [[SELECT:%.+]] = arith.select [[ARG0]], [[ALLOC0]], [[ALLOC1]]
+//  CHECK-DYNAMIC-NEXT:   [[RET:%.+]]:2 = call @f([[SELECT]], [[ARG0]])
 //  CHECK-DYNAMIC-NEXT:   test.copy
 //  CHECK-DYNAMIC-NEXT:   [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
-//  CHECK-DYNAMIC-NEXT:   bufferization.dealloc ([[ALLOC0]], [[CLONE]], [[BASE]] :
-//  CHECK-DYNAMIC-SAME:     if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, [[RET]]#1)
+//  CHECK-DYNAMIC-NEXT:   bufferization.dealloc ([[ALLOC0]], [[BASE]] :
+//  CHECK-DYNAMIC-SAME:     if (%true{{[0-9_]*}}, [[RET]]#1)
 //   CHECK-DYNAMIC-NOT:     retain
 //  CHECK-DYNAMIC-NEXT:   return
 



More information about the Mlir-commits mailing list