[Mlir-commits] [mlir] 01334d1 - [mlir][bufferization] Add an ownership based buffer deallocation pass (#66337)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 14 03:13:42 PDT 2023
Author: Martin Erhart
Date: 2023-09-14T12:13:37+02:00
New Revision: 01334d1abb7a54b18f5f45fe0589cc6136613023
URL: https://github.com/llvm/llvm-project/commit/01334d1abb7a54b18f5f45fe0589cc6136613023
DIFF: https://github.com/llvm/llvm-project/commit/01334d1abb7a54b18f5f45fe0589cc6136613023.diff
LOG: [mlir][bufferization] Add an ownership based buffer deallocation pass (#66337)
Add a new Buffer Deallocation pass with the intend to replace the old
one. For now it is added as a separate pass alongside in order to allow
downstream users to migrate over gradually. This new pass has the goal
of inserting fewer clone operations and supporting additional use-cases.
Please refer to the Buffer Deallocation section in the updated
Bufferization.md file for more information on how this new pass works.
Added:
mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-existing-deallocs.mlir
mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-function-boundaries.mlir
mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir
mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-subviews.mlir
mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/invalid-buffer-deallocation.mlir
Modified:
mlir/docs/Bufferization.md
mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md
index f03d7bb877c9c74..09bec06743c7a65 100644
--- a/mlir/docs/Bufferization.md
+++ b/mlir/docs/Bufferization.md
@@ -224,6 +224,9 @@ dialect conversion-based bufferization.
## Buffer Deallocation
+**Important: this pass is deprecated, please use the ownership based buffer**
+**deallocation pass instead**
+
One-Shot Bufferize deallocates all buffers that it allocates. This is in
contrast to the dialect conversion-based bufferization that delegates this job
to the
@@ -300,6 +303,607 @@ One-Shot Bufferize can be configured to leak all memory and not generate any
buffer deallocations with `create-deallocs=0`. This can be useful for
compatibility with legacy code that has its own method of deallocating buffers.
+## Ownership-based Buffer Deallocation
+
+Recommended compilation pipeline:
+```
+one-shot-bufferize
+ | it's recommended to perform all bufferization here at latest,
+ | <- any allocations inserted after this point have to be handled
+ V manually
+expand-realloc
+ V
+ownership-based-buffer-deallocation
+ V
+ canonicalize <- mostly for scf.if simplifications
+ V
+buffer-deallocation-simplification
+ V <- from this point onwards no tensor values are allowed
+lower-deallocations
+ V
+ CSE
+ V
+ canonicalize
+```
+
+One-Shot Bufferize does not deallocate any buffers that it allocates. This job
+is delegated to the
+[`-ownership-based-buffer-deallocation`](https://mlir.llvm.org/docs/Passes/#-ownership-based-buffer-deallocation)
+pass, i.e., after running One-Shot Bufferize, the result IR may have a number of
+`memref.alloc` ops, but no `memref.dealloc` ops. This pass processes operations
+implementing `FunctionOpInterface` one-by-one without analysing the call-graph.
+This means, that there have to be [some rules](#function-boundary-abi) on how
+MemRefs are handled when being passed from one function to another. The rest of
+the pass revolves heavily around the `bufferization.dealloc` operation which is
+inserted at the end of each basic block with appropriate operands and should be
+optimized using the Buffer Deallocation Simplification pass
+(`--buffer-deallocation-simplification`) and the regular canonicalizer
+(`--canonicalize`). Lowering the result of the
+`-ownership-based-buffer-deallocation` pass directly using
+`--convert-bufferization-to-memref` without beforehand optimization is not
+recommended as it will lead to very inefficient code (the runtime-cost of
+`bufferization.dealloc` is `O(|memrefs|^2+|memref|*|retained|)`).
+
+### Function boundary ABI
+
+The Buffer Deallocation pass operates on the level of operations implementing
+the `FunctionOpInterface`. Such operations can take MemRefs as arguments, but
+also return them. To ensure compatibility among all functions (including
+external ones), some rules have to be enforced:
+* When a MemRef is passed as a function argument, ownership is never acquired.
+ It is always the caller's responsibility to deallocate such MemRefs.
+* Returning a MemRef from a function always passes ownership to the caller,
+ i.e., it is also the caller's responsibility to deallocate memrefs returned
+ from a called function.
+* A function must not return a MemRef with the same allocated base buffer as
+ one of its arguments (in this case a copy has to be created). Note that in
+ this context two subviews of the same buffer that don't overlap are also
+ considered to alias.
+
+For external functions (e.g., library functions written externally in C), the
+externally provided implementation has to adhere to these rules and they are
+just assumed by the buffer deallocation pass. Functions on which the
+deallocation pass is applied and the implementation is accessible are modified
+by the pass such that the ABI is respected (i.e., buffer copies are inserted as
+necessary).
+
+### Inserting `bufferization.dealloc` operations
+
+`bufferization.dealloc` operations are unconditionally inserted at the end of
+each basic block (just before the terminator). The majority of the pass is about
+finding the correct operands for this operation. There are three variadic
+operand lists to be populated, the first contains all MemRef values that may
+need to be deallocated, the second list contains their associated ownership
+values (of `i1` type), and the third list contains MemRef values that are still
+needed at a later point and should thus not be deallocated. This operation
+allows us to deal with any kind of aliasing behavior: it lowers to runtime
+aliasing checks when not enough information can be collected statically. When
+enough aliasing information is statically available, operands or the entire op
+may fold away.
+
+**Ownerships**
+
+To do so, we use a concept of ownership indicators of memrefs which materialize
+as an `i1` value for any SSA value of `memref` type, indicating whether the
+basic block in which it was materialized has ownership of this MemRef. Ideally,
+this is a constant `true` or `false`, but might also be a non-constant SSA
+value. To keep track of those ownership values without immediately materializing
+them (which might require insertion of `bufferization.clone` operations or
+operations checking for aliasing at runtime at positions where we don't actually
+need a materialized value), we use the `Ownership` class. This class represents
+the ownership in three states forming a lattice on a partial order:
+```
+forall X in SSA values. uninitialized < unique(X) < unknown
+forall X, Y in SSA values.
+ unique(X) == unique(Y) iff X and Y always evaluate to the same value
+ unique(X) != unique(Y) otherwise
+```
+Intuitively, the states have the following meaning:
+* Uninitialized: the ownership is not initialized yet, this is the default
+ state; once an operation is finished processing the ownership of all
+ operation results with MemRef type should not be uninitialized anymore.
+* Unique: there is a specific SSA value that can be queried to check ownership
+ without materializing any additional IR
+* Unknown: no specific SSA value is available without materializing additional
+ IR, typically this is because two ownerships in 'Unique' state would have to
+ be merged manually (e.g., the result of an `arith.select` either has the
+ ownership of the then or else case depending on the condition value,
+ inserting another `arith.select` for the ownership values can perform the
+ merge and provide a 'Unique' ownership for the result), however, in the
+ general case this 'Unknown' state has to be assigned.
+
+Implied by the above partial order, the pass combines two ownerships in the
+following way:
+
+| Ownership 1 | Ownership 2 | Combined Ownership |
+|:--------------|:--------------|:-------------------|
+| uninitialized | uninitialized | uninitialized |
+| unique(X) | uninitialized | unique(X) |
+| unique(X) | unique(X) | unique(X) |
+| unique(X) | unique(Y) | unknown |
+| unknown | unique | unknown |
+| unknown | uninitialized | unknown |
+| <td colspan=3> + symmetric cases |
+
+**Collecting the list of MemRefs that potentially need to be deallocated**
+
+For a given block, the list of MemRefs that potentially need to be deallocated
+at the end of that block is computed by keeping track of all values for which
+the block potentially takes over ownership. This includes MemRefs provided as
+basic block arguments, interface handlers for operations like `memref.alloc` and
+`func.call`, but also liveness information in regions with multiple basic
+blocks. More concretely, it is computed by taking the MemRefs in the 'in' set
+of the liveness analysis of the current basic block B, appended by the MemRef
+block arguments and by the set of MemRefs allocated in B itself (determined by
+the interface handlers), then subtracted (also determined by the interface
+handlers) by the set of MemRefs deallocated in B.
+
+Note that we don't have to take the intersection of the liveness 'in' set with
+the 'out' set of the predecessor block because a value that is in the 'in' set
+must be defined in an ancestor block that dominates all direct predecessors and
+thus the 'in' set of this block is a subset of the 'out' sets of each
+predecessor.
+
+```
+memrefs = filter((liveIn(block) U
+ allocated(block) U arguments(block)) \ deallocated(block), isMemRef)
+```
+
+The list of conditions for the second variadic operands list of
+`bufferization.dealloc` is computed by querying the stored ownership value for
+each of the MemRefs collected as described above. The ownership state is updated
+by the interface handlers while processing the basic block.
+
+**Collecting the list of MemRefs to retain**
+
+Given a basic block B, the list of MemRefs that have to be retained can be
+
diff erent for each successor block S. For the two basic blocks B and S and the
+values passed via block arguments to the destination block S, we compute the
+list of MemRefs that have to be retained in B by taking the MemRefs in the
+successor operand list of the terminator and the MemRefs in the 'out' set of the
+liveness analysis for B intersected with the 'in' set of the destination block
+S.
+
+This list of retained values makes sure that we cannot run into use-after-free
+situations even if no aliasing information is present at compile-time.
+
+```
+toRetain = filter(successorOperands + (liveOut(fromBlock) insersect
+ liveIn(toBlock)), isMemRef)
+```
+
+### Supported interfaces
+
+The pass uses liveness analysis and a few interfaces:
+* `FunctionOpInterface`
+* `CallOpInterface`
+* `MemoryEffectOpInterface`
+* `RegionBranchOpInterface`
+* `RegionBranchTerminatorOpInterface`
+
+Due to insufficient information provided by the interface, it also special-cases
+on the `cf.cond_br` operation and makes some assumptions about operations
+implementing the `RegionBranchOpInterface` at the moment, but improving the
+interfaces would allow us to remove those dependencies in the future.
+
+### Limitations
+
+The Buffer Deallocation pass has some requirements and limitations on the input
+IR. These are checked in the beginning of the pass and errors are emitted
+accordingly:
+* The set of interfaces the pass operates on must be implemented (correctly).
+ E.g., if there is an operation present with a nested region, but does not
+ implement the `RegionBranchOpInterface`, an error is emitted because the
+ pass cannot know the semantics of the nested region (and does not make any
+ default assumptions on it).
+* No explicit control-flow loops are present. Currently, only loops using
+ structural-control-flow are supported. However, this limitation could be
+ lifted in the future.
+* Deallocation operations should not be present already. The pass should
+ handle them correctly already (at least in most cases), but it's not
+ supported yet due to insufficient testing.
+* Terminators must implement either `RegionBranchTerminatorOpInterface` or
+ `BranchOpInterface`, but not both. Terminators with more than one successor
+ are not supported (except `cf.cond_br`). This is not a fundamental
+ limitation, but there is no use-case justifying the more complex
+ implementation at the moment.
+
+### Example
+
+The following example contains a few interesting cases:
+* Basic block arguments are modified to also pass along the ownership
+ indicator, but not for entry bocks of non-private functions (assuming the
+ `private-function-dynamic-ownership` pass option is disabled) where the
+ function boundary ABI is applied instead. "Private" in this context refers
+ to functions that cannot be called externally.
+* The result of `arith.select` initially has 'Unknown' assigned as ownership,
+ but once the `bufferization.dealloc` operation is inserted it is put in the
+ 'retained' list (since it has uses in a later basic block) and thus the
+ 'Unknown' ownership can be replaced with a 'Unique' ownership using the
+ corresponding result of the dealloc operation.
+* The `cf.cond_br` operation has more than one successor and thus has to
+ insert two `bufferization.dealloc` operations (one for each successor).
+ While they have the same list of MemRefs to deallocate (because they perform
+ the deallocations for the same block), it must be taken into account that
+ some MemRefs remain *live* for one branch but not the other (thus set
+ intersection is performed on the *live-out* of the current block and the
+ *live-in* of the target block). Also, `cf.cond_br` supports separate
+ forwarding operands for each successor. To make sure that no MemRef is
+ deallocated twice (because there are two `bufferization.dealloc` operations
+ with the same MemRefs to deallocate), the condition operands are adjusted to
+ take the branch condition into account. While a generic lowering for such
+ terminator operations could be implemented, a specialized implementation can
+ take all the semantics of this particular operation into account and thus
+ generate a more efficient lowering.
+
+```mlir
+func.func @example(%memref: memref<?xi8>, %select_cond: i1, %br_cond: i1) {
+ %alloc = memref.alloc() : memref<?xi8>
+ %alloca = memref.alloca() : memref<?xi8>
+ %select = arith.select %select_cond, %alloc, %alloca : memref<?xi8>
+ cf.cond_br %br_cond, ^bb1(%alloc : memref<?xi8>), ^bb1(%memref : memref<?xi8>)
+^bb1(%bbarg: memref<?xi8>):
+ test.copy(%bbarg, %select) : (memref<?xi8>, memref<?xi8>)
+ return
+}
+```
+
+After running `--ownership-based-buffer-deallocation`, it looks as follows:
+
+```mlir
+// Since this is not a private function, the signature will not be modified even
+// when private-function-dynamic-ownership is enabled. Instead the function
+// boundary ABI has to be applied which means that ownership of `%memref` will
+// never be acquired.
+func.func @example(%memref: memref<?xi8>, %select_cond: i1, %br_cond: i1) {
+ %false = arith.constant false
+ %true = arith.constant true
+
+ // The ownership of a MemRef defined by the `memref.alloc` operation is always
+ // assigned to be 'true'.
+ %alloc = memref.alloc() : memref<?xi8>
+
+ // The ownership of a MemRef defined by the `memref.alloca` operation is
+ // always assigned to be 'false'.
+ %alloca = memref.alloca() : memref<?xi8>
+
+ // The ownership of %select will be the join of the ownership of %alloc and
+ // the ownership of %alloca, i.e., of %true and %false. Because the pass does
+ // not know about the semantics of the `arith.select` operation (unless a
+ // custom handler is implemented), the ownership join will be 'Unknown'. If
+ // the materialized ownership indicator of %select is needed, either a clone
+ // has to be created for which %true is assigned as ownership or the result
+ // of a `bufferization.dealloc` where %select is in the retain list has to be
+ // used.
+ %select = arith.select %select_cond, %alloc, %alloca : memref<?xi8>
+
+ // We use `memref.extract_strided_metadata` to get the base memref since it is
+ // not allowed to pass arbitrary memrefs to `memref.dealloc`. This property is
+ // already enforced for `bufferization.dealloc`
+ %base_buffer_memref, ... = memref.extract_strided_metadata %memref
+ : memref<?xi8> -> memref<i8>, index, index, index
+ %base_buffer_alloc, ... = memref.extract_strided_metadata %alloc
+ : memref<?xi8> -> memref<i8>, index, index, index
+ %base_buffer_alloca, ... = memref.extract_strided_metadata %alloca
+ : memref<?xi8> -> memref<i8>, index, index, index
+
+ // The deallocation conditions need to be adjusted to incorporate the branch
+ // condition. In this example, this requires only a single negation, but might
+ // also require multiple arith.andi operations.
+ %not_br_cond = arith.xori %true, %br_cond : i1
+
+ // There are two dealloc operations inserted in this basic block, one per
+ // successor. Both have the same list of MemRefs to deallocate and the
+ // conditions only
diff er by the branch condition conjunct.
+ // Note, however, that the retained list
diff ers. Here, both contain the
+ // %select value because it is used in both successors (since it's the same
+ // block), but the value passed via block argument
diff ers (%memref vs.
+ // %alloc).
+ %10:2 = bufferization.dealloc
+ (%base_buffer_memref, %base_buffer_alloc, %base_buffer_alloca
+ : memref<i8>, memref<i8>, memref<i8>)
+ if (%false, %br_cond, %false)
+ retain (%alloc, %select : memref<?xi8>, memref<?xi8>)
+
+ %11:2 = bufferization.dealloc
+ (%base_buffer_memref, %base_buffer_alloc, %base_buffer_alloca
+ : memref<i8>, memref<i8>, memref<i8>)
+ if (%false, %not_br_cond, %false)
+ retain (%memref, %select : memref<?xi8>, memref<?xi8>)
+
+ // Because %select is used in ^bb1 without passing it via block argument, we
+ // need to update it's ownership value here by merging the ownership values
+ // returned by the dealloc operations
+ %new_ownership = arith.select %br_cond, %10#1, %11#1 : i1
+
+ // The terminator is modified to pass along the ownership indicator values
+ // with each MemRef value.
+ cf.cond_br %br_cond, ^bb1(%alloc, %10#0 : memref<?xi8>, i1),
+ ^bb1(%memref, %11#0 : memref<?xi8>, i1)
+
+// All non-entry basic blocks are modified to have an additional i1 argument for
+// each MemRef value in the argument list.
+^bb1(%13: memref<?xi8>, %14: i1): // 2 preds: ^bb0, ^bb0
+ test.copy(%13, %select) : (memref<?xi8>, memref<?xi8>)
+
+ %base_buffer_13, ... = memref.extract_strided_metadata %13
+ : memref<?xi8> -> memref<i8>, index, index, index
+ %base_buffer_select, ... = memref.extract_strided_metadata %select
+ : memref<?xi8> -> memref<i8>, index, index, index
+
+ // Here, we don't have a retained list, because the block has no successors
+ // and the return has no operands.
+ bufferization.dealloc (%base_buffer_13, %base_buffer_select
+ : memref<i8>, memref<i8>)
+ if (%14, %new_ownership)
+ return
+}
+```
+
+## Buffer Deallocation Simplification Pass
+
+The [semantics of the `bufferization.dealloc` operation](https://mlir.llvm.org/docs/Dialects/BufferizationOps/#bufferizationdealloc-bufferizationdeallocop)
+provide a lot of opportunities for optimizations which can be conveniently split
+into patterns using the greedy pattern rewriter. Some of those patterns need
+access to additional analyses such as an analysis that can determine whether two
+MemRef values must, may, or never originate from the same buffer allocation.
+These patterns are collected in the Buffer Deallocation Simplification pass,
+while patterns that don't need additional analyses are registered as part of the
+regular canonicalizer pass. This pass is best run after
+`--ownership-based-buffer-deallocation` followed by `--canonicalize`.
+
+The pass applies patterns for the following simplifications:
+* Remove MemRefs from retain list when guaranteed to not alias with any value
+ in the 'memref' operand list. This avoids an additional aliasing check with
+ the removed value.
+* Split off values in the 'memref' list to new `bufferization.dealloc`
+ operations only containing this value in the 'memref' list when it is
+ guaranteed to not alias with any other value in the 'memref' list. This
+ avoids at least one aliasing check at runtime and enables using a more
+ efficient lowering for this new `bufferization.dealloc` operation.
+* Remove values from the 'memref' operand list when it is guaranteed to alias
+ with at least one value in the 'retained' list and may not alias any other
+ value in the 'retain' list.
+
+## Lower Deallocations Pass
+
+The `-lower-deallocations` pass transforms all `bufferization.dealloc`
+operations to `memref.dealloc` operations and may also insert operations from
+the `scf`, `func`, and `arith` dialects to make deallocations conditional and
+check whether two MemRef values come from the same allocation at runtime (when
+the `buffer-deallocation-simplification` pass wasn't able to determine it
+statically).
+
+The same lowering of the `bufferization.dealloc` operation is also part of the
+`-convert-bufferization-to-memref` conversion pass which also lowers all the
+other operations of the bufferization dialect.
+
+We distinguish multiple cases in this lowering pass to provide an overall more
+efficient lowering. In the general case, a library function is created to avoid
+quadratic code size explosion (relative to the number of operands of the dealloc
+operation). The specialized lowerings aim to avoid this library function because
+it requires allocating auxiliary MemRefs of index values.
+
+### Generic Lowering
+
+A library function is generated to avoid code-size blow-up. On a high level, the
+base-memref of all operands is extracted as an index value and stored into
+specifically allocated MemRefs and passed to the library function which then
+determines whether they come from the same original allocation. This information
+is needed to avoid double-free situations and to correctly retain the MemRef
+values in the `retained` list.
+
+**Dealloc Operation Lowering**
+
+This lowering supports all features the dealloc operation has to offer. It
+computes the base pointer of each memref (as an index), stores it in a
+new memref helper structure and passes it to the helper function generated
+in `buildDeallocationLibraryFunction`. The results are stored in two lists
+(represented as MemRefs) of booleans passed as arguments. The first list
+stores whether the corresponding condition should be deallocated, the
+second list stores the ownership of the retained values which can be used
+to replace the result values of the `bufferization.dealloc` operation.
+
+Example:
+```
+%0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>)
+ if (%cond0, %cond1)
+ retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
+```
+lowers to (simplified):
+```
+%c0 = arith.constant 0 : index
+%c1 = arith.constant 1 : index
+%dealloc_base_pointer_list = memref.alloc() : memref<2xindex>
+%cond_list = memref.alloc() : memref<2xi1>
+%retain_base_pointer_list = memref.alloc() : memref<2xindex>
+%m0_base_pointer = memref.extract_aligned_pointer_as_index %m0
+memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0]
+%m1_base_pointer = memref.extract_aligned_pointer_as_index %m1
+memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1]
+memref.store %cond0, %cond_list[%c0]
+memref.store %cond1, %cond_list[%c1]
+%r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
+memref.store %r0_base_pointer, %retain_base_pointer_list[%c0]
+%r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
+memref.store %r1_base_pointer, %retain_base_pointer_list[%c1]
+%dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list :
+ memref<2xindex> to memref<?xindex>
+%dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1>
+%dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list :
+ memref<2xindex> to memref<?xindex>
+%dealloc_cond_out = memref.alloc() : memref<2xi1>
+%ownership_out = memref.alloc() : memref<2xi1>
+%dyn_dealloc_cond_out = memref.cast %dealloc_cond_out :
+ memref<2xi1> to memref<?xi1>
+%dyn_ownership_out = memref.cast %ownership_out :
+ memref<2xi1> to memref<?xi1>
+call @dealloc_helper(%dyn_dealloc_base_pointer_list,
+ %dyn_retain_base_pointer_list,
+ %dyn_cond_list,
+ %dyn_dealloc_cond_out,
+ %dyn_ownership_out) : (...)
+%m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1>
+scf.if %m0_dealloc_cond {
+ memref.dealloc %m0 : memref<2xf32>
+}
+%m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1>
+scf.if %m1_dealloc_cond {
+ memref.dealloc %m1 : memref<5xf32>
+}
+%r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1>
+%r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1>
+memref.dealloc %dealloc_base_pointer_list : memref<2xindex>
+memref.dealloc %retain_base_pointer_list : memref<2xindex>
+memref.dealloc %cond_list : memref<2xi1>
+memref.dealloc %dealloc_cond_out : memref<2xi1>
+memref.dealloc %ownership_out : memref<2xi1>
+// replace %0#0 with %r0_ownership
+// replace %0#1 with %r1_ownership
+```
+
+**Library function**
+
+A library function is built per compilation unit that can be called at
+bufferization dealloc sites to determine whether two MemRefs come from the same
+allocation and their new ownerships.
+
+The generated function takes two MemRefs of indices and three MemRefs of
+booleans as arguments:
+ * The first argument A should contain the result of the
+ extract_aligned_pointer_as_index operation applied to the MemRefs to be
+ deallocated
+ * The second argument B should contain the result of the
+ extract_aligned_pointer_as_index operation applied to the MemRefs to be
+ retained
+ * The third argument C should contain the conditions as passed directly
+ to the deallocation operation.
+ * The fourth argument D is used to pass results to the caller. Those
+ represent the condition under which the MemRef at the corresponding
+ position in A should be deallocated.
+ * The fifth argument E is used to pass results to the caller. It
+ provides the ownership value corresponding the the MemRef at the same
+ position in B
+
+This helper function is supposed to be called once for each
+`bufferization.dealloc` operation to determine the deallocation need and
+new ownership indicator for the retained values, but does not perform the
+deallocation itself.
+
+Generated code:
+```
+func.func @dealloc_helper(
+ %dyn_dealloc_base_pointer_list: memref<?xindex>,
+ %dyn_retain_base_pointer_list: memref<?xindex>,
+ %dyn_cond_list: memref<?xi1>,
+ %dyn_dealloc_cond_out: memref<?xi1>,
+ %dyn_ownership_out: memref<?xi1>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %true = arith.constant true
+ %false = arith.constant false
+ %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0
+ %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0
+ // Zero initialize result buffer.
+ scf.for %i = %c0 to %num_retain_memrefs step %c1 {
+ memref.store %false, %dyn_ownership_out[%i] : memref<?xi1>
+ }
+ scf.for %i = %c0 to %num_dealloc_memrefs step %c1 {
+ %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i]
+ %cond = memref.load %dyn_cond_list[%i]
+ // Check for aliasing with retained memrefs.
+ %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs
+ step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) {
+ %retain_bp = memref.load %dyn_retain_base_pointer_list[%j]
+ %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index
+ scf.if %does_alias {
+ %curr_ownership = memref.load %dyn_ownership_out[%j]
+ %updated_ownership = arith.ori %curr_ownership, %cond : i1
+ memref.store %updated_ownership, %dyn_ownership_out[%j]
+ }
+ %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index
+ %updated_aggregate = arith.andi %does_not_alias_aggregated,
+ %does_not_alias : i1
+ scf.yield %updated_aggregate : i1
+ }
+ // Check for aliasing with dealloc memrefs in the list before the
+ // current one, i.e.,
+ // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j],
+ // %dyn_dealloc_base_pointer[i])`
+ %does_not_alias_any = scf.for %j = %c0 to %i step %c1
+ iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) {
+ %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j]
+ %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp
+ %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias
+ scf.yield %updated_alias_agg : i1
+ }
+ %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1
+ memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref<?xi1>
+ }
+ return
+}
+```
+
+### Specialized Lowerings
+
+Currently, there are two special lowerings for common cases to avoid the library
+function and thus unnecessary memory load and store operations and function
+calls:
+
+**One memref, no retained**
+
+Lower a simple case without any retained values and a single MemRef. Ideally,
+static analysis can provide enough information such that the
+`buffer-deallocation-simplification` pass is able to split the dealloc
+operations up into this simple case as much as possible before running this
+pass.
+
+Example:
+```mlir
+bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+```
+is lowered to
+```mlir
+scf.if %arg1 {
+ memref.dealloc %arg0 : memref<2xf32>
+}
+```
+
+In most cases, the branch condition is either constant 'true' or 'false' and can
+thus be optimized away entirely by the canonicalizer pass.
+
+**One memref, arbitrarily many retained**
+
+A special case lowering for the deallocation operation with exactly one MemRef,
+but an arbitrary number of retained values. The size of the code produced by
+this lowering is linear to the number of retained values.
+
+Example:
+```mlir
+%0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond)
+ retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
+return %0#0, %0#1 : i1, i1
+```
+is lowered to
+```mlir
+%m_base_pointer = memref.extract_aligned_pointer_as_index %m
+%r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
+%r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer
+%r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
+%r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer
+%not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1
+%should_dealloc = arith.andi %not_retained, %cond : i1
+scf.if %should_dealloc {
+ memref.dealloc %m : memref<2xf32>
+}
+%true = arith.constant true
+%r0_does_alias = arith.xori %r0_does_not_alias, %true : i1
+%r0_ownership = arith.andi %r0_does_alias, %cond : i1
+%r1_does_alias = arith.xori %r1_does_not_alias, %true : i1
+%r1_ownership = arith.andi %r1_does_alias, %cond : i1
+return %r0_ownership, %r1_ownership : i1, i1
+```
+
## Memory Layouts
One-Shot Bufferize bufferizes ops from top to bottom. This works well when all
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index 85e9c47ad5302cb..83e55fd70de6bb8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -121,6 +121,14 @@ class BufferPlacementTransformationBase {
Liveness liveness;
};
+/// Compare two SSA values in a deterministic manner. Two block arguments are
+/// ordered by argument number, block arguments are always less than operation
+/// results, and operation results are ordered by the `isBeforeInBlock` order of
+/// their defining operation.
+struct ValueComparator {
+ bool operator()(const Value &lhs, const Value &rhs) const;
+};
+
// Create a global op for the given tensor-valued constant in the program.
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
// names. Duplicates are avoided.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index b0b62acffe77a2a..23eed02a15d4801 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -4,6 +4,7 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+class FunctionOpInterface;
class ModuleOp;
class RewritePatternSet;
class OpBuilder;
@@ -27,6 +28,10 @@ struct OneShotBufferizationOptions;
/// buffers.
std::unique_ptr<Pass> createBufferDeallocationPass();
+/// Creates an instance of the OwnershipBasedBufferDeallocation pass to free all
+/// allocated buffers.
+std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass();
+
/// Creates a pass that optimizes `bufferization.dealloc` operations. For
/// example, it reduces the number of alias checks needed at runtime using
/// static alias analysis.
@@ -127,6 +132,10 @@ func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc,
/// Run buffer deallocation.
LogicalResult deallocateBuffers(Operation *op);
+/// Run ownership basedbuffer deallocation.
+LogicalResult deallocateBuffersOwnershipBased(FunctionOpInterface op,
+ bool privateFuncDynamicOwnership);
+
/// Creates a pass that moves allocations upwards to reduce the number of
/// required copies that are inserted during the BufferDeallocation pass.
std::unique_ptr<Pass> createBufferHoistingPass();
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index ff43cff817b64a8..f3c2a29c0589f29 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -88,6 +88,150 @@ def BufferDeallocation : Pass<"buffer-deallocation", "func::FuncOp"> {
let constructor = "mlir::bufferization::createBufferDeallocationPass()";
}
+def OwnershipBasedBufferDeallocation : Pass<
+ "ownership-based-buffer-deallocation", "func::FuncOp"> {
+ let summary = "Adds all required dealloc operations for all allocations in "
+ "the input program";
+ let description = [{
+ This pass implements an algorithm to automatically introduce all required
+ deallocation operations for all buffers in the input program. This ensures
+ that the resulting program does not have any memory leaks.
+
+ The Buffer Deallocation pass operates on the level of operations
+ implementing the FunctionOpInterface. Such operations can take MemRefs as
+ arguments, but also return them. To ensure compatibility among all functions
+ (including external ones), some rules have to be enforced. They are just
+ assumed to hold for all external functions. Functions for which the
+ definition is available ideally also already adhere to the ABI.
+ Otherwise, all MemRef write operations in the input IR must dominate all
+ MemRef read operations in the input IR. Then, the pass may modify the input
+ IR by inserting `bufferization.clone` operations such that the output IR
+ adheres to the function boundary ABI:
+ * When a MemRef is passed as a function argument, ownership is never
+ acquired. It is always the caller's responsibility to deallocate such
+ MemRefs.
+ * Returning a MemRef from a function always passes ownership to the caller,
+ i.e., it is also the caller's responsibility to deallocate MemRefs
+ returned from a called function.
+ * A function must not return a MemRef with the same allocated base buffer as
+ one of its arguments (in this case a copy has to be created). Note that in
+ this context two subviews of the same buffer that don't overlap are also
+ considered an alias.
+
+ It is recommended to bufferize all operations first such that no tensor
+ values remain in the IR once this pass is applied. That way all allocated
+ MemRefs will be properly deallocated without any additional manual work.
+ Otherwise, the pass that bufferizes the remaining tensors is responsible to
+ add the corresponding deallocation operations. Note that this pass does not
+ consider any values of tensor type and assumes that MemRef values defined by
+ `bufferization.to_memref` do not return ownership and do not have to be
+ deallocated. `bufferization.to_tensor` operations are handled similarly to
+ `bufferization.clone` operations with the exception that the result value is
+ not handled because it's a tensor (not a MemRef).
+
+ Input
+
+ ```mlir
+ #map0 = affine_map<(d0) -> (d0)>
+ module {
+ func.func @condBranch(%arg0: i1,
+ %arg1: memref<2xf32>,
+ %arg2: memref<2xf32>) {
+ cf.cond_br %arg0, ^bb1, ^bb2
+ ^bb1:
+ cf.br ^bb3(%arg1 : memref<2xf32>)
+ ^bb2:
+ %0 = memref.alloc() : memref<2xf32>
+ linalg.generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel"]}
+ outs(%arg1, %0 : memref<2xf32>, memref<2xf32>) {
+ ^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
+ %tmp1 = exp %gen1_arg0 : f32
+ linalg.yield %tmp1 : f32
+ }
+ cf.br ^bb3(%0 : memref<2xf32>)
+ ^bb3(%1: memref<2xf32>):
+ "memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+ return
+ }
+ }
+ ```
+
+ Output
+
+ ```mlir
+ #map = affine_map<(d0) -> (d0)>
+ module {
+ func.func @condBranch(%arg0: i1,
+ %arg1: memref<2xf32>,
+ %arg2: memref<2xf32>) {
+ %false = arith.constant false
+ %true = arith.constant true
+ cf.cond_br %arg0, ^bb1, ^bb2
+ ^bb1: // pred: ^bb0
+ cf.br ^bb3(%arg1, %false : memref<2xf32>, i1)
+ ^bb2: // pred: ^bb0
+ %alloc = memref.alloc() : memref<2xf32>
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel"]}
+ outs(%arg1, %alloc : memref<2xf32>, memref<2xf32>)
+ attrs = {args_in = 1 : i64, args_out = 1 : i64} {
+ ^bb0(%out: f32, %out_0: f32):
+ %2 = math.exp %out : f32
+ linalg.yield %2, %out_0 : f32, f32
+ }
+ cf.br ^bb3(%alloc, %true : memref<2xf32>, i1)
+ ^bb3(%0: memref<2xf32>, %1: i1): // 2 preds: ^bb1, ^bb2
+ memref.copy %0, %arg2 : memref<2xf32> to memref<2xf32>
+ %base_buffer, %offset, %sizes, %strides =
+ memref.extract_strided_metadata %0 :
+ memref<2xf32> -> memref<f32>, index, index, index
+ bufferization.dealloc (%base_buffer : memref<f32>) if (%1)
+ return
+ }
+ }
+ ```
+
+ The `private-function-dynamic-ownership` pass option allows the pass to add
+ additional arguments to private functions to dynamically give ownership of
+ MemRefs to callees. This can enable earlier deallocations and allows the
+ pass to by-pass the function boundary ABI and thus potentially leading to
+ fewer MemRef clones being inserted. For example, the private function
+ ```mlir
+ func.func private @passthrough(%memref: memref<2xi32>) -> memref<2xi32> {
+ return %memref : memref<2xi32>
+ }
+ ```
+ would be converted to
+ ```mlir
+ func.func private @passthrough(%memref: memref<2xi32>,
+ %ownership: i1) -> (memref<2xi32>, i1) {
+ return %memref, %ownership : memref<2xi32>, i1
+ }
+ ```
+ and thus allows the returned MemRef to alias with the MemRef passed as
+ argument (which would otherwise be forbidden according to the function
+ boundary ABI).
+ }];
+ let options = [
+ Option<"privateFuncDynamicOwnership", "private-function-dynamic-ownership",
+ "bool", /*default=*/"false",
+ "Allows to add additional arguments to private functions to "
+ "dynamically pass ownership of memrefs to callees. This can enable "
+ "earlier deallocations.">,
+ ];
+ let constructor = "mlir::bufferization::createOwnershipBasedBufferDeallocationPass()";
+
+ let dependentDialects = [
+ "mlir::bufferization::BufferizationDialect", "mlir::arith::ArithDialect",
+ "mlir::memref::MemRefDialect", "mlir::scf::SCFDialect"
+ ];
+}
+
def BufferDeallocationSimplification :
Pass<"buffer-deallocation-simplification", "func::FuncOp"> {
let summary = "Optimizes `bufferization.dealloc` operation for more "
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 119801f9cc92f32..b8fd99a5541242f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -202,3 +202,62 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
global->moveBefore(&moduleOp.front());
return global;
}
+
+//===----------------------------------------------------------------------===//
+// ValueComparator
+//===----------------------------------------------------------------------===//
+
+bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
+ if (lhs == rhs)
+ return false;
+
+ // Block arguments are less than results.
+ bool lhsIsBBArg = lhs.isa<BlockArgument>();
+ if (lhsIsBBArg != rhs.isa<BlockArgument>()) {
+ return lhsIsBBArg;
+ }
+
+ Region *lhsRegion;
+ Region *rhsRegion;
+ if (lhsIsBBArg) {
+ auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
+ auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
+ if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
+ return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
+ }
+ lhsRegion = lhsBBArg.getParentRegion();
+ rhsRegion = rhsBBArg.getParentRegion();
+ assert(lhsRegion != rhsRegion &&
+ "lhsRegion == rhsRegion implies lhs == rhs");
+ } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
+ return llvm::cast<OpResult>(lhs).getResultNumber() <
+ llvm::cast<OpResult>(rhs).getResultNumber();
+ } else {
+ lhsRegion = lhs.getDefiningOp()->getParentRegion();
+ rhsRegion = rhs.getDefiningOp()->getParentRegion();
+ if (lhsRegion == rhsRegion) {
+ return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
+ }
+ }
+
+ // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
+ // - have
diff erent heights
+ // - or there's a spot where their region numbers
diff er
+ // - or their parent regions are the same and their parent ops are
+ //
diff erent.
+ while (lhsRegion && rhsRegion) {
+ if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
+ return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
+ }
+ if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
+ return lhsRegion->getParentOp()->isBeforeInBlock(
+ rhsRegion->getParentOp());
+ }
+ lhsRegion = lhsRegion->getParentRegion();
+ rhsRegion = rhsRegion->getParentRegion();
+ }
+ if (rhsRegion)
+ return true;
+ assert(lhsRegion && "this should only happen if lhs == rhs");
+ return false;
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 16659e0e3b20366..cbbfe7a81205857 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
LowerDeallocations.cpp
OneShotAnalysis.cpp
OneShotModuleBufferize.cpp
+ OwnershipBasedBufferDeallocation.cpp
TensorCopyInsertion.cpp
ADDITIONAL_HEADER_DIRS
@@ -34,6 +35,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
MLIRPass
MLIRTensorDialect
MLIRSCFDialect
+ MLIRControlFlowDialect
MLIRSideEffectInterfaces
MLIRTransforms
MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
new file mode 100644
index 000000000000000..eaced7202f4e606
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -0,0 +1,1383 @@
+//===- OwnershipBasedBufferDeallocation.cpp - impl. for buffer dealloc. ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for computing correct `bufferization.dealloc`
+// positions. Furthermore, buffer deallocation also adds required new clone
+// operations to ensure that memrefs returned by functions never alias an
+// argument.
+//
+// TODO:
+// The current implementation does not support explicit-control-flow loops and
+// the resulting code will be invalid with respect to program semantics.
+// However, structured control-flow loops are fully supported.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/ADT/SetOperations.h"
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_OWNERSHIPBASEDBUFFERDEALLOCATION
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+//===----------------------------------------------------------------------===//
+// Helpers
+//===----------------------------------------------------------------------===//
+
+static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
+ return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
+}
+
+static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
+
+//===----------------------------------------------------------------------===//
+// Backedges analysis
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// A straight-forward program analysis which detects loop backedges induced by
+/// explicit control flow.
+class Backedges {
+public:
+ using BlockSetT = SmallPtrSet<Block *, 16>;
+ using BackedgeSetT = llvm::DenseSet<std::pair<Block *, Block *>>;
+
+public:
+ /// Constructs a new backedges analysis using the op provided.
+ Backedges(Operation *op) { recurse(op); }
+
+ /// Returns the number of backedges formed by explicit control flow.
+ size_t size() const { return edgeSet.size(); }
+
+ /// Returns the start iterator to loop over all backedges.
+ BackedgeSetT::const_iterator begin() const { return edgeSet.begin(); }
+
+ /// Returns the end iterator to loop over all backedges.
+ BackedgeSetT::const_iterator end() const { return edgeSet.end(); }
+
+private:
+ /// Enters the current block and inserts a backedge into the `edgeSet` if we
+ /// have already visited the current block. The inserted edge links the given
+ /// `predecessor` with the `current` block.
+ bool enter(Block ¤t, Block *predecessor) {
+ bool inserted = visited.insert(¤t).second;
+ if (!inserted)
+ edgeSet.insert(std::make_pair(predecessor, ¤t));
+ return inserted;
+ }
+
+ /// Leaves the current block.
+ void exit(Block ¤t) { visited.erase(¤t); }
+
+ /// Recurses into the given operation while taking all attached regions into
+ /// account.
+ void recurse(Operation *op) {
+ Block *current = op->getBlock();
+ // If the current op implements the `BranchOpInterface`, there can be
+ // cycles in the scope of all successor blocks.
+ if (isa<BranchOpInterface>(op)) {
+ for (Block *succ : current->getSuccessors())
+ recurse(*succ, current);
+ }
+ // Recurse into all distinct regions and check for explicit control-flow
+ // loops.
+ for (Region ®ion : op->getRegions()) {
+ if (!region.empty())
+ recurse(region.front(), current);
+ }
+ }
+
+ /// Recurses into explicit control-flow structures that are given by
+ /// the successor relation defined on the block level.
+ void recurse(Block &block, Block *predecessor) {
+ // Try to enter the current block. If this is not possible, we are
+ // currently processing this block and can safely return here.
+ if (!enter(block, predecessor))
+ return;
+
+ // Recurse into all operations and successor blocks.
+ for (Operation &op : block.getOperations())
+ recurse(&op);
+
+ // Leave the current block.
+ exit(block);
+ }
+
+ /// Stores all blocks that are currently visited and on the processing stack.
+ BlockSetT visited;
+
+ /// Stores all backedges in the format (source, target).
+ BackedgeSetT edgeSet;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// BufferDeallocation
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class is used to track the ownership of values. The ownership can
+/// either be not initialized yet ('Uninitialized' state), set to a unique SSA
+/// value which indicates the ownership at runtime (or statically if it is a
+/// constant value) ('Unique' state), or it cannot be represented in a single
+/// SSA value ('Unknown' state). An artificial example of a case where ownership
+/// cannot be represented in a single i1 SSA value could be the following:
+/// `%0 = test.non_deterministic_select %arg0, %arg1 : i32`
+/// Since the operation does not provide us a separate boolean indicator on
+/// which of the two operands was selected, we would need to either insert an
+/// alias check at runtime to determine if `%0` aliases with `%arg0` or `%arg1`,
+/// or insert a `bufferization.clone` operation to get a fresh buffer which we
+/// could assign ownership to.
+///
+/// The three states this class can represent form a lattice on a partial order:
+/// forall X in SSA values. uninitialized < unique(X) < unknown
+/// forall X, Y in SSA values.
+/// unique(X) == unique(Y) iff X and Y always evaluate to the same value
+/// unique(X) != unique(Y) otherwise
+class Ownership {
+public:
+ /// Constructor that creates an 'Uninitialized' ownership. This is needed for
+ /// default-construction when used in DenseMap.
+ Ownership() = default;
+
+ /// Constructor that creates an 'Unique' ownership. This is a non-explicit
+ /// constructor to allow implicit conversion from 'Value'.
+ Ownership(Value indicator) : indicator(indicator), state(State::Unique) {}
+
+ /// Get an ownership value in 'Unknown' state.
+ static Ownership getUnknown() {
+ Ownership unknown;
+ unknown.indicator = Value();
+ unknown.state = State::Unknown;
+ return unknown;
+ }
+ /// Get an ownership value in 'Unique' state with 'indicator' as parameter.
+ static Ownership getUnique(Value indicator) { return Ownership(indicator); }
+ /// Get an ownership value in 'Uninitialized' state.
+ static Ownership getUninitialized() { return Ownership(); }
+
+ /// Check if this ownership value is in the 'Uninitialized' state.
+ bool isUninitialized() const { return state == State::Uninitialized; }
+ /// Check if this ownership value is in the 'Unique' state.
+ bool isUnique() const { return state == State::Unique; }
+ /// Check if this ownership value is in the 'Unknown' state.
+ bool isUnknown() const { return state == State::Unknown; }
+
+ /// If this ownership value is in 'Unique' state, this function can be used to
+ /// get the indicator parameter. Using this function in any other state is UB.
+ Value getIndicator() const {
+ assert(isUnique() && "must have unique ownership to get the indicator");
+ return indicator;
+ }
+
+ /// Get the join of the two-element subset {this,other}. Does not modify
+ /// 'this'.
+ Ownership getCombined(Ownership other) const {
+ if (other.isUninitialized())
+ return *this;
+ if (isUninitialized())
+ return other;
+
+ if (!isUnique() || !other.isUnique())
+ return getUnknown();
+
+ // Since we create a new constant i1 value for (almost) each use-site, we
+ // should compare the actual value rather than just the SSA Value to avoid
+ // unnecessary invalidations.
+ if (isEqualConstantIntOrValue(indicator, other.indicator))
+ return *this;
+
+ // Return the join of the lattice if the indicator of both ownerships cannot
+ // be merged.
+ return getUnknown();
+ }
+
+ /// Modify 'this' ownership to be the join of the current 'this' and 'other'.
+ void combine(Ownership other) { *this = getCombined(other); }
+
+private:
+ enum class State {
+ Uninitialized,
+ Unique,
+ Unknown,
+ };
+
+ // The indicator value is only relevant in the 'Unique' state.
+ Value indicator;
+ State state = State::Uninitialized;
+};
+
+/// The buffer deallocation transformation which ensures that all allocs in the
+/// program have a corresponding de-allocation.
+class BufferDeallocation {
+public:
+ BufferDeallocation(Operation *op, bool privateFuncDynamicOwnership)
+ : liveness(op), privateFuncDynamicOwnership(privateFuncDynamicOwnership) {
+ }
+
+ /// Performs the actual placement/creation of all dealloc operations.
+ LogicalResult deallocate(FunctionOpInterface op);
+
+private:
+ /// The base case for the recursive template below.
+ template <typename... T>
+ typename std::enable_if<sizeof...(T) == 0, FailureOr<Operation *>>::type
+ handleOp(Operation *op) {
+ return op;
+ }
+
+ /// Applies all the handlers of the interfaces in the template list
+ /// implemented by 'op'. In particular, if an operation implements more than
+ /// one of the interfaces in the template list, all the associated handlers
+ /// will be applied to the operation in the same order as the template list
+ /// specifies. If a handler reports a failure or removes the operation without
+ /// replacement (indicated by returning 'nullptr'), no further handlers are
+ /// applied and the return value is propagated to the caller of 'handleOp'.
+ ///
+ /// The interface handlers job is to update the deallocation state, most
+ /// importantly the ownership map and list of memrefs to potentially be
+ /// deallocated per block, but also to insert `bufferization.dealloc`
+ /// operations where needed. Obviously, no MemRefs that may be used at a later
+ /// point in the control-flow may be deallocated and the ownership map has to
+ /// be updated to reflect potential ownership changes caused by the dealloc
+ /// operation (e.g., if two interfaces on the same op insert a dealloc
+ /// operation each, the second one should query the ownership map and use them
+ /// as deallocation condition such that MemRefs already deallocated in the
+ /// first dealloc operation are not deallocated a second time (double-free)).
+ /// Note that currently only the interfaces on terminators may insert dealloc
+ /// operations and it is verified as a precondition that a terminator op must
+ /// implement exactly one of the interfaces handling dealloc insertion.
+ ///
+ /// The return value of the 'handleInterface' functions should be a
+ /// FailureOr<Operation *> indicating whether there was a failure or otherwise
+ /// returning the operation itself or a replacement operation.
+ ///
+ /// Note: The
diff erence compared to `TypeSwitch` is that all
+ /// matching cases are applied instead of just the first match.
+ template <typename InterfaceT, typename... InterfacesU>
+ FailureOr<Operation *> handleOp(Operation *op) {
+ Operation *next = op;
+ if (auto concreteOp = dyn_cast<InterfaceT>(op)) {
+ FailureOr<Operation *> result = handleInterface(concreteOp);
+ if (failed(result))
+ return failure();
+ next = *result;
+ }
+ if (!next)
+ return nullptr;
+ return handleOp<InterfacesU...>(next);
+ }
+
+ /// Apply all supported interface handlers to the given op.
+ FailureOr<Operation *> handleAllInterfaces(Operation *op) {
+ if (failed(verifyOperationPreconditions(op)))
+ return failure();
+
+ return handleOp<MemoryEffectOpInterface, RegionBranchOpInterface,
+ CallOpInterface, BranchOpInterface, cf::CondBranchOp,
+ RegionBranchTerminatorOpInterface>(op);
+ }
+
+ /// While CondBranchOp also implements the BranchOpInterface, we add a
+ /// special-case implementation here because the BranchOpInterface does not
+ /// offer all of the functionality we need to insert dealloc operations in an
+ /// efficient way. More precisely, there is no way to extract the branch
+ /// condition without casting to CondBranchOp specifically. It would still be
+ /// possible to implement deallocation for cases where we don't know to which
+ /// successor the terminator branches before the actual branch happens by
+ /// inserting auxiliary blocks and putting the dealloc op there, however, this
+ /// can lead to less efficient code.
+ /// This function inserts two dealloc operations (one for each successor) and
+ /// adjusts the dealloc conditions according to the branch condition, then the
+ /// ownerships of the retained MemRefs are updated by combining the result
+ /// values of the two dealloc operations.
+ ///
+ /// Example:
+ /// ```
+ /// ^bb1:
+ /// <more ops...>
+ /// cf.cond_br cond, ^bb2(<forward-to-bb2>), ^bb3(<forward-to-bb2>)
+ /// ```
+ /// becomes
+ /// ```
+ /// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1)
+ /// // let r0 = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>)
+ /// // let r1 = getMemrefsToRetain(bb1, bb3, <forward-to-bb3>)
+ /// ^bb1:
+ /// <more ops...>
+ /// let thenCond = map(c, (c) -> arith.andi cond, c)
+ /// let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c)
+ /// o0 = bufferization.dealloc m if thenCond retain r0
+ /// o1 = bufferization.dealloc m if elseCond retain r1
+ /// // replace ownership(r0) with o0 element-wise
+ /// // replace ownership(r1) with o1 element-wise
+ /// // let ownership0 := (r) -> o in o0 corresponding to r
+ /// // let ownership1 := (r) -> o in o1 corresponding to r
+ /// // let cmn := intersection(r0, r1)
+ /// foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)):
+ /// forall r in r0: replace ownership0(r) with arith.select cond, a, b)
+ /// forall r in r1: replace ownership1(r) with arith.select cond, a, b)
+ /// cf.cond_br cond, ^bb2(<forward-to-bb2>, o0), ^bb3(<forward-to-bb3>, o1)
+ /// ```
+ FailureOr<Operation *> handleInterface(cf::CondBranchOp op);
+
+ /// Make sure that for each forwarded MemRef value, an ownership indicator
+ /// `i1` value is forwarded as well such that the successor block knows
+ /// whether the MemRef has to be deallocated.
+ ///
+ /// Example:
+ /// ```
+ /// ^bb1:
+ /// <more ops...>
+ /// cf.br ^bb2(<forward-to-bb2>)
+ /// ```
+ /// becomes
+ /// ```
+ /// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1)
+ /// // let r = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>)
+ /// ^bb1:
+ /// <more ops...>
+ /// o = bufferization.dealloc m if c retain r
+ /// // replace ownership(r) with o element-wise
+ /// cf.br ^bb2(<forward-to-bb2>, o)
+ /// ```
+ FailureOr<Operation *> handleInterface(BranchOpInterface op);
+
+ /// Add an ownership indicator for every forwarding MemRef operand and result.
+ /// Nested regions never take ownership of MemRefs owned by a parent region
+ /// (neither via forwarding operand nor when captured implicitly when the
+ /// region is not isolated from above). Ownerships will only be passed to peer
+ /// regions (when an operation has multiple regions, such as scf.while), or to
+ /// parent regions.
+ /// Note that the block arguments in the nested region are currently handled
+ /// centrally in the 'dealloc' function, but better interface support could
+ /// allow us to do this here for the nested region specifically to reduce the
+ /// amount of assumptions we make on the structure of ops implementing this
+ /// interface.
+ ///
+ /// Example:
+ /// ```
+ /// %ret = scf.for %i = %c0 to %c10 step %c1 iter_args(%m = %memref) {
+ /// <more ops...>
+ /// scf.yield %m : memref<2xi32>, i1
+ /// }
+ /// ```
+ /// becomes
+ /// ```
+ /// %ret:2 = scf.for %i = %c0 to %c10 step %c1
+ /// iter_args(%m = %memref, %own = %false) {
+ /// <more ops...>
+ /// // Note that the scf.yield is handled by the
+ /// // RegionBranchTerminatorOpInterface (not this handler)
+ /// // let o = getMemrefWithUniqueOwnership(%own)
+ /// scf.yield %m, o : memref<2xi32>, i1
+ /// }
+ /// ```
+ FailureOr<Operation *> handleInterface(RegionBranchOpInterface op);
+
+ /// If the private-function-dynamic-ownership pass option is enabled and the
+ /// called function is private, additional arguments and results are added for
+ /// each MemRef argument/result to pass the dynamic ownership indicator along.
+ /// Otherwise, updates the ownership map and list of memrefs to be deallocated
+ /// according to the function boundary ABI, i.e., assume ownership of all
+ /// returned MemRefs.
+ ///
+ /// Example (assume `private-function-dynamic-ownership` is enabled):
+ /// ```
+ /// func.func @f(%arg0: memref<2xi32>) -> memref<2xi32> {...}
+ /// func.func private @g(%arg0: memref<2xi32>) -> memref<2xi32> {...}
+ ///
+ /// %ret_f = func.call @f(%memref) : (memref<2xi32>) -> memref<2xi32>
+ /// %ret_g = func.call @g(%memref) : (memref<2xi32>) -> memref<2xi32>
+ /// ```
+ /// becomes
+ /// ```
+ /// func.func @f(%arg0: memref<2xi32>) -> memref<2xi32> {...}
+ /// func.func private @g(%arg0: memref<2xi32>) -> memref<2xi32> {...}
+ ///
+ /// %ret_f = func.call @f(%memref) : (memref<2xi32>) -> memref<2xi32>
+ /// // set ownership(%ret_f) := true
+ /// // remember to deallocate %ret_f
+ ///
+ /// // (new_memref, own) = getmemrefWithUniqueOwnership(%memref)
+ /// %ret_g:2 = func.call @g(new_memref, own) :
+ /// (memref<2xi32>, i1) -> (memref<2xi32>, i1)
+ /// // set ownership(%ret_g#0) := %ret_g#1
+ /// // remember to deallocate %ret_g
+ /// ```
+ FailureOr<Operation *> handleInterface(CallOpInterface op);
+
+ /// Takes care of allocation and free side-effects. It collects allocated
+ /// MemRefs that we have to add to manually deallocate, but also removes
+ /// values again that are already deallocated before the end of the block. It
+ /// also updates the ownership map accordingly.
+ ///
+ /// Example:
+ /// ```
+ /// %alloc = memref.alloc()
+ /// %alloca = memref.alloca()
+ /// ```
+ /// becomes
+ /// ```
+ /// %alloc = memref.alloc()
+ /// %alloca = memref.alloca()
+ /// // set ownership(alloc) := true
+ /// // set ownership(alloca) := false
+ /// // remember to deallocate %alloc
+ /// ```
+ FailureOr<Operation *> handleInterface(MemoryEffectOpInterface op);
+
+ /// Takes care that the function boundary ABI is adhered to if the parent
+ /// operation implements FunctionOpInterface, inserting a
+ /// `bufferization.clone` if necessary, and inserts the
+ /// `bufferization.dealloc` operation according to the ops operands.
+ ///
+ /// Example:
+ /// ```
+ /// ^bb1:
+ /// <more ops...>
+ /// func.return <return-vals>
+ /// ```
+ /// becomes
+ /// ```
+ /// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1)
+ /// // let r = getMemrefsToRetain(bb1, nullptr, <return-vals>)
+ /// ^bb1:
+ /// <more ops...>
+ /// o = bufferization.dealloc m if c retain r
+ /// func.return <return-vals>
+ /// (if !isFunctionWithoutDynamicOwnership: append o)
+ /// ```
+ FailureOr<Operation *> handleInterface(RegionBranchTerminatorOpInterface op);
+
+ /// Construct a new operation which is exactly the same as the passed 'op'
+ /// except that the OpResults list is appended by new results of the passed
+ /// 'types'.
+ /// TODO: ideally, this would be implemented using an OpInterface because it
+ /// is used to append function results, loop iter_args, etc. and thus makes
+ /// some assumptions that the variadic list of those is at the end of the
+ /// OpResults range.
+ Operation *appendOpResults(Operation *op, ArrayRef<Type> types);
+
+ /// A convenience template for the generic 'appendOpResults' function above to
+ /// avoid manual casting of the result.
+ template <typename OpTy>
+ OpTy appendOpResults(OpTy op, ArrayRef<Type> types) {
+ return cast<OpTy>(appendOpResults(op.getOperation(), types));
+ }
+
+ /// Performs deallocation of a single basic block. This is a private function
+ /// because some internal data structures have to be set up beforehand and
+ /// this function has to be called on blocks in a region in dominance order.
+ LogicalResult deallocate(Block *block);
+
+ /// Small helper function to update the ownership map by taking the current
+ /// ownership ('Uninitialized' state if not yet present), computing the join
+ /// with the passed ownership and storing this new value in the map. By
+ /// default, it will be performed for the block where 'owned' is defined. If
+ /// the ownership of the given value should be updated for another block, the
+ /// 'block' argument can be explicitly passed.
+ void joinOwnership(Value owned, Ownership ownership, Block *block = nullptr);
+
+ /// Removes ownerships associated with all values in the passed range for
+ /// 'block'.
+ void clearOwnershipOf(ValueRange values, Block *block);
+
+ /// After all relevant interfaces of an operation have been processed by the
+ /// 'handleInterface' functions, this function sets the ownership of operation
+ /// results that have not been set yet by the 'handleInterface' functions. It
+ /// generally assumes that each result can alias with every operand of the
+ /// operation, if there are MemRef typed results but no MemRef operands it
+ /// assigns 'false' as ownership. This happens, e.g., for the
+ /// memref.get_global operation. It would also be possible to query some alias
+ /// analysis to get more precise ownerships, however, the analysis would have
+ /// to be updated according to the IR modifications this pass performs (e.g.,
+ /// re-building operations to have more result values, inserting clone
+ /// operations, etc.).
+ void populateRemainingOwnerships(Operation *op);
+
+ /// Given two basic blocks and the values passed via block arguments to the
+ /// destination block, compute the list of MemRefs that have to be retained in
+ /// the 'fromBlock' to not run into a use-after-free situation.
+ /// This list consists of the MemRefs in the successor operand list of the
+ /// terminator and the MemRefs in the 'out' set of the liveness analysis
+ /// intersected with the 'in' set of the destination block.
+ ///
+ /// toRetain = filter(successorOperands + (liveOut(fromBlock) insersect
+ /// liveIn(toBlock)), isMemRef)
+ void getMemrefsToRetain(Block *fromBlock, Block *toBlock,
+ ValueRange destOperands,
+ SmallVectorImpl<Value> &toRetain) const;
+
+ /// For a given block, computes the list of MemRefs that potentially need to
+ /// be deallocated at the end of that block. This list also contains values
+ /// that have to be retained (and are thus part of the list returned by
+ /// `getMemrefsToRetain`) and is computed by taking the MemRefs in the 'in'
+ /// set of the liveness analysis of 'block' appended by the set of MemRefs
+ /// allocated in 'block' itself and subtracted by the set of MemRefs
+ /// deallocated in 'block'.
+ /// Note that we don't have to take the intersection of the liveness 'in' set
+ /// with the 'out' set of the predecessor block because a value that is in the
+ /// 'in' set must be defined in an ancestor block that dominates all direct
+ /// predecessors and thus the 'in' set of this block is a subset of the 'out'
+ /// sets of each predecessor.
+ ///
+ /// memrefs = filter((liveIn(block) U
+ /// allocated(block) U arguments(block)) \ deallocated(block), isMemRef)
+ ///
+ /// The list of conditions is then populated by querying the internal
+ /// datastructures for the ownership value of that MemRef.
+ LogicalResult
+ getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc,
+ Block *block,
+ SmallVectorImpl<Value> &memrefs,
+ SmallVectorImpl<Value> &conditions) const;
+
+ /// Given an SSA value of MemRef type, this function queries the ownership and
+ /// if it is not already in the 'Unique' state, potentially inserts IR to get
+ /// a new SSA value, returned as the first element of the pair, which has
+ /// 'Unique' ownership and can be used instead of the passed Value with the
+ /// the ownership indicator returned as the second element of the pair.
+ std::pair<Value, Value> getMemrefWithUniqueOwnership(OpBuilder &builder,
+ Value memref);
+
+ /// Given an SSA value of MemRef type, returns the same of a new SSA value
+ /// which has 'Unique' ownership where the ownership indicator is guaranteed
+ /// to be always 'true'.
+ Value getMemrefWithGuaranteedOwnership(OpBuilder &builder, Value memref);
+
+ /// Returns whether the given operation implements FunctionOpInterface, has
+ /// private visibility, and the private-function-dynamic-ownership pass option
+ /// is enabled.
+ bool isFunctionWithoutDynamicOwnership(Operation *op);
+
+ /// Checks all the preconditions for operations implementing the
+ /// FunctionOpInterface that have to hold for the deallocation to be
+ /// applicable:
+ /// (1) Checks that there are not explicit control flow loops.
+ static LogicalResult verifyFunctionPreconditions(FunctionOpInterface op);
+
+ /// Checks all the preconditions for operations inside the region of
+ /// operations implementing the FunctionOpInterface that have to hold for the
+ /// deallocation to be applicable:
+ /// (1) Checks if all operations that have at least one attached region
+ /// implement the RegionBranchOpInterface. This is not required in edge cases,
+ /// where we have a single attached region and the parent operation has no
+ /// results.
+ /// (2) Checks that no deallocations already exist. Especially deallocations
+ /// in nested regions are not properly supported yet since this requires
+ /// ownership of the memref to be transferred to the nested region, which does
+ /// not happen by default. This constrained can be lifted in the future.
+ /// (3) Checks that terminators with more than one successor except
+ /// `cf.cond_br` are not present and that either BranchOpInterface or
+ /// RegionBranchTerminatorOpInterface is implemented.
+ static LogicalResult verifyOperationPreconditions(Operation *op);
+
+ /// When the 'private-function-dynamic-ownership' pass option is enabled,
+ /// additional `i1` arguments and return values are added for each MemRef
+ /// value in the function signature. This function takes care of updating the
+ /// `function_type` attribute of the function according to the actually
+ /// returned values from the terminators.
+ static LogicalResult updateFunctionSignature(FunctionOpInterface op);
+
+private:
+ // Mapping from each SSA value with MemRef type to the associated ownership in
+ // each block.
+ DenseMap<std::pair<Value, Block *>, Ownership> ownershipMap;
+
+ // Collects the list of MemRef values that potentially need to be deallocated
+ // per block. It is also fine (albeit not efficient) to add MemRef values that
+ // don't have to be deallocated, but only when the ownership is not 'Unknown'.
+ DenseMap<Block *, SmallVector<Value>> memrefsToDeallocatePerBlock;
+
+ // Symbol cache to lookup functions from call operations to check attributes
+ // on the function operation.
+ SymbolTableCollection symbolTable;
+
+ // The underlying liveness analysis to compute fine grained information about
+ // alloc and dealloc positions.
+ Liveness liveness;
+
+ // A pass option indicating whether private functions should be modified to
+ // pass the ownership of MemRef values instead of adhering to the function
+ // boundary ABI.
+ bool privateFuncDynamicOwnership;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// BufferDeallocation Implementation
+//===----------------------------------------------------------------------===//
+
+void BufferDeallocation::joinOwnership(Value owned, Ownership ownership,
+ Block *block) {
+ // In most cases we care about the block where the value is defined.
+ if (block == nullptr)
+ block = owned.getParentBlock();
+
+ // Update ownership of current memref itself.
+ ownershipMap[{owned, block}].combine(ownership);
+}
+
+void BufferDeallocation::clearOwnershipOf(ValueRange values, Block *block) {
+ for (Value val : values) {
+ ownershipMap[{val, block}] = Ownership::getUninitialized();
+ }
+}
+
+static bool regionOperatesOnMemrefValues(Region ®ion) {
+ WalkResult result = region.walk([](Block *block) {
+ if (llvm::any_of(block->getArguments(), isMemref))
+ return WalkResult::interrupt();
+ for (Operation &op : *block) {
+ if (llvm::any_of(op.getOperands(), isMemref))
+ return WalkResult::interrupt();
+ if (llvm::any_of(op.getResults(), isMemref))
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return result.wasInterrupted();
+}
+
+LogicalResult
+BufferDeallocation::verifyFunctionPreconditions(FunctionOpInterface op) {
+ // (1) Ensure that there are supported loops only (no explicit control flow
+ // loops).
+ Backedges backedges(op);
+ if (backedges.size()) {
+ op->emitError("Only structured control-flow loops are supported.");
+ return failure();
+ }
+
+ return success();
+}
+
+LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) {
+ // (1) Check that the control flow structures are supported.
+ auto regions = op->getRegions();
+ // Check that if the operation has at
+ // least one region it implements the RegionBranchOpInterface. If there
+ // is an operation that does not fulfill this condition, we cannot apply
+ // the deallocation steps. Furthermore, we accept cases, where we have a
+ // region that returns no results, since, in that case, the intra-region
+ // control flow does not affect the transformation.
+ size_t size = regions.size();
+ if (((size == 1 && !op->getResults().empty()) || size > 1) &&
+ !dyn_cast<RegionBranchOpInterface>(op)) {
+ if (llvm::any_of(regions, regionOperatesOnMemrefValues))
+ return op->emitError("All operations with attached regions need to "
+ "implement the RegionBranchOpInterface.");
+ }
+
+ // (2) The pass does not work properly when deallocations are already present.
+ // Alternatively, we could also remove all deallocations as a pre-pass.
+ if (isa<DeallocOp>(op))
+ return op->emitError(
+ "No deallocation operations must be present when running this pass!");
+
+ // (3) Check that terminators with more than one successor except `cf.cond_br`
+ // are not present and that either BranchOpInterface or
+ // RegionBranchTerminatorOpInterface is implemented.
+ if (op->hasTrait<OpTrait::NoTerminator>())
+ return op->emitError("NoTerminator trait is not supported");
+
+ if (op->hasTrait<OpTrait::IsTerminator>()) {
+ // Either one of those interfaces has to be implemented on terminators, but
+ // not both.
+ if (!isa<BranchOpInterface, RegionBranchTerminatorOpInterface>(op) ||
+ (isa<BranchOpInterface>(op) &&
+ isa<RegionBranchTerminatorOpInterface>(op)))
+
+ return op->emitError(
+ "Terminators must implement either BranchOpInterface or "
+ "RegionBranchTerminatorOpInterface (but not both)!");
+
+ // We only support terminators with 0 or 1 successors for now and
+ // special-case the conditional branch op.
+ if (op->getSuccessors().size() > 1 && !isa<cf::CondBranchOp>(op))
+
+ return op->emitError("Terminators with more than one successor "
+ "are not supported (except cf.cond_br)!");
+ }
+
+ return success();
+}
+
+LogicalResult
+BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) {
+ SmallVector<TypeRange> returnOperandTypes(llvm::map_range(
+ op.getFunctionBody().getOps<RegionBranchTerminatorOpInterface>(),
+ [](RegionBranchTerminatorOpInterface op) {
+ return op.getSuccessorOperands(RegionBranchPoint::parent()).getTypes();
+ }));
+ if (!llvm::all_equal(returnOperandTypes))
+ return op->emitError(
+ "there are multiple return operations with
diff erent operand types");
+
+ TypeRange resultTypes = op.getResultTypes();
+ // Check if we found a return operation because that doesn't necessarily
+ // always have to be the case, e.g., consider a function with one block that
+ // has a cf.br at the end branching to itself again (i.e., an infinite loop).
+ // In that case we don't want to crash but just not update the return types.
+ if (!returnOperandTypes.empty())
+ resultTypes = returnOperandTypes[0];
+
+ // TODO: it would be nice if the FunctionOpInterface had a method to not only
+ // get the function type but also set it.
+ op->setAttr(
+ "function_type",
+ TypeAttr::get(FunctionType::get(
+ op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
+ resultTypes)));
+
+ return success();
+}
+
+LogicalResult BufferDeallocation::deallocate(FunctionOpInterface op) {
+ // Stop and emit a proper error message if we don't support the input IR.
+ if (failed(verifyFunctionPreconditions(op)))
+ return failure();
+
+ // Process the function block by block.
+ auto result = op->walk<WalkOrder::PostOrder, ForwardDominanceIterator<>>(
+ [&](Block *block) {
+ if (failed(deallocate(block)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ if (result.wasInterrupted())
+ return failure();
+
+ // Update the function signature if the function is private, dynamic ownership
+ // is enabled, and the function has memrefs as arguments or results.
+ return updateFunctionSignature(op);
+}
+
+void BufferDeallocation::getMemrefsToRetain(
+ Block *fromBlock, Block *toBlock, ValueRange destOperands,
+ SmallVectorImpl<Value> &toRetain) const {
+ for (Value operand : destOperands) {
+ if (!isMemref(operand))
+ continue;
+ toRetain.push_back(operand);
+ }
+
+ SmallPtrSet<Value, 16> liveOut;
+ for (auto val : liveness.getLiveOut(fromBlock))
+ if (isMemref(val))
+ liveOut.insert(val);
+
+ if (toBlock)
+ llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock));
+
+ // liveOut has non-deterministic order because it was constructed by iterating
+ // over a hash-set.
+ SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
+ std::sort(retainedByLiveness.begin(), retainedByLiveness.end(),
+ ValueComparator());
+ toRetain.append(retainedByLiveness);
+}
+
+LogicalResult BufferDeallocation::getMemrefsAndConditionsToDeallocate(
+ OpBuilder &builder, Location loc, Block *block,
+ SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
+
+ for (auto [i, memref] :
+ llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {
+ Ownership ownership = ownershipMap.lookup({memref, block});
+ assert(ownership.isUnique() && "MemRef value must have valid ownership");
+
+ // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
+ // that we can call extract_strided_metadata on it.
+ if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
+ memref = builder.create<memref::ReinterpretCastOp>(
+ loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
+ 0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+
+ // Use the `memref.extract_strided_metadata` operation to get the base
+ // memref. This is needed because the same MemRef that was produced by the
+ // alloc operation has to be passed to the dealloc operation. Passing
+ // subviews, etc. to a dealloc operation is not allowed.
+ memrefs.push_back(
+ builder.create<memref::ExtractStridedMetadataOp>(loc, memref)
+ .getResult(0));
+ conditions.push_back(ownership.getIndicator());
+ }
+
+ return success();
+}
+
+LogicalResult BufferDeallocation::deallocate(Block *block) {
+ OpBuilder builder = OpBuilder::atBlockBegin(block);
+
+ // Compute liveness transfers of ownership to this block.
+ for (auto li : liveness.getLiveIn(block)) {
+ if (!isMemref(li))
+ continue;
+
+ // Ownership of implicitly captured memrefs from other regions is never
+ // taken, but ownership of memrefs in the same region (but
diff erent block)
+ // is taken.
+ if (li.getParentRegion() == block->getParent()) {
+ joinOwnership(li, ownershipMap[{li, li.getParentBlock()}], block);
+ memrefsToDeallocatePerBlock[block].push_back(li);
+ continue;
+ }
+
+ if (li.getParentRegion()->isProperAncestor(block->getParent())) {
+ Value falseVal = buildBoolValue(builder, li.getLoc(), false);
+ joinOwnership(li, falseVal, block);
+ }
+ }
+
+ for (unsigned i = 0, e = block->getNumArguments(); i < e; ++i) {
+ BlockArgument arg = block->getArgument(i);
+ if (!isMemref(arg))
+ continue;
+
+ // Adhere to function boundary ABI: no ownership of function argument
+ // MemRefs is taken.
+ if (isFunctionWithoutDynamicOwnership(block->getParentOp()) &&
+ block->isEntryBlock()) {
+ Value newArg = buildBoolValue(builder, arg.getLoc(), false);
+ joinOwnership(arg, newArg);
+ continue;
+ }
+
+ // Pass MemRef ownerships along via `i1` values.
+ Value newArg = block->addArgument(builder.getI1Type(), arg.getLoc());
+ joinOwnership(arg, newArg);
+ memrefsToDeallocatePerBlock[block].push_back(arg);
+ }
+
+ // For each operation in the block, handle the interfaces that affect aliasing
+ // and ownership of memrefs.
+ for (Operation &op : llvm::make_early_inc_range(*block)) {
+ FailureOr<Operation *> result = handleAllInterfaces(&op);
+ if (failed(result))
+ return failure();
+
+ populateRemainingOwnerships(*result);
+ }
+
+ // TODO: if block has no terminator, handle dealloc insertion here.
+ return success();
+}
+
+Operation *BufferDeallocation::appendOpResults(Operation *op,
+ ArrayRef<Type> types) {
+ SmallVector<Type> newTypes(op->getResultTypes());
+ newTypes.append(types.begin(), types.end());
+ auto *newOp = Operation::create(op->getLoc(), op->getName(), newTypes,
+ op->getOperands(), op->getAttrDictionary(),
+ op->getPropertiesStorage(),
+ op->getSuccessors(), op->getNumRegions());
+ for (auto [oldRegion, newRegion] :
+ llvm::zip(op->getRegions(), newOp->getRegions()))
+ newRegion.takeBody(oldRegion);
+
+ OpBuilder(op).insert(newOp);
+ op->replaceAllUsesWith(newOp->getResults().take_front(op->getNumResults()));
+ op->erase();
+
+ return newOp;
+}
+
+FailureOr<Operation *>
+BufferDeallocation::handleInterface(cf::CondBranchOp op) {
+ OpBuilder builder(op);
+
+ // The list of memrefs to pass to the `bufferization.dealloc` op as "memrefs
+ // to deallocate" in this block is independent of which branch is taken.
+ SmallVector<Value> memrefs, ownerships;
+ if (failed(getMemrefsAndConditionsToDeallocate(
+ builder, op.getLoc(), op->getBlock(), memrefs, ownerships)))
+ return failure();
+
+ // Helper lambda to factor out common logic for inserting the dealloc
+ // operations for each successor.
+ auto insertDeallocForBranch =
+ [&](Block *target, MutableOperandRange destOperands,
+ ArrayRef<Value> conditions,
+ DenseMap<Value, Value> &ownershipMapping) -> DeallocOp {
+ SmallVector<Value> toRetain;
+ getMemrefsToRetain(op->getBlock(), target, OperandRange(destOperands),
+ toRetain);
+ auto deallocOp = builder.create<bufferization::DeallocOp>(
+ op.getLoc(), memrefs, conditions, toRetain);
+ clearOwnershipOf(deallocOp.getRetained(), op->getBlock());
+ for (auto [retained, ownership] :
+ llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
+ joinOwnership(retained, ownership, op->getBlock());
+ ownershipMapping[retained] = ownership;
+ }
+ SmallVector<Value> replacements, ownerships;
+ for (Value operand : destOperands) {
+ replacements.push_back(operand);
+ if (isMemref(operand)) {
+ assert(ownershipMapping.contains(operand) &&
+ "Should be contained at this point");
+ ownerships.push_back(ownershipMapping[operand]);
+ }
+ }
+ replacements.append(ownerships);
+ destOperands.assign(replacements);
+ return deallocOp;
+ };
+
+ // Call the helper lambda and make sure the dealloc conditions are properly
+ // modified to reflect the branch condition as well.
+ DenseMap<Value, Value> thenOwnershipMap, elseOwnershipMap;
+
+ // Retain `trueDestOperands` if "true" branch is taken.
+ SmallVector<Value> thenOwnerships(
+ llvm::map_range(ownerships, [&](Value cond) {
+ return builder.create<arith::AndIOp>(op.getLoc(), cond,
+ op.getCondition());
+ }));
+ DeallocOp thenTakenDeallocOp =
+ insertDeallocForBranch(op.getTrueDest(), op.getTrueDestOperandsMutable(),
+ thenOwnerships, thenOwnershipMap);
+
+ // Retain `elseDestOperands` if "false" branch is taken.
+ SmallVector<Value> elseOwnerships(
+ llvm::map_range(ownerships, [&](Value cond) {
+ Value trueVal = builder.create<arith::ConstantOp>(
+ op.getLoc(), builder.getBoolAttr(true));
+ Value negation = builder.create<arith::XOrIOp>(op.getLoc(), trueVal,
+ op.getCondition());
+ return builder.create<arith::AndIOp>(op.getLoc(), cond, negation);
+ }));
+ DeallocOp elseTakenDeallocOp = insertDeallocForBranch(
+ op.getFalseDest(), op.getFalseDestOperandsMutable(), elseOwnerships,
+ elseOwnershipMap);
+
+ // We specifically need to update the ownerships of values that are retained
+ // in both dealloc operations again to get a combined 'Unique' ownership
+ // instead of an 'Unknown' ownership.
+ SmallPtrSet<Value, 16> thenValues(thenTakenDeallocOp.getRetained().begin(),
+ thenTakenDeallocOp.getRetained().end());
+ SetVector<Value> commonValues;
+ for (Value val : elseTakenDeallocOp.getRetained()) {
+ if (thenValues.contains(val))
+ commonValues.insert(val);
+ }
+
+ for (Value retained : commonValues) {
+ clearOwnershipOf(retained, op->getBlock());
+ Value combinedOwnership = builder.create<arith::SelectOp>(
+ op.getLoc(), op.getCondition(), thenOwnershipMap[retained],
+ elseOwnershipMap[retained]);
+ joinOwnership(retained, combinedOwnership, op->getBlock());
+ }
+
+ return op.getOperation();
+}
+
+FailureOr<Operation *>
+BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
+ OpBuilder builder = OpBuilder::atBlockBegin(op->getBlock());
+
+ // TODO: the RegionBranchOpInterface does not provide all the necessary
+ // methods to perform this transformation without additional assumptions on
+ // the structure. In particular, that
+ // * additional values to be passed to the next region can be added to the end
+ // of the operand list, the end of the block argument list, and the end of
+ // the result value list. However, it seems to be the general guideline for
+ // operations implementing this interface to follow this structure.
+ // * and that the block arguments and result values match the forwarded
+ // operands one-to-one (i.e., that there are no other values appended to the
+ // front).
+ // These assumptions are satisfied by the `scf.if`, `scf.for`, and `scf.while`
+ // operations.
+
+ SmallVector<RegionSuccessor> regions;
+ op.getSuccessorRegions(RegionBranchPoint::parent(), regions);
+ assert(!regions.empty() && "Must have at least one successor region");
+ SmallVector<Value> entryOperands(
+ op.getEntrySuccessorOperands(regions.front()));
+ unsigned numMemrefOperands = llvm::count_if(entryOperands, isMemref);
+
+ // No ownership is acquired for any MemRefs that are passed to the region from
+ // the outside.
+ Value falseVal = buildBoolValue(builder, op.getLoc(), false);
+ op->insertOperands(op->getNumOperands(),
+ SmallVector<Value>(numMemrefOperands, falseVal));
+
+ int counter = op->getNumResults();
+ unsigned numMemrefResults = llvm::count_if(op->getResults(), isMemref);
+ SmallVector<Type> ownershipResults(numMemrefResults, builder.getI1Type());
+ RegionBranchOpInterface newOp = appendOpResults(op, ownershipResults);
+
+ for (auto result : llvm::make_filter_range(newOp->getResults(), isMemref)) {
+ joinOwnership(result, newOp->getResult(counter++));
+ memrefsToDeallocatePerBlock[newOp->getBlock()].push_back(result);
+ }
+
+ return newOp.getOperation();
+}
+
+std::pair<Value, Value>
+BufferDeallocation::getMemrefWithUniqueOwnership(OpBuilder &builder,
+ Value memref) {
+ auto iter = ownershipMap.find({memref, memref.getParentBlock()});
+ assert(iter != ownershipMap.end() &&
+ "Value must already have been registered in the ownership map");
+
+ Ownership ownership = iter->second;
+ if (ownership.isUnique())
+ return {memref, ownership.getIndicator()};
+
+ // Instead of inserting a clone operation we could also insert a dealloc
+ // operation earlier in the block and use the updated ownerships returned by
+ // the op for the retained values. Alternatively, we could insert code to
+ // check aliasing at runtime and use this information to combine two unique
+ // ownerships more intelligently to not end up with an 'Unknown' ownership in
+ // the first place.
+ auto cloneOp =
+ builder.create<bufferization::CloneOp>(memref.getLoc(), memref);
+ Value condition = buildBoolValue(builder, memref.getLoc(), true);
+ Value newMemref = cloneOp.getResult();
+ joinOwnership(newMemref, condition);
+ memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
+ return {newMemref, condition};
+}
+
+Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder,
+ Value memref) {
+ // First, make sure we at least have 'Unique' ownership already.
+ std::pair<Value, Value> newMemrefAndOnwership =
+ getMemrefWithUniqueOwnership(builder, memref);
+ Value newMemref = newMemrefAndOnwership.first;
+ Value condition = newMemrefAndOnwership.second;
+
+ // Avoid inserting additional IR if ownership is already guaranteed. In
+ // particular, this is already the case when we had 'Unknown' ownership
+ // initially and a clone was inserted to get to 'Unique' ownership.
+ if (matchPattern(condition, m_One()))
+ return newMemref;
+
+ // Insert a runtime check and only clone if we still don't have ownership at
+ // runtime.
+ Value maybeClone =
+ builder
+ .create<scf::IfOp>(
+ memref.getLoc(), condition,
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, newMemref);
+ },
+ [&](OpBuilder &builder, Location loc) {
+ Value clone =
+ builder.create<bufferization::CloneOp>(loc, newMemref);
+ builder.create<scf::YieldOp>(loc, clone);
+ })
+ .getResult(0);
+ Value trueVal = buildBoolValue(builder, memref.getLoc(), true);
+ joinOwnership(maybeClone, trueVal);
+ memrefsToDeallocatePerBlock[maybeClone.getParentBlock()].push_back(
+ maybeClone);
+ return maybeClone;
+}
+
+FailureOr<Operation *>
+BufferDeallocation::handleInterface(BranchOpInterface op) {
+ // Skip conditional branches since we special case them for now.
+ if (isa<cf::CondBranchOp>(op.getOperation()))
+ return op.getOperation();
+
+ if (op->getNumSuccessors() != 1)
+ return emitError(op.getLoc(),
+ "only BranchOpInterface operations with exactly "
+ "one successor are supported yet");
+
+ if (op.getSuccessorOperands(0).getProducedOperandCount() > 0)
+ return op.emitError("produced operands are not supported");
+
+ // Collect the values to deallocate and retain and use them to create the
+ // dealloc operation.
+ Block *block = op->getBlock();
+ OpBuilder builder(op);
+ SmallVector<Value> memrefs, conditions, toRetain;
+ if (failed(getMemrefsAndConditionsToDeallocate(builder, op.getLoc(), block,
+ memrefs, conditions)))
+ return failure();
+
+ OperandRange forwardedOperands =
+ op.getSuccessorOperands(0).getForwardedOperands();
+ getMemrefsToRetain(block, op->getSuccessor(0), forwardedOperands, toRetain);
+
+ auto deallocOp = builder.create<bufferization::DeallocOp>(
+ op.getLoc(), memrefs, conditions, toRetain);
+
+ // We want to replace the current ownership of the retained values with the
+ // result values of the dealloc operation as they are always unique.
+ clearOwnershipOf(deallocOp.getRetained(), block);
+ for (auto [retained, ownership] :
+ llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
+ joinOwnership(retained, ownership, block);
+ }
+
+ unsigned numAdditionalReturns = llvm::count_if(forwardedOperands, isMemref);
+ SmallVector<Value> newOperands(forwardedOperands);
+ auto additionalConditions =
+ deallocOp.getUpdatedConditions().take_front(numAdditionalReturns);
+ newOperands.append(additionalConditions.begin(), additionalConditions.end());
+ op.getSuccessorOperands(0).getMutableForwardedOperands().assign(newOperands);
+
+ return op.getOperation();
+}
+
+FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
+ OpBuilder builder(op);
+
+ // Lookup the function operation and check if it has private visibility. If
+ // the function is referenced by SSA value instead of a Symbol, it's assumed
+ // to be always private.
+ Operation *funcOp = op.resolveCallable(&symbolTable);
+ bool isPrivate = true;
+ if (auto symbol = dyn_cast<SymbolOpInterface>(funcOp))
+ isPrivate &= (symbol.getVisibility() == SymbolTable::Visibility::Private);
+
+ // If the private-function-dynamic-ownership option is enabled and we are
+ // calling a private function, we need to add an additional `i1`
+ // argument/result for each MemRef argument/result to dynamically pass the
+ // current ownership indicator rather than adhering to the function boundary
+ // ABI.
+ if (privateFuncDynamicOwnership && isPrivate) {
+ SmallVector<Value> newOperands, ownershipIndicatorsToAdd;
+ for (Value operand : op.getArgOperands()) {
+ if (!isMemref(operand)) {
+ newOperands.push_back(operand);
+ continue;
+ }
+ auto [memref, condition] = getMemrefWithUniqueOwnership(builder, operand);
+ newOperands.push_back(memref);
+ ownershipIndicatorsToAdd.push_back(condition);
+ }
+ newOperands.append(ownershipIndicatorsToAdd.begin(),
+ ownershipIndicatorsToAdd.end());
+ op.getArgOperandsMutable().assign(newOperands);
+
+ unsigned numMemrefs = llvm::count_if(op->getResults(), isMemref);
+ SmallVector<Type> ownershipTypesToAppend(numMemrefs, builder.getI1Type());
+ unsigned ownershipCounter = op->getNumResults();
+ op = appendOpResults(op, ownershipTypesToAppend);
+
+ for (auto result : llvm::make_filter_range(op->getResults(), isMemref)) {
+ joinOwnership(result, op->getResult(ownershipCounter++));
+ memrefsToDeallocatePerBlock[result.getParentBlock()].push_back(result);
+ }
+
+ return op.getOperation();
+ }
+
+ // According to the function boundary ABI we are guaranteed to get ownership
+ // of all MemRefs returned by the function. Thus we set ownership to constant
+ // 'true' and remember to deallocate it.
+ Value trueVal = buildBoolValue(builder, op.getLoc(), true);
+ for (auto result : llvm::make_filter_range(op->getResults(), isMemref)) {
+ joinOwnership(result, trueVal);
+ memrefsToDeallocatePerBlock[result.getParentBlock()].push_back(result);
+ }
+
+ return op.getOperation();
+}
+
+FailureOr<Operation *>
+BufferDeallocation::handleInterface(MemoryEffectOpInterface op) {
+ auto *block = op->getBlock();
+
+ for (auto operand : llvm::make_filter_range(op->getOperands(), isMemref))
+ if (op.getEffectOnValue<MemoryEffects::Free>(operand).has_value())
+ return op->emitError(
+ "memory free side-effect on MemRef value not supported!");
+
+ OpBuilder builder = OpBuilder::atBlockBegin(block);
+ for (auto res : llvm::make_filter_range(op->getResults(), isMemref)) {
+ auto allocEffect = op.getEffectOnValue<MemoryEffects::Allocate>(res);
+ if (allocEffect.has_value()) {
+ if (isa<SideEffects::AutomaticAllocationScopeResource>(
+ allocEffect->getResource())) {
+ // Make sure that the ownership of auto-managed allocations is set to
+ // false. This is important for operations that have at least one memref
+ // typed operand. E.g., consider an operation like `bufferization.clone`
+ // that lowers to a `memref.alloca + memref.copy` instead of a
+ // `memref.alloc`. If we wouldn't set the ownership of the result here,
+ // the default ownership population in `populateRemainingOwnerships`
+ // would assume aliasing with the MemRef operand.
+ clearOwnershipOf(res, block);
+ joinOwnership(res, buildBoolValue(builder, op.getLoc(), false));
+ continue;
+ }
+
+ joinOwnership(res, buildBoolValue(builder, op.getLoc(), true));
+ memrefsToDeallocatePerBlock[block].push_back(res);
+ }
+ }
+
+ return op.getOperation();
+}
+
+FailureOr<Operation *>
+BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
+ OpBuilder builder(op);
+
+ // If this is a return operation of a function that is not private or the
+ // dynamic function boundary ownership is disabled, we need to return memref
+ // values for which we have guaranteed ownership to pass on to adhere to the
+ // function boundary ABI.
+ bool funcWithoutDynamicOwnership =
+ isFunctionWithoutDynamicOwnership(op->getParentOp());
+ if (funcWithoutDynamicOwnership) {
+ for (OpOperand &val : op->getOpOperands()) {
+ if (!isMemref(val.get()))
+ continue;
+
+ val.set(getMemrefWithGuaranteedOwnership(builder, val.get()));
+ }
+ }
+
+ // TODO: getSuccessorRegions is not implemented by all operations we care
+ // about, but we would need to check how many successors there are and under
+ // which condition they are taken, etc.
+
+ MutableOperandRange operands =
+ op.getMutableSuccessorOperands(RegionBranchPoint::parent());
+
+ // Collect the values to deallocate and retain and use them to create the
+ // dealloc operation.
+ Block *block = op->getBlock();
+ SmallVector<Value> memrefs, conditions, toRetain;
+ if (failed(getMemrefsAndConditionsToDeallocate(builder, op.getLoc(), block,
+ memrefs, conditions)))
+ return failure();
+
+ getMemrefsToRetain(block, nullptr, OperandRange(operands), toRetain);
+ if (memrefs.empty() && toRetain.empty())
+ return op.getOperation();
+
+ auto deallocOp = builder.create<bufferization::DeallocOp>(
+ op.getLoc(), memrefs, conditions, toRetain);
+
+ // We want to replace the current ownership of the retained values with the
+ // result values of the dealloc operation as they are always unique.
+ clearOwnershipOf(deallocOp.getRetained(), block);
+ for (auto [retained, ownership] :
+ llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
+ joinOwnership(retained, ownership, block);
+
+ // Add an additional operand for every MemRef for the ownership indicator.
+ if (!funcWithoutDynamicOwnership) {
+ unsigned numMemRefs = llvm::count_if(operands, isMemref);
+ SmallVector<Value> newOperands{OperandRange(operands)};
+ auto ownershipValues =
+ deallocOp.getUpdatedConditions().take_front(numMemRefs);
+ newOperands.append(ownershipValues.begin(), ownershipValues.end());
+ operands.assign(newOperands);
+ }
+
+ return op.getOperation();
+}
+
+bool BufferDeallocation::isFunctionWithoutDynamicOwnership(Operation *op) {
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ return funcOp && (!privateFuncDynamicOwnership ||
+ funcOp.getVisibility() != SymbolTable::Visibility::Private);
+}
+
+void BufferDeallocation::populateRemainingOwnerships(Operation *op) {
+ for (auto res : op->getResults()) {
+ if (!isMemref(res))
+ continue;
+ if (ownershipMap.count({res, op->getBlock()}))
+ continue;
+
+ // Don't take ownership of a returned memref if no allocate side-effect is
+ // present, relevant for memref.get_global, for example.
+ if (op->getNumOperands() == 0) {
+ OpBuilder builder(op);
+ joinOwnership(res, buildBoolValue(builder, op->getLoc(), false));
+ continue;
+ }
+
+ // Assume the result may alias with any operand and thus combine all their
+ // ownerships.
+ for (auto operand : op->getOperands()) {
+ if (!isMemref(operand))
+ continue;
+
+ ownershipMap[{res, op->getBlock()}].combine(
+ ownershipMap[{operand, operand.getParentBlock()}]);
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// OwnershipBasedBufferDeallocationPass
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// The actual buffer deallocation pass that inserts and moves dealloc nodes
+/// into the right positions. Furthermore, it inserts additional clones if
+/// necessary. It uses the algorithm described at the top of the file.
+struct OwnershipBasedBufferDeallocationPass
+ : public bufferization::impl::OwnershipBasedBufferDeallocationBase<
+ OwnershipBasedBufferDeallocationPass> {
+ void runOnOperation() override {
+ func::FuncOp func = getOperation();
+ if (func.isExternal())
+ return;
+
+ if (failed(
+ deallocateBuffersOwnershipBased(func, privateFuncDynamicOwnership)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Implement bufferization API
+//===----------------------------------------------------------------------===//
+
+LogicalResult bufferization::deallocateBuffersOwnershipBased(
+ FunctionOpInterface op, bool privateFuncDynamicOwnership) {
+ // Gather all required allocation nodes and prepare the deallocation phase.
+ BufferDeallocation deallocation(op, privateFuncDynamicOwnership);
+
+ // Place all required temporary clone and dealloc nodes.
+ return deallocation.deallocate(op);
+}
+
+//===----------------------------------------------------------------------===//
+// OwnershipBasedBufferDeallocationPass construction
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass>
+mlir::bufferization::createOwnershipBasedBufferDeallocationPass() {
+ return std::make_unique<OwnershipBasedBufferDeallocationPass>();
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
new file mode 100644
index 000000000000000..23a628cc2b83d99
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -0,0 +1,589 @@
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
+// RUN: -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true -split-input-file %s > /dev/null
+
+// Test Case:
+// bb0
+// / \
+// bb1 bb2 <- Initial position of AllocOp
+// \ /
+// bb3
+// BufferDeallocation expected behavior: bb2 contains an AllocOp which is
+// passed to bb3. In the latter block, there should be a deallocation.
+// Since bb1 does not contain an adequate alloc, the deallocation has to be
+// made conditional on the branch taken in bb0.
+
+func.func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ cf.cond_br %arg0, ^bb2(%arg1 : memref<2xf32>), ^bb1
+^bb1:
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
+ cf.br ^bb2(%0 : memref<2xf32>)
+^bb2(%1: memref<2xf32>):
+ test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @condBranch
+// CHECK-SAME: ([[ARG0:%.+]]: i1,
+// CHECK-SAME: [[ARG1:%.+]]: memref<2xf32>,
+// CHECK-SAME: [[ARG2:%.+]]: memref<2xf32>)
+// CHECK-NOT: bufferization.dealloc
+// CHECK: cf.cond_br{{.*}}, ^bb2([[ARG1]], %false{{[0-9_]*}} :{{.*}}), ^bb1
+// CHECK: ^bb1:
+// CHECK: %[[ALLOC1:.*]] = memref.alloc
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: cf.br ^bb2(%[[ALLOC1]], %true
+// CHECK-NEXT: ^bb2([[ALLOC2:%.+]]: memref<2xf32>, [[COND1:%.+]]: i1):
+// CHECK: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC2]]
+// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND1]])
+// CHECK-NEXT: return
+
+// -----
+
+// Test Case:
+// bb0
+// / \
+// bb1 bb2 <- Initial position of AllocOp
+// \ /
+// bb3
+// BufferDeallocation expected behavior: The existing AllocOp has a dynamic
+// dependency to block argument %0 in bb2. Since the dynamic type is passed
+// to bb3 via the block argument %2, it is currently required to allocate a
+// temporary buffer for %2 that gets copies of %arg0 and %1 with their
+// appropriate shape dimensions. The copy buffer deallocation will be applied
+// to %2 in block bb3.
+
+func.func @condBranchDynamicType(
+ %arg0: i1,
+ %arg1: memref<?xf32>,
+ %arg2: memref<?xf32>,
+ %arg3: index) {
+ cf.cond_br %arg0, ^bb2(%arg1 : memref<?xf32>), ^bb1(%arg3: index)
+^bb1(%0: index):
+ %1 = memref.alloc(%0) : memref<?xf32>
+ test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
+ cf.br ^bb2(%1 : memref<?xf32>)
+^bb2(%2: memref<?xf32>):
+ test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>)
+ return
+}
+
+// CHECK-LABEL: func @condBranchDynamicType
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<?xf32>, [[ARG2:%.+]]: memref<?xf32>, [[ARG3:%.+]]: index)
+// CHECK-NOT: bufferization.dealloc
+// CHECK: cf.cond_br{{.*}}^bb2(%arg1, %false{{[0-9_]*}} :{{.*}}), ^bb1
+// CHECK: ^bb1([[IDX:%.*]]:{{.*}})
+// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: cf.br ^bb2([[ALLOC1]], %true
+// CHECK-NEXT: ^bb2([[ALLOC3:%.*]]:{{.*}}, [[COND:%.+]]:{{.*}})
+// CHECK: test.copy([[ALLOC3]],
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC3]]
+// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND]])
+// CHECK-NEXT: return
+
+// -----
+
+// Test case: See above.
+
+func.func @condBranchUnrankedType(
+ %arg0: i1,
+ %arg1: memref<*xf32>,
+ %arg2: memref<*xf32>,
+ %arg3: index) {
+ cf.cond_br %arg0, ^bb2(%arg1 : memref<*xf32>), ^bb1(%arg3: index)
+^bb1(%0: index):
+ %1 = memref.alloc(%0) : memref<?xf32>
+ %2 = memref.cast %1 : memref<?xf32> to memref<*xf32>
+ test.buffer_based in(%arg1: memref<*xf32>) out(%2: memref<*xf32>)
+ cf.br ^bb2(%2 : memref<*xf32>)
+^bb2(%3: memref<*xf32>):
+ test.copy(%3, %arg2) : (memref<*xf32>, memref<*xf32>)
+ return
+}
+
+// CHECK-LABEL: func @condBranchUnrankedType
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<*xf32>, [[ARG2:%.+]]: memref<*xf32>, [[ARG3:%.+]]: index)
+// CHECK-NOT: bufferization.dealloc
+// CHECK: cf.cond_br{{.*}}^bb2([[ARG1]], %false{{[0-9_]*}} :{{.*}}), ^bb1
+// CHECK: ^bb1([[IDX:%.*]]:{{.*}})
+// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
+// CHECK-NEXT: [[CAST:%.+]] = memref.cast [[ALLOC1]]
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: cf.br ^bb2([[CAST]], %true
+// CHECK-NEXT: ^bb2([[ALLOC3:%.*]]:{{.*}}, [[COND:%.+]]:{{.*}})
+// CHECK: test.copy([[ALLOC3]],
+// CHECK-NEXT: [[CAST:%.+]] = memref.reinterpret_cast [[ALLOC3]]
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[CAST]]
+// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND]])
+// CHECK-NEXT: return
+
+// TODO: we can get rid of first dealloc by doing some must-alias analysis
+
+// -----
+
+// Test Case:
+// bb0
+// / \
+// bb1 bb2 <- Initial position of AllocOp
+// | / \
+// | bb3 bb4
+// | \ /
+// \ bb5
+// \ /
+// bb6
+// |
+// bb7
+// BufferDeallocation expected behavior: The existing AllocOp has a dynamic
+// dependency to block argument %0 in bb2. Since the dynamic type is passed to
+// bb5 via the block argument %2 and to bb6 via block argument %3, it is
+// currently required to pass along the condition under which the newly
+// allocated buffer should be deallocated, since the path via bb1 does not
+// allocate a buffer.
+
+func.func @condBranchDynamicTypeNested(
+ %arg0: i1,
+ %arg1: memref<?xf32>,
+ %arg2: memref<?xf32>,
+ %arg3: index) {
+ cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
+^bb1:
+ cf.br ^bb6(%arg1 : memref<?xf32>)
+^bb2(%0: index):
+ %1 = memref.alloc(%0) : memref<?xf32>
+ test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
+ cf.cond_br %arg0, ^bb3, ^bb4
+^bb3:
+ cf.br ^bb5(%1 : memref<?xf32>)
+^bb4:
+ cf.br ^bb5(%1 : memref<?xf32>)
+^bb5(%2: memref<?xf32>):
+ cf.br ^bb6(%2 : memref<?xf32>)
+^bb6(%3: memref<?xf32>):
+ cf.br ^bb7(%3 : memref<?xf32>)
+^bb7(%4: memref<?xf32>):
+ test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>)
+ return
+}
+
+// CHECK-LABEL: func @condBranchDynamicTypeNested
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<?xf32>, [[ARG2:%.+]]: memref<?xf32>, [[ARG3:%.+]]: index)
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.cond_br{{.*}}
+// CHECK-NEXT: ^bb1
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
+// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
+// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
+// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
+// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
+// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
+// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
+// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
+// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+// CHECK: test.copy
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
+// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
+// CHECK-NEXT: return
+
+// TODO: the dealloc in bb5 can be optimized away by adding another
+// canonicalization pattern
+
+// -----
+
+// Test Case:
+// bb0
+// / \
+// | bb1 <- Initial position of AllocOp
+// \ /
+// bb2
+// BufferDeallocation expected behavior: It should insert a DeallocOp at the
+// exit block after CopyOp since %1 is an alias for %0 and %arg1.
+
+func.func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ cf.cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
+^bb1:
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
+ cf.br ^bb2(%0 : memref<2xf32>)
+^bb2(%1: memref<2xf32>):
+ test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @criticalEdge
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xf32>, [[ARG2:%.+]]: memref<2xf32>)
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.cond_br{{.*}}, ^bb1, ^bb2([[ARG1]], %false
+// CHECK: [[ALLOC1:%.*]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: cf.br ^bb2([[ALLOC1]], %true
+// CHECK-NEXT: ^bb2([[ALLOC2:%.+]]:{{.*}}, [[COND:%.+]]: {{.*}})
+// CHECK: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC2]]
+// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND]])
+// CHECK-NEXT: return
+
+// -----
+
+// Test Case:
+// bb0 <- Initial position of AllocOp
+// / \
+// | bb1
+// \ /
+// bb2
+// BufferDeallocation expected behavior: It only inserts a DeallocOp at the
+// exit block after CopyOp since %1 is an alias for %0 and %arg1.
+
+func.func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
+ cf.cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
+^bb1:
+ cf.br ^bb2(%0 : memref<2xf32>)
+^bb2(%1: memref<2xf32>):
+ test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @invCriticalEdge
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xf32>, [[ARG2:%.+]]: memref<2xf32>)
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : {{.*}}) if ([[NOT_ARG0]])
+// CHECK-NEXT: cf.cond_br{{.*}}^bb1, ^bb2([[ARG1]], %false
+// CHECK-NEXT: ^bb1:
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.br ^bb2([[ALLOC]], [[ARG0]]
+// CHECK-NEXT: ^bb2([[ALLOC1:%.+]]:{{.*}}, [[COND:%.+]]:{{.*}})
+// CHECK: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC1]]
+// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND]])
+// CHECK-NEXT: return
+
+// -----
+
+// Test Case:
+// bb0 <- Initial position of the first AllocOp
+// / \
+// bb1 bb2
+// \ /
+// bb3 <- Initial position of the second AllocOp
+// BufferDeallocation expected behavior: It only inserts two missing
+// DeallocOps in the exit block. %5 is an alias for %0. Therefore, the
+// DeallocOp for %0 should occur after the last BufferBasedOp. The Dealloc for
+// %7 should happen after CopyOp.
+
+func.func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
+ cf.cond_br %arg0,
+ ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
+ ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
+^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
+ cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
+^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
+ cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
+^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
+ %7 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
+ test.copy(%7, %arg2) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @ifElse
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xf32>, [[ARG2:%.+]]: memref<2xf32>)
+// CHECK: [[ALLOC0:%.+]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
+// CHECK-NEXT: cf.cond_br {{.*}}^bb1([[ARG1]], [[ALLOC0]], %false{{[0-9_]*}}, [[ARG0]] : {{.*}}), ^bb2([[ALLOC0]], [[ARG1]], [[NOT_ARG0]], %false{{[0-9_]*}} : {{.*}})
+// CHECK: ^bb3([[A0:%.+]]:{{.*}}, [[A1:%.+]]:{{.*}}, [[COND0:%.+]]: i1, [[COND1:%.+]]: i1):
+// CHECK: [[ALLOC1:%.+]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: test.copy
+// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A0]]
+// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A1]]
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC1]] : {{.*}}) if (%true
+// CHECK-NOT: retain
+// CHECK-NEXT: bufferization.dealloc ([[BASE0]], [[BASE1]] : {{.*}}) if ([[COND0]], [[COND1]])
+// CHECK-NOT: retain
+// CHECK-NEXT: return
+
+// TODO: Instead of deallocating the bbarg memrefs, a slightly better analysis
+// could do an unconditional deallocation on ALLOC0 and move it before the
+// test.copy (dealloc of ALLOC1 would remain after the copy)
+
+// -----
+
+// Test Case: No users for buffer in if-else CFG
+// bb0 <- Initial position of AllocOp
+// / \
+// bb1 bb2
+// \ /
+// bb3
+// BufferDeallocation expected behavior: It only inserts a missing DeallocOp
+// in the exit block since %5 or %6 are the latest aliases of %0.
+
+func.func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
+ cf.cond_br %arg0,
+ ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
+ ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
+^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
+ cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
+^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
+ cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
+^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
+ test.copy(%arg1, %arg2) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @ifElseNoUsers
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xf32>, [[ARG2:%.+]]: memref<2xf32>)
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
+// CHECK-NEXT: cf.cond_br {{.*}}^bb1([[ARG1]], [[ALLOC]], %false{{[0-9_]*}}, [[ARG0]] : {{.*}}), ^bb2([[ALLOC]], [[ARG1]], [[NOT_ARG0]], %false{{[0-9_]*}} : {{.*}})
+// CHECK: ^bb3([[A0:%.+]]:{{.*}}, [[A1:%.+]]:{{.*}}, [[COND0:%.+]]: i1, [[COND1:%.+]]: i1):
+// CHECK: test.copy
+// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A0]]
+// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A1]]
+// CHECK-NEXT: bufferization.dealloc ([[BASE0]], [[BASE1]] : {{.*}}) if ([[COND0]], [[COND1]])
+// CHECK-NOT: retain
+// CHECK-NEXT: return
+
+// TODO: slightly better analysis could just insert an unconditional dealloc on %0
+
+// -----
+
+// Test Case:
+// bb0 <- Initial position of the first AllocOp
+// / \
+// bb1 bb2
+// | / \
+// | bb3 bb4
+// \ \ /
+// \ /
+// bb5 <- Initial position of the second AllocOp
+// BufferDeallocation expected behavior: Two missing DeallocOps should be
+// inserted in the exit block.
+
+func.func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
+ cf.cond_br %arg0,
+ ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
+ ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
+^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
+ cf.br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
+^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
+ cf.cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
+^bb3(%5: memref<2xf32>):
+ cf.br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
+^bb4(%6: memref<2xf32>):
+ cf.br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
+^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
+ %9 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
+ test.copy(%9, %arg2) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @ifElseNested
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xf32>, [[ARG2:%.+]]: memref<2xf32>)
+// CHECK: [[ALLOC0:%.+]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
+// CHECK-NEXT: cf.cond_br {{.*}}^bb1([[ARG1]], [[ALLOC0]], %false{{[0-9_]*}}, [[ARG0]] : {{.*}}), ^bb2([[ALLOC0]], [[ARG1]], [[NOT_ARG0]], %false{{[0-9_]*}} :
+// CHECK: ^bb5([[A0:%.+]]: memref<2xf32>, [[A1:%.+]]: memref<2xf32>, [[COND0:%.+]]: i1, [[COND1:%.+]]: i1):
+// CHECK: [[ALLOC1:%.+]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: test.copy
+// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A0]]
+// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A1]]
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC1]] : {{.*}}) if (%true
+// CHECK-NOT: retain
+// CHECK-NEXT: bufferization.dealloc ([[BASE0]], [[BASE1]] : {{.*}}) if ([[COND0]], [[COND1]])
+// CHECK-NOT: retain
+// CHECK-NEXT: return
+
+// TODO: Instead of deallocating the bbarg memrefs, a slightly better analysis
+// could do an unconditional deallocation on ALLOC0 and move it before the
+// test.copy (dealloc of ALLOC1 would remain after the copy)
+
+// -----
+
+// Test Case:
+// bb0
+// / \
+// Initial pos of the 1st AllocOp -> bb1 bb2 <- Initial pos of the 2nd AllocOp
+// \ /
+// bb3
+// BufferDeallocation expected behavior: We need to introduce a copy for each
+// buffer since the buffers are passed to bb3. The both missing DeallocOps are
+// inserted in the respective block of the allocs. The copy is freed in the exit
+// block.
+
+func.func @moving_alloc_and_inserting_missing_dealloc(
+ %cond: i1,
+ %arg0: memref<2xf32>,
+ %arg1: memref<2xf32>) {
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>)
+ cf.br ^exit(%0 : memref<2xf32>)
+^bb2:
+ %1 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%1: memref<2xf32>) out(%arg0: memref<2xf32>)
+ cf.br ^exit(%1 : memref<2xf32>)
+^exit(%arg2: memref<2xf32>):
+ test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @moving_alloc_and_inserting_missing_dealloc
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG0:%.+]]: memref<2xf32>, [[ARG0:%.+]]: memref<2xf32>)
+// CHECK: ^bb1:
+// CHECK: [[ALLOC0:%.+]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: cf.br ^bb3([[ALLOC0]], %true
+// CHECK: ^bb2:
+// CHECK: [[ALLOC1:%.+]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: cf.br ^bb3([[ALLOC1]], %true
+// CHECK: ^bb3([[A0:%.+]]: memref<2xf32>, [[COND0:%.+]]: i1):
+// CHECK: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A0]]
+// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND0]])
+// CHECK-NEXT: return
+
+// -----
+
+func.func @select_aliases(%arg0: index, %arg1: memref<?xi8>, %arg2: i1) {
+ %0 = memref.alloc(%arg0) : memref<?xi8>
+ %1 = memref.alloc(%arg0) : memref<?xi8>
+ %2 = arith.select %arg2, %0, %1 : memref<?xi8>
+ test.copy(%2, %arg1) : (memref<?xi8>, memref<?xi8>)
+ return
+}
+
+// CHECK-LABEL: func @select_aliases
+// CHECK: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK: arith.select
+// CHECK: test.copy
+// CHECK: bufferization.dealloc ([[ALLOC0]] : {{.*}}) if (%true
+// CHECK-NOT: retain
+// CHECK: bufferization.dealloc ([[ALLOC1]] : {{.*}}) if (%true
+// CHECK-NOT: retain
+
+// -----
+
+func.func @select_aliases_not_same_ownership(%arg0: index, %arg1: memref<?xi8>, %arg2: i1) {
+ %0 = memref.alloc(%arg0) : memref<?xi8>
+ %1 = memref.alloca(%arg0) : memref<?xi8>
+ %2 = arith.select %arg2, %0, %1 : memref<?xi8>
+ cf.br ^bb1(%2 : memref<?xi8>)
+^bb1(%arg3: memref<?xi8>):
+ test.copy(%arg3, %arg1) : (memref<?xi8>, memref<?xi8>)
+ return
+}
+
+// CHECK-LABEL: func @select_aliases_not_same_ownership
+// CHECK: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: memref<?xi8>, [[ARG2:%.+]]: i1)
+// CHECK: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK: [[ALLOC1:%.+]] = memref.alloca(
+// CHECK: [[SELECT:%.+]] = arith.select
+// CHECK: [[OWN:%.+]] = bufferization.dealloc ([[ALLOC0]] :{{.*}}) if (%true{{[0-9_]*}}) retain ([[SELECT]] :
+// CHECK: cf.br ^bb1([[SELECT]], [[OWN]] :
+// CHECK: ^bb1([[A0:%.+]]: memref<?xi8>, [[COND:%.+]]: i1)
+// CHECK: test.copy
+// CHECK: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A0]]
+// CHECK: bufferization.dealloc ([[BASE0]] : {{.*}}) if ([[COND]])
+// CHECK-NOT: retain
+
+// -----
+
+func.func @select_captured_in_next_block(%arg0: index, %arg1: memref<?xi8>, %arg2: i1, %arg3: i1) {
+ %0 = memref.alloc(%arg0) : memref<?xi8>
+ %1 = memref.alloca(%arg0) : memref<?xi8>
+ %2 = arith.select %arg2, %0, %1 : memref<?xi8>
+ cf.cond_br %arg3, ^bb1(%0 : memref<?xi8>), ^bb1(%arg1 : memref<?xi8>)
+^bb1(%arg4: memref<?xi8>):
+ test.copy(%arg4, %2) : (memref<?xi8>, memref<?xi8>)
+ return
+}
+
+// CHECK-LABEL: func @select_captured_in_next_block
+// CHECK: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: memref<?xi8>, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: i1)
+// CHECK: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK: [[ALLOC1:%.+]] = memref.alloca(
+// CHECK: [[SELECT:%.+]] = arith.select
+// CHECK: [[OWN0:%.+]]:2 = bufferization.dealloc ([[ALLOC0]] :{{.*}}) if ([[ARG3]]) retain ([[ALLOC0]], [[SELECT]] :
+// CHECK: [[NOT_ARG3:%.+]] = arith.xori [[ARG3]], %true
+// CHECK: [[OWN1:%.+]] = bufferization.dealloc ([[ALLOC0]] :{{.*}}) if ([[NOT_ARG3]]) retain ([[SELECT]] :
+// CHECK: [[MERGED_OWN:%.+]] = arith.select [[ARG3]], [[OWN0]]#1, [[OWN1]]
+// CHECK: cf.cond_br{{.*}}^bb1([[ALLOC0]], [[OWN0]]#0 :{{.*}}), ^bb1([[ARG1]], %false
+// CHECK: ^bb1([[A0:%.+]]: memref<?xi8>, [[COND:%.+]]: i1)
+// CHECK: test.copy
+// CHECK: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[SELECT]]
+// CHECK: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A0]]
+// CHECK: bufferization.dealloc ([[BASE0]], [[BASE1]] : {{.*}}) if ([[MERGED_OWN]], [[COND]])
+
+// There are two interesting parts here:
+// * The dealloc condition of %0 in the second block should be the corresponding
+// result of the dealloc operation of the first block, because %0 has unknown
+// ownership status and thus would other wise require a clone in the first
+// block.
+// * The dealloc of the first block must make sure that the branch condition and
+// respective retained values are handled correctly, i.e., only the ones for the
+// actual branch taken have to be retained.
+
+// -----
+
+func.func @blocks_not_preordered_by_dominance() {
+ cf.br ^bb1
+^bb2:
+ "test.memref_user"(%alloc) : (memref<2xi32>) -> ()
+ return
+^bb1:
+ %alloc = memref.alloc() : memref<2xi32>
+ cf.br ^bb2
+}
+
+// CHECK-LABEL: func @blocks_not_preordered_by_dominance
+// CHECK-NEXT: [[TRUE:%.+]] = arith.constant true
+// CHECK-NEXT: cf.br [[BB1:\^.+]]
+// CHECK-NEXT: [[BB2:\^[a-zA-Z0-9_]+]]:
+// CHECK-NEXT: "test.memref_user"([[ALLOC:%[a-zA-Z0-9_]+]])
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : {{.*}}) if ([[TRUE]])
+// CHECK-NOT: retain
+// CHECK-NEXT: return
+// CHECK-NEXT: [[BB1]]:
+// CHECK-NEXT: [[ALLOC]] = memref.alloc()
+// CHECK-NEXT: cf.br [[BB2]]
+// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
new file mode 100644
index 000000000000000..67128fee3dfe0ab
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=false \
+// RUN: -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true \
+// RUN: --buffer-deallocation-simplification -split-input-file %s | FileCheck %s --check-prefix=CHECK-DYNAMIC
+
+func.func private @f(%arg0: memref<f64>) -> memref<f64> {
+ return %arg0 : memref<f64>
+}
+
+func.func @function_call() {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloc() : memref<f64>
+ %ret = call @f(%alloc) : (memref<f64>) -> memref<f64>
+ test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+ return
+}
+
+// CHECK-LABEL: func @function_call()
+// CHECK: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK-NEXT: [[RET:%.+]] = call @f([[ALLOC0]]) : (memref<f64>) -> memref<f64>
+// CHECK-NEXT: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]
+// COM: the following dealloc operation should be split into three since we can
+// COM: be sure that the memrefs will never alias according to the buffer
+// COM: deallocation ABI, however, the local alias analysis is not powerful
+// COM: enough to detect this yet.
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC0]], [[ALLOC1]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, %true{{[0-9_]*}})
+
+// CHECK-DYNAMIC-LABEL: func @function_call()
+// CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK-DYNAMIC-NEXT: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[ALLOC0]], %true{{[0-9_]*}}) : (memref<f64>, i1) -> (memref<f64>, i1)
+// CHECK-DYNAMIC-NEXT: test.copy
+// CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
+// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[ALLOC1]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, [[RET]]#1)
+
+// -----
+
+func.func @f(%arg0: memref<f64>) -> memref<f64> {
+ return %arg0 : memref<f64>
+}
+
+func.func @function_call_non_private() {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloc() : memref<f64>
+ %ret = call @f(%alloc) : (memref<f64>) -> memref<f64>
+ test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+ return
+}
+
+// CHECK-LABEL: func @function_call_non_private
+// CHECK: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK: [[RET:%.+]] = call @f([[ALLOC0]]) : (memref<f64>) -> memref<f64>
+// CHECK-NEXT: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC0]], [[ALLOC1]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, %true{{[0-9_]*}})
+// CHECK-NEXT: return
+
+// CHECK-DYNAMIC-LABEL: func @function_call_non_private
+// CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK-DYNAMIC: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK-DYNAMIC: [[RET:%.+]] = call @f([[ALLOC0]]) : (memref<f64>) -> memref<f64>
+// CHECK-DYNAMIC-NEXT: test.copy
+// CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]
+// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[ALLOC1]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, %true{{[0-9_]*}})
+// CHECK-DYNAMIC-NEXT: return
+
+// -----
+
+func.func private @f(%arg0: memref<f64>) -> memref<f64> {
+ return %arg0 : memref<f64>
+}
+
+func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloca() : memref<f64>
+ %0 = arith.select %arg0, %alloc, %alloc2 : memref<f64>
+ %ret = call @f(%0) : (memref<f64>) -> memref<f64>
+ test.copy(%ret, %alloc) : (memref<f64>, memref<f64>)
+ return
+}
+
+// CHECK-LABEL: func @function_call_requries_merged_ownership_mid_block
+// CHECK: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloca(
+// CHECK-NEXT: [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
+// CHECK-NEXT: [[RET:%.+]] = call @f([[SELECT]])
+// CHECK-NEXT: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC0]], [[BASE]] :
+// CHECK-SAME: if (%true{{[0-9_]*}}, %true{{[0-9_]*}})
+// CHECK-NOT: retain
+// CHECK-NEXT: return
+
+// CHECK-DYNAMIC-LABEL: func @function_call_requries_merged_ownership_mid_block
+// CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK-DYNAMIC-NEXT: [[ALLOC1:%.+]] = memref.alloca(
+// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
+// CHECK-DYNAMIC-NEXT: [[CLONE:%.+]] = bufferization.clone [[SELECT]]
+// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[CLONE]], %true{{[0-9_]*}})
+// CHECK-DYNAMIC-NEXT: test.copy
+// CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
+// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[CLONE]], [[BASE]] :
+// CHECK-DYNAMIC-SAME: if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, [[RET]]#1)
+// CHECK-DYNAMIC-NOT: retain
+// CHECK-DYNAMIC-NEXT: return
+
+// TODO: the inserted clone is not necessary, we just have to know which of the
+// two allocations was selected, either by checking aliasing of the result at
+// runtime or by extracting the select condition using an OpInterface or by
+// hardcoding the select op
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-existing-deallocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-existing-deallocs.mlir
new file mode 100644
index 000000000000000..bf4eabd31a81241
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-existing-deallocs.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt -verify-diagnostics -expand-realloc=emit-deallocs=false -ownership-based-buffer-deallocation \
+// RUN: --buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+
+func.func @auto_dealloc() {
+ %c10 = arith.constant 10 : index
+ %c100 = arith.constant 100 : index
+ %alloc = memref.alloc(%c10) : memref<?xi32>
+ %realloc = memref.realloc %alloc(%c100) : memref<?xi32> to memref<?xi32>
+ "test.memref_user"(%realloc) : (memref<?xi32>) -> ()
+ return
+}
+
+// CHECK-LABEL: func @auto_dealloc
+// CHECK: [[ALLOC:%.*]] = memref.alloc(
+// CHECK-NOT: bufferization.dealloc
+// CHECK: [[V0:%.+]]:2 = scf.if
+// CHECK-NOT: bufferization.dealloc
+// CHECK: test.memref_user
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1)
+// CHECK-NEXT: return
+
+// -----
+
+func.func @auto_dealloc_inside_nested_region(%arg0: memref<?xi32>, %arg1: i1) {
+ %c100 = arith.constant 100 : index
+ %0 = scf.if %arg1 -> memref<?xi32> {
+ %realloc = memref.realloc %arg0(%c100) : memref<?xi32> to memref<?xi32>
+ scf.yield %realloc : memref<?xi32>
+ } else {
+ scf.yield %arg0 : memref<?xi32>
+ }
+ "test.memref_user"(%0) : (memref<?xi32>) -> ()
+ return
+}
+
+// CHECK-LABEL: func @auto_dealloc_inside_nested_region
+// CHECK-SAME: (%arg0: memref<?xi32>, %arg1: i1)
+// CHECK-NOT: dealloc
+// CHECK: "test.memref_user"([[V0:%.+]]#0)
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK-NEXT: bufferization.dealloc ([[BASE]] : memref<i32>) if ([[V0]]#1)
+// CHECK-NEXT: return
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-function-boundaries.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-function-boundaries.mlir
new file mode 100644
index 000000000000000..44f3e20c5009309
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-function-boundaries.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt --allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=false \
+// RUN: --buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+// RUN: mlir-opt --allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true \
+// RUN: --buffer-deallocation-simplification -split-input-file %s | FileCheck %s --check-prefix=CHECK-DYNAMIC
+
+// Test Case: Existing AllocOp with no users.
+// BufferDeallocation expected behavior: It should insert a DeallocOp right
+// before ReturnOp.
+
+func.func private @emptyUsesValue(%arg0: memref<4xf32>) {
+ %0 = memref.alloc() : memref<4xf32>
+ "test.memref_user"(%0) : (memref<4xf32>) -> ()
+ return
+}
+
+// CHECK-LABEL: func private @emptyUsesValue(
+// CHECK: [[ALLOC:%.*]] = memref.alloc()
+// CHECK: bufferization.dealloc ([[ALLOC]] :
+// CHECK-SAME: if (%true{{[0-9_]*}})
+// CHECK-NOT: retain
+// CHECK-NEXT: return
+
+// CHECK-DYNAMIC-LABEL: func private @emptyUsesValue(
+// CHECK-DYNAMIC-SAME: [[ARG0:%.+]]: memref<4xf32>, [[ARG1:%.+]]: i1)
+// CHECK-DYNAMIC: [[ALLOC:%.*]] = memref.alloc()
+// CHECK-DYNAMIC: [[BASE:%[a-zA-Z0-9_]+]], {{.*}} = memref.extract_strided_metadata [[ARG0]]
+// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG1]])
+// CHECK-DYNAMIC-NOT: retain
+// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true{{[0-9_]*}})
+// CHECK-DYNAMIC-NOT: retain
+// CHECK-DYNAMIC-NEXT: return
+
+// -----
+
+func.func @emptyUsesValue(%arg0: memref<4xf32>) {
+ %0 = memref.alloc() : memref<4xf32>
+ "test.memref_user"(%0) : (memref<4xf32>) -> ()
+ return
+}
+
+// CHECK-LABEL: func @emptyUsesValue(
+
+// CHECK-DYNAMIC-LABEL: func @emptyUsesValue(
+// CHECK-DYNAMIC: [[ALLOC:%.*]] = memref.alloc()
+// CHECK-DYNAMIC: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true{{[0-9_]*}})
+// CHECK-DYNAMIC-NOT: retain
+// CHECK-DYNAMIC-NEXT: return
+
+// -----
+
+// Test Case: Dead operations in a single block.
+// BufferDeallocation expected behavior: It only inserts the two missing
+// DeallocOps after the last BufferBasedOp.
+
+func.func private @redundantOperations(%arg0: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>)
+ %1 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%0: memref<2xf32>) out(%1: memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func private @redundantOperations
+// CHECK: (%[[ARG0:.*]]: {{.*}})
+// CHECK: %[[FIRST_ALLOC:.*]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK: %[[SECOND_ALLOC:.*]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based
+// CHECK-NEXT: bufferization.dealloc (%[[FIRST_ALLOC]] : {{.*}}) if (%true{{[0-9_]*}})
+// CHECK-NEXT: bufferization.dealloc (%[[SECOND_ALLOC]] : {{.*}}) if (%true{{[0-9_]*}})
+// CHECK-NEXT: return
+
+// CHECK-DYNAMIC-LABEL: func private @redundantOperations
+// CHECK-DYNAMIC: (%[[ARG0:.*]]: memref{{.*}}, %[[ARG1:.*]]: i1)
+// CHECK-DYNAMIC: %[[FIRST_ALLOC:.*]] = memref.alloc()
+// CHECK-DYNAMIC-NEXT: test.buffer_based
+// CHECK-DYNAMIC: %[[SECOND_ALLOC:.*]] = memref.alloc()
+// CHECK-DYNAMIC-NEXT: test.buffer_based
+// CHECK-DYNAMIC-NEXT: %[[BASE:[a-zA-Z0-9_]+]], {{.*}} = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-DYNAMIC-NEXT: bufferization.dealloc (%[[BASE]] : {{.*}}) if (%[[ARG1]])
+// CHECK-DYNAMIC-NEXT: bufferization.dealloc (%[[FIRST_ALLOC]] : {{.*}}) if (%true{{[0-9_]*}})
+// CHECK-DYNAMIC-NEXT: bufferization.dealloc (%[[SECOND_ALLOC]] : {{.*}}) if (%true{{[0-9_]*}})
+// CHECK-DYNAMIC-NEXT: return
+
+// -----
+
+// Test Case: buffer deallocation escaping
+// BufferDeallocation expected behavior: It must not dealloc %arg1 and %x
+// since they are operands of return operation and should escape from
+// deallocating. It should dealloc %y after CopyOp.
+
+func.func private @memref_in_function_results(
+ %arg0: memref<5xf32>,
+ %arg1: memref<10xf32>,
+ %arg2: memref<5xf32>) -> (memref<10xf32>, memref<15xf32>) {
+ %x = memref.alloc() : memref<15xf32>
+ %y = memref.alloc() : memref<5xf32>
+ test.buffer_based in(%arg0: memref<5xf32>) out(%y: memref<5xf32>)
+ test.copy(%y, %arg2) : (memref<5xf32>, memref<5xf32>)
+ return %arg1, %x : memref<10xf32>, memref<15xf32>
+}
+
+// CHECK-LABEL: func private @memref_in_function_results
+// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>,
+// CHECK-SAME: %[[RESULT:.*]]: memref<5xf32>)
+// CHECK: %[[X:.*]] = memref.alloc()
+// CHECK: %[[Y:.*]] = memref.alloc()
+// CHECK: test.copy
+// CHECK-NEXT: %[[V0:.+]] = scf.if %false
+// CHECK-NEXT: scf.yield %[[ARG1]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[CLONE:.+]] = bufferization.clone %[[ARG1]]
+// CHECK-NEXT: scf.yield %[[CLONE]]
+// CHECK-NEXT: }
+// CHECK: bufferization.dealloc (%[[Y]] : {{.*}}) if (%true{{[0-9_]*}})
+// CHECK-NOT: retain
+// CHECK: return %[[V0]], %[[X]]
+
+// CHECK-DYNAMIC-LABEL: func private @memref_in_function_results
+// CHECK-DYNAMIC: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>,
+// CHECK-DYNAMIC-SAME: %[[RESULT:.*]]: memref<5xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1, %[[ARG5:.*]]: i1)
+// CHECK-DYNAMIC: %[[X:.*]] = memref.alloc()
+// CHECK-DYNAMIC: %[[Y:.*]] = memref.alloc()
+// CHECK-DYNAMIC: test.copy
+// CHECK-DYNAMIC: %[[BASE0:[a-zA-Z0-9_]+]], {{.+}} = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-DYNAMIC: %[[BASE1:[a-zA-Z0-9_]+]], {{.+}} = memref.extract_strided_metadata %[[RESULT]]
+// CHECK-DYNAMIC: bufferization.dealloc (%[[Y]] : {{.*}}) if (%true{{[0-9_]*}})
+// CHECK-DYNAMIC-NOT: retain
+// CHECK-DYNAMIC: [[OWN:%.+]] = bufferization.dealloc (%[[BASE0]], %[[BASE1]] : {{.*}}) if (%[[ARG3]], %[[ARG5]]) retain (%[[ARG1]] :
+// CHECK-DYNAMIC: [[OR:%.+]] = arith.ori [[OWN]], %[[ARG4]]
+// CHECK-DYNAMIC: return %[[ARG1]], %[[X]], [[OR]], %true
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir
new file mode 100644
index 000000000000000..460e37aa03059ff
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir
@@ -0,0 +1,124 @@
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
+// RUN: --buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true -split-input-file %s > /dev/null
+
+// Test Case: Dead operations in a single block.
+// BufferDeallocation expected behavior: It only inserts the two missing
+// DeallocOps after the last BufferBasedOp.
+
+// CHECK-LABEL: func @redundantOperations
+func.func @redundantOperations(%arg0: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>)
+ %1 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%0: memref<2xf32>) out(%1: memref<2xf32>)
+ return
+}
+
+// CHECK: (%[[ARG0:.*]]: {{.*}})
+// CHECK: %[[FIRST_ALLOC:.*]] = memref.alloc()
+// CHECK-NOT: bufferization.dealloc
+// CHECK: test.buffer_based in(%[[ARG0]]{{.*}}out(%[[FIRST_ALLOC]]
+// CHECK-NOT: bufferization.dealloc
+// CHECK: %[[SECOND_ALLOC:.*]] = memref.alloc()
+// CHECK-NOT: bufferization.dealloc
+// CHECK: test.buffer_based in(%[[FIRST_ALLOC]]{{.*}}out(%[[SECOND_ALLOC]]
+// CHECK: bufferization.dealloc (%[[FIRST_ALLOC]] :{{.*}}) if (%true{{[0-9_]*}})
+// CHECK: bufferization.dealloc (%[[SECOND_ALLOC]] :{{.*}}) if (%true{{[0-9_]*}})
+// CHECK-NEXT: return
+
+// TODO: The dealloc could be split in two to avoid runtime aliasing checks
+// since we can be sure at compile time that they will never alias.
+
+// -----
+
+// CHECK-LABEL: func @allocaIsNotDeallocated
+func.func @allocaIsNotDeallocated(%arg0: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>)
+ %1 = memref.alloca() : memref<2xf32>
+ test.buffer_based in(%0: memref<2xf32>) out(%1: memref<2xf32>)
+ return
+}
+
+// CHECK: (%[[ARG0:.*]]: {{.*}})
+// CHECK: %[[FIRST_ALLOC:.*]] = memref.alloc()
+// CHECK-NEXT: test.buffer_based in(%[[ARG0]]{{.*}}out(%[[FIRST_ALLOC]]
+// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = memref.alloca()
+// CHECK-NEXT: test.buffer_based in(%[[FIRST_ALLOC]]{{.*}}out(%[[SECOND_ALLOC]]
+// CHECK: bufferization.dealloc (%[[FIRST_ALLOC]] :{{.*}}) if (%true{{[0-9_]*}})
+// CHECK-NEXT: return
+
+// -----
+
+// Test Case: Inserting missing DeallocOp in a single block.
+
+// CHECK-LABEL: func @inserting_missing_dealloc_simple
+func.func @inserting_missing_dealloc_simple(
+ %arg0 : memref<2xf32>,
+ %arg1: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>)
+ test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK: %[[ALLOC0:.*]] = memref.alloc()
+// CHECK: test.copy
+// CHECK: bufferization.dealloc (%[[ALLOC0]] :{{.*}}) if (%true{{[0-9_]*}})
+
+// -----
+
+// Test Case: The ownership indicator is set to false for alloca
+
+// CHECK-LABEL: func @alloca_ownership_indicator_is_false
+func.func @alloca_ownership_indicator_is_false() {
+ %0 = memref.alloca() : memref<2xf32>
+ cf.br ^bb1(%0: memref<2xf32>)
+^bb1(%arg0 : memref<2xf32>):
+ return
+}
+
+// CHECK: %[[ALLOC0:.*]] = memref.alloca()
+// CHECK-NEXT: cf.br ^bb1(%[[ALLOC0]], %false :
+// CHECK-NEXT: ^bb1([[A0:%.+]]: memref<2xf32>, [[COND0:%.+]]: i1):
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A0]]
+// CHECK: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND0]])
+// CHECK-NEXT: return
+
+// -----
+
+func.func @dealloc_existing_clones(%arg0: memref<?x?xf64>, %arg1: memref<?x?xf64>) -> memref<?x?xf64> {
+ %0 = bufferization.clone %arg0 : memref<?x?xf64> to memref<?x?xf64>
+ %1 = bufferization.clone %arg1 : memref<?x?xf64> to memref<?x?xf64>
+ return %0 : memref<?x?xf64>
+}
+
+// CHECK-LABEL: func @dealloc_existing_clones
+// CHECK: (%[[ARG0:.*]]: memref<?x?xf64>, %[[ARG1:.*]]: memref<?x?xf64>)
+// CHECK: %[[RES0:.*]] = bufferization.clone %[[ARG0]]
+// CHECK: %[[RES1:.*]] = bufferization.clone %[[ARG1]]
+// CHECK-NEXT: bufferization.dealloc (%[[RES1]] :{{.*}}) if (%true{{[0-9_]*}})
+// CHECK-NOT: retain
+// CHECK-NEXT: return %[[RES0]]
+
+// TODO: The retain operand could be dropped to avoid runtime aliasing checks
+// since We can guarantee at compile-time that it will never alias with the
+// dealloc operand
+
+// -----
+
+memref.global "private" constant @__constant_4xf32 : memref<4xf32> = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]>
+
+func.func @op_without_aliasing_and_allocation() -> memref<4xf32> {
+ %0 = memref.get_global @__constant_4xf32 : memref<4xf32>
+ return %0 : memref<4xf32>
+}
+
+// CHECK-LABEL: func @op_without_aliasing_and_allocation
+// CHECK: [[GLOBAL:%.+]] = memref.get_global @__constant_4xf32
+// CHECK: [[RES:%.+]] = scf.if %false
+// CHECK: scf.yield [[GLOBAL]] :
+// CHECK: [[CLONE:%.+]] = bufferization.clone [[GLOBAL]]
+// CHECK: scf.yield [[CLONE]] :
+// CHECK: return [[RES]] :
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
new file mode 100644
index 000000000000000..d8090591c70513f
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
@@ -0,0 +1,695 @@
+// RUN: mlir-opt -allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation \
+// RUN: --buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true -split-input-file %s > /dev/null
+
+// Test Case: Nested regions - This test defines a BufferBasedOp inside the
+// region of a RegionBufferBasedOp.
+// BufferDeallocation expected behavior: The AllocOp for the BufferBasedOp
+// should remain inside the region of the RegionBufferBasedOp and it should insert
+// the missing DeallocOp in the same region. The missing DeallocOp should be
+// inserted after CopyOp.
+
+func.func @nested_regions_and_cond_branch(
+ %arg0: i1,
+ %arg1: memref<2xf32>,
+ %arg2: memref<2xf32>) {
+ cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+ cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+ %0 = memref.alloc() : memref<2xf32>
+ test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
+ ^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
+ %1 = memref.alloc() : memref<2xf32>
+ test.buffer_based in(%arg1: memref<2xf32>) out(%1: memref<2xf32>)
+ %tmp1 = math.exp %gen1_arg0 : f32
+ test.region_yield %tmp1 : f32
+ }
+ cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+ test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @nested_regions_and_cond_branch
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xf32>, [[ARG2:%.+]]: memref<2xf32>)
+// CHECK: ^bb1:
+// CHECK-NOT: bufferization.clone
+// CHECK-NOT: bufferization.dealloc
+// CHECK: cf.br ^bb3([[ARG1]], %false
+// CHECK: ^bb2:
+// CHECK: [[ALLOC0:%.+]] = memref.alloc()
+// CHECK: test.region_buffer_based
+// CHECK: [[ALLOC1:%.+]] = memref.alloc()
+// CHECK: test.buffer_based
+// CHECK: bufferization.dealloc ([[ALLOC1]] : memref<2xf32>) if (%true
+// CHECK-NEXT: test.region_yield
+// CHECK-NOT: bufferization.clone
+// CHECK-NOT: bufferization.dealloc
+// CHECK: cf.br ^bb3([[ALLOC0]], %true
+// CHECK: ^bb3([[A0:%.+]]: memref<2xf32>, [[COND0:%.+]]: i1):
+// CHECK: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[A0]]
+// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND0]])
+// CHECK: return
+
+// -----
+
+// Test Case: nested region control flow
+// The alloc %1 flows through both if branches until it is finally returned.
+// Hence, it does not require a specific dealloc operation. However, %3
+// requires a dealloc.
+
+func.func @nested_region_control_flow(
+ %arg0 : index,
+ %arg1 : index) -> memref<?x?xf32> {
+ %0 = arith.cmpi eq, %arg0, %arg1 : index
+ %1 = memref.alloc(%arg0, %arg0) : memref<?x?xf32>
+ %2 = scf.if %0 -> (memref<?x?xf32>) {
+ scf.yield %1 : memref<?x?xf32>
+ } else {
+ %3 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
+ "test.memref_user"(%3) : (memref<?x?xf32>) -> ()
+ scf.yield %1 : memref<?x?xf32>
+ }
+ return %2 : memref<?x?xf32>
+}
+
+// CHECK-LABEL: func @nested_region_control_flow
+// CHECK: [[ALLOC:%.+]] = memref.alloc(
+// CHECK: [[V0:%.+]]:2 = scf.if
+// CHECK: scf.yield [[ALLOC]], %false
+// CHECK: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK: bufferization.dealloc ([[ALLOC1]] :{{.*}}) if (%true{{[0-9_]*}})
+// CHECK-NOT: retain
+// CHECK: scf.yield [[ALLOC]], %false
+// CHECK: [[V1:%.+]] = scf.if [[V0]]#1
+// CHECK: scf.yield [[V0]]#0
+// CHECK: [[CLONE:%.+]] = bufferization.clone [[V0]]#0
+// CHECK: scf.yield [[CLONE]]
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] : {{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]] :
+// CHECK: return [[V1]]
+
+// -----
+
+// Test Case: nested region control flow with a nested buffer allocation in a
+// divergent branch.
+// Buffer deallocation places a copy for both %1 and %3, since they are
+// returned in the end.
+
+func.func @nested_region_control_flow_div(
+ %arg0 : index,
+ %arg1 : index) -> memref<?x?xf32> {
+ %0 = arith.cmpi eq, %arg0, %arg1 : index
+ %1 = memref.alloc(%arg0, %arg0) : memref<?x?xf32>
+ %2 = scf.if %0 -> (memref<?x?xf32>) {
+ scf.yield %1 : memref<?x?xf32>
+ } else {
+ %3 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
+ scf.yield %3 : memref<?x?xf32>
+ }
+ return %2 : memref<?x?xf32>
+}
+
+// CHECK-LABEL: func @nested_region_control_flow_div
+// CHECK: [[ALLOC:%.+]] = memref.alloc(
+// CHECK: [[V0:%.+]]:2 = scf.if
+// CHECK: scf.yield [[ALLOC]], %false
+// CHECK: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK: scf.yield [[ALLOC1]], %true
+// CHECK: [[V1:%.+]] = scf.if [[V0]]#1
+// CHECK: scf.yield [[V0]]#0
+// CHECK: [[CLONE:%.+]] = bufferization.clone [[V0]]#0
+// CHECK: scf.yield [[CLONE]]
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]] :
+// CHECK: return [[V1]]
+
+// -----
+
+// Test Case: nested region control flow within a region interface.
+// No copies are required in this case since the allocation finally escapes
+// the method.
+
+func.func @inner_region_control_flow(%arg0 : index) -> memref<?x?xf32> {
+ %0 = memref.alloc(%arg0, %arg0) : memref<?x?xf32>
+ %1 = test.region_if %0 : memref<?x?xf32> -> (memref<?x?xf32>) then {
+ ^bb0(%arg1 : memref<?x?xf32>):
+ test.region_if_yield %arg1 : memref<?x?xf32>
+ } else {
+ ^bb0(%arg1 : memref<?x?xf32>):
+ test.region_if_yield %arg1 : memref<?x?xf32>
+ } join {
+ ^bb0(%arg1 : memref<?x?xf32>):
+ test.region_if_yield %arg1 : memref<?x?xf32>
+ }
+ return %1 : memref<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @inner_region_control_flow
+// CHECK: [[ALLOC:%.+]] = memref.alloc(
+// CHECK: [[V0:%.+]]:2 = test.region_if [[ALLOC]], %false
+// CHECK: ^bb0([[ARG1:%.+]]: memref<?x?xf32>, [[ARG2:%.+]]: i1):
+// CHECK: test.region_if_yield [[ARG1]], [[ARG2]]
+// CHECK: ^bb0([[ARG1:%.+]]: memref<?x?xf32>, [[ARG2:%.+]]: i1):
+// CHECK: test.region_if_yield [[ARG1]], [[ARG2]]
+// CHECK: ^bb0([[ARG1:%.+]]: memref<?x?xf32>, [[ARG2:%.+]]: i1):
+// CHECK: test.region_if_yield [[ARG1]], [[ARG2]]
+// CHECK: [[V1:%.+]] = scf.if [[V0]]#1
+// CHECK: scf.yield [[V0]]#0
+// CHECK: [[CLONE:%.+]] = bufferization.clone [[V0]]#0
+// CHECK: scf.yield [[CLONE]]
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]] :
+// CHECK: return [[V1]]
+
+// -----
+
+func.func @nestedRegionsAndCondBranchAlloca(
+ %arg0: i1,
+ %arg1: memref<2xf32>,
+ %arg2: memref<2xf32>) {
+ cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+ cf.br ^bb3(%arg1 : memref<2xf32>)
+^bb2:
+ %0 = memref.alloc() : memref<2xf32>
+ test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
+ ^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
+ %1 = memref.alloca() : memref<2xf32>
+ test.buffer_based in(%arg1: memref<2xf32>) out(%1: memref<2xf32>)
+ %tmp1 = math.exp %gen1_arg0 : f32
+ test.region_yield %tmp1 : f32
+ }
+ cf.br ^bb3(%0 : memref<2xf32>)
+^bb3(%1: memref<2xf32>):
+ test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @nestedRegionsAndCondBranchAlloca
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xf32>, [[ARG2:%.+]]: memref<2xf32>)
+// CHECK: ^bb1:
+// CHECK: cf.br ^bb3([[ARG1]], %false
+// CHECK: ^bb2:
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: test.region_buffer_based
+// CHECK: memref.alloca()
+// CHECK: test.buffer_based
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: test.region_yield
+// CHECK: }
+// CHECK: cf.br ^bb3([[ALLOC]], %true
+// CHECK: ^bb3([[A0:%.+]]: memref<2xf32>, [[COND:%.+]]: i1):
+// CHECK: test.copy
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[A0]]
+// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND]])
+
+// -----
+
+func.func @nestedRegionControlFlowAlloca(
+ %arg0 : index, %arg1 : index, %arg2: f32) -> memref<?x?xf32> {
+ %0 = arith.cmpi eq, %arg0, %arg1 : index
+ %1 = memref.alloc(%arg0, %arg0) : memref<?x?xf32>
+ %2 = scf.if %0 -> (memref<?x?xf32>) {
+ scf.yield %1 : memref<?x?xf32>
+ } else {
+ %3 = memref.alloca(%arg0, %arg1) : memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ memref.store %arg2, %3[%c0, %c0] : memref<?x?xf32>
+ scf.yield %1 : memref<?x?xf32>
+ }
+ return %2 : memref<?x?xf32>
+}
+
+// CHECK-LABEL: func @nestedRegionControlFlowAlloca
+// CHECK: [[ALLOC:%.+]] = memref.alloc(
+// CHECK: [[V0:%.+]]:2 = scf.if
+// CHECK: scf.yield [[ALLOC]], %false
+// CHECK: memref.alloca(
+// CHECK: scf.yield [[ALLOC]], %false
+// CHECK: [[V1:%.+]] = scf.if [[V0]]#1
+// CHECK: scf.yield [[V0]]#0
+// CHECK: [[CLONE:%.+]] = bufferization.clone [[V0]]#0
+// CHECK: scf.yield [[CLONE]]
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]] :
+// CHECK: return [[V1]]
+
+// -----
+
+// Test Case: structured control-flow loop using a nested alloc.
+// The iteration argument %iterBuf has to be freed before yielding %3 to avoid
+// memory leaks.
+
+func.func @loop_alloc(
+ %lb: index,
+ %ub: index,
+ %step: index,
+ %buf: memref<2xf32>,
+ %res: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ "test.memref_user"(%0) : (memref<2xf32>) -> ()
+ %1 = scf.for %i = %lb to %ub step %step
+ iter_args(%iterBuf = %buf) -> memref<2xf32> {
+ %2 = arith.cmpi eq, %i, %ub : index
+ %3 = memref.alloc() : memref<2xf32>
+ scf.yield %3 : memref<2xf32>
+ }
+ test.copy(%1, %res) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @loop_alloc
+// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index, [[ARG2:%.+]]: index, [[ARG3:%.+]]: memref<2xf32>, [[ARG4:%.+]]: memref<2xf32>)
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: [[V0:%.+]]:2 = scf.for {{.*}} iter_args([[ARG6:%.+]] = [[ARG3]], [[ARG7:%.+]] = %false
+// CHECK: [[ALLOC1:%.+]] = memref.alloc()
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG6]]
+// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG7]]) retain ([[ALLOC1]] :
+// CHECK: scf.yield [[ALLOC1]], %true
+// CHECK: test.copy
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true
+// CHECK-NOT: retain
+// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[V0]]#1)
+// CHECK-NOT: retain
+
+// -----
+
+// Test Case: structured control-flow loop with a nested if operation.
+// The loop yields buffers that have been defined outside of the loop and the
+// backedges only use the iteration arguments (or one of its aliases).
+// Therefore, we do not have to (and are not allowed to) free any buffers
+// that are passed via the backedges.
+
+func.func @loop_nested_if_no_alloc(
+ %lb: index,
+ %ub: index,
+ %step: index,
+ %buf: memref<2xf32>,
+ %res: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ %1 = scf.for %i = %lb to %ub step %step
+ iter_args(%iterBuf = %buf) -> memref<2xf32> {
+ %2 = arith.cmpi eq, %i, %ub : index
+ %3 = scf.if %2 -> (memref<2xf32>) {
+ scf.yield %0 : memref<2xf32>
+ } else {
+ scf.yield %iterBuf : memref<2xf32>
+ }
+ scf.yield %3 : memref<2xf32>
+ }
+ test.copy(%1, %res) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @loop_nested_if_no_alloc
+// CHECK-SAME: ({{.*}}, [[ARG3:%.+]]: memref<2xf32>, [[ARG4:%.+]]: memref<2xf32>)
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: [[V0:%.+]]:2 = scf.for {{.*}} iter_args([[ARG6:%.+]] = [[ARG3]], [[ARG7:%.+]] = %false
+// CHECK: [[V1:%.+]]:2 = scf.if
+// CHECK: scf.yield [[ALLOC]], %false
+// CHECK: scf.yield [[ARG6]], %false
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG6]]
+// CHECK: [[OWN:%.+]] = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG7]]) retain ([[V1]]#0 :
+// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[V1]]#1
+// CHECK: scf.yield [[V1]]#0, [[OWN_AGG]]
+// CHECK: test.copy
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1)
+
+// TODO: we know statically that the inner dealloc will never deallocate
+// anything, i.e., we can optimize it away
+
+// -----
+
+// Test Case: structured control-flow loop with a nested if operation using
+// a deeply nested buffer allocation.
+
+func.func @loop_nested_if_alloc(
+ %lb: index,
+ %ub: index,
+ %step: index,
+ %buf: memref<2xf32>) -> memref<2xf32> {
+ %0 = memref.alloc() : memref<2xf32>
+ %1 = scf.for %i = %lb to %ub step %step
+ iter_args(%iterBuf = %buf) -> memref<2xf32> {
+ %2 = arith.cmpi eq, %i, %ub : index
+ %3 = scf.if %2 -> (memref<2xf32>) {
+ %4 = memref.alloc() : memref<2xf32>
+ scf.yield %4 : memref<2xf32>
+ } else {
+ scf.yield %0 : memref<2xf32>
+ }
+ scf.yield %3 : memref<2xf32>
+ }
+ return %1 : memref<2xf32>
+}
+
+// CHECK-LABEL: func @loop_nested_if_alloc
+// CHECK-SAME: ({{.*}}, [[ARG3:%.+]]: memref<2xf32>)
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: [[V0:%.+]]:2 = scf.for {{.*}} iter_args([[ARG5:%.+]] = [[ARG3]], [[ARG6:%.+]] = %false
+// CHECK: [[V1:%.+]]:2 = scf.if
+// CHECK: [[ALLOC1:%.+]] = memref.alloc()
+// CHECK: scf.yield [[ALLOC1]], %true
+// CHECK: scf.yield [[ALLOC]], %false
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG5]]
+// CHECK: [[OWN:%.+]] = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG6]]) retain ([[V1]]#0 :
+// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[V1]]#1
+// CHECK: scf.yield [[V1]]#0, [[OWN_AGG]]
+// CHECK: }
+// CHECK: [[V2:%.+]] = scf.if [[V0]]#1
+// CHECK: scf.yield [[V0]]#0
+// CHECK: [[CLONE:%.+]] = bufferization.clone [[V0]]#0
+// CHECK: scf.yield [[CLONE]]
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V2]] :
+// CHECK: return [[V2]]
+
+// -----
+
+// Test Case: several nested structured control-flow loops with a deeply nested
+// buffer allocation inside an if operation.
+
+func.func @loop_nested_alloc(
+ %lb: index,
+ %ub: index,
+ %step: index,
+ %buf: memref<2xf32>,
+ %res: memref<2xf32>) {
+ %0 = memref.alloc() : memref<2xf32>
+ "test.memref_user"(%0) : (memref<2xf32>) -> ()
+ %1 = scf.for %i = %lb to %ub step %step
+ iter_args(%iterBuf = %buf) -> memref<2xf32> {
+ %2 = scf.for %i2 = %lb to %ub step %step
+ iter_args(%iterBuf2 = %iterBuf) -> memref<2xf32> {
+ %3 = scf.for %i3 = %lb to %ub step %step
+ iter_args(%iterBuf3 = %iterBuf2) -> memref<2xf32> {
+ %4 = memref.alloc() : memref<2xf32>
+ "test.memref_user"(%4) : (memref<2xf32>) -> ()
+ %5 = arith.cmpi eq, %i, %ub : index
+ %6 = scf.if %5 -> (memref<2xf32>) {
+ %7 = memref.alloc() : memref<2xf32>
+ scf.yield %7 : memref<2xf32>
+ } else {
+ scf.yield %iterBuf3 : memref<2xf32>
+ }
+ scf.yield %6 : memref<2xf32>
+ }
+ scf.yield %3 : memref<2xf32>
+ }
+ scf.yield %2 : memref<2xf32>
+ }
+ test.copy(%1, %res) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @loop_nested_alloc
+// CHECK: ({{.*}}, [[ARG3:%.+]]: memref<2xf32>, {{.*}}: memref<2xf32>)
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: [[V0:%.+]]:2 = scf.for {{.*}} iter_args([[ARG6:%.+]] = [[ARG3]], [[ARG7:%.+]] = %false
+// CHECK: [[V1:%.+]]:2 = scf.for {{.*}} iter_args([[ARG9:%.+]] = [[ARG6]], [[ARG10:%.+]] = %false
+// CHECK: [[V2:%.+]]:2 = scf.for {{.*}} iter_args([[ARG12:%.+]] = [[ARG9]], [[ARG13:%.+]] = %false
+// CHECK: [[ALLOC1:%.+]] = memref.alloc()
+// CHECK: [[V3:%.+]]:2 = scf.if
+// CHECK: [[ALLOC2:%.+]] = memref.alloc()
+// CHECK: scf.yield [[ALLOC2]], %true
+// CHECK: } else {
+// CHECK: scf.yield [[ARG12]], %false
+// CHECK: }
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG12]]
+// CHECK: [[OWN:%.+]] = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG13]]) retain ([[V3]]#0 :
+// CHECK: bufferization.dealloc ([[ALLOC1]] :{{.*}}) if (%true{{[0-9_]*}})
+// CHECK-NOT: retain
+// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[V3]]#1
+// CHECK: scf.yield [[V3]]#0, [[OWN_AGG]]
+// CHECK: }
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG9]]
+// CHECK: [[OWN:%.+]] = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG10]]) retain ([[V2]]#0 :
+// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[V2]]#1
+// CHECK: scf.yield [[V2]]#0, [[OWN_AGG]]
+// CHECK: }
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG6]]
+// CHECK: [[OWN:%.+]] = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG7]]) retain ([[V1]]#0 :
+// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[V1]]#1
+// CHECK: scf.yield [[V1]]#0, [[OWN_AGG]]
+// CHECK: }
+// CHECK: test.copy
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true
+// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[V0]]#1)
+
+// TODO: all the retain operands could be removed by doing some more thorough analysis
+
+// -----
+
+func.func @affine_loop() -> f32 {
+ %buffer = memref.alloc() : memref<1024xf32>
+ %sum_init_0 = arith.constant 0.0 : f32
+ %res = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_init_0) -> f32 {
+ %t = affine.load %buffer[%i] : memref<1024xf32>
+ %sum_next = arith.addf %sum_iter, %t : f32
+ affine.yield %sum_next : f32
+ }
+ return %res : f32
+}
+
+// CHECK-LABEL: func @affine_loop
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: affine.for {{.*}} iter_args(%arg1 = %cst)
+// CHECK: affine.yield
+// CHECK: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true
+
+// -----
+
+func.func @assumingOp(
+ %arg0: !shape.witness,
+ %arg2: memref<2xf32>,
+ %arg3: memref<2xf32>) {
+ // Confirm the alloc will be dealloc'ed in the block.
+ %1 = shape.assuming %arg0 -> memref<2xf32> {
+ %0 = memref.alloc() : memref<2xf32>
+ "test.memref_user"(%0) : (memref<2xf32>) -> ()
+ shape.assuming_yield %arg2 : memref<2xf32>
+ }
+ // Confirm the alloc will be returned and dealloc'ed after its use.
+ %3 = shape.assuming %arg0 -> memref<2xf32> {
+ %2 = memref.alloc() : memref<2xf32>
+ shape.assuming_yield %2 : memref<2xf32>
+ }
+ test.copy(%3, %arg3) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// CHECK-LABEL: func @assumingOp
+// CHECK: ({{.*}}, [[ARG1:%.+]]: memref<2xf32>, {{.*}}: memref<2xf32>)
+// CHECK: [[V0:%.+]]:2 = shape.assuming
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true{{[0-9_]*}})
+// CHECK-NOT: retain
+// CHECK: shape.assuming_yield [[ARG1]], %false
+// CHECK: }
+// CHECK: [[V1:%.+]]:2 = shape.assuming
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: shape.assuming_yield [[ALLOC]], %true
+// CHECK: }
+// CHECK: test.copy
+// CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V1]]#0
+// CHECK: bufferization.dealloc ([[BASE0]] :{{.*}}) if ([[V0]]#1)
+// CHECK-NOT: retain
+// CHECK: bufferization.dealloc ([[BASE1]] :{{.*}}) if ([[V1]]#1)
+// CHECK-NOT: retain
+// CHECK: return
+
+// -----
+
+// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
+// This is only allowed in buffer deallocation because the operation's region
+// does not deal with any MemRef values.
+
+func.func @noRegionBranchOpInterface() {
+ %0 = "test.bar"() ({
+ %1 = "test.bar"() ({
+ "test.yield"() : () -> ()
+ }) : () -> (i32)
+ "test.yield"() : () -> ()
+ }) : () -> (i32)
+ "test.terminator"() : () -> ()
+}
+
+// -----
+
+// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
+// This is not allowed in buffer deallocation.
+
+func.func @noRegionBranchOpInterface() {
+ // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
+ %0 = "test.bar"() ({
+ %1 = "test.bar"() ({
+ %2 = "test.get_memref"() : () -> memref<2xi32>
+ "test.yield"(%2) : (memref<2xi32>) -> ()
+ }) : () -> (memref<2xi32>)
+ "test.yield"() : () -> ()
+ }) : () -> (i32)
+ "test.terminator"() : () -> ()
+}
+
+// -----
+
+func.func @while_two_arg(%arg0: index) {
+ %a = memref.alloc(%arg0) : memref<?xf32>
+ scf.while (%arg1 = %a, %arg2 = %a) : (memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
+ %0 = "test.make_condition"() : () -> i1
+ scf.condition(%0) %arg1, %arg2 : memref<?xf32>, memref<?xf32>
+ } do {
+ ^bb0(%arg1: memref<?xf32>, %arg2: memref<?xf32>):
+ %b = memref.alloc(%arg0) : memref<?xf32>
+ scf.yield %arg1, %b : memref<?xf32>, memref<?xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @while_two_arg
+// CHECK: [[ALLOC:%.+]] = memref.alloc(
+// CHECK: [[V0:%.+]]:4 = scf.while ({{.*}} = [[ALLOC]], {{.*}} = [[ALLOC]], {{.*}} = %false{{[0-9_]*}}, {{.*}} = %false{{[0-9_]*}})
+// CHECK: scf.condition
+// CHECK: ^bb0([[ARG1:%.+]]: memref<?xf32>, [[ARG2:%.+]]: memref<?xf32>, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1):
+// CHECK: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]]
+// CHECK: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG4]]) retain ([[ARG1]], [[ALLOC1]] :
+// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]]#0, [[ARG3]]
+// CHECK: scf.yield [[ARG1]], [[ALLOC1]], [[OWN_AGG]], %true
+// CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#1
+// CHECK: bufferization.dealloc ([[ALLOC]], [[BASE0]], [[BASE1]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#2, [[V0]]#3)
+
+// -----
+
+func.func @while_three_arg(%arg0: index) {
+ %a = memref.alloc(%arg0) : memref<?xf32>
+ scf.while (%arg1 = %a, %arg2 = %a, %arg3 = %a) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
+ %0 = "test.make_condition"() : () -> i1
+ scf.condition(%0) %arg1, %arg2, %arg3 : memref<?xf32>, memref<?xf32>, memref<?xf32>
+ } do {
+ ^bb0(%arg1: memref<?xf32>, %arg2: memref<?xf32>, %arg3: memref<?xf32>):
+ %b = memref.alloc(%arg0) : memref<?xf32>
+ %q = memref.alloc(%arg0) : memref<?xf32>
+ scf.yield %q, %b, %arg2: memref<?xf32>, memref<?xf32>, memref<?xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @while_three_arg
+// CHECK: [[ALLOC:%.+]] = memref.alloc(
+// CHECK: [[V0:%.+]]:6 = scf.while ({{.*}} = [[ALLOC]], {{.*}} = [[ALLOC]], {{.*}} = [[ALLOC]], {{.*}} = %false{{[0-9_]*}}, {{.*}} = %false{{[0-9_]*}}, {{.*}} = %false
+// CHECK: scf.condition
+// CHECK: ^bb0([[ARG1:%.+]]: memref<?xf32>, [[ARG2:%.+]]: memref<?xf32>, [[ARG3:%.+]]: memref<?xf32>, [[ARG4:%.+]]: i1, [[ARG5:%.+]]: i1, [[ARG6:%.+]]: i1):
+// CHECK: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK: [[ALLOC2:%.+]] = memref.alloc(
+// CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG1]]
+// CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]]
+// CHECK: [[BASE2:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG3]]
+// CHECK: [[OWN:%.+]]:3 = bufferization.dealloc ([[BASE0]], [[BASE1]], [[BASE2]], [[ALLOC1]] :{{.*}}) if ([[ARG4]], [[ARG5]], [[ARG6]], %true{{[0-9_]*}}) retain ([[ALLOC2]], [[ALLOC1]], [[ARG2]] :
+// CHECK: scf.yield [[ALLOC2]], [[ALLOC1]], [[ARG2]], %true{{[0-9_]*}}, %true{{[0-9_]*}}, [[OWN]]#2 :
+// CHECK: }
+// CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#1
+// CHECK: [[BASE2:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#2
+// CHECK: bufferization.dealloc ([[ALLOC]], [[BASE0]], [[BASE1]], [[BASE2]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#3, [[V0]]#4, [[V0]]#5)
+
+// TODO: better alias analysis could simplify the dealloc inside the body further
+
+// -----
+
+// Memref allocated in `then` region and passed back to the parent if op.
+#set = affine_set<() : (0 >= 0)>
+func.func @test_affine_if_1(%arg0: memref<10xf32>) -> memref<10xf32> {
+ %0 = affine.if #set() -> memref<10xf32> {
+ %alloc = memref.alloc() : memref<10xf32>
+ affine.yield %alloc : memref<10xf32>
+ } else {
+ affine.yield %arg0 : memref<10xf32>
+ }
+ return %0 : memref<10xf32>
+}
+
+// CHECK-LABEL: func @test_affine_if_1
+// CHECK-SAME: ([[ARG0:%.*]]: memref<10xf32>)
+// CHECK: [[V0:%.+]]:2 = affine.if
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: affine.yield [[ALLOC]], %true
+// CHECK: affine.yield [[ARG0]], %false
+// CHECK: [[V1:%.+]] = scf.if [[V0]]#1
+// CHECK: scf.yield [[V0]]#0
+// CHECK: [[CLONE:%.+]] = bufferization.clone [[V0]]#0
+// CHECK: scf.yield [[CLONE]]
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[V0]]#1) retain ([[V1]] :
+// CHECK: return [[V1]]
+
+// TODO: the dealloc could be optimized away since the memref to be deallocated
+// either aliases with V1 or the condition is false
+
+// -----
+
+// Memref allocated before parent IfOp and used in `then` region.
+// Expected result: deallocation should happen after affine.if op.
+#set = affine_set<() : (0 >= 0)>
+func.func @test_affine_if_2() -> memref<10xf32> {
+ %alloc0 = memref.alloc() : memref<10xf32>
+ %0 = affine.if #set() -> memref<10xf32> {
+ affine.yield %alloc0 : memref<10xf32>
+ } else {
+ %alloc = memref.alloc() : memref<10xf32>
+ affine.yield %alloc : memref<10xf32>
+ }
+ return %0 : memref<10xf32>
+}
+// CHECK-LABEL: func @test_affine_if_2
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: [[V0:%.+]]:2 = affine.if
+// CHECK: affine.yield [[ALLOC]], %false
+// CHECK: [[ALLOC1:%.+]] = memref.alloc()
+// CHECK: affine.yield [[ALLOC1]], %true
+// CHECK: [[V1:%.+]] = scf.if [[V0]]#1
+// CHECK: scf.yield [[V0]]#0
+// CHECK: [[CLONE:%.+]] = bufferization.clone [[V0]]#0
+// CHECK: scf.yield [[CLONE]]
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]] :
+// CHECK: return [[V1]]
+
+// -----
+
+// Memref allocated before parent IfOp and used in `else` region.
+// Expected result: deallocation should happen after affine.if op.
+#set = affine_set<() : (0 >= 0)>
+func.func @test_affine_if_3() -> memref<10xf32> {
+ %alloc0 = memref.alloc() : memref<10xf32>
+ %0 = affine.if #set() -> memref<10xf32> {
+ %alloc = memref.alloc() : memref<10xf32>
+ affine.yield %alloc : memref<10xf32>
+ } else {
+ affine.yield %alloc0 : memref<10xf32>
+ }
+ return %0 : memref<10xf32>
+}
+
+// CHECK-LABEL: func @test_affine_if_3
+// CHECK: [[ALLOC:%.+]] = memref.alloc()
+// CHECK: [[V0:%.+]]:2 = affine.if
+// CHECK: [[ALLOC1:%.+]] = memref.alloc()
+// CHECK: affine.yield [[ALLOC1]], %true
+// CHECK: affine.yield [[ALLOC]], %false
+// CHECK: [[V1:%.+]] = scf.if [[V0]]#1
+// CHECK: scf.yield [[V0]]#0
+// CHECK: [[CLONE:%.+]] = bufferization.clone [[V0]]#0
+// CHECK: scf.yield [[CLONE]]
+// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
+// CHECK: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1) retain ([[V1]]
+// CHECK: return [[V1]]
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-subviews.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-subviews.mlir
new file mode 100644
index 000000000000000..666b3b08995e8ad
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-subviews.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
+// RUN: --buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation=private-function-dynamic-ownership=true -split-input-file %s > /dev/null
+
+// CHECK-LABEL: func @subview
+func.func @subview(%arg0 : index, %arg1 : index, %arg2 : memref<?x?xf32>) {
+ %0 = memref.alloc() : memref<64x4xf32, strided<[4, 1], offset: 0>>
+ %1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] :
+ memref<64x4xf32, strided<[4, 1], offset: 0>>
+ to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ test.copy(%1, %arg2) :
+ (memref<?x?xf32, strided<[?, ?], offset: ?>>, memref<?x?xf32>)
+ return
+}
+
+// CHECK: %[[ALLOC:.*]] = memref.alloc()
+// CHECK-NEXT: memref.subview
+// CHECK-NEXT: test.copy
+// CHECK-NEXT: bufferization.dealloc (%[[ALLOC]] :
+// CHECK-SAME: if (%true)
+// CHECK-NEXT: return
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/invalid-buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/invalid-buffer-deallocation.mlir
new file mode 100644
index 000000000000000..c623891e48362fa
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/invalid-buffer-deallocation.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation -split-input-file %s
+
+
+// Test Case: explicit control-flow loop with a dynamically allocated buffer.
+// The BufferDeallocation transformation should fail on this explicit
+// control-flow loop since they are not supported.
+
+// expected-error at +1 {{Only structured control-flow loops are supported}}
+func.func @loop_dynalloc(
+ %arg0 : i32,
+ %arg1 : i32,
+ %arg2: memref<?xf32>,
+ %arg3: memref<?xf32>) {
+ %const0 = arith.constant 0 : i32
+ cf.br ^loopHeader(%const0, %arg2 : i32, memref<?xf32>)
+
+^loopHeader(%i : i32, %buff : memref<?xf32>):
+ %lessThan = arith.cmpi slt, %i, %arg1 : i32
+ cf.cond_br %lessThan,
+ ^loopBody(%i, %buff : i32, memref<?xf32>),
+ ^exit(%buff : memref<?xf32>)
+
+^loopBody(%val : i32, %buff2: memref<?xf32>):
+ %const1 = arith.constant 1 : i32
+ %inc = arith.addi %val, %const1 : i32
+ %size = arith.index_cast %inc : i32 to index
+ %alloc1 = memref.alloc(%size) : memref<?xf32>
+ cf.br ^loopHeader(%inc, %alloc1 : i32, memref<?xf32>)
+
+^exit(%buff3 : memref<?xf32>):
+ test.copy(%buff3, %arg3) : (memref<?xf32>, memref<?xf32>)
+ return
+}
+
+// -----
+
+// Test Case: explicit control-flow loop with a dynamically allocated buffer.
+// The BufferDeallocation transformation should fail on this explicit
+// control-flow loop since they are not supported.
+
+// expected-error at +1 {{Only structured control-flow loops are supported}}
+func.func @do_loop_alloc(
+ %arg0 : i32,
+ %arg1 : i32,
+ %arg2: memref<2xf32>,
+ %arg3: memref<2xf32>) {
+ %const0 = arith.constant 0 : i32
+ cf.br ^loopBody(%const0, %arg2 : i32, memref<2xf32>)
+
+^loopBody(%val : i32, %buff2: memref<2xf32>):
+ %const1 = arith.constant 1 : i32
+ %inc = arith.addi %val, %const1 : i32
+ %alloc1 = memref.alloc() : memref<2xf32>
+ cf.br ^loopHeader(%inc, %alloc1 : i32, memref<2xf32>)
+
+^loopHeader(%i : i32, %buff : memref<2xf32>):
+ %lessThan = arith.cmpi slt, %i, %arg1 : i32
+ cf.cond_br %lessThan,
+ ^loopBody(%i, %buff : i32, memref<2xf32>),
+ ^exit(%buff : memref<2xf32>)
+
+^exit(%buff3 : memref<2xf32>):
+ test.copy(%buff3, %arg3) : (memref<2xf32>, memref<2xf32>)
+ return
+}
+
+// -----
+
+func.func @free_effect() {
+ %alloc = memref.alloc() : memref<2xi32>
+ // expected-error @below {{memory free side-effect on MemRef value not supported!}}
+ %new_alloc = memref.realloc %alloc : memref<2xi32> to memref<4xi32>
+ return
+}
+
+// -----
+
+func.func @free_effect() {
+ %alloc = memref.alloc() : memref<2xi32>
+ // expected-error @below {{memory free side-effect on MemRef value not supported!}}
+ memref.dealloc %alloc : memref<2xi32>
+ return
+}
+
+// -----
+
+func.func @free_effect() {
+ %true = arith.constant true
+ %alloc = memref.alloc() : memref<2xi32>
+ // expected-error @below {{No deallocation operations must be present when running this pass!}}
+ bufferization.dealloc (%alloc : memref<2xi32>) if (%true)
+ return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2263414388b0811..d4390e7651be0f0 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -12137,6 +12137,7 @@ cc_library(
":BufferizationDialect",
":BufferizationEnumsIncGen",
":BufferizationPassIncGen",
+ ":ControlFlowDialect",
":ControlFlowInterfaces",
":FuncDialect",
":IR",
More information about the Mlir-commits
mailing list