[llvm] [mlir][bufferization] Add option to LowerDeallocations to choose the kind of dealloc op to build (PR #67565)

Martin Erhart via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 27 07:53:28 PDT 2023


https://github.com/maerhart created https://github.com/llvm/llvm-project/pull/67565

* Add option to LowerDeallocations to choose the kind of dealloc op to build
* Use alloca instead of alloc in the generic `bufferization.dealloc` op lowering.

>From 7026a8c65d347534e30c8e3491fd7e7a2659a01c Mon Sep 17 00:00:00 2001
From: Martin Erhart <merhart at google.com>
Date: Wed, 27 Sep 2023 07:51:00 +0000
Subject: [PATCH 1/3] [mlir][bufferization] Don't clone on unknown ownership
 and verify function boundary ABI

Inserting clones requires a lot of assumptions to hold on the input IR, e.g.,
all writes to a buffer need to dominate all reads. This is not guaranteed by
one-shot bufferization and isn't easy to verify, thus it could quickly lead to
incorrect results that are hard to debug. This commit changes the mechanism of
how an ownership indicator is materialized when there is not already a unique
ownership present. Additionally, we don't create copies of returned memrefs
anymore when we don't have ownership. Instead, we insert assert operations to
make sure we have ownership at runtime, or otherwise report to the user that
correctness could not be guaranteed.
---
 .../IR/BufferDeallocationOpInterface.h        |  24 ++--
 .../IR/BufferDeallocationOpInterface.td       |  34 +++++-
 .../Dialect/Bufferization/Pipelines/Passes.h  |  19 +++-
 .../Dialect/Bufferization/Transforms/Passes.h |   8 +-
 .../Bufferization/Transforms/Passes.td        |  12 +-
 .../BufferDeallocationOpInterfaceImpl.cpp     |  13 ++-
 .../IR/BufferDeallocationOpInterface.cpp      |  97 +++++++++++++----
 .../Pipelines/BufferizationPipelines.cpp      |  14 ++-
 .../Bufferization/Pipelines/CMakeLists.txt    |   1 +
 .../Bufferization/Transforms/CMakeLists.txt   |   1 +
 .../OwnershipBasedBufferDeallocation.cpp      | 103 ++++++------------
 .../dealloc-branchop-interface.mlir           |   2 +-
 .../dealloc-callop-interface.mlir             |   2 +-
 .../dealloc-function-boundaries.mlir          |  26 ++---
 .../dealloc-memoryeffect-interface.mlir       |  21 +---
 .../dealloc-region-branchop-interface.mlir    |  71 +++---------
 .../dealloc-runtime-verification.mlir         |  13 +++
 .../dealloc-subviews.mlir                     |   2 +-
 .../dealloc-unknown-ops.mlir                  |  29 +++++
 .../Linalg/CPU/test-collapse-tensor.mlir      |   2 +-
 .../Linalg/CPU/test-expand-tensor.mlir        |   2 +-
 .../Linalg/CPU/test-one-shot-bufferize.mlir   |   2 +-
 .../Dialect/Linalg/CPU/test-tensor-e2e.mlir   |   2 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |   1 +
 24 files changed, 289 insertions(+), 212 deletions(-)
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-runtime-verification.mlir
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-unknown-ops.mlir

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
index 752a4a2c6f42a2f..838641db20cbbc3 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
@@ -95,7 +95,17 @@ struct DeallocationOptions {
   // A pass option indicating whether private functions should be modified to
   // pass the ownership of MemRef values instead of adhering to the function
   // boundary ABI.
-  bool privateFuncDynamicOwnership = false;
+  bool privateFuncDynamicOwnership = true;
+
+  /// Inserts `cf.assert` operations to verify the function boundary ABI at
+  /// runtime. Currently, it is only checked that the ownership of returned
+  /// MemRefs is 'true'. This makes sure that ownership is yielded and the
+  /// returned MemRef does not originate from the same allocation as a function
+  /// argument. TODO: check that returned MemRefs don't alias each other.
+  /// If it can be determined statically that the ABI is not adhered
+  /// to, an error will already be emitted at compile time. This cannot be
+  /// changed with this option.
+  bool verifyFunctionBoundaryABI = true;
 };
 
 /// This class collects all the state that we need to perform the buffer
@@ -138,12 +148,12 @@ class DeallocationState {
   void getLiveMemrefsIn(Block *block, SmallVectorImpl<Value> &memrefs);
 
   /// Given an SSA value of MemRef type, this function queries the ownership and
-  /// if it is not already in the 'Unique' state, potentially inserts IR to get
-  /// a new SSA value, returned as the first element of the pair, which has
-  /// 'Unique' ownership and can be used instead of the passed Value with the
-  /// the ownership indicator returned as the second element of the pair.
-  std::pair<Value, Value>
-  getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block);
+  /// if it is not already in the 'Unique' state, potentially inserts IR to
+  /// determine the ownership (which might involve expensive aliasing checks at
+  /// runtime).
+  Value getMemrefWithUniqueOwnership(const DeallocationOptions &options,
+                                     OpBuilder &builder, Value memref,
+                                     Block *block);
 
   /// Given two basic blocks and the values passed via block arguments to the
   /// destination block, compute the list of MemRefs that have to be retained in
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
index 3e11432c65c5f08..8fe143eea5bf8ec 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
@@ -55,8 +55,38 @@ def BufferDeallocationOpInterface :
           ownership indicator when needed, it should be implemented using this
           method (which is especially important if operations are created that
           cannot be easily canonicalized away anymore).
+          Ownership indicators have to be materialized when
+            * needed for the condition operands of a `bufferization.dealloc` op
+            * passed along MemRefs to successor blocks via additional forwarded
+              operands of terminator ops
+            * passing them as additional operands to nested regions (e.g.,
+              init_args of `scf.for`)
+            * passing them as additional operands to a call operation when
+              `private-function-dynamic-ownership` is enabled
+            * a copy is made conditionally on the current ownership, etc.
+
+          In the following example, the deallocation pass would add an
+          additional block argument to `^bb1` for passing the ownership of `%0`
+          along and thus the ownership indicator has to be materialized before
+          the `cf.br` operation and added as a forwarded operand.
+          ```mlir
+            %0 = arith.select %cond, %m1, %m2 : memref<f32>
+            cf.br ^bb1(%0 : memref<f32>)
+          ^bb1(%arg0: memref<f32>)
+            ...
+          ```
+          The `arith.select` operation could implement this interface method to
+          materialize another `arith.select` operation to select the
+          corresponding ownership indicator.
+          ```mlir
+            %0 = arith.select %cond, %m1, %m2 : memref<f32>
+            %0_ownership = arith.select %cond, %m1_ownership, %m2_ownership : i1
+            cf.br ^bb1(%0, %0_ownership : memref<f32>, i1)
+          ^bb1(%arg0: memref<f32>, %arg1: i1)
+            ...
+          ```
         }],
-        /*retType=*/"std::pair<Value, Value>",
+        /*retType=*/"Value",
         /*methodName=*/"materializeUniqueOwnershipForMemref",
         /*args=*/(ins "DeallocationState &":$state,
                       "const DeallocationOptions &":$options,
@@ -65,7 +95,7 @@ def BufferDeallocationOpInterface :
         /*methodBody=*/[{}],
         /*defaultImplementation=*/[{
           return state.getMemrefWithUniqueOwnership(
-            builder, memref, memref.getParentBlock());
+            options, builder, memref, memref.getParentBlock());
         }]>,
   ];
 }
diff --git a/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h
index 7acacb763cd2c18..7578351d2c4f501 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h
@@ -17,6 +17,7 @@
 
 namespace mlir {
 namespace bufferization {
+struct DeallocationOptions;
 
 /// Options for the buffer deallocation pipeline.
 struct BufferDeallocationPipelineOptions
@@ -27,7 +28,23 @@ struct BufferDeallocationPipelineOptions
           "Allows to add additional arguments to private functions to "
           "dynamically pass ownership of memrefs to callees. This can enable "
           "earlier deallocations."),
-      llvm::cl::init(false)};
+      llvm::cl::init(true)};
+  PassOptions::Option<bool> verifyFunctionBoundaryABI{
+      *this, "verify-function-boundary-abi",
+      llvm::cl::desc(
+          "Inserts `cf.assert` operations to verify the function boundary ABI "
+          "at runtime. Currently, it is only checked that the ownership of "
+          "returned MemRefs is 'true'. This makes sure that ownership is "
+          "yielded and the returned MemRef does not originate from the same "
+          "allocation as a function argument. If it can be determined "
+          "statically that the ABI is not adhered to, an error will already be "
+          "emitted at compile time. This cannot be changed with this option."),
+      llvm::cl::init(true)};
+
+  /// Convert this BufferDeallocationPipelineOptions struct to a
+  /// DeallocationOptions struct to be passed to the
+  /// OwnershipBasedBufferDeallocationPass.
+  DeallocationOptions asDeallocationOptions() const;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index a6f668b26aa10e4..2bf82dd6f88c81c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -1,6 +1,7 @@
 #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
 
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -31,7 +32,7 @@ std::unique_ptr<Pass> createBufferDeallocationPass();
 /// Creates an instance of the OwnershipBasedBufferDeallocation pass to free all
 /// allocated buffers.
 std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
-    bool privateFuncDynamicOwnership = false);
+    const DeallocationOptions &options = DeallocationOptions());
 
 /// Creates a pass that optimizes `bufferization.dealloc` operations. For
 /// example, it reduces the number of alias checks needed at runtime using
@@ -134,8 +135,9 @@ func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc,
 LogicalResult deallocateBuffers(Operation *op);
 
 /// Run ownership basedbuffer deallocation.
-LogicalResult deallocateBuffersOwnershipBased(FunctionOpInterface op,
-                                              bool privateFuncDynamicOwnership);
+LogicalResult deallocateBuffersOwnershipBased(
+    FunctionOpInterface op,
+    const DeallocationOptions &options = DeallocationOptions());
 
 /// Creates a pass that moves allocations upwards to reduce the number of
 /// required copies that are inserted during the BufferDeallocation pass.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e01f36b8daa18d6..5de17cf7faa7ef6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -219,10 +219,20 @@ def OwnershipBasedBufferDeallocation : Pass<
   }];
   let options = [
     Option<"privateFuncDynamicOwnership", "private-function-dynamic-ownership",
-           "bool", /*default=*/"false",
+           "bool", /*default=*/"true",
            "Allows to add additional arguments to private functions to "
            "dynamically pass ownership of memrefs to callees. This can enable "
            "earlier deallocations.">,
+    Option<"verifyFunctionBoundaryABI", "verify-function-boundary-abi",
+           "bool", /*default=*/"true",
+           "Inserts `cf.assert` operations to verify the function boundary ABI "
+           "at runtime. Currently, it is only checked that the ownership of "
+           "returned MemRefs is 'true'. This makes sure that ownership is "
+           "yielded and the returned MemRef does not originate from the same "
+           "allocation as a function argument. "
+           "If it can be determined statically that the ABI is not adhered "
+           "to, an error will already be emitted at compile time. This cannot "
+           "be changed with this option.">,
   ];
   let constructor = "mlir::bufferization::createOwnershipBasedBufferDeallocationPass()";
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index f2e7732e8ea4aa3..29e6ba9bfcf5891 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -53,10 +53,11 @@ struct SelectOpInterface
     return op; // nothing to do
   }
 
-  std::pair<Value, Value>
-  materializeUniqueOwnershipForMemref(Operation *op, DeallocationState &state,
-                                      const DeallocationOptions &options,
-                                      OpBuilder &builder, Value value) const {
+  Value materializeUniqueOwnershipForMemref(Operation *op,
+                                            DeallocationState &state,
+                                            const DeallocationOptions &options,
+                                            OpBuilder &builder,
+                                            Value value) const {
     auto selectOp = cast<arith::SelectOp>(op);
     assert(value == selectOp.getResult() &&
            "Value not defined by this operation");
@@ -64,14 +65,14 @@ struct SelectOpInterface
     Block *block = value.getParentBlock();
     if (!state.getOwnership(selectOp.getTrueValue(), block).isUnique() ||
         !state.getOwnership(selectOp.getFalseValue(), block).isUnique())
-      return state.getMemrefWithUniqueOwnership(builder, value,
+      return state.getMemrefWithUniqueOwnership(options, builder, value,
                                                 value.getParentBlock());
 
     Value ownership = builder.create<arith::SelectOp>(
         op->getLoc(), selectOp.getCondition(),
         state.getOwnership(selectOp.getTrueValue(), block).getIndicator(),
         state.getOwnership(selectOp.getFalseValue(), block).getIndicator());
-    return {selectOp.getResult(), ownership};
+    return ownership;
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index 8d21446f1eb777e..50def5f45538bb5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -132,30 +132,79 @@ void DeallocationState::getLiveMemrefsIn(Block *block,
   memrefs.append(liveMemrefs);
 }
 
-std::pair<Value, Value>
-DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
-                                                Value memref, Block *block) {
-  auto iter = ownershipMap.find({memref, block});
-  assert(iter != ownershipMap.end() &&
-         "Value must already have been registered in the ownership map");
-
-  Ownership ownership = iter->second;
-  if (ownership.isUnique())
-    return {memref, ownership.getIndicator()};
-
-  // Instead of inserting a clone operation we could also insert a dealloc
-  // operation earlier in the block and use the updated ownerships returned by
-  // the op for the retained values. Alternatively, we could insert code to
-  // check aliasing at runtime and use this information to combine two unique
-  // ownerships more intelligently to not end up with an 'Unknown' ownership in
-  // the first place.
-  auto cloneOp =
-      builder.create<bufferization::CloneOp>(memref.getLoc(), memref);
-  Value condition = buildBoolValue(builder, memref.getLoc(), true);
-  Value newMemref = cloneOp.getResult();
-  updateOwnership(newMemref, condition);
-  memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
-  return {newMemref, condition};
+Value DeallocationState::getMemrefWithUniqueOwnership(
+    const DeallocationOptions &options, OpBuilder &builder, Value memref,
+    Block *block) {
+  // NOTE: * if none of the operands have the same allocated pointer, a new
+  // memref was allocated and thus the operation should have the allocate
+  // side-effect defined on that result value and thus the correct unique
+  // ownership is pre-populated by the ownership pass (unless an interface
+  // implementation is incorrect).
+  //       * if exactly one operand has the same allocated pointer, this retunes
+  //       the ownership of exactly that operand
+  //       * if multiple operands match the allocated pointer of the result, the
+  //       ownership indicators of all of them always have to evaluate to the
+  //       same value because no dealloc operations may be present and because
+  //       of the rules they are passed to nested regions and successor blocks.
+  //       This could be verified at runtime by inserting `cf.assert`
+  //       operations, but would require O(|operands|^2) additional operations
+  //       to check and is thus not implemented yet (would need to insert a
+  //       library function to avoid code-size explosion which would make the
+  //       deallocation pass a module pass)
+  auto ipSave = builder.saveInsertionPoint();
+  SmallVector<Value> worklist;
+  worklist.push_back(memref);
+
+  while (!worklist.empty()) {
+    Value curr = worklist.back();
+    Ownership ownership = getOwnership(curr, block);
+    if (ownership.isUnique()) {
+      worklist.pop_back();
+      continue;
+    }
+
+    Operation *defOp = curr.getDefiningOp();
+    assert(defOp &&
+           "the ownership-based deallocation pass should be written in a way "
+           "that pre-populates ownership for block arguments");
+
+    bool allKnown = true;
+    for (Value val : llvm::make_filter_range(defOp->getOperands(), isMemref)) {
+      Ownership ownership = getOwnership(val, block);
+      if (ownership.isUnique())
+        continue;
+
+      worklist.push_back(val);
+      allKnown = false;
+    }
+
+    if (allKnown) {
+      builder.setInsertionPointAfter(defOp);
+      SmallVector<Value> operands(
+          llvm::make_filter_range(defOp->getOperands(), isMemref));
+      Value resultPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
+          defOp->getLoc(), curr);
+      Value ownership = getOwnership(operands.front(), block).getIndicator();
+
+      for (Value val : ArrayRef(operands).drop_front()) {
+        Value operandPtr =
+            builder.create<memref::ExtractAlignedPointerAsIndexOp>(
+                defOp->getLoc(), val);
+        Value isSameBuffer = builder.create<arith::CmpIOp>(
+            defOp->getLoc(), arith::CmpIPredicate::eq, resultPtr, operandPtr);
+        Value newOwnership = getOwnership(val, block).getIndicator();
+        ownership = builder.create<arith::SelectOp>(
+            defOp->getLoc(), isSameBuffer, newOwnership, ownership);
+      }
+      // Ownership is already 'Unknown', so we need to override instead of
+      // joining.
+      resetOwnerships(curr, block);
+      updateOwnership(curr, ownership, block);
+    }
+  }
+
+  builder.restoreInsertionPoint(ipSave);
+  return getOwnership(memref, block).getIndicator();
 }
 
 void DeallocationState::getMemrefsToRetain(
diff --git a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
index a2878f0b80fa1cd..ea05fa29ea608eb 100644
--- a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
+++ b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
@@ -8,22 +8,34 @@
 
 #include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
 
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/Passes.h"
 
+using namespace mlir;
+using namespace bufferization;
+
 //===----------------------------------------------------------------------===//
 // Pipeline implementation.
 //===----------------------------------------------------------------------===//
 
+DeallocationOptions
+BufferDeallocationPipelineOptions::asDeallocationOptions() const {
+  DeallocationOptions opts;
+  opts.privateFuncDynamicOwnership = privateFunctionDynamicOwnership.getValue();
+  opts.verifyFunctionBoundaryABI = verifyFunctionBoundaryABI.getValue();
+  return opts;
+}
+
 void mlir::bufferization::buildBufferDeallocationPipeline(
     OpPassManager &pm, const BufferDeallocationPipelineOptions &options) {
   pm.addPass(memref::createExpandReallocPass(/*emitDeallocs=*/false));
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createOwnershipBasedBufferDeallocationPass(
-      options.privateFunctionDynamicOwnership.getValue()));
+      options.asDeallocationOptions()));
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createBufferDeallocationSimplificationPass());
   pm.addPass(createLowerDeallocationsPass());
diff --git a/mlir/lib/Dialect/Bufferization/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Pipelines/CMakeLists.txt
index 6e8dab64ba6b935..d67b28b308fa10e 100644
--- a/mlir/lib/Dialect/Bufferization/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Pipelines/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRBufferizationPipelines
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
 
   LINK_LIBS PUBLIC
+  MLIRBufferizationDialect
   MLIRBufferizationTransforms
   MLIRMemRefTransforms
   MLIRFuncDialect
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index ed8dbd57bf40ba1..7cef74be6177fa4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   LINK_LIBS PUBLIC
   MLIRArithDialect
   MLIRBufferizationDialect
+  MLIRControlFlowDialect
   MLIRControlFlowInterfaces
   MLIRFuncDialect
   MLIRFunctionInterfaces
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index fd36716163d0ad4..362f054142fc019 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -139,10 +140,8 @@ namespace {
 /// program have a corresponding de-allocation.
 class BufferDeallocation {
 public:
-  BufferDeallocation(Operation *op, bool privateFuncDynamicOwnership)
-      : state(op) {
-    options.privateFuncDynamicOwnership = privateFuncDynamicOwnership;
-  }
+  BufferDeallocation(Operation *op, DeallocationOptions options)
+      : state(op), options(options) {}
 
   /// Performs the actual placement/creation of all dealloc operations.
   LogicalResult deallocate(FunctionOpInterface op);
@@ -373,12 +372,6 @@ class BufferDeallocation {
   /// operations, etc.).
   void populateRemainingOwnerships(Operation *op);
 
-  /// Given an SSA value of MemRef type, returns the same of a new SSA value
-  /// which has 'Unique' ownership where the ownership indicator is guaranteed
-  /// to be always 'true'.
-  Value materializeMemrefWithGuaranteedOwnership(OpBuilder &builder,
-                                                 Value memref, Block *block);
-
   /// Returns whether the given operation implements FunctionOpInterface, has
   /// private visibility, and the private-function-dynamic-ownership pass option
   /// is enabled.
@@ -391,8 +384,8 @@ class BufferDeallocation {
   /// is requested does not match the block in which 'memref' is defined, the
   /// default implementation in
   /// `DeallocationState::getMemrefWithUniqueOwnership` is queried instead.
-  std::pair<Value, Value>
-  materializeUniqueOwnership(OpBuilder &builder, Value memref, Block *block);
+  Value materializeUniqueOwnership(OpBuilder &builder, Value memref,
+                                   Block *block);
 
   /// Checks all the preconditions for operations implementing the
   /// FunctionOpInterface that have to hold for the deallocation to be
@@ -430,7 +423,7 @@ class BufferDeallocation {
   DeallocationState state;
 
   /// Collects all pass options in a single place.
-  DeallocationOptions options;
+  const DeallocationOptions options;
 };
 
 } // namespace
@@ -439,13 +432,13 @@ class BufferDeallocation {
 // BufferDeallocation Implementation
 //===----------------------------------------------------------------------===//
 
-std::pair<Value, Value>
-BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
-                                               Block *block) {
+Value BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder,
+                                                     Value memref,
+                                                     Block *block) {
   // The interface can only materialize ownership indicators in the same block
   // as the defining op.
   if (memref.getParentBlock() != block)
-    return state.getMemrefWithUniqueOwnership(builder, memref, block);
+    return state.getMemrefWithUniqueOwnership(options, builder, memref, block);
 
   Operation *owner = memref.getDefiningOp();
   if (!owner)
@@ -458,7 +451,7 @@ BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
         state, options, builder, memref);
 
   // Otherwise use the default implementation.
-  return state.getMemrefWithUniqueOwnership(builder, memref, block);
+  return state.getMemrefWithUniqueOwnership(options, builder, memref, block);
 }
 
 static bool regionOperatesOnMemrefValues(Region &region) {
@@ -712,41 +705,6 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
   return newOp.getOperation();
 }
 
-Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership(
-    OpBuilder &builder, Value memref, Block *block) {
-  // First, make sure we at least have 'Unique' ownership already.
-  std::pair<Value, Value> newMemrefAndOnwership =
-      materializeUniqueOwnership(builder, memref, block);
-  Value newMemref = newMemrefAndOnwership.first;
-  Value condition = newMemrefAndOnwership.second;
-
-  // Avoid inserting additional IR if ownership is already guaranteed. In
-  // particular, this is already the case when we had 'Unknown' ownership
-  // initially and a clone was inserted to get to 'Unique' ownership.
-  if (matchPattern(condition, m_One()))
-    return newMemref;
-
-  // Insert a runtime check and only clone if we still don't have ownership at
-  // runtime.
-  Value maybeClone =
-      builder
-          .create<scf::IfOp>(
-              memref.getLoc(), condition,
-              [&](OpBuilder &builder, Location loc) {
-                builder.create<scf::YieldOp>(loc, newMemref);
-              },
-              [&](OpBuilder &builder, Location loc) {
-                Value clone =
-                    builder.create<bufferization::CloneOp>(loc, newMemref);
-                builder.create<scf::YieldOp>(loc, clone);
-              })
-          .getResult(0);
-  Value trueVal = buildBoolValue(builder, memref.getLoc(), true);
-  state.updateOwnership(maybeClone, trueVal);
-  state.addMemrefToDeallocate(maybeClone, maybeClone.getParentBlock());
-  return maybeClone;
-}
-
 FailureOr<Operation *>
 BufferDeallocation::handleInterface(BranchOpInterface op) {
   if (op->getNumSuccessors() > 1)
@@ -819,10 +777,11 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
         newOperands.push_back(operand);
         continue;
       }
-      auto [memref, condition] =
+      Value ownership =
           materializeUniqueOwnership(builder, operand, op->getBlock());
-      newOperands.push_back(memref);
-      ownershipIndicatorsToAdd.push_back(condition);
+
+      newOperands.push_back(operand);
+      ownershipIndicatorsToAdd.push_back(ownership);
     }
     newOperands.append(ownershipIndicatorsToAdd.begin(),
                        ownershipIndicatorsToAdd.end());
@@ -903,8 +862,14 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
       if (!isMemref(val.get()))
         continue;
 
-      val.set(materializeMemrefWithGuaranteedOwnership(builder, val.get(),
-                                                       op->getBlock()));
+      if (options.verifyFunctionBoundaryABI) {
+        Value ownership =
+            materializeUniqueOwnership(builder, val.get(), op->getBlock());
+        builder.create<cf::AssertOp>(
+            op->getLoc(), ownership,
+            builder.getStringAttr("Must have ownership of operand #" +
+                                  Twine(val.getOperandNumber())));
+      }
     }
   }
 
@@ -978,17 +943,22 @@ struct OwnershipBasedBufferDeallocationPass
     : public bufferization::impl::OwnershipBasedBufferDeallocationBase<
           OwnershipBasedBufferDeallocationPass> {
   OwnershipBasedBufferDeallocationPass() = default;
-  OwnershipBasedBufferDeallocationPass(bool privateFuncDynamicOwnership)
+  OwnershipBasedBufferDeallocationPass(const DeallocationOptions &options)
       : OwnershipBasedBufferDeallocationPass() {
-    this->privateFuncDynamicOwnership.setValue(privateFuncDynamicOwnership);
+    privateFuncDynamicOwnership.setValue(options.privateFuncDynamicOwnership);
+    verifyFunctionBoundaryABI.setValue(options.verifyFunctionBoundaryABI);
   }
   void runOnOperation() override {
+    DeallocationOptions options;
+    options.privateFuncDynamicOwnership =
+        privateFuncDynamicOwnership.getValue();
+    options.verifyFunctionBoundaryABI = verifyFunctionBoundaryABI.getValue();
+
     auto status = getOperation()->walk([&](func::FuncOp func) {
       if (func.isExternal())
         return WalkResult::skip();
 
-      if (failed(deallocateBuffersOwnershipBased(func,
-                                                 privateFuncDynamicOwnership)))
+      if (failed(deallocateBuffersOwnershipBased(func, options)))
         return WalkResult::interrupt();
 
       return WalkResult::advance();
@@ -1005,9 +975,9 @@ struct OwnershipBasedBufferDeallocationPass
 //===----------------------------------------------------------------------===//
 
 LogicalResult bufferization::deallocateBuffersOwnershipBased(
-    FunctionOpInterface op, bool privateFuncDynamicOwnership) {
+    FunctionOpInterface op, const DeallocationOptions &options) {
   // Gather all required allocation nodes and prepare the deallocation phase.
-  BufferDeallocation deallocation(op, privateFuncDynamicOwnership);
+  BufferDeallocation deallocation(op, options);
 
   // Place all required temporary clone and dealloc nodes.
   return deallocation.deallocate(op);
@@ -1019,7 +989,6 @@ LogicalResult bufferization::deallocateBuffersOwnershipBased(
 
 std::unique_ptr<Pass>
 mlir::bufferization::createOwnershipBasedBufferDeallocationPass(
-    bool privateFuncDynamicOwnership) {
-  return std::make_unique<OwnershipBasedBufferDeallocationPass>(
-      privateFuncDynamicOwnership);
+    const DeallocationOptions &options) {
+  return std::make_unique<OwnershipBasedBufferDeallocationPass>(options);
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 3ae0529ab7d7466..0e51c2380eb429f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
 // RUN:  -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
-// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true -split-input-file %s > /dev/null
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=false -split-input-file %s > /dev/null
 
 // RUN: mlir-opt %s -buffer-deallocation-pipeline --split-input-file > /dev/null
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
index 02bf2d10e9e3f56..b310fcc4731bf30 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=false \
 // RUN:   -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
-// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true \
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
 // RUN:   --buffer-deallocation-simplification -split-input-file %s | FileCheck %s --check-prefix=CHECK-DYNAMIC
 
 // RUN: mlir-opt %s -buffer-deallocation-pipeline --split-input-file > /dev/null
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-function-boundaries.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-function-boundaries.mlir
index 13c55d0289880ef..306ce0e098b7acc 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-function-boundaries.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-function-boundaries.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt --allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=false \
 // RUN:  --buffer-deallocation-simplification -split-input-file %s | FileCheck %s
-// RUN: mlir-opt --allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true \
+// RUN: mlir-opt --allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation \
 // RUN:  --buffer-deallocation-simplification -split-input-file %s | FileCheck %s --check-prefix=CHECK-DYNAMIC
 
 // RUN: mlir-opt %s -buffer-deallocation-pipeline --split-input-file > /dev/null
@@ -94,34 +94,26 @@ func.func private @redundantOperations(%arg0: memref<2xf32>) {
 
 func.func private @memref_in_function_results(
   %arg0: memref<5xf32>,
-  %arg1: memref<10xf32>,
-  %arg2: memref<5xf32>) -> (memref<10xf32>, memref<15xf32>) {
+  %arg2: memref<5xf32>) -> (memref<15xf32>) {
   %x = memref.alloc() : memref<15xf32>
   %y = memref.alloc() : memref<5xf32>
   test.buffer_based in(%arg0: memref<5xf32>) out(%y: memref<5xf32>)
   test.copy(%y, %arg2) : (memref<5xf32>, memref<5xf32>)
-  return %arg1, %x : memref<10xf32>, memref<15xf32>
+  return %x : memref<15xf32>
 }
 
 // CHECK-LABEL: func private @memref_in_function_results
-//       CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>,
+//       CHECK: (%[[ARG0:.*]]: memref<5xf32>,
 //  CHECK-SAME: %[[RESULT:.*]]: memref<5xf32>)
 //       CHECK: %[[X:.*]] = memref.alloc()
 //       CHECK: %[[Y:.*]] = memref.alloc()
 //       CHECK: test.copy
-//  CHECK-NEXT: %[[V0:.+]] = scf.if %false
-//  CHECK-NEXT:   scf.yield %[[ARG1]]
-//  CHECK-NEXT: } else {
-//  CHECK-NEXT:   %[[CLONE:.+]] = bufferization.clone %[[ARG1]]
-//  CHECK-NEXT:   scf.yield %[[CLONE]]
-//  CHECK-NEXT: }
 //       CHECK: bufferization.dealloc (%[[Y]] : {{.*}}) if (%true{{[0-9_]*}})
 //   CHECK-NOT: retain
-//       CHECK: return %[[V0]], %[[X]]
+//       CHECK: return %[[X]]
 
 // CHECK-DYNAMIC-LABEL: func private @memref_in_function_results
-//       CHECK-DYNAMIC: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>,
-//  CHECK-DYNAMIC-SAME: %[[RESULT:.*]]: memref<5xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1, %[[ARG5:.*]]: i1)
+//  CHECK-DYNAMIC-SAME: (%[[ARG0:.*]]: memref<5xf32>, %[[RESULT:.*]]: memref<5xf32>, %[[ARG3:.*]]: i1, %[[ARG5:.*]]: i1)
 //       CHECK-DYNAMIC: %[[X:.*]] = memref.alloc()
 //       CHECK-DYNAMIC: %[[Y:.*]] = memref.alloc()
 //       CHECK-DYNAMIC: test.copy
@@ -129,6 +121,6 @@ func.func private @memref_in_function_results(
 //       CHECK-DYNAMIC: %[[BASE1:[a-zA-Z0-9_]+]], {{.+}} = memref.extract_strided_metadata %[[RESULT]]
 //       CHECK-DYNAMIC: bufferization.dealloc (%[[Y]] : {{.*}}) if (%true{{[0-9_]*}})
 //   CHECK-DYNAMIC-NOT: retain
-//       CHECK-DYNAMIC: [[OWN:%.+]] = bufferization.dealloc (%[[BASE0]], %[[BASE1]] : {{.*}}) if (%[[ARG3]], %[[ARG5]]) retain (%[[ARG1]] :
-//       CHECK-DYNAMIC: [[OR:%.+]] = arith.ori [[OWN]], %[[ARG4]]
-//       CHECK-DYNAMIC: return %[[ARG1]], %[[X]], [[OR]], %true
+//       CHECK-DYNAMIC: bufferization.dealloc (%[[BASE0]], %[[BASE1]] : {{.*}}) if (%[[ARG3]], %[[ARG5]])
+//   CHECK-DYNAMIC-NOT: retain
+//       CHECK-DYNAMIC: return %[[X]], %true
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir
index 44cf16385603e07..269b8b71f7beb1b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
 // RUN:   --buffer-deallocation-simplification -split-input-file %s | FileCheck %s
-// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true -split-input-file %s > /dev/null
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=false -split-input-file %s > /dev/null
 
 // RUN: mlir-opt %s -buffer-deallocation-pipeline --split-input-file > /dev/null
 
@@ -100,27 +100,10 @@ func.func @dealloc_existing_clones(%arg0: memref<?x?xf64>, %arg1: memref<?x?xf64
 //       CHECK: (%[[ARG0:.*]]: memref<?x?xf64>, %[[ARG1:.*]]: memref<?x?xf64>)
 //       CHECK: %[[RES0:.*]] = bufferization.clone %[[ARG0]]
 //       CHECK: %[[RES1:.*]] = bufferization.clone %[[ARG1]]
-//  CHECK-NEXT: bufferization.dealloc (%[[RES1]] :{{.*}}) if (%true{{[0-9_]*}})
+//       CHECK: bufferization.dealloc (%[[RES1]] :{{.*}}) if (%true{{[0-9_]*}})
 //   CHECK-NOT: retain
 //  CHECK-NEXT: return %[[RES0]]
 
 // TODO: The retain operand could be dropped to avoid runtime aliasing checks
 // since We can guarantee at compile-time that it will never alias with the
 // dealloc operand
-
-// -----
-
-memref.global "private" constant @__constant_4xf32 : memref<4xf32> = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]>
-
-func.func @op_without_aliasing_and_allocation() -> memref<4xf32> {
-  %0 = memref.get_global @__constant_4xf32 : memref<4xf32>
-  return %0 : memref<4xf32>
-}
-
-// CHECK-LABEL: func @op_without_aliasing_and_allocation
-//       CHECK:   [[GLOBAL:%.+]] = memref.get_global @__constant_4xf32
-//       CHECK:   [[RES:%.+]] = scf.if %false
-//       CHECK:     scf.yield [[GLOBAL]] :
-//       CHECK:     [[CLONE:%.+]] = bufferization.clone [[GLOBAL]]
-//       CHECK:     scf.yield [[CLONE]] :
-//       CHECK:   return [[RES]] :
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
index dc372749fc074be..f1b753e405531e5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation \
 // RUN:  --buffer-deallocation-simplification -split-input-file %s | FileCheck %s
-// RUN: mlir-opt -allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true -split-input-file %s > /dev/null
+// RUN: mlir-opt -allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=false -split-input-file %s > /dev/null
 
 // RUN: mlir-opt %s -buffer-deallocation-pipeline --split-input-file --verify-diagnostics > /dev/null
 
@@ -85,13 +85,7 @@ func.func @nested_region_control_flow(
 //       CHECK:     bufferization.dealloc ([[ALLOC1]] :{{.*}}) if (%true{{[0-9_]*}})
 //   CHECK-NOT: retain
 //       CHECK:     scf.yield [[ALLOC]], %false
-//       CHECK:   [[V1:%.+]] = scf.if [[V0]]#1
-//       CHECK:     scf.yield [[V0]]#0
-//       CHECK:     [[CLONE:%.+]] = bufferization.clone [[V0]]#0
-//       CHECK:     scf.yield [[CLONE]]
-//       CHECK:   [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
-//       CHECK:   bufferization.dealloc ([[ALLOC]], [[BASE]] : {{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]] :
-//       CHECK:   return [[V1]]
+//       CHECK:   return [[V0]]#0
 
 // -----
 
@@ -120,13 +114,8 @@ func.func @nested_region_control_flow_div(
 //       CHECK:     scf.yield [[ALLOC]], %false
 //       CHECK:     [[ALLOC1:%.+]] = memref.alloc(
 //       CHECK:     scf.yield [[ALLOC1]], %true
-//       CHECK:   [[V1:%.+]] = scf.if [[V0]]#1
-//       CHECK:     scf.yield [[V0]]#0
-//       CHECK:     [[CLONE:%.+]] = bufferization.clone [[V0]]#0
-//       CHECK:     scf.yield [[CLONE]]
-//       CHECK:   [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
-//       CHECK:   bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]] :
-//       CHECK:   return [[V1]]
+//       CHECK:   bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true{{[0-9_]*}}) retain ([[V0]]#0 :
+//       CHECK:   return [[V0]]#0
 
 // -----
 
@@ -158,13 +147,8 @@ func.func @inner_region_control_flow(%arg0 : index) -> memref<?x?xf32> {
 //       CHECK:     test.region_if_yield [[ARG1]], [[ARG2]]
 //       CHECK:   ^bb0([[ARG1:%.+]]: memref<?x?xf32>, [[ARG2:%.+]]: i1):
 //       CHECK:     test.region_if_yield [[ARG1]], [[ARG2]]
-//       CHECK:   [[V1:%.+]] = scf.if [[V0]]#1
-//       CHECK:     scf.yield [[V0]]#0
-//       CHECK:     [[CLONE:%.+]] = bufferization.clone [[V0]]#0
-//       CHECK:     scf.yield [[CLONE]]
-//       CHECK:   [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
-//       CHECK:   bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]] :
-//       CHECK:   return [[V1]]
+//   CHECK-NOT:   bufferization.dealloc
+//       CHECK:   return [[V0]]#0
 
 // -----
 
@@ -232,13 +216,8 @@ func.func @nestedRegionControlFlowAlloca(
 //       CHECK:   scf.yield [[ALLOC]], %false
 //       CHECK:   memref.alloca(
 //       CHECK:   scf.yield [[ALLOC]], %false
-//       CHECK: [[V1:%.+]] = scf.if [[V0]]#1
-//       CHECK:   scf.yield [[V0]]#0
-//       CHECK:   [[CLONE:%.+]] = bufferization.clone [[V0]]#0
-//       CHECK:   scf.yield [[CLONE]]
-//       CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
-//       CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]] :
-//       CHECK: return [[V1]]
+//   CHECK-NOT: bufferization.dealloc
+//       CHECK: return [[V0]]#0
 
 // -----
 
@@ -364,13 +343,8 @@ func.func @loop_nested_if_alloc(
 //       CHECK:   [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[V1]]#1
 //       CHECK:   scf.yield [[V1]]#0, [[OWN_AGG]]
 //       CHECK: }
-//       CHECK: [[V2:%.+]] = scf.if [[V0]]#1
-//       CHECK:   scf.yield [[V0]]#0
-//       CHECK:   [[CLONE:%.+]] = bufferization.clone [[V0]]#0
-//       CHECK:   scf.yield [[CLONE]]
-//       CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
-//       CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V2]] :
-//       CHECK: return [[V2]]
+//       CHECK: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true{{[0-9_]*}}) retain ([[V0]]#0 :
+//       CHECK: return [[V0]]#0
 
 // -----
 
@@ -626,13 +600,7 @@ func.func @test_affine_if_1(%arg0: memref<10xf32>) -> memref<10xf32> {
 //       CHECK:   [[ALLOC:%.+]] = memref.alloc()
 //       CHECK:   affine.yield [[ALLOC]], %true
 //       CHECK:   affine.yield [[ARG0]], %false
-//       CHECK: [[V1:%.+]] = scf.if [[V0]]#1
-//       CHECK:   scf.yield [[V0]]#0
-//       CHECK:   [[CLONE:%.+]] = bufferization.clone [[V0]]#0
-//       CHECK:   scf.yield [[CLONE]]
-//       CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
-//       CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[V0]]#1) retain ([[V1]] :
-//       CHECK: return [[V1]]
+//       CHECK: return [[V0]]#0
 
 // TODO: the dealloc could be optimized away since the memref to be deallocated
 //       either aliases with V1 or the condition is false
@@ -652,19 +620,14 @@ func.func @test_affine_if_2() -> memref<10xf32> {
   }
   return %0 : memref<10xf32>
 }
+
 // CHECK-LABEL: func @test_affine_if_2
 //       CHECK: [[ALLOC:%.+]] = memref.alloc()
 //       CHECK: [[V0:%.+]]:2 = affine.if
 //       CHECK:   affine.yield [[ALLOC]], %false
 //       CHECK:   [[ALLOC1:%.+]] = memref.alloc()
 //       CHECK:   affine.yield [[ALLOC1]], %true
-//       CHECK: [[V1:%.+]] = scf.if [[V0]]#1
-//       CHECK:   scf.yield [[V0]]#0
-//       CHECK:   [[CLONE:%.+]] = bufferization.clone [[V0]]#0
-//       CHECK:   scf.yield [[CLONE]]
-//       CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
-//       CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]] :
-//       CHECK: return [[V1]]
+//       CHECK: return [[V0]]#0
 
 // -----
 
@@ -688,10 +651,4 @@ func.func @test_affine_if_3() -> memref<10xf32> {
 //       CHECK:   [[ALLOC1:%.+]] = memref.alloc()
 //       CHECK:   affine.yield [[ALLOC1]], %true
 //       CHECK:   affine.yield [[ALLOC]], %false
-//       CHECK: [[V1:%.+]] = scf.if [[V0]]#1
-//       CHECK:   scf.yield [[V0]]#0
-//       CHECK:   [[CLONE:%.+]] = bufferization.clone [[V0]]#0
-//       CHECK:   scf.yield [[CLONE]]
-//       CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
-//       CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]]
-//       CHECK: return [[V1]]
+//       CHECK: return [[V0]]#0
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-runtime-verification.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-runtime-verification.mlir
new file mode 100644
index 000000000000000..9951347efa45d27
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-runtime-verification.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation -split-input-file %s | FileCheck %s
+
+memref.global "private" constant @__constant_4xf32 : memref<4xf32> = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]>
+
+func.func @op_without_aliasing_and_allocation() -> memref<4xf32> {
+  %0 = memref.get_global @__constant_4xf32 : memref<4xf32>
+  return %0 : memref<4xf32>
+}
+
+// CHECK-LABEL: func @op_without_aliasing_and_allocation
+//       CHECK:   [[GLOBAL:%.+]] = memref.get_global @__constant_4xf32
+//       CHECK:   cf.assert %false{{[0-9_]*}}, "Must have ownership of operand #0"
+//       CHECK:   return [[GLOBAL]] :
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-subviews.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-subviews.mlir
index 35523319de1548e..bf55730c8fbb380 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-subviews.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-subviews.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
 // RUN:   --buffer-deallocation-simplification -split-input-file %s | FileCheck %s
-// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true -split-input-file %s > /dev/null
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=false -split-input-file %s > /dev/null
 
 // RUN: mlir-opt %s -buffer-deallocation-pipeline --split-input-file > /dev/null
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-unknown-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-unknown-ops.mlir
new file mode 100644
index 000000000000000..6808053500e6cab
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-unknown-ops.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt --allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation \
+// RUN:  -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+
+func.func private @callee(%arg0: memref<f32>) -> memref<f32> {
+  return %arg0 : memref<f32>
+}
+
+func.func @generic_ownership_materialization() {
+  %a1 = memref.alloc() : memref<f32>
+  %a2 = memref.alloca() : memref<f32>
+  %0 = "my_dialect.select_randomly"(%a1, %a2, %a1) : (memref<f32>, memref<f32>, memref<f32>) -> memref<f32>
+  %1 = call @callee(%0) : (memref<f32>) -> memref<f32>
+  return
+}
+
+// CHECK-LABEL: func @generic_ownership_materialization
+//       CHECK: [[ALLOC:%.+]] = memref.alloc(
+//       CHECK: [[ALLOCA:%.+]] = memref.alloca(
+//       CHECK: [[SELECT:%.+]] = "my_dialect.select_randomly"([[ALLOC]], [[ALLOCA]], [[ALLOC]])
+//       CHECK: [[SELECT_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[SELECT]]
+//       CHECK: [[ALLOCA_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[ALLOCA]]
+//       CHECK: [[EQ1:%.+]] = arith.cmpi eq, [[SELECT_PTR]], [[ALLOCA_PTR]]
+//       CHECK: [[OWN1:%.+]] = arith.select [[EQ1]], %false{{[0-9_]*}}, %true
+//       CHECK: [[ALLOC_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[ALLOC]]
+//       CHECK: [[EQ2:%.+]] = arith.cmpi eq, [[SELECT_PTR]], [[ALLOC_PTR]]
+//       CHECK: [[OWN2:%.+]] = arith.select [[EQ2]], %true{{[0-9_]*}}, [[OWN1]]
+//       CHECK: [[CALL:%.+]]:2 = call @callee([[SELECT]], [[OWN2]])
+//       CHECK: [[BASE:%.+]],{{.*}} = memref.extract_strided_metadata [[CALL]]#0
+//       CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[CALL]]#1)
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
index 43e423d4c3e8e14..a0a41d04e1a00ee 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
@@ -33,7 +33,7 @@ func.func @main() {
 
 func.func private @printMemrefF32(%ptr : tensor<*xf32>)
 
-func.func @collapse_dynamic_shape(%arg0 : tensor<2x?x?x?xf32>) -> tensor<2x?x?xf32> {
+func.func private @collapse_dynamic_shape(%arg0 : tensor<2x?x?x?xf32>) -> tensor<2x?x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3]]: tensor<2x?x?x?xf32> into tensor<2x?x?xf32>
   return %0 : tensor<2x?x?xf32>
 }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir
index a101b76ef186b5e..0aa1b81b5a42665 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir
@@ -34,7 +34,7 @@ func.func @main() {
 
 func.func private @printMemrefF32(%ptr : tensor<*xf32>)
 
-func.func @expand_dynamic_shape(%arg0 : tensor<2x?x?xf32>) -> tensor<2x2x?x1x?xf32> {
+func.func private @expand_dynamic_shape(%arg0 : tensor<2x?x?xf32>) -> tensor<2x2x?x1x?xf32> {
   %0 = tensor.expand_shape %arg0 [[0], [1, 2, 3], [4]]: tensor<2x?x?xf32> into tensor<2x2x?x1x?xf32>
   return %0 : tensor<2x2x?x1x?xf32>
 }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-one-shot-bufferize.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-one-shot-bufferize.mlir
index 06165515d4613c6..d58414bb43cc3d1 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-one-shot-bufferize.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-one-shot-bufferize.mlir
@@ -9,7 +9,7 @@
 #map0 = affine_map<(d0, d1)[s0] -> ((d1 - d0) ceildiv s0)>
 #map1 = affine_map<(d0, d1)[s0] -> ((d0 - d1) ceildiv s0)>
 
-func.func @init_and_dot(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<f32>) -> tensor<f32> {
+func.func private @init_and_dot(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<f32>) -> tensor<f32> {
   %c64 = arith.constant 64 : index
   %cst = arith.constant 0.000000e+00 : f32
   %c2 = arith.constant 2 : index
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-e2e.mlir
index 38b49cd444df3c1..dcaa484315af188 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-e2e.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-e2e.mlir
@@ -5,7 +5,7 @@
 // RUN:   -shared-libs=%mlir_runner_utils \
 // RUN: | FileCheck %s
 
-func.func @foo() -> tensor<4xf32> {
+func.func private @foo() -> tensor<4xf32> {
   %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
   return %0 : tensor<4xf32>
 }
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7b88c9deae69667..40091e7212f0052 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -12367,6 +12367,7 @@ cc_library(
     hdrs = ["include/mlir/Dialect/Bufferization/Pipelines/Passes.h"],
     includes = ["include"],
     deps = [
+        ":BufferizationDialect",
         ":BufferizationToMemRef",
         ":BufferizationTransforms",
         ":FuncDialect",

>From 590efce7c0e164129879c3064b064f9459f7ab5a Mon Sep 17 00:00:00 2001
From: Martin Erhart <merhart at google.com>
Date: Wed, 27 Sep 2023 14:44:21 +0000
Subject: [PATCH 2/3] [mlir][bufferization] Add deallocation option to remove
 existing dealloc operations, add option to specify the kind of alloc
 operations to consider

---
 .../IR/BufferDeallocationOpInterface.h        |  47 +++++++
 .../Dialect/Bufferization/Pipelines/Passes.h  |   6 +
 .../Bufferization/Transforms/Passes.td        |   5 +
 .../Pipelines/BufferizationPipelines.cpp      |   1 +
 .../OwnershipBasedBufferDeallocation.cpp      |  34 ++++-
 .../dealloc-mixed-allocations.mlir            |  33 +++++
 .../lib/Dialect/Bufferization/CMakeLists.txt  |   4 +
 .../TestOwnershipBasedBufferDeallocation.cpp  | 128 ++++++++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 .../mlir/test/BUILD.bazel                     |   7 +-
 10 files changed, 260 insertions(+), 7 deletions(-)
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-mixed-allocations.mlir
 create mode 100644 mlir/test/lib/Dialect/Bufferization/TestOwnershipBasedBufferDeallocation.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
index 838641db20cbbc3..d2e718ac93045ce 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
@@ -10,6 +10,8 @@
 #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_
 
 #include "mlir/Analysis/Liveness.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Support/LLVM.h"
@@ -92,6 +94,9 @@ class Ownership {
 
 /// Options for BufferDeallocationOpInterface-based buffer deallocation.
 struct DeallocationOptions {
+  using DetectionFn = std::function<bool(Operation *)>;
+  using ReplaceDeallocFn = std::function<FailureOr<ValueRange>(Operation *)>;
+
   // A pass option indicating whether private functions should be modified to
   // pass the ownership of MemRef values instead of adhering to the function
   // boundary ABI.
@@ -106,6 +111,48 @@ struct DeallocationOptions {
   /// to, an error will already be emitted at compile time. This cannot be
   /// changed with this option.
   bool verifyFunctionBoundaryABI = true;
+
+  /// Given an allocation side-effect on the passed operation, determine whether
+  /// this allocation operation is of relevance (i.e., should assign ownership
+  /// to the allocated value). If it is determined to not be relevant,
+  /// ownership will be set to 'false', i.e., it will be leaked. This is useful
+  /// to support deallocation of multiple different kinds of allocation ops.
+  DetectionFn isRelevantAllocOp = [](Operation *op) {
+    return isa<memref::MemRefDialect, bufferization::BufferizationDialect>(
+        op->getDialect());
+  };
+
+  /// Given a free side-effect on the passed operation, determine whether this
+  /// deallocation operation is of relevance (i.e., should be removed if the
+  /// `removeExistingDeallocations` option is enabled or otherwise an error
+  /// should be emitted because existing deallocation operations are not
+  /// supported without that flag). If it is determined to not be relevant,
+  /// the operation will be ignored. This is useful to support deallocation of
+  /// multiple different kinds of allocation ops where deallocations for some of
+  /// them are already present in the IR.
+  DetectionFn isRelevantDeallocOp = [](Operation *op) {
+    return isa<memref::MemRefDialect, bufferization::BufferizationDialect>(
+        op->getDialect());
+  };
+
+  /// When enabled, remove deallocation operations determined to be relevant
+  /// according to `isRelevantDeallocOp`. If the operation has result values,
+  /// `getDeallocReplacement` will be called to determine the SSA values that
+  /// should be used as replacements.
+  bool removeExistingDeallocations = false;
+
+  /// Provides SSA values for deallocation operations when
+  /// `removeExistingDeallocations` is enabled. May return a failure when the
+  /// given deallocation operation is not supported (e.g., because no
+  /// replacement for a result value can be determined). A failure will directly
+  /// lead to a failure emitted by the deallocation pass.
+  ReplaceDeallocFn getDeallocReplacement =
+      [](Operation *op) -> FailureOr<ValueRange> {
+    if (isa<memref::DeallocOp>(op))
+      return ValueRange{};
+    // ReallocOp has to be expanded before running the dealloc pass.
+    return failure();
+  };
 };
 
 /// This class collects all the state that we need to perform the buffer
diff --git a/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h
index 7578351d2c4f501..38883ff4588d2ee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h
@@ -40,6 +40,12 @@ struct BufferDeallocationPipelineOptions
           "statically that the ABI is not adhered to, an error will already be "
           "emitted at compile time. This cannot be changed with this option."),
       llvm::cl::init(true)};
+  PassOptions::Option<bool> removeExistingDeallocations{
+      *this, "remove-existing-deallocations",
+      llvm::cl::desc("Removes all pre-existing memref.dealloc operations and "
+                     "insert all deallocations according to the buffer "
+                     "deallocation rules."),
+      llvm::cl::init(false)};
 
   /// Convert this BufferDeallocationPipelineOptions struct to a
   /// DeallocationOptions struct to be passed to the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 5de17cf7faa7ef6..e182ae6548b95bf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -233,6 +233,11 @@ def OwnershipBasedBufferDeallocation : Pass<
            "If it can be determined statically that the ABI is not adhered "
            "to, an error will already be emitted at compile time. This cannot "
            "be changed with this option.">,
+    Option<"removeExistingDeallocations", "remove-existing-deallocations",
+           "bool", /*default=*/"false",
+           "Remove already existing MemRef deallocation operations and let the "
+           "deallocation pass insert the deallocation operations according to "
+           "its rules.">,
   ];
   let constructor = "mlir::bufferization::createOwnershipBasedBufferDeallocationPass()";
 
diff --git a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
index ea05fa29ea608eb..1fa42a2099f98c9 100644
--- a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
+++ b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
@@ -27,6 +27,7 @@ BufferDeallocationPipelineOptions::asDeallocationOptions() const {
   DeallocationOptions opts;
   opts.privateFuncDynamicOwnership = privateFunctionDynamicOwnership.getValue();
   opts.verifyFunctionBoundaryABI = verifyFunctionBoundaryABI.getValue();
+  opts.removeExistingDeallocations = removeExistingDeallocations.getValue();
   return opts;
 }
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 362f054142fc019..80828676e70f22d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -816,15 +816,28 @@ FailureOr<Operation *>
 BufferDeallocation::handleInterface(MemoryEffectOpInterface op) {
   auto *block = op->getBlock();
 
-  for (auto operand : llvm::make_filter_range(op->getOperands(), isMemref))
-    if (op.getEffectOnValue<MemoryEffects::Free>(operand).has_value())
+  for (auto operand : llvm::make_filter_range(op->getOperands(), isMemref)) {
+    if (op.getEffectOnValue<MemoryEffects::Free>(operand).has_value() &&
+        options.isRelevantDeallocOp(op)) {
+      if (auto repl = options.getDeallocReplacement(op);
+          succeeded(repl) && options.removeExistingDeallocations) {
+        op->replaceAllUsesWith(repl.value());
+        op.erase();
+        return FailureOr<Operation *>(nullptr);
+      }
+
       return op->emitError(
           "memory free side-effect on MemRef value not supported!");
+    }
+  }
 
   OpBuilder builder = OpBuilder::atBlockBegin(block);
   for (auto res : llvm::make_filter_range(op->getResults(), isMemref)) {
     auto allocEffect = op.getEffectOnValue<MemoryEffects::Allocate>(res);
     if (allocEffect.has_value()) {
+      // Assuming that an alloc effect is interpreted as MUST and not MAY.
+      state.resetOwnerships(res, block);
+
       if (isa<SideEffects::AutomaticAllocationScopeResource>(
               allocEffect->getResource())) {
         // Make sure that the ownership of auto-managed allocations is set to
@@ -839,8 +852,15 @@ BufferDeallocation::handleInterface(MemoryEffectOpInterface op) {
         continue;
       }
 
-      state.updateOwnership(res, buildBoolValue(builder, op.getLoc(), true));
-      state.addMemrefToDeallocate(res, block);
+      if (options.isRelevantAllocOp(op)) {
+        state.updateOwnership(res, buildBoolValue(builder, op.getLoc(), true));
+        state.addMemrefToDeallocate(res, block);
+        continue;
+      }
+
+      // Alloc operations from other dialects are expected to have matching
+      // deallocation operations inserted by another pass.
+      state.updateOwnership(res, buildBoolValue(builder, op.getLoc(), false));
     }
   }
 
@@ -943,16 +963,18 @@ struct OwnershipBasedBufferDeallocationPass
     : public bufferization::impl::OwnershipBasedBufferDeallocationBase<
           OwnershipBasedBufferDeallocationPass> {
   OwnershipBasedBufferDeallocationPass() = default;
-  OwnershipBasedBufferDeallocationPass(const DeallocationOptions &options)
-      : OwnershipBasedBufferDeallocationPass() {
+  OwnershipBasedBufferDeallocationPass(const DeallocationOptions &options) {
     privateFuncDynamicOwnership.setValue(options.privateFuncDynamicOwnership);
     verifyFunctionBoundaryABI.setValue(options.verifyFunctionBoundaryABI);
+    removeExistingDeallocations.setValue(options.removeExistingDeallocations);
   }
   void runOnOperation() override {
     DeallocationOptions options;
     options.privateFuncDynamicOwnership =
         privateFuncDynamicOwnership.getValue();
     options.verifyFunctionBoundaryABI = verifyFunctionBoundaryABI.getValue();
+    options.removeExistingDeallocations =
+        removeExistingDeallocations.getValue();
 
     auto status = getOperation()->walk([&](func::FuncOp func) {
       if (func.isExternal())
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-mixed-allocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-mixed-allocations.mlir
new file mode 100644
index 000000000000000..6902d5fe23eee09
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-mixed-allocations.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --test-ownership-based-buffer-deallocation -split-input-file | FileCheck %s
+
+func.func @mixed_allocations(%cond: i1) -> (memref<f32>, !gpu.async.token) {
+  %a1 = memref.alloc() : memref<f32>
+  %a2 = gpu.alloc() : memref<f32>
+  %0 = arith.select %cond, %a1, %a2 : memref<f32>
+  %token = gpu.dealloc async [] %a2 : memref<f32>
+  memref.dealloc %a1 : memref<f32>
+  return %0, %token : memref<f32>, !gpu.async.token
+}
+
+// CHECK: [[A1:%.+]] = memref.alloc(
+// CHECK: [[A2:%.+]] = gpu.alloc
+// CHECK: [[SELECT:%.+]] = arith.select {{.*}}, [[A1]], [[A2]]
+// CHECK: [[TOKEN:%.+]] = gpu.wait async
+// CHECK: [[A1_BASE:%.+]],{{.*}} = memref.extract_strided_metadata [[A1]]
+// CHECK: [[A1_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[A1_BASE]]
+// CHECK: [[SELECT_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[SELECT]]
+// CHECK: [[ALIAS0:%.+]] = arith.cmpi ne, [[A1_PTR]], [[SELECT_PTR]]
+// CHECK: [[COND0:%.+]] = arith.andi [[ALIAS0]], %true
+// CHECK: scf.if [[COND0]] {
+// CHECK:   memref.dealloc [[A1_BASE]]
+// CHECK: }
+// CHECK: [[A2_BASE:%.+]],{{.*}} = memref.extract_strided_metadata [[A2]]
+// CHECK: [[A2_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[A2_BASE]]
+// CHECK: [[SELECT_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[SELECT]]
+// CHECK: [[ALIAS1:%.+]] = arith.cmpi ne, [[A2_PTR]], [[SELECT_PTR]]
+// CHECK: [[COND1:%.+]] = arith.andi [[ALIAS1]], %true
+// CHECK: scf.if [[COND1]] {
+// TODO: add pass option to lower-deallocation to insert gpu.dealloc here
+// CHECK:   memref.dealloc [[A2_BASE]]
+// CHECK: }
+// CHECK: return [[SELECT]], [[TOKEN]]
diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
index a183d02cefed7b5..b0175561fd609b9 100644
--- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
@@ -5,8 +5,12 @@ add_mlir_library(MLIRBufferizationTestPasses
   EXCLUDE_FROM_LIBMLIR
 
   LINK_LIBS PUBLIC
+  MLIRArithDialect
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
+  MLIRFuncDialect
+  MLIRGPUDialect
+  MLIRSCFDialect
   MLIRIR
   MLIRPass
 )
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOwnershipBasedBufferDeallocation.cpp b/mlir/test/lib/Dialect/Bufferization/TestOwnershipBasedBufferDeallocation.cpp
new file mode 100644
index 000000000000000..2b57dd94e571003
--- /dev/null
+++ b/mlir/test/lib/Dialect/Bufferization/TestOwnershipBasedBufferDeallocation.cpp
@@ -0,0 +1,128 @@
+//===- TestOwnershipBasedBufferDeallocation.cpp -----------------*- c++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+/// This pass runs the ownership based deallocation pass once for `memref.alloc`
+/// operations, then lowers the `bufferization.dealloc` operations, and
+/// afterwards runs the deallocation pass again for `gpu.alloc` operations and
+/// lowers the inserted `bufferization.dealloc` operations again to the
+/// corresponding deallocation operations.
+struct TestOwnershipBasedBufferDeallocationPass
+    : public PassWrapper<TestOwnershipBasedBufferDeallocationPass,
+                         OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestOwnershipBasedBufferDeallocationPass)
+
+  TestOwnershipBasedBufferDeallocationPass() = default;
+  TestOwnershipBasedBufferDeallocationPass(
+      const TestOwnershipBasedBufferDeallocationPass &pass)
+      : TestOwnershipBasedBufferDeallocationPass() {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+                    scf::SCFDialect, func::FuncDialect, arith::ArithDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-ownership-based-buffer-deallocation";
+  }
+  StringRef getDescription() const final {
+    return "Module pass to test the Ownership-based Buffer Deallocation pass";
+  }
+
+  void runOnOperation() override {
+    ModuleOp module = getOperation();
+
+    // Build the library function for the lowering of `bufferization.dealloc`.
+    OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
+    SymbolTable symbolTable(module);
+    func::FuncOp helper = bufferization::buildDeallocationLibraryFunction(
+        builder, module.getLoc(), symbolTable);
+
+    RewritePatternSet patterns(module->getContext());
+    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
+                                                               helper);
+    FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+    WalkResult result = getOperation()->walk([&](FunctionOpInterface funcOp) {
+      // Deallocate the `memref.alloc` operations.
+      bufferization::DeallocationOptions options;
+      options.removeExistingDeallocations = true;
+      if (failed(
+              bufferization::deallocateBuffersOwnershipBased(funcOp, options)))
+        return WalkResult::interrupt();
+
+      // Lower the inserted `bufferization.dealloc` operations.
+      ConversionTarget target(getContext());
+      target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
+                             scf::SCFDialect, func::FuncDialect>();
+      target.addIllegalOp<bufferization::DeallocOp>();
+
+      if (failed(applyPartialConversion(funcOp, target, frozenPatterns)))
+        return WalkResult::interrupt();
+
+      // Deallocate the `gpu.alloc` operations.
+      options.isRelevantAllocOp = [](Operation *op) {
+        return isa<gpu::GPUDialect>(op->getDialect());
+      };
+      options.isRelevantDeallocOp = [](Operation *op) {
+        return isa<gpu::GPUDialect>(op->getDialect());
+      };
+      options.getDeallocReplacement =
+          [](Operation *op) -> FailureOr<ValueRange> {
+        if (auto gpuDealloc = dyn_cast<gpu::DeallocOp>(op)) {
+          if (gpuDealloc.getAsyncToken()) {
+            OpBuilder builder(op);
+            ValueRange token =
+                builder
+                    .create<gpu::WaitOp>(
+                        op->getLoc(),
+                        gpu::AsyncTokenType::get(builder.getContext()),
+                        ValueRange{})
+                    .getResults();
+            return token;
+          }
+          return ValueRange{};
+        }
+        return failure();
+      };
+      if (failed(
+              bufferization::deallocateBuffersOwnershipBased(funcOp, options)))
+        return WalkResult::interrupt();
+
+      // Lower the `bufferization.dealloc` operations inserted in the second
+      // deallocation run.
+      // TODO: they are currently also lowered to memref.dealloc, we need to
+      // add pass options to the lowering pass that allow us to select the
+      // dealloc operation to be inserted.
+      if (failed(applyPartialConversion(funcOp, target, frozenPatterns)))
+        return WalkResult::interrupt();
+
+      return WalkResult::advance();
+    });
+    if (result.wasInterrupted())
+      signalPassFailure();
+  }
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestOwnershipBasedBufferDeallocationPass() {
+  PassRegistration<TestOwnershipBasedBufferDeallocationPass>();
+}
+} // namespace mlir::test
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index b7647d7de78a10e..87d8a2240be206e 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -121,6 +121,7 @@ void registerTestMemRefStrideCalculation();
 void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
+void registerTestOwnershipBasedBufferDeallocationPass();
 void registerTestPadFusion();
 void registerTestPDLByteCodePass();
 void registerTestPDLLPasses();
@@ -241,6 +242,7 @@ void registerTestPasses() {
   mlir::test::registerTestNextAccessPass();
   mlir::test::registerTestOneToNTypeConversionPass();
   mlir::test::registerTestOpaqueLoc();
+  mlir::test::registerTestOwnershipBasedBufferDeallocationPass();
   mlir::test::registerTestPadFusion();
   mlir::test::registerTestPDLByteCodePass();
   mlir::test::registerTestPDLLPasses();
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 192c9d156e5781e..2214e241693b43b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -3,9 +3,9 @@
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
+load("//llvm:lit_test.bzl", "package_path")
 load("//mlir:build_defs.bzl", "if_cuda_available")
 load("//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
-load("//llvm:lit_test.bzl", "package_path")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -843,10 +843,15 @@ cc_library(
     defines = ["MLIR_CUDA_CONVERSIONS_ENABLED"],
     includes = ["lib/Dialect/Test"],
     deps = [
+        "//mlir:ArithDialect",
         "//mlir:BufferizationDialect",
         "//mlir:BufferizationTransforms",
+        "//mlir:FuncDialect",
+        "//mlir:GPUDialect",
         "//mlir:IR",
         "//mlir:Pass",
+        "//mlir:SCFDialect",
+        "//mlir:Transforms",
     ],
 )
 

>From 80bfb75cc1b543d1c0ae8f463e5f628fea68d8fa Mon Sep 17 00:00:00 2001
From: Martin Erhart <merhart at google.com>
Date: Wed, 27 Sep 2023 14:44:39 +0000
Subject: [PATCH 3/3] [mlir][bufferization] Add option to LowerDeallocations to
 choose the kind of dealloc op to build

---
 .../Dialect/Bufferization/Transforms/Passes.h | 16 +++++-
 .../Transforms/LowerDeallocations.cpp         | 54 ++++++++++---------
 .../dealloc-mixed-allocations.mlir            |  4 +-
 .../Transforms/lower-deallocations.mlir       | 15 ++----
 .../TestOwnershipBasedBufferDeallocation.cpp  | 33 ++++++++----
 5 files changed, 73 insertions(+), 49 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 2bf82dd6f88c81c..59b3ab5c84e80de 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -18,6 +18,16 @@ class FuncOp;
 namespace bufferization {
 struct OneShotBufferizationOptions;
 
+/// Options for the LowerDeallocation pass and rewrite patterns.
+struct LowerDeallocationOptions {
+  /// Given a MemRef value, build the operation(s) necessary to properly
+  /// deallocate the value.
+  std::function<void(OpBuilder &, Location, Value)> buildDeallocOp =
+      [](OpBuilder &builder, Location loc, Value memref) {
+        builder.create<memref::DeallocOp>(loc, memref);
+      };
+};
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
@@ -41,12 +51,14 @@ std::unique_ptr<Pass> createBufferDeallocationSimplificationPass();
 
 /// Creates an instance of the LowerDeallocations pass to lower
 /// `bufferization.dealloc` operations to the `memref` dialect.
-std::unique_ptr<Pass> createLowerDeallocationsPass();
+std::unique_ptr<Pass> createLowerDeallocationsPass(
+    const LowerDeallocationOptions &options = LowerDeallocationOptions());
 
 /// Adds the conversion pattern of the `bufferization.dealloc` operation to the
 /// given pattern set for use in other transformation passes.
 void populateBufferizationDeallocLoweringPattern(
-    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
+    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc,
+    const LowerDeallocationOptions &options = LowerDeallocationOptions());
 
 /// Construct the library function needed for the fully generic
 /// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
index 982d9558d313260..c2922faac00ed5c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
@@ -66,7 +66,7 @@ class DeallocOpConversion
 
     rewriter.replaceOpWithNewOp<scf::IfOp>(
         op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) {
-          builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
+          options.buildDeallocOp(builder, loc, adaptor.getMemrefs()[0]);
           builder.create<scf::YieldOp>(loc);
         });
     return success();
@@ -133,7 +133,7 @@ class DeallocOpConversion
 
     rewriter.create<scf::IfOp>(
         op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
-          builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
+          options.buildDeallocOp(builder, loc, adaptor.getMemrefs()[0]);
           builder.create<scf::YieldOp>(loc);
         });
 
@@ -232,13 +232,13 @@ class DeallocOpConversion
     // Without storing them to memrefs, we could not use for-loops but only a
     // completely unrolled version of it, potentially leading to code-size
     // blow-up.
-    Value toDeallocMemref = rewriter.create<memref::AllocOp>(
+    Value toDeallocMemref = rewriter.create<memref::AllocaOp>(
         op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
                                      rewriter.getIndexType()));
-    Value conditionMemref = rewriter.create<memref::AllocOp>(
+    Value conditionMemref = rewriter.create<memref::AllocaOp>(
         op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()},
                                      rewriter.getI1Type()));
-    Value toRetainMemref = rewriter.create<memref::AllocOp>(
+    Value toRetainMemref = rewriter.create<memref::AllocaOp>(
         op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
                                      rewriter.getIndexType()));
 
@@ -285,10 +285,10 @@ class DeallocOpConversion
         MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
         toRetainMemref);
 
-    Value deallocCondsMemref = rewriter.create<memref::AllocOp>(
+    Value deallocCondsMemref = rewriter.create<memref::AllocaOp>(
         op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
                                      rewriter.getI1Type()));
-    Value retainCondsMemref = rewriter.create<memref::AllocOp>(
+    Value retainCondsMemref = rewriter.create<memref::AllocaOp>(
         op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
                                      rewriter.getI1Type()));
 
@@ -313,7 +313,7 @@ class DeallocOpConversion
           op.getLoc(), deallocCondsMemref, idxValue);
       rewriter.create<scf::IfOp>(
           op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
-            builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
+            options.buildDeallocOp(builder, loc, adaptor.getMemrefs()[i]);
             builder.create<scf::YieldOp>(loc);
           });
     }
@@ -326,22 +326,15 @@ class DeallocOpConversion
       replacements.push_back(ownership);
     }
 
-    // Deallocate above allocated memrefs again to avoid memory leaks.
-    // Deallocation will not be run on code after this stage.
-    rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
-    rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
-    rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref);
-    rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref);
-    rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref);
-
     rewriter.replaceOp(op, replacements);
     return success();
   }
 
 public:
-  DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
+  DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc,
+                      const bufferization::LowerDeallocationOptions &options)
       : OpConversionPattern<bufferization::DeallocOp>(context),
-        deallocHelperFunc(deallocHelperFunc) {}
+        deallocHelperFunc(deallocHelperFunc), options(options) {}
 
   LogicalResult
   matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
@@ -371,6 +364,7 @@ class DeallocOpConversion
 
 private:
   func::FuncOp deallocHelperFunc;
+  const bufferization::LowerDeallocationOptions options;
 };
 } // namespace
 
@@ -378,6 +372,13 @@ namespace {
 struct LowerDeallocationsPass
     : public bufferization::impl::LowerDeallocationsBase<
           LowerDeallocationsPass> {
+  LowerDeallocationsPass() = default;
+  LowerDeallocationsPass(const LowerDeallocationsPass &other)
+      : LowerDeallocationsPass(other.options) {}
+  explicit LowerDeallocationsPass(
+      const bufferization::LowerDeallocationOptions &options)
+      : options(options) {}
+
   void runOnOperation() override {
     if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
       emitError(getOperation()->getLoc(),
@@ -404,8 +405,8 @@ struct LowerDeallocationsPass
     }
 
     RewritePatternSet patterns(&getContext());
-    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
-                                                               helperFuncOp);
+    bufferization::populateBufferizationDeallocLoweringPattern(
+        patterns, helperFuncOp, options);
 
     ConversionTarget target(getContext());
     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
@@ -416,6 +417,8 @@ struct LowerDeallocationsPass
                                       std::move(patterns))))
       signalPassFailure();
   }
+
+  const bufferization::LowerDeallocationOptions options;
 };
 } // namespace
 
@@ -536,10 +539,13 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
 }
 
 void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
-    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
-  patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc);
+    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc,
+    const LowerDeallocationOptions &options) {
+  patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc,
+                                    options);
 }
 
-std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {
-  return std::make_unique<LowerDeallocationsPass>();
+std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass(
+    const LowerDeallocationOptions &options) {
+  return std::make_unique<LowerDeallocationsPass>(options);
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-mixed-allocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-mixed-allocations.mlir
index 6902d5fe23eee09..186e161d2160f3f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-mixed-allocations.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-mixed-allocations.mlir
@@ -27,7 +27,7 @@ func.func @mixed_allocations(%cond: i1) -> (memref<f32>, !gpu.async.token) {
 // CHECK: [[ALIAS1:%.+]] = arith.cmpi ne, [[A2_PTR]], [[SELECT_PTR]]
 // CHECK: [[COND1:%.+]] = arith.andi [[ALIAS1]], %true
 // CHECK: scf.if [[COND1]] {
-// TODO: add pass option to lower-deallocation to insert gpu.dealloc here
-// CHECK:   memref.dealloc [[A2_BASE]]
+// CHECK:   [[T:%.+]] = gpu.dealloc async [[A2_BASE]]
+// CHECK:   gpu.wait [[[T]]]
 // CHECK: }
 // CHECK: return [[SELECT]], [[TOKEN]]
diff --git a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
index 2c69fcab08a8d6a..b9d82502e6c4f14 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
@@ -72,9 +72,9 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
 // CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<5xf32>,
 // CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1,
 // CHECK-SAME: [[ARG5:%.+]]: memref<2xf32>)
-//      CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloc() : memref<2xindex>
-//      CHECK: [[CONDS:%.+]] = memref.alloc() : memref<2xi1>
-//      CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<2xindex>
+//      CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloca() : memref<2xindex>
+//      CHECK: [[CONDS:%.+]] = memref.alloca() : memref<2xi1>
+//      CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloca() : memref<2xindex>
 //  CHECK-DAG: [[V0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
 //  CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
 //  CHECK-DAG: memref.store [[V0]], [[TO_DEALLOC_MR]][[[C0]]]
@@ -94,8 +94,8 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
 //  CHECK-DAG: [[CAST_DEALLOC:%.+]] = memref.cast [[TO_DEALLOC_MR]] : memref<2xindex> to memref<?xindex>
 //  CHECK-DAG: [[CAST_CONDS:%.+]] = memref.cast [[CONDS]] : memref<2xi1> to memref<?xi1>
 //  CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<2xindex> to memref<?xindex>
-//      CHECK: [[DEALLOC_CONDS:%.+]] = memref.alloc() : memref<2xi1>
-//      CHECK: [[RETAIN_CONDS:%.+]] = memref.alloc() : memref<2xi1>
+//      CHECK: [[DEALLOC_CONDS:%.+]] = memref.alloca() : memref<2xi1>
+//      CHECK: [[RETAIN_CONDS:%.+]] = memref.alloca() : memref<2xi1>
 //      CHECK: [[CAST_DEALLOC_CONDS:%.+]] = memref.cast [[DEALLOC_CONDS]] : memref<2xi1> to memref<?xi1>
 //      CHECK: [[CAST_RETAIN_CONDS:%.+]] = memref.cast [[RETAIN_CONDS]] : memref<2xi1> to memref<?xi1>
 //      CHECK: call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[CAST_CONDS]], [[CAST_DEALLOC_CONDS]], [[CAST_RETAIN_CONDS]])
@@ -113,11 +113,6 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
 //      CHECK: [[OWNERSHIP0:%.+]] = memref.load [[RETAIN_CONDS]][[[C0]]]
 //      CHECK: [[C1:%.+]] = arith.constant 1 : index
 //      CHECK: [[OWNERSHIP1:%.+]] = memref.load [[RETAIN_CONDS]][[[C1]]]
-//      CHECK: memref.dealloc [[TO_DEALLOC_MR]]
-//      CHECK: memref.dealloc [[TO_RETAIN_MR]]
-//      CHECK: memref.dealloc [[CONDS]]
-//      CHECK: memref.dealloc [[DEALLOC_CONDS]]
-//      CHECK: memref.dealloc [[RETAIN_CONDS]]
 //      CHECK: return [[OWNERSHIP0]], [[OWNERSHIP1]]
 
 //      CHECK: func private @dealloc_helper
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOwnershipBasedBufferDeallocation.cpp b/mlir/test/lib/Dialect/Bufferization/TestOwnershipBasedBufferDeallocation.cpp
index 2b57dd94e571003..05fe2d52adad804 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestOwnershipBasedBufferDeallocation.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestOwnershipBasedBufferDeallocation.cpp
@@ -54,11 +54,6 @@ struct TestOwnershipBasedBufferDeallocationPass
     func::FuncOp helper = bufferization::buildDeallocationLibraryFunction(
         builder, module.getLoc(), symbolTable);
 
-    RewritePatternSet patterns(module->getContext());
-    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
-                                                               helper);
-    FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-
     WalkResult result = getOperation()->walk([&](FunctionOpInterface funcOp) {
       // Deallocate the `memref.alloc` operations.
       bufferization::DeallocationOptions options;
@@ -70,10 +65,16 @@ struct TestOwnershipBasedBufferDeallocationPass
       // Lower the inserted `bufferization.dealloc` operations.
       ConversionTarget target(getContext());
       target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
-                             scf::SCFDialect, func::FuncDialect>();
+                             scf::SCFDialect, func::FuncDialect,
+                             gpu::GPUDialect>();
       target.addIllegalOp<bufferization::DeallocOp>();
 
-      if (failed(applyPartialConversion(funcOp, target, frozenPatterns)))
+      RewritePatternSet memrefPatterns(module->getContext());
+      bufferization::LowerDeallocationOptions loweringOptions;
+      bufferization::populateBufferizationDeallocLoweringPattern(memrefPatterns,
+                                                                 helper);
+      if (failed(applyPartialConversion(funcOp, target,
+                                        std::move(memrefPatterns))))
         return WalkResult::interrupt();
 
       // Deallocate the `gpu.alloc` operations.
@@ -107,10 +108,20 @@ struct TestOwnershipBasedBufferDeallocationPass
 
       // Lower the `bufferization.dealloc` operations inserted in the second
       // deallocation run.
-      // TODO: they are currently also lowered to memref.dealloc, we need to
-      // add pass options to the lowering pass that allow us to select the
-      // dealloc operation to be inserted.
-      if (failed(applyPartialConversion(funcOp, target, frozenPatterns)))
+      RewritePatternSet gpuPatterns(module->getContext());
+      loweringOptions.buildDeallocOp = [](OpBuilder &builder, Location loc,
+                                          Value memref) {
+        Value token =
+            builder
+                .create<gpu::DeallocOp>(
+                    loc, gpu::AsyncTokenType::get(builder.getContext()), memref)
+                .getResult(0);
+        builder.create<gpu::WaitOp>(loc, Type(), token);
+      };
+      bufferization::populateBufferizationDeallocLoweringPattern(
+          gpuPatterns, helper, loweringOptions);
+      if (failed(
+              applyPartialConversion(funcOp, target, std::move(gpuPatterns))))
         return WalkResult::interrupt();
 
       return WalkResult::advance();



More information about the llvm-commits mailing list