[Mlir-commits] [mlir] 4bde084 - [mlir][bufferization] Change semantics of DeallocOp result values

Martin Erhart llvmlistbot at llvm.org
Fri Aug 4 06:56:45 PDT 2023


Author: Martin Erhart
Date: 2023-08-04T13:54:46Z
New Revision: 4bde084f0c1837234a5ddab66c95c55fe4b81599

URL: https://github.com/llvm/llvm-project/commit/4bde084f0c1837234a5ddab66c95c55fe4b81599
DIFF: https://github.com/llvm/llvm-project/commit/4bde084f0c1837234a5ddab66c95c55fe4b81599.diff

LOG: [mlir][bufferization] Change semantics of DeallocOp result values

This change allows supporting operations for which we don't get precise aliasing information without the need to insert clone operations. E.g., `arith.select`.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D156992

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
    mlir/test/Dialect/Bufferization/canonicalize.mlir
    mlir/test/Dialect/Bufferization/invalid.mlir
    mlir/test/Dialect/Bufferization/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 55182f83832b47..a9e7b7215c1b43 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -485,22 +485,41 @@ def Bufferization_DeallocOp : Bufferization_Op<"dealloc", [
     deallocating that memref). If two memrefs alias each other, only one will be
     deallocated to avoid double free situations.
 
-    The memrefs to be deallocated must be the originally allocated memrefs,
-    however, the memrefs to be retained may be arbitrary memrefs.
-
-    Returns a list of conditions corresponding to the list of memrefs which
-    indicates the new ownerships, i.e., if the memref was deallocated the
-    ownership was dropped (set to 'false') and otherwise will be the same as the
-    input condition.
+    The number of variadic `memref` operands (the memrefs to be deallocated)
+    must equal the number of variadic `condition` operands and correspond to
+    each other element-wise.
+
+    The `memref` operands must be the originally allocated memrefs, however, the
+    `retained` memref operands may be arbitrary memrefs.
+
+    This operation returns a variadic number of `updatedConditions` operands,
+    one updated condition per retained memref. An updated condition indicates
+    the ownership of the respective retained memref. It is computed as the
+    disjunction of all `conditions` operands where the corresponding to
+    `memrefs` operand aliases with the retained memref. If the retained memref
+    has no aliases among `memrefs`, the resulting updated condition is 'false'.
+    This is because all memrefs that need to be deallocated within one basic
+    block should be added to the same `bufferization.dealloc` operation at the
+    end of the block; if no aliasing memref is present, then it does not have to
+    be deallocated and thus we don't need to claim ownership. If the memrefs to
+    be deallocated are split over multiple dealloc operations (e.g., to avoid
+    aliasing checks at runtime between the `memref` operands), then the results
+    have to be manually combined using an `arith.ori` operation and all of them
+    still require the same list of `retained` memref operands unless the
+    (potentially empty) set of aliasing memrefs can be determined statically. In
+    that case, the `updatedCondition` operand can be replaced accordingly (e.g.,
+    by a canonicalizer).
 
     Example:
     ```mlir
-    %0:2 = bufferization.dealloc %a0, %a1 if %cond0, %cond1 retain %r0, %r1 :
-      memref<2xf32>, memref<4xi32> retain memref<?xf32>, memref<f64>
+    %0:3 = bufferization.dealloc (%a0, %a1 : memref<2xf32>, memref<4xi32>)
+      if (%cond0, %cond1) retain (%r0, %r1, %r2 : memref<?xf32>, memref<f64>,
+      memref<2xi32>)
     ```
-    Deallocation will be called on `%a0` if `%cond0` is 'true' and neither `%r0`
-    or `%r1` are aliases of `%a0`. `%a1` will be deallocated when `%cond1` is
-    set to 'true' and none of `%r0`, %r1` and `%a0` are aliases.
+    Deallocation will be called on `%a0` if `%cond0` is 'true' and neither
+    `%r0`, `%r1`, or `%r2` are aliases of `%a0`. `%a1` will be deallocated when
+    `%cond1` is set to 'true' and none of `%r0`, %r1`, `%r2`, and `%a0` are
+    aliases.
   }];
 
   let arguments = (ins Variadic<AnyRankedOrUnrankedMemRef>:$memrefs,

diff  --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 886bcfa2f8530d..f998c8ce172a03 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -91,77 +91,100 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
 class DeallocOpConversion
     : public OpConversionPattern<bufferization::DeallocOp> {
 
-  /// Lower a simple case avoiding the helper function. Ideally, static analysis
-  /// can provide enough aliasing information to split the dealloc operations up
-  /// into this simple case as much as possible before running this pass.
+  /// Lower a simple case without any retained values and a single memref to
+  /// avoiding the helper function. Ideally, static analysis can provide enough
+  /// aliasing information to split the dealloc operations up into this simple
+  /// case as much as possible before running this pass.
   ///
   /// Example:
   /// ```
-  /// %0 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+  /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
   /// ```
   /// is lowered to
   /// ```
   /// scf.if %arg1 {
   ///   memref.dealloc %arg0 : memref<2xf32>
   /// }
-  /// %0 = arith.constant false
   /// ```
   LogicalResult
   rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {
-    rewriter.create<scf::IfOp>(op.getLoc(), adaptor.getConditions()[0],
-                               [&](OpBuilder &builder, Location loc) {
-                                 builder.create<memref::DeallocOp>(
-                                     loc, adaptor.getMemrefs()[0]);
-                                 builder.create<scf::YieldOp>(loc);
-                               });
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op,
-                                                   rewriter.getBoolAttr(false));
+    assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
+    assert(adaptor.getRetained().empty() && "expected no retained memrefs");
+
+    rewriter.replaceOpWithNewOp<scf::IfOp>(
+        op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) {
+          builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
+          builder.create<scf::YieldOp>(loc);
+        });
     return success();
   }
 
   /// Lowering that supports all features the dealloc operation has to offer. It
-  /// computes the base pointer of each memref (as an index), stores them in a
-  /// new memref and passes it to the helper function generated in
-  /// 'buildDeallocationHelperFunction'. The two return values are used as
-  /// condition for the scf if operation containing the memref deallocate and as
-  /// replacement for the original bufferization dealloc respectively.
+  /// computes the base pointer of each memref (as an index), stores it in a
+  /// new memref helper structure and passes it to the helper function generated
+  /// in 'buildDeallocationHelperFunction'. The results are stored in two lists
+  /// (represented as memrefs) of booleans passed as arguments. The first list
+  /// stores whether the corresponding condition should be deallocated, the
+  /// second list stores the ownership of the retained values which can be used
+  /// to replace the result values of the `bufferization.dealloc` operation.
   ///
   /// Example:
   /// ```
-  /// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>)
-  ///                           if (%arg3, %arg4) retain (%arg2 : memref<1xf32>)
+  /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>)
+  ///                           if (%cond0, %cond1)
+  ///                       retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
   /// ```
   /// lowers to (simplified):
   /// ```
   /// %c0 = arith.constant 0 : index
   /// %c1 = arith.constant 1 : index
-  /// %alloc = memref.alloc() : memref<2xindex>
-  /// %alloc_0 = memref.alloc() : memref<1xindex>
-  /// %intptr = memref.extract_aligned_pointer_as_index %arg0
-  /// memref.store %intptr, %alloc[%c0] : memref<2xindex>
-  /// %intptr_1 = memref.extract_aligned_pointer_as_index %arg1
-  /// memref.store %intptr_1, %alloc[%c1] : memref<2xindex>
-  /// %intptr_2 = memref.extract_aligned_pointer_as_index %arg2
-  /// memref.store %intptr_2, %alloc_0[%c0] : memref<1xindex>
-  /// %cast = memref.cast %alloc : memref<2xindex> to memref<?xindex>
-  /// %cast_4 = memref.cast %alloc_0 : memref<1xindex> to memref<?xindex>
-  /// %0:2 = call @dealloc_helper(%cast, %cast_4, %c0)
-  /// %1 = arith.andi %0#0, %arg3 : i1
-  /// %2 = arith.andi %0#1, %arg3 : i1
-  /// scf.if %1 {
-  ///   memref.dealloc %arg0 : memref<2xf32>
+  /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex>
+  /// %cond_list = memref.alloc() : memref<2xi1>
+  /// %retain_base_pointer_list = memref.alloc() : memref<2xindex>
+  /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0
+  /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0]
+  /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1
+  /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1]
+  /// memref.store %cond0, %cond_list[%c0]
+  /// memref.store %cond1, %cond_list[%c1]
+  /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
+  /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0]
+  /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
+  /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1]
+  /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list :
+  ///    memref<2xindex> to memref<?xindex>
+  /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1>
+  /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list :
+  ///    memref<2xindex> to memref<?xindex>
+  /// %dealloc_cond_out = memref.alloc() : memref<2xi1>
+  /// %ownership_out = memref.alloc() : memref<2xi1>
+  /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out :
+  ///    memref<2xi1> to memref<?xi1>
+  /// %dyn_ownership_out = memref.cast %ownership_out :
+  ///    memref<2xi1> to memref<?xi1>
+  /// call @dealloc_helper(%dyn_dealloc_base_pointer_list,
+  ///                      %dyn_retain_base_pointer_list,
+  ///                      %dyn_cond_list,
+  ///                      %dyn_dealloc_cond_out,
+  ///                      %dyn_ownership_out) : (...)
+  /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1>
+  /// scf.if %m0_dealloc_cond {
+  ///   memref.dealloc %m0 : memref<2xf32>
   /// }
-  /// %3:2 = call @dealloc_helper(%cast, %cast_4, %c1)
-  /// %4 = arith.andi %3#0, %arg4 : i1
-  /// %5 = arith.andi %3#1, %arg4 : i1
-  /// scf.if %4 {
-  ///   memref.dealloc %arg1 : memref<5xf32>
+  /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1>
+  /// scf.if %m1_dealloc_cond {
+  ///   memref.dealloc %m1 : memref<5xf32>
   /// }
-  /// memref.dealloc %alloc : memref<2xindex>
-  /// memref.dealloc %alloc_0 : memref<1xindex>
-  /// // replace %0#0 with %2
-  /// // replace %0#1 with %5
+  /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1>
+  /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1>
+  /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex>
+  /// memref.dealloc %retain_base_pointer_list : memref<2xindex>
+  /// memref.dealloc %cond_list : memref<2xi1>
+  /// memref.dealloc %dealloc_cond_out : memref<2xi1>
+  /// memref.dealloc %ownership_out : memref<2xi1>
+  /// // replace %0#0 with %r0_ownership
+  /// // replace %0#1 with %r1_ownership
   /// ```
   LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
                                    OpAdaptor adaptor,
@@ -175,6 +198,9 @@ class DeallocOpConversion
     Value toDeallocMemref = rewriter.create<memref::AllocOp>(
         op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
                                      rewriter.getIndexType()));
+    Value conditionMemref = rewriter.create<memref::AllocOp>(
+        op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()},
+                                     rewriter.getI1Type()));
     Value toRetainMemref = rewriter.create<memref::AllocOp>(
         op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
                                      rewriter.getIndexType()));
@@ -193,6 +219,11 @@ class DeallocOpConversion
       rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
                                        toDeallocMemref, getConstValue(i));
     }
+
+    for (auto [i, cond] : llvm::enumerate(adaptor.getConditions()))
+      rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref,
+                                       getConstValue(i));
+
     for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) {
       Value memrefAsIdx =
           rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
@@ -208,22 +239,41 @@ class DeallocOpConversion
         op->getLoc(),
         MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
         toDeallocMemref);
+    Value castedCondsMemref = rewriter.create<memref::CastOp>(
+        op->getLoc(),
+        MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
+        conditionMemref);
     Value castedRetainMemref = rewriter.create<memref::CastOp>(
         op->getLoc(),
         MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
         toRetainMemref);
 
-    SmallVector<Value> replacements;
+    Value deallocCondsMemref = rewriter.create<memref::AllocOp>(
+        op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
+                                     rewriter.getI1Type()));
+    Value retainCondsMemref = rewriter.create<memref::AllocOp>(
+        op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
+                                     rewriter.getI1Type()));
+
+    Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>(
+        op->getLoc(),
+        MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
+        deallocCondsMemref);
+    Value castedRetainCondsMemref = rewriter.create<memref::CastOp>(
+        op->getLoc(),
+        MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
+        retainCondsMemref);
+
+    rewriter.create<func::CallOp>(
+        op.getLoc(), deallocHelperFunc,
+        SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
+                           castedCondsMemref, castedDeallocCondsMemref,
+                           castedRetainCondsMemref});
+
     for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
-      auto callOp = rewriter.create<func::CallOp>(
-          op.getLoc(), deallocHelperFunc,
-          SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
-                             getConstValue(i)});
-      Value shouldDealloc = rewriter.create<arith::AndIOp>(
-          op.getLoc(), callOp.getResult(0), adaptor.getConditions()[i]);
-      Value ownership = rewriter.create<arith::AndIOp>(
-          op.getLoc(), callOp.getResult(1), adaptor.getConditions()[i]);
-      replacements.push_back(ownership);
+      Value idxValue = getConstValue(i);
+      Value shouldDealloc = rewriter.create<memref::LoadOp>(
+          op.getLoc(), deallocCondsMemref, idxValue);
       rewriter.create<scf::IfOp>(
           op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
             builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
@@ -231,10 +281,21 @@ class DeallocOpConversion
           });
     }
 
+    SmallVector<Value> replacements;
+    for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
+      Value idxValue = getConstValue(i);
+      Value ownership = rewriter.create<memref::LoadOp>(
+          op.getLoc(), retainCondsMemref, idxValue);
+      replacements.push_back(ownership);
+    }
+
     // Deallocate above allocated memrefs again to avoid memory leaks.
     // Deallocation will not be run on code after this stage.
     rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
     rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
+    rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref);
+    rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref);
+    rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref);
 
     rewriter.replaceOp(op, replacements);
     return success();
@@ -261,70 +322,95 @@ class DeallocOpConversion
   /// Build a helper function per compilation unit that can be called at
   /// bufferization dealloc sites to determine aliasing and ownership.
   ///
-  /// The generated function takes two memrefs of indices and one index value as
-  /// arguments and returns two boolean values:
-  ///   * The first memref argument A should contain the result of the
+  /// The generated function takes two memrefs of indices and three memrefs of
+  /// booleans as arguments:
+  ///   * The first argument A should contain the result of the
   ///   extract_aligned_pointer_as_index operation applied to the memrefs to be
   ///   deallocated
-  ///   * The second memref argument B should contain the result of the
+  ///   * The second argument B should contain the result of the
   ///   extract_aligned_pointer_as_index operation applied to the memrefs to be
   ///   retained
-  ///   * The index argument I represents the currently processed index of
-  ///   memref A and is needed because aliasing with all previously deallocated
-  ///   memrefs has to be checked to avoid double deallocation
-  ///   * The first result indicates whether the memref at position I should be
-  ///   deallocated
-  ///   * The second result provides the updated ownership value corresponding
-  ///   the the memref at position I
-  ///
-  /// This helper function is supposed to be called for each element in the list
-  /// of memrefs to be deallocated to determine the deallocation need and new
-  /// ownership indicator, but does not perform the deallocation itself.
+  ///   * The third argument C should contain the conditions as passed directly
+  ///   to the deallocation operation.
+  ///   * The fourth argument D is used to pass results to the caller. Those
+  ///   represent the condition under which the memref at the corresponding
+  ///   position in A should be deallocated.
+  ///   * The fifth argument E is used to pass results to the caller. It
+  ///   provides the ownership value corresponding the the memref at the same
+  ///   position in B
   ///
-  /// The first scf for loop in the body computes whether the memref at index I
-  /// aliases with any memref in the list of retained memrefs.
-  /// The second loop additionally checks whether one of the previously
-  /// deallocated memrefs aliases with the currently processed one.
+  /// This helper function is supposed to be called once for each
+  /// `bufferization.dealloc` operation to determine the deallocation need and
+  /// new ownership indicator for the retained values, but does not perform the
+  /// deallocation itself.
   ///
   /// Generated code:
   /// ```
-  /// func.func @dealloc_helper(%arg0: memref<?xindex>,
-  ///                           %arg1: memref<?xindex>,
-  ///                           %arg2: index) -> (i1, i1) {
+  /// func.func @dealloc_helper(
+  ///     %dyn_dealloc_base_pointer_list: memref<?xindex>,
+  ///     %dyn_retain_base_pointer_list: memref<?xindex>,
+  ///     %dyn_cond_list: memref<?xi1>,
+  ///     %dyn_dealloc_cond_out: memref<?xi1>,
+  ///     %dyn_ownership_out: memref<?xi1>) {
   ///   %c0 = arith.constant 0 : index
   ///   %c1 = arith.constant 1 : index
   ///   %true = arith.constant true
-  ///   %dim = memref.dim %arg1, %c0 : memref<?xindex>
-  ///   %0 = memref.load %arg0[%arg2] : memref<?xindex>
-  ///   %1 = scf.for %i = %c0 to %dim step %c1 iter_args(%arg4 = %true) -> (i1){
-  ///     %4 = memref.load %arg1[%i] : memref<?xindex>
-  ///     %5 = arith.cmpi ne, %4, %0 : index
-  ///     %6 = arith.andi %arg4, %5 : i1
-  ///     scf.yield %6 : i1
+  ///   %false = arith.constant false
+  ///   %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0
+  ///   %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0
+  ///   // Zero initialize result buffer.
+  ///   scf.for %i = %c0 to %num_retain_memrefs step %c1 {
+  ///     memref.store %false, %dyn_ownership_out[%i] : memref<?xi1>
   ///   }
-  ///   %2 = scf.for %i = %c0 to %arg2 step %c1 iter_args(%arg4 = %1) -> (i1) {
-  ///     %4 = memref.load %arg0[%i] : memref<?xindex>
-  ///     %5 = arith.cmpi ne, %4, %0 : index
-  ///     %6 = arith.andi %arg4, %5 : i1
-  ///     scf.yield %6 : i1
+  ///   scf.for %i = %c0 to %num_dealloc_memrefs step %c1 {
+  ///     %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i]
+  ///     %cond = memref.load %dyn_cond_list[%i]
+  ///     // Check for aliasing with retained memrefs.
+  ///     %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs
+  ///         step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) {
+  ///       %retain_bp = memref.load %dyn_retain_base_pointer_list[%j]
+  ///       %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index
+  ///       scf.if %does_alias {
+  ///         %curr_ownership = memref.load %dyn_ownership_out[%j]
+  ///         %updated_ownership = arith.ori %curr_ownership, %cond : i1
+  ///         memref.store %updated_ownership, %dyn_ownership_out[%j]
+  ///       }
+  ///       %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index
+  ///       %updated_aggregate = arith.andi %does_not_alias_aggregated,
+  ///                                       %does_not_alias : i1
+  ///       scf.yield %updated_aggregate : i1
+  ///     }
+  ///     // Check for aliasing with dealloc memrefs in the list before the
+  ///     // current one, i.e.,
+  ///     // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j],
+  ///     // %dyn_dealloc_base_pointer[i])`
+  ///     %does_not_alias_any = scf.for %j = %c0 to %i step %c1
+  ///        iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) {
+  ///       %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j]
+  ///       %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp
+  ///       %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias
+  ///       scf.yield %updated_alias_agg : i1
+  ///     }
+  ///     %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1
+  ///     memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref<?xi1>
   ///   }
-  ///   %3 = arith.xori %1, %true : i1
-  ///   return %2, %3 : i1, i1
+  ///   return
   /// }
   /// ```
   static func::FuncOp
   buildDeallocationHelperFunction(OpBuilder &builder, Location loc,
                                   SymbolTable &symbolTable) {
-    Type idxType = builder.getIndexType();
-    Type memrefArgType = MemRefType::get({ShapedType::kDynamic}, idxType);
-    SmallVector<Type> argTypes{memrefArgType, memrefArgType, idxType};
+    Type indexMemrefType =
+        MemRefType::get({ShapedType::kDynamic}, builder.getIndexType());
+    Type boolMemrefType =
+        MemRefType::get({ShapedType::kDynamic}, builder.getI1Type());
+    SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType,
+                               boolMemrefType, boolMemrefType};
     builder.clearInsertionPoint();
 
     // Generate the func operation itself.
     auto helperFuncOp = func::FuncOp::create(
-        loc, "dealloc_helper",
-        builder.getFunctionType(argTypes,
-                                {builder.getI1Type(), builder.getI1Type()}));
+        loc, "dealloc_helper", builder.getFunctionType(argTypes, {}));
     symbolTable.insert(helperFuncOp);
     auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
     block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc));
@@ -332,57 +418,101 @@ class DeallocOpConversion
     builder.setInsertionPointToStart(&block);
     Value toDeallocMemref = helperFuncOp.getArguments()[0];
     Value toRetainMemref = helperFuncOp.getArguments()[1];
-    Value idxArg = helperFuncOp.getArguments()[2];
+    Value conditionMemref = helperFuncOp.getArguments()[2];
+    Value deallocCondsMemref = helperFuncOp.getArguments()[3];
+    Value retainCondsMemref = helperFuncOp.getArguments()[4];
 
     // Insert some prerequisites.
     Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0));
     Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1));
     Value trueValue =
         builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true));
+    Value falseValue =
+        builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false));
+    Value toDeallocSize =
+        builder.create<memref::DimOp>(loc, toDeallocMemref, c0);
     Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0);
-    Value toDealloc =
-        builder.create<memref::LoadOp>(loc, toDeallocMemref, idxArg);
-
-    // Build the first for loop that computes aliasing with retained memrefs.
-    Value noRetainAlias =
-        builder
-            .create<scf::ForOp>(
-                loc, c0, toRetainSize, c1, trueValue,
-                [&](OpBuilder &builder, Location loc, Value i,
-                    ValueRange iterArgs) {
-                  Value retainValue =
-                      builder.create<memref::LoadOp>(loc, toRetainMemref, i);
-                  Value doesntAlias = builder.create<arith::CmpIOp>(
-                      loc, arith::CmpIPredicate::ne, retainValue, toDealloc);
-                  Value yieldValue = builder.create<arith::AndIOp>(
-                      loc, iterArgs[0], doesntAlias);
-                  builder.create<scf::YieldOp>(loc, yieldValue);
-                })
-            .getResult(0);
-
-    // Build the second for loop that adds aliasing with previously deallocated
-    // memrefs.
-    Value noAlias =
-        builder
-            .create<scf::ForOp>(
-                loc, c0, idxArg, c1, noRetainAlias,
-                [&](OpBuilder &builder, Location loc, Value i,
-                    ValueRange iterArgs) {
-                  Value prevDeallocValue =
-                      builder.create<memref::LoadOp>(loc, toDeallocMemref, i);
-                  Value doesntAlias = builder.create<arith::CmpIOp>(
-                      loc, arith::CmpIPredicate::ne, prevDeallocValue,
-                      toDealloc);
-                  Value yieldValue = builder.create<arith::AndIOp>(
-                      loc, iterArgs[0], doesntAlias);
-                  builder.create<scf::YieldOp>(loc, yieldValue);
-                })
-            .getResult(0);
-
-    Value ownership =
-        builder.create<arith::XOrIOp>(loc, noRetainAlias, trueValue);
-    builder.create<func::ReturnOp>(loc, SmallVector<Value>{noAlias, ownership});
 
+    builder.create<scf::ForOp>(
+        loc, c0, toRetainSize, c1, std::nullopt,
+        [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
+          builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref,
+                                          i);
+          builder.create<scf::YieldOp>(loc);
+        });
+
+    builder.create<scf::ForOp>(
+        loc, c0, toDeallocSize, c1, std::nullopt,
+        [&](OpBuilder &builder, Location loc, Value outerIter,
+            ValueRange iterArgs) {
+          Value toDealloc =
+              builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
+          Value cond =
+              builder.create<memref::LoadOp>(loc, conditionMemref, outerIter);
+
+          // Build the first for loop that computes aliasing with retained
+          // memrefs.
+          Value noRetainAlias =
+              builder
+                  .create<scf::ForOp>(
+                      loc, c0, toRetainSize, c1, trueValue,
+                      [&](OpBuilder &builder, Location loc, Value i,
+                          ValueRange iterArgs) {
+                        Value retainValue = builder.create<memref::LoadOp>(
+                            loc, toRetainMemref, i);
+                        Value doesAlias = builder.create<arith::CmpIOp>(
+                            loc, arith::CmpIPredicate::eq, retainValue,
+                            toDealloc);
+                        builder.create<scf::IfOp>(
+                            loc, doesAlias,
+                            [&](OpBuilder &builder, Location loc) {
+                              Value retainCondValue =
+                                  builder.create<memref::LoadOp>(
+                                      loc, retainCondsMemref, i);
+                              Value aggregatedRetainCond =
+                                  builder.create<arith::OrIOp>(
+                                      loc, retainCondValue, cond);
+                              builder.create<memref::StoreOp>(
+                                  loc, aggregatedRetainCond, retainCondsMemref,
+                                  i);
+                              builder.create<scf::YieldOp>(loc);
+                            });
+                        Value doesntAlias = builder.create<arith::CmpIOp>(
+                            loc, arith::CmpIPredicate::ne, retainValue,
+                            toDealloc);
+                        Value yieldValue = builder.create<arith::AndIOp>(
+                            loc, iterArgs[0], doesntAlias);
+                        builder.create<scf::YieldOp>(loc, yieldValue);
+                      })
+                  .getResult(0);
+
+          // Build the second for loop that adds aliasing with previously
+          // deallocated memrefs.
+          Value noAlias =
+              builder
+                  .create<scf::ForOp>(
+                      loc, c0, outerIter, c1, noRetainAlias,
+                      [&](OpBuilder &builder, Location loc, Value i,
+                          ValueRange iterArgs) {
+                        Value prevDeallocValue = builder.create<memref::LoadOp>(
+                            loc, toDeallocMemref, i);
+                        Value doesntAlias = builder.create<arith::CmpIOp>(
+                            loc, arith::CmpIPredicate::ne, prevDeallocValue,
+                            toDealloc);
+                        Value yieldValue = builder.create<arith::AndIOp>(
+                            loc, iterArgs[0], doesntAlias);
+                        builder.create<scf::YieldOp>(loc, yieldValue);
+                      })
+                  .getResult(0);
+
+          Value shouldDealoc =
+              builder.create<arith::AndIOp>(loc, noAlias, cond);
+          builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
+                                          outerIter);
+          builder.create<scf::YieldOp>(loc);
+        });
+
+    builder.create<func::ReturnOp>(loc);
     return helperFuncOp;
   }
 

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 3747ab1562ff42..e981b80a60083a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -755,7 +755,8 @@ LogicalResult DeallocOp::inferReturnTypes(
     ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
     RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
   DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
-  inferredReturnTypes = SmallVector<Type>(adaptor.getConditions().getTypes());
+  inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
+                                          IntegerType::get(context, 1));
   return success();
 }
 
@@ -766,44 +767,46 @@ LogicalResult DeallocOp::verify() {
   return success();
 }
 
+static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
+                                            ArrayRef<Value> memrefs,
+                                            ArrayRef<Value> conditions,
+                                            PatternRewriter &rewriter) {
+  if (deallocOp.getMemrefs() == memrefs)
+    return failure();
+
+  rewriter.updateRootInPlace(deallocOp, [&]() {
+    deallocOp.getMemrefsMutable().assign(memrefs);
+    deallocOp.getConditionsMutable().assign(conditions);
+  });
+  return success();
+}
+
 namespace {
 
-/// Remove duplicate values in the list of retained memrefs as well as the list
-/// of memrefs to be deallocated. For the latter, we need to make sure the
-/// corresponding condition values match as well, or otherwise have to combine
-/// them (by computing the disjunction of them).
+/// Remove duplicate values in the list of memrefs to be deallocated. We need to
+/// make sure the corresponding condition value is updated accordingly since
+/// their two conditions might not cover the same set of cases. In that case, we
+/// have to combine them (by computing the disjunction of them).
 /// Example:
 /// ```mlir
-/// %0:2 = bufferization.dealloc (%arg0, %arg0 : ...)
-///                           if (%arg1, %arg2)
-///                       retain (%arg3, %arg3 : ...)
+/// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
 /// ```
 /// is canonicalized to
 /// ```mlir
 /// %0 = arith.ori %arg1, %arg2 : i1
-/// %1 = bufferization.dealloc (%arg0 : memref<2xi32>)
-///                         if (%0)
-///                     retain (%arg3 : memref<2xi32>)
+/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
 /// ```
-struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
+struct DeallocRemoveDuplicateDeallocMemrefs
+    : public OpRewritePattern<DeallocOp> {
   using OpRewritePattern<DeallocOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
     // Unique memrefs to be deallocated.
-    DenseSet<Value> retained(deallocOp.getRetained().begin(),
-                             deallocOp.getRetained().end());
     DenseMap<Value, unsigned> memrefToCondition;
-    SmallVector<Value> newMemrefs, newConditions, newRetained;
-    SmallVector<int32_t> resultIndices(deallocOp.getMemrefs().size(), -1);
+    SmallVector<Value> newMemrefs, newConditions;
     for (auto [i, memref, cond] :
          llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
-      if (retained.contains(memref)) {
-        rewriter.replaceAllUsesWith(deallocOp.getResult(i),
-                                    deallocOp.getConditions()[i]);
-        continue;
-      }
-
       if (memrefToCondition.count(memref)) {
         // If the dealloc conditions don't match, we need to make sure that the
         // dealloc happens on the union of cases.
@@ -816,42 +819,121 @@ struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
         newMemrefs.push_back(memref);
         newConditions.push_back(cond);
       }
-      resultIndices[i] = memrefToCondition[memref];
     }
 
+    // Return failure if we don't change anything such that we don't run into an
+    // infinite loop of pattern applications.
+    return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+                                  rewriter);
+  }
+};
+
+/// Remove duplicate values in the list of retained memrefs. We need to make
+/// sure the corresponding result condition value is replaced properly.
+/// Example:
+/// ```mlir
+/// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
+/// ```
+/// is canonicalized to
+/// ```mlir
+/// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
+/// ```
+struct DeallocRemoveDuplicateRetainedMemrefs
+    : public OpRewritePattern<DeallocOp> {
+  using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DeallocOp deallocOp,
+                                PatternRewriter &rewriter) const override {
     // Unique retained values
-    DenseSet<Value> seen;
+    DenseMap<Value, unsigned> seen;
+    SmallVector<Value> newRetained;
+    SmallVector<unsigned> resultReplacementIdx;
+    unsigned i = 0;
     for (auto retained : deallocOp.getRetained()) {
-      if (!seen.contains(retained)) {
-        seen.insert(retained);
-        newRetained.push_back(retained);
+      if (seen.count(retained)) {
+        resultReplacementIdx.push_back(seen[retained]);
+        continue;
       }
+
+      seen[retained] = i;
+      newRetained.push_back(retained);
+      resultReplacementIdx.push_back(i++);
     }
 
     // Return failure if we don't change anything such that we don't run into an
     // infinite loop of pattern applications.
-    if (newConditions.size() == deallocOp.getConditions().size() &&
-        newRetained.size() == deallocOp.getRetained().size())
+    if (newRetained.size() == deallocOp.getRetained().size())
       return failure();
 
     // We need to create a new op because the number of results is always the
     // same as the number of condition operands.
-    auto newDealloc = rewriter.create<DeallocOp>(deallocOp.getLoc(), newMemrefs,
-                                                 newConditions, newRetained);
-    for (auto [i, newIdx] : llvm::enumerate(resultIndices))
-      if (newIdx != -1)
-        rewriter.replaceAllUsesWith(deallocOp.getResult(i),
-                                    newDealloc.getResult(newIdx));
-
-    rewriter.eraseOp(deallocOp);
+    auto newDeallocOp =
+        rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
+                                   deallocOp.getConditions(), newRetained);
+    SmallVector<Value> replacements(
+        llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
+          return newDeallocOp.getUpdatedConditions()[idx];
+        }));
+    rewriter.replaceOp(deallocOp, replacements);
     return success();
   }
 };
 
+/// Remove memrefs to be deallocated that are also present in the retained list
+/// since they will always alias and thus never actually be deallocated.
+/// Example:
+/// ```mlir
+/// %0 = bufferization.dealloc (%arg0 : ...) if (%arg1) retain (%arg0 : ...)
+/// ```
+/// is canonicalized to
+/// ```mlir
+/// %0 = bufferization.dealloc retain (%arg0 : ...)
+/// ```
+struct DeallocRemoveDeallocMemrefsContainedInRetained
+    : public OpRewritePattern<DeallocOp> {
+  using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DeallocOp deallocOp,
+                                PatternRewriter &rewriter) const override {
+    // Unique memrefs to be deallocated.
+    DenseMap<Value, unsigned> retained;
+    for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained()))
+      retained[ret] = i;
+
+    // There must not be any duplicates in the retain list anymore because we
+    // would miss updating one of the result values otherwise.
+    if (retained.size() != deallocOp.getRetained().size())
+      return failure();
+
+    SmallVector<Value> newMemrefs, newConditions;
+    for (auto [memref, cond] :
+         llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+      if (retained.contains(memref)) {
+        rewriter.setInsertionPointAfter(deallocOp);
+        auto orOp = rewriter.create<arith::OrIOp>(
+            deallocOp.getLoc(),
+            deallocOp.getUpdatedConditions()[retained[memref]], cond);
+        rewriter.replaceAllUsesExcept(
+            deallocOp.getUpdatedConditions()[retained[memref]],
+            orOp.getResult(), orOp);
+        continue;
+      }
+
+      newMemrefs.push_back(memref);
+      newConditions.push_back(cond);
+    }
+
+    // Return failure if we don't change anything such that we don't run into an
+    // infinite loop of pattern applications.
+    return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+                                  rewriter);
+  }
+};
+
 /// Erase deallocation operations where the variadic list of memrefs to
-/// deallocate is emtpy. Example:
+/// deallocate is empty. Example:
 /// ```mlir
-/// bufferization.dealloc retain (%arg0: memref<2xi32>)
+/// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
 /// ```
 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
   using OpRewritePattern<DeallocOp>::OpRewritePattern;
@@ -859,7 +941,11 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
     if (deallocOp.getMemrefs().empty()) {
-      rewriter.eraseOp(deallocOp);
+      Value constFalse = rewriter.create<arith::ConstantOp>(
+          deallocOp.getLoc(), rewriter.getBoolAttr(false));
+      rewriter.replaceOp(
+          deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
+                                        constFalse));
       return success();
     }
     return failure();
@@ -871,12 +957,12 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
 ///
 /// Example:
 /// ```
-/// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
+/// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
 ///                           if (%arg2, %false)
 /// ```
 /// becomes
 /// ```
-/// %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
+/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
 /// ```
 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
   using OpRewritePattern<DeallocOp>::OpRewritePattern;
@@ -884,32 +970,16 @@ struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Value> newMemrefs, newConditions;
-    SmallVector<Value> replacements;
-
-    for (auto [res, memref, cond] :
-         llvm::zip(deallocOp.getUpdatedConditions(), deallocOp.getMemrefs(),
-                   deallocOp.getConditions())) {
-      if (matchPattern(cond, m_Zero())) {
-        replacements.push_back(cond);
-        continue;
+    for (auto [memref, cond] :
+         llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+      if (!matchPattern(cond, m_Zero())) {
+        newMemrefs.push_back(memref);
+        newConditions.push_back(cond);
       }
-      newMemrefs.push_back(memref);
-      newConditions.push_back(cond);
-      replacements.push_back({});
     }
 
-    if (newMemrefs.size() == deallocOp.getMemrefs().size())
-      return failure();
-
-    auto newDeallocOp = rewriter.create<DeallocOp>(
-        deallocOp.getLoc(), newMemrefs, newConditions, deallocOp.getRetained());
-    unsigned i = 0;
-    for (auto &repl : replacements)
-      if (!repl)
-        repl = newDeallocOp.getResult(i++);
-
-    rewriter.replaceOp(deallocOp, replacements);
-    return success();
+    return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+                                  rewriter);
   }
 };
 
@@ -917,9 +987,10 @@ struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
 
 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results
-      .add<DeallocRemoveDuplicates, EraseEmptyDealloc, EraseAlwaysFalseDealloc>(
-          context);
+  results.add<DeallocRemoveDuplicateDeallocMemrefs,
+              DeallocRemoveDuplicateRetainedMemrefs,
+              DeallocRemoveDeallocMemrefsContainedInRetained, EraseEmptyDealloc,
+              EraseAlwaysFalseDealloc>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
index aed232d29d175c..0f80e9a1140e30 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -81,78 +81,106 @@ func.func @conversion_dealloc_empty() {
 // CHECK-LABEL: func @conversion_dealloc_simple
 // CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>
 // CHECK-SAME: [[ARG1:%.+]]: i1
-func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) -> i1 {
-  %0 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
-  return %0 : i1
+func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) {
+  bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+  return
 }
 
 //      CHECk: scf.if [[ARG1]] {
 // CHECk-NEXT:   memref.dealloc [[ARG0]] : memref<2xf32>
 // CHECk-NEXT: }
-// CHECk-NEXT: [[FALSE:%.+]] = arith.constant false
-// CHECk-NEXT: return [[FALSE]] : i1
+// CHECk-NEXT: return
 
 // -----
 
-func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1) -> (i1, i1) {
-  %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2 : memref<1xf32>)
+func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+  %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
   return %0#0, %0#1 : i1, i1
 }
 
 // CHECK-LABEL: func @conversion_dealloc_multiple_memrefs_and_retained
-// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>,
-// CHECK-SAME: [[ARG1:%.+]]: memref<5xf32>,
-// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>,
-// CHECK-SAME: [[ARG3:%.+]]: i1,
-// CHECK-SAME: [[ARG4:%.+]]: i1
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<5xf32>,
+// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1,
+// CHECK-SAME: [[ARG5:%.+]]: memref<2xf32>)
 //      CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloc() : memref<2xindex>
-//      CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<1xindex>
+//      CHECK: [[CONDS:%.+]] = memref.alloc() : memref<2xi1>
+//      CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<2xindex>
 //  CHECK-DAG: [[V0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
 //  CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
 //  CHECK-DAG: memref.store [[V0]], [[TO_DEALLOC_MR]][[[C0]]]
 //  CHECK-DAG: [[V1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]]
 //  CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
 //  CHECK-DAG: memref.store [[V1]], [[TO_DEALLOC_MR]][[[C1]]]
+//  CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+//  CHECK-DAG: memref.store [[ARG3]], [[CONDS]][[[C0]]]
+//  CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+//  CHECK-DAG: memref.store [[ARG4]], [[CONDS]][[[C1]]]
 //  CHECK-DAG: [[V2:%.+]] = memref.extract_aligned_pointer_as_index [[ARG2]]
 //  CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
 //  CHECK-DAG: memref.store [[V2]], [[TO_RETAIN_MR]][[[C0]]]
+//  CHECK-DAG: [[V3:%.+]] = memref.extract_aligned_pointer_as_index [[ARG5]]
+//  CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+//  CHECK-DAG: memref.store [[V3]], [[TO_RETAIN_MR]][[[C1]]]
 //  CHECK-DAG: [[CAST_DEALLOC:%.+]] = memref.cast [[TO_DEALLOC_MR]] : memref<2xindex> to memref<?xindex>
-//  CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<1xindex> to memref<?xindex>
-//  CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
-//      CHECK: [[RES0:%.+]]:2 = call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[C0]])
-//      CHECK: [[SHOULD_DEALLOC_0:%.+]] = arith.andi [[RES0]]#0, [[ARG3]]
-//      CHECK: [[OWNERSHIP0:%.+]] = arith.andi [[RES0]]#1, [[ARG3]]
+//  CHECK-DAG: [[CAST_CONDS:%.+]] = memref.cast [[CONDS]] : memref<2xi1> to memref<?xi1>
+//  CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<2xindex> to memref<?xindex>
+//      CHECK: [[DEALLOC_CONDS:%.+]] = memref.alloc() : memref<2xi1>
+//      CHECK: [[RETAIN_CONDS:%.+]] = memref.alloc() : memref<2xi1>
+//      CHECK: [[CAST_DEALLOC_CONDS:%.+]] = memref.cast [[DEALLOC_CONDS]] : memref<2xi1> to memref<?xi1>
+//      CHECK: [[CAST_RETAIN_CONDS:%.+]] = memref.cast [[RETAIN_CONDS]] : memref<2xi1> to memref<?xi1>
+//      CHECK: call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[CAST_CONDS]], [[CAST_DEALLOC_CONDS]], [[CAST_RETAIN_CONDS]])
+//      CHECK: [[C0:%.+]] = arith.constant 0 : index
+//      CHECK: [[SHOULD_DEALLOC_0:%.+]] = memref.load [[DEALLOC_CONDS]][[[C0]]]
 //      CHECK: scf.if [[SHOULD_DEALLOC_0]] {
 //      CHECK:   memref.dealloc %arg0
 //      CHECK: }
 //      CHECK: [[C1:%.+]] = arith.constant 1 : index
-//      CHECK: [[RES1:%.+]]:2 = call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[C1]])
-//      CHECK: [[SHOULD_DEALLOC_1:%.+]] = arith.andi [[RES1:%.+]]#0, [[ARG4]]
-//      CHECK: [[OWNERSHIP1:%.+]] = arith.andi [[RES1:%.+]]#1, [[ARG4]]
+//      CHECK: [[SHOULD_DEALLOC_1:%.+]] = memref.load [[DEALLOC_CONDS]][[[C1]]]
 //      CHECK: scf.if [[SHOULD_DEALLOC_1]]
 //      CHECK:   memref.dealloc [[ARG1]]
 //      CHECK: }
+//      CHECK: [[C0:%.+]] = arith.constant 0 : index
+//      CHECK: [[OWNERSHIP0:%.+]] = memref.load [[RETAIN_CONDS]][[[C0]]]
+//      CHECK: [[C1:%.+]] = arith.constant 1 : index
+//      CHECK: [[OWNERSHIP1:%.+]] = memref.load [[RETAIN_CONDS]][[[C1]]]
 //      CHECK: memref.dealloc [[TO_DEALLOC_MR]]
 //      CHECK: memref.dealloc [[TO_RETAIN_MR]]
+//      CHECK: memref.dealloc [[CONDS]]
+//      CHECK: memref.dealloc [[DEALLOC_CONDS]]
+//      CHECK: memref.dealloc [[RETAIN_CONDS]]
 //      CHECK: return [[OWNERSHIP0]], [[OWNERSHIP1]]
 
 //      CHECK: func @dealloc_helper
-// CHECK-SAME: [[ARG0:%.+]]: memref<?xindex>, [[ARG1:%.+]]: memref<?xindex>
-// CHECK-SAME: [[ARG2:%.+]]: index
-// CHECK-SAME:  -> (i1, i1)
-//      CHECK:   [[TO_RETAIN_SIZE:%.+]] = memref.dim [[ARG1]], %c0
-//      CHECK:   [[TO_DEALLOC:%.+]] = memref.load [[ARG0]][[[ARG2]]] : memref<?xindex>
-// CHECK-NEXT:   [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) {
-// CHECK-NEXT:     [[RETAIN_VAL:%.+]] = memref.load [[ARG1]][[[ITER]]] : memref<?xindex>
-// CHECK-NEXT:     [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
-// CHECK-NEXT:     [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
-// CHECK-NEXT:     scf.yield [[AGG_DOES_ALIAS]] : i1
+// CHECK-SAME: ([[TO_DEALLOC_MR:%.+]]: memref<?xindex>, [[TO_RETAIN_MR:%.+]]: memref<?xindex>,
+// CHECK-SAME: [[CONDS:%.+]]: memref<?xi1>, [[DEALLOC_CONDS_OUT:%.+]]: memref<?xi1>,
+// CHECK-SAME: [[RETAIN_CONDS_OUT:%.+]]: memref<?xi1>)
+//      CHECK:   [[TO_DEALLOC_SIZE:%.+]] = memref.dim [[TO_DEALLOC_MR]], %c0
+//      CHECK:   [[TO_RETAIN_SIZE:%.+]] = memref.dim [[TO_RETAIN_MR]], %c0
+//      CHECK:   scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 {
+// CHECK-NEXT:     memref.store %false, [[RETAIN_CONDS_OUT]][[[ITER]]]
 // CHECK-NEXT:   }
-// CHECK-NEXT:   [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[ARG2]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) {
-// CHECK-NEXT:     [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref<?xindex>
-// CHECK-NEXT:     [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index
-// CHECK-NEXT:     [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
-// CHECK-NEXT:     scf.yield [[AGG_DOES_ALIAS]] : i1
+//      CHECK:   scf.for [[OUTER_ITER:%.+]] = %c0 to [[TO_DEALLOC_SIZE]] step %c1 {
+//      CHECK:     [[TO_DEALLOC:%.+]] = memref.load [[TO_DEALLOC_MR]][[[OUTER_ITER]]]
+// CHECK-NEXT:     [[COND:%.+]] = memref.load [[CONDS]][[[OUTER_ITER]]]
+// CHECK-NEXT:     [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) {
+// CHECK-NEXT:       [[RETAIN_VAL:%.+]] = memref.load [[TO_RETAIN_MR]][[[ITER]]] : memref<?xindex>
+// CHECK-NEXT:       [[DOES_ALIAS:%.+]] = arith.cmpi eq, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT:       scf.if [[DOES_ALIAS]]
+// CHECK-NEXT:         [[RETAIN_COND:%.+]] = memref.load [[RETAIN_CONDS_OUT]][[[ITER]]]
+// CHECK-NEXT:         [[AGG_RETAIN_COND:%.+]] = arith.ori [[RETAIN_COND]], [[COND]] : i1
+// CHECK-NEXT:         memref.store [[AGG_RETAIN_COND]], [[RETAIN_CONDS_OUT]][[[ITER]]]
+// CHECK-NEXT:       }
+// CHECK-NEXT:       [[DOES_NOT_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT:       [[AGG_DOES_NOT_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_NOT_ALIAS]] : i1
+// CHECK-NEXT:       scf.yield [[AGG_DOES_NOT_ALIAS]] : i1
+// CHECK-NEXT:     }
+// CHECK-NEXT:     [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[OUTER_ITER]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) {
+// CHECK-NEXT:       [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref<?xindex>
+// CHECK-NEXT:       [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT:       [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
+// CHECK-NEXT:       scf.yield [[AGG_DOES_ALIAS]] : i1
+// CHECK-NEXT:     }
+// CHECK-NEXT:     [[DEALLOC_COND:%.+]] = arith.andi [[SHOULD_DEALLOC]], [[COND]] : i1
+// CHECK-NEXT:     memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
 // CHECK-NEXT:   }
-// CHECK-NEXT:   [[OWNERSHIP:%.+]] = arith.xori [[NO_RETAIN_ALIAS]], %true : i1
-// CHECK-NEXT:   return [[SHOULD_DEALLOC]], [[OWNERSHIP]] : i1, i1
+// CHECK-NEXT:   return

diff  --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 96f82f6835dd61..6a4edf1e6335f1 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -282,44 +282,44 @@ func.func @dealloc_canonicalize_clone_removal(%arg0: memref<?xindex>) -> memref<
 
 // -----
 
-func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1, i1, i1) {
+func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1) {
   %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg5, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>)
-  %1:2 = bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2)
-  return %0#0, %0#1, %0#2, %1#0, %1#1 : i1, i1, i1, i1, i1
+  bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2)
+  return %0#0, %0#1, %0#2 : i1, i1, i1
 }
 
 // CHECK-LABEL: func @dealloc_canonicalize_duplicates
 //  CHECK-SAME:  ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>, [[ARG5:%.+]]: memref<2xi32>)
 //  CHECK-NEXT:   [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG5]] : memref<2xi32>, memref<2xi32>)
 //  CHECK-NEXT:   [[NEW_COND:%.+]] = arith.ori [[ARG1]], [[ARG2]] : i1
-//  CHECK-NEXT:   [[V1:%.+]] = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]])
-//  CHECK-NEXT:   return [[V0]]#0, [[V0]]#1, [[V0]]#1, [[V1]], [[V1]] :
+//  CHECK-NEXT:   bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]])
+//  CHECK-NEXT:   return [[V0]]#0, [[V0]]#1, [[V0]]#0 :
 
 // -----
 
-func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1) {
+func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1) {
   %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
-  %1:2 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
+  %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
   bufferization.dealloc
   bufferization.dealloc retain (%arg0 : memref<2xi32>)
-  return %0, %1#0, %1#1 : i1, i1, i1
+  return %0, %1 : i1, i1
 }
 
 // CHECK-LABEL: func @dealloc_canonicalize_retained_and_deallocated
 //  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
 //  CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
-//  CHECK-NEXT: return [[ARG1]], [[ARG1]], [[V0]] :
+//  CHECK-NEXT: [[V1:%.+]] = arith.ori [[V0]], [[ARG1]]
+//  CHECK-NEXT: return [[ARG1]], [[V1]] :
 
 // -----
 
-func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) -> (i1, i1) {
+func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) {
   %false = arith.constant false
-  %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2)
-  return %0#0, %0#1 : i1, i1
+  bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2)
+  return
 }
 
 // CHECK-LABEL: func @dealloc_always_false_condition
 //  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1)
-//  CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
-//  CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]])
-//  CHECK-NEXT: return [[FALSE]], [[V0]] :
+//  CHECK-NEXT: bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]])
+//  CHECK-NEXT: return

diff  --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 8dded8b4debe3e..3b4bfee5622e9b 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -106,8 +106,16 @@ func.func @invalid_tensor_copy(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
 
 // -----
 
-func.func @invalid_dealloc_memref_condition_mismatch(%arg0: memref<2xf32>, %arg1: memref<4xi32>, %arg2: i1) -> i1 {
+func.func @invalid_dealloc_memref_condition_mismatch(%arg0: memref<2xf32>, %arg1: memref<4xi32>, %arg2: i1) {
   // expected-error @below{{must have the same number of conditions as memrefs to deallocate}}
-  %0 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2)
-  return %0 : i1
+  bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2)
+  return
+}
+
+// -----
+
+func.func @invalid_dealloc_wrong_number_of_results(%arg0: memref<2xf32>, %arg1: memref<4xi32>, %arg2: i1) -> i1 {
+  // expected-error @below{{operation defines 1 results but was provided 2 to bind}}
+  %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg2) retain (%arg1 : memref<4xi32>)
+  return %0#0 : i1
 }

diff  --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index 55af9ef37384a9..773f15c1ffcb89 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -73,8 +73,8 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>,
   // CHECK: bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg3) retain (%arg4, %arg5 : memref<?xf32>, memref<*xf64>)
   %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg3) retain (%arg4, %arg5 : memref<?xf32>, memref<*xf64>)
   // CHECK: bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2)
-  %1 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2)
+  bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2)
   // CHECK: bufferization.dealloc
   bufferization.dealloc
-  return %0, %1 : i1, i1
+  return %0#0, %0#1 : i1, i1
 }


        


More information about the Mlir-commits mailing list