[Mlir-commits] [mlir] 950f094 - [mlir][bufferization] Factor out bufferization.dealloc lowering into separate pass
Martin Erhart
llvmlistbot at llvm.org
Thu Aug 31 00:10:48 PDT 2023
Author: Martin Erhart
Date: 2023-08-31T07:10:31Z
New Revision: 950f0944c95a5416415786487a1263de2dfcec13
URL: https://github.com/llvm/llvm-project/commit/950f0944c95a5416415786487a1263de2dfcec13
DIFF: https://github.com/llvm/llvm-project/commit/950f0944c95a5416415786487a1263de2dfcec13.diff
LOG: [mlir][bufferization] Factor out bufferization.dealloc lowering into separate pass
Moves the lowering of `bufferization.dealloc` to memref into a separate pass,
but still registers the pattern in the conversion pass. This is helpful when
some tensor values (and thus `to_memref` or `to_tensor` operations) still
remain, e.g., when the function boundaries are not converted, or when constant
tensors are converted to memref.get_global at a later point.
However, it is still recommended to perform all bufferization before
deallocation to avoid memory leaks as all memref allocations inserted after the
deallocation pass was applied, have to be handled manually.
Note: The buffer deallocation pass assumes that memref values defined by
`bufferization.to_memref` don't return ownership and don't have to be
deallocated. `bufferization.to_tensor` operations are handled similarly to
`bufferization.clone` operations with the exception that the result value is
not handled because it's a tensor (not a memref).
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D159180
Added:
mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
mlir/test/Dialect/Bufferization/Transforms/lower-deallocations-func.mlir
mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 39dd075bc46f17..eaf016bde69e3b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -195,12 +195,14 @@ def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
This pass converts bufferization operations into memref operations.
In the current state, this pass only transforms a `bufferization.clone`
- operation into `memref.alloc` and `memref.copy` operations. This conversion
- is needed, since some clone operations could remain after applying several
- transformation processes. Currently, only `canonicalize` transforms clone
- operations or even eliminates them. This can lead to errors if any clone op
- survived after all conversion passes (starting from the bufferization
- dialect) are performed.
+ operation into `memref.alloc` and `memref.copy` operations and
+ `bufferization.dealloc` operations (the same way as the
+ `-bufferization-lower-deallocations` pass). The conversion of `clone`
+ operations is needed, since some clone operations could remain after
+ applying several transformation processes. Currently, only `canonicalize`
+ transforms clone operations or even eliminates them. This can lead to errors
+ if any clone op survived after all conversion passes (starting from the
+ bufferization dialect) are performed.
See:
https://llvm.discourse.group/t/bufferization-error-related-to-memref-clone/4665
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 8c09ffe17bebd0..b0b62acffe77a2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -5,6 +5,9 @@
namespace mlir {
class ModuleOp;
+class RewritePatternSet;
+class OpBuilder;
+class SymbolTable;
namespace func {
class FuncOp;
@@ -29,6 +32,98 @@ std::unique_ptr<Pass> createBufferDeallocationPass();
/// static alias analysis.
std::unique_ptr<Pass> createBufferDeallocationSimplificationPass();
+/// Creates an instance of the LowerDeallocations pass to lower
+/// `bufferization.dealloc` operations to the `memref` dialect.
+std::unique_ptr<Pass> createLowerDeallocationsPass();
+
+/// Adds the conversion pattern of the `bufferization.dealloc` operation to the
+/// given pattern set for use in other transformation passes.
+void populateBufferizationDeallocLoweringPattern(
+ RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
+
+/// Construct the library function needed for the fully generic
+/// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
+/// The function can then be called at bufferization dealloc sites to determine
+/// aliasing and ownership.
+///
+/// The generated function takes two memrefs of indices and three memrefs of
+/// booleans as arguments:
+/// * The first argument A should contain the result of the
+/// extract_aligned_pointer_as_index operation applied to the memrefs to be
+/// deallocated
+/// * The second argument B should contain the result of the
+/// extract_aligned_pointer_as_index operation applied to the memrefs to be
+/// retained
+/// * The third argument C should contain the conditions as passed directly
+/// to the deallocation operation.
+/// * The fourth argument D is used to pass results to the caller. Those
+/// represent the condition under which the memref at the corresponding
+/// position in A should be deallocated.
+/// * The fifth argument E is used to pass results to the caller. It
+/// provides the ownership value corresponding the the memref at the same
+/// position in B
+///
+/// This helper function is supposed to be called once for each
+/// `bufferization.dealloc` operation to determine the deallocation need and new
+/// ownership indicator for the retained values, but does not perform the
+/// deallocation itself.
+///
+/// Generated code:
+/// ```
+/// func.func @dealloc_helper(
+/// %dyn_dealloc_base_pointer_list: memref<?xindex>,
+/// %dyn_retain_base_pointer_list: memref<?xindex>,
+/// %dyn_cond_list: memref<?xi1>,
+/// %dyn_dealloc_cond_out: memref<?xi1>,
+/// %dyn_ownership_out: memref<?xi1>) {
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %true = arith.constant true
+/// %false = arith.constant false
+/// %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0
+/// %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0
+/// // Zero initialize result buffer.
+/// scf.for %i = %c0 to %num_retain_memrefs step %c1 {
+/// memref.store %false, %dyn_ownership_out[%i] : memref<?xi1>
+/// }
+/// scf.for %i = %c0 to %num_dealloc_memrefs step %c1 {
+/// %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i]
+/// %cond = memref.load %dyn_cond_list[%i]
+/// // Check for aliasing with retained memrefs.
+/// %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs
+/// step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) {
+/// %retain_bp = memref.load %dyn_retain_base_pointer_list[%j]
+/// %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index
+/// scf.if %does_alias {
+/// %curr_ownership = memref.load %dyn_ownership_out[%j]
+/// %updated_ownership = arith.ori %curr_ownership, %cond : i1
+/// memref.store %updated_ownership, %dyn_ownership_out[%j]
+/// }
+/// %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index
+/// %updated_aggregate = arith.andi %does_not_alias_aggregated,
+/// %does_not_alias : i1
+/// scf.yield %updated_aggregate : i1
+/// }
+/// // Check for aliasing with dealloc memrefs in the list before the
+/// // current one, i.e.,
+/// // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j],
+/// // %dyn_dealloc_base_pointer[i])`
+/// %does_not_alias_any = scf.for %j = %c0 to %i step %c1
+/// iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) {
+/// %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j]
+/// %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp
+/// %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias
+/// scf.yield %updated_alias_agg : i1
+/// }
+/// %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1
+/// memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref<?xi1>
+/// }
+/// return
+/// }
+/// ```
+func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc,
+ SymbolTable &symbolTable);
+
/// Run buffer deallocation.
LogicalResult deallocateBuffers(Operation *op);
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 7f1474e26be432..df9bfcbfc54880 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -108,6 +108,29 @@ def BufferDeallocationSimplification :
];
}
+def LowerDeallocations : Pass<"bufferization-lower-deallocations"> {
+ let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`"
+ "operations";
+ let description = [{
+ This pass lowers `bufferization.dealloc` operations to the `memref` dialect.
+ It can be applied to a `builtin.module` or operations implementing the
+ `FunctionOpInterface`. For the latter, only simple `dealloc` operations can
+ be lowered because the library function necessary for the fully generic
+ lowering cannot be inserted. In this case, an error will be emitted.
+ Next to `memref.dealloc` operations, it may also emit operations from the
+ `arith`, `scf`, and `func` dialects to build conditional deallocations and
+ library functions to avoid code-size blow-up.
+ }];
+
+ let constructor =
+ "mlir::bufferization::createLowerDeallocationsPass()";
+
+ let dependentDialects = [
+ "arith::ArithDialect", "memref::MemRefDialect", "scf::SCFDialect",
+ "func::FuncDialect"
+ ];
+}
+
def BufferHoisting : Pass<"buffer-hoisting", "func::FuncOp"> {
let summary = "Optimizes placement of allocation operations by moving them "
"into common dominators and out of nested regions";
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 5c83ae2b36bb28..3069f6e0732400 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -80,543 +81,6 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
}
};
-/// The DeallocOpConversion transforms all bufferization dealloc operations into
-/// memref dealloc operations potentially guarded by scf if operations.
-/// Additionally, memref extract_aligned_pointer_as_index and arith operations
-/// are inserted to compute the guard conditions. We distinguish multiple cases
-/// to provide an overall more efficient lowering. In the general case, a helper
-/// func is created to avoid quadratic code size explosion (relative to the
-/// number of operands of the dealloc operation). For examples of each case,
-/// refer to the documentation of the member functions of this class.
-class DeallocOpConversion
- : public OpConversionPattern<bufferization::DeallocOp> {
-
- /// Lower a simple case without any retained values and a single memref to
- /// avoiding the helper function. Ideally, static analysis can provide enough
- /// aliasing information to split the dealloc operations up into this simple
- /// case as much as possible before running this pass.
- ///
- /// Example:
- /// ```
- /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
- /// ```
- /// is lowered to
- /// ```
- /// scf.if %arg1 {
- /// memref.dealloc %arg0 : memref<2xf32>
- /// }
- /// ```
- LogicalResult
- rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
- assert(adaptor.getRetained().empty() && "expected no retained memrefs");
-
- rewriter.replaceOpWithNewOp<scf::IfOp>(
- op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) {
- builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
- builder.create<scf::YieldOp>(loc);
- });
- return success();
- }
-
- /// A special case lowering for the deallocation operation with exactly one
- /// memref, but arbitrary number of retained values. This avoids the helper
- /// function that the general case needs and thus also avoids storing indices
- /// to specifically allocated memrefs. The size of the code produced by this
- /// lowering is linear to the number of retained values.
- ///
- /// Example:
- /// ```mlir
- /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond)
- // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
- /// return %0#0, %0#1 : i1, i1
- /// ```
- /// ```mlir
- /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m
- /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
- /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer
- /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
- /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer
- /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1
- /// %should_dealloc = arith.andi %not_retained, %cond : i1
- /// scf.if %should_dealloc {
- /// memref.dealloc %m : memref<2xf32>
- /// }
- /// %true = arith.constant true
- /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1
- /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1
- /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1
- /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1
- /// return %r0_ownership, %r1_ownership : i1, i1
- /// ```
- LogicalResult rewriteOneMemrefMultipleRetainCase(
- bufferization::DeallocOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
-
- // Compute the base pointer indices, compare all retained indices to the
- // memref index to check if they alias.
- SmallVector<Value> doesNotAliasList;
- Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(
- op->getLoc(), adaptor.getMemrefs()[0]);
- for (Value retained : adaptor.getRetained()) {
- Value retainedAsIdx =
- rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(),
- retained);
- Value doesNotAlias = rewriter.create<arith::CmpIOp>(
- op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
- doesNotAliasList.push_back(doesNotAlias);
- }
-
- // AND-reduce the list of booleans from above.
- Value prev = doesNotAliasList.front();
- for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front())
- prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias);
-
- // Also consider the condition given by the dealloc operation and perform a
- // conditional deallocation guarded by that value.
- Value shouldDealloc = rewriter.create<arith::AndIOp>(
- op->getLoc(), prev, adaptor.getConditions()[0]);
-
- rewriter.create<scf::IfOp>(
- op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
- builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
- builder.create<scf::YieldOp>(loc);
- });
-
- // Compute the replacement values for the dealloc operation results. This
- // inserts an already canonicalized form of
- // `select(does_alias_with_memref(r), memref_cond, false)` for each retained
- // value r.
- SmallVector<Value> replacements;
- Value trueVal = rewriter.create<arith::ConstantOp>(
- op->getLoc(), rewriter.getBoolAttr(true));
- for (Value doesNotAlias : doesNotAliasList) {
- Value aliases =
- rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal);
- Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases,
- adaptor.getConditions()[0]);
- replacements.push_back(result);
- }
-
- rewriter.replaceOp(op, replacements);
-
- return success();
- }
-
- /// Lowering that supports all features the dealloc operation has to offer. It
- /// computes the base pointer of each memref (as an index), stores it in a
- /// new memref helper structure and passes it to the helper function generated
- /// in 'buildDeallocationHelperFunction'. The results are stored in two lists
- /// (represented as memrefs) of booleans passed as arguments. The first list
- /// stores whether the corresponding condition should be deallocated, the
- /// second list stores the ownership of the retained values which can be used
- /// to replace the result values of the `bufferization.dealloc` operation.
- ///
- /// Example:
- /// ```
- /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>)
- /// if (%cond0, %cond1)
- /// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
- /// ```
- /// lowers to (simplified):
- /// ```
- /// %c0 = arith.constant 0 : index
- /// %c1 = arith.constant 1 : index
- /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex>
- /// %cond_list = memref.alloc() : memref<2xi1>
- /// %retain_base_pointer_list = memref.alloc() : memref<2xindex>
- /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0
- /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0]
- /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1
- /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1]
- /// memref.store %cond0, %cond_list[%c0]
- /// memref.store %cond1, %cond_list[%c1]
- /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
- /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0]
- /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
- /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1]
- /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list :
- /// memref<2xindex> to memref<?xindex>
- /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1>
- /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list :
- /// memref<2xindex> to memref<?xindex>
- /// %dealloc_cond_out = memref.alloc() : memref<2xi1>
- /// %ownership_out = memref.alloc() : memref<2xi1>
- /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out :
- /// memref<2xi1> to memref<?xi1>
- /// %dyn_ownership_out = memref.cast %ownership_out :
- /// memref<2xi1> to memref<?xi1>
- /// call @dealloc_helper(%dyn_dealloc_base_pointer_list,
- /// %dyn_retain_base_pointer_list,
- /// %dyn_cond_list,
- /// %dyn_dealloc_cond_out,
- /// %dyn_ownership_out) : (...)
- /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1>
- /// scf.if %m0_dealloc_cond {
- /// memref.dealloc %m0 : memref<2xf32>
- /// }
- /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1>
- /// scf.if %m1_dealloc_cond {
- /// memref.dealloc %m1 : memref<5xf32>
- /// }
- /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1>
- /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1>
- /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex>
- /// memref.dealloc %retain_base_pointer_list : memref<2xindex>
- /// memref.dealloc %cond_list : memref<2xi1>
- /// memref.dealloc %dealloc_cond_out : memref<2xi1>
- /// memref.dealloc %ownership_out : memref<2xi1>
- /// // replace %0#0 with %r0_ownership
- /// // replace %0#1 with %r1_ownership
- /// ```
- LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- // Allocate two memrefs holding the base pointer indices of the list of
- // memrefs to be deallocated and the ones to be retained. These can then be
- // passed to the helper function and the for-loops can iterate over them.
- // Without storing them to memrefs, we could not use for-loops but only a
- // completely unrolled version of it, potentially leading to code-size
- // blow-up.
- Value toDeallocMemref = rewriter.create<memref::AllocOp>(
- op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
- rewriter.getIndexType()));
- Value conditionMemref = rewriter.create<memref::AllocOp>(
- op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()},
- rewriter.getI1Type()));
- Value toRetainMemref = rewriter.create<memref::AllocOp>(
- op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
- rewriter.getIndexType()));
-
- auto getConstValue = [&](uint64_t value) -> Value {
- return rewriter.create<arith::ConstantOp>(op.getLoc(),
- rewriter.getIndexAttr(value));
- };
-
- // Extract the base pointers of the memrefs as indices to check for aliasing
- // at runtime.
- for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) {
- Value memrefAsIdx =
- rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
- toDealloc);
- rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
- toDeallocMemref, getConstValue(i));
- }
-
- for (auto [i, cond] : llvm::enumerate(adaptor.getConditions()))
- rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref,
- getConstValue(i));
-
- for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) {
- Value memrefAsIdx =
- rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
- toRetain);
- rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref,
- getConstValue(i));
- }
-
- // Cast the allocated memrefs to dynamic shape because we want only one
- // helper function no matter how many operands the bufferization.dealloc
- // has.
- Value castedDeallocMemref = rewriter.create<memref::CastOp>(
- op->getLoc(),
- MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
- toDeallocMemref);
- Value castedCondsMemref = rewriter.create<memref::CastOp>(
- op->getLoc(),
- MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
- conditionMemref);
- Value castedRetainMemref = rewriter.create<memref::CastOp>(
- op->getLoc(),
- MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
- toRetainMemref);
-
- Value deallocCondsMemref = rewriter.create<memref::AllocOp>(
- op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
- rewriter.getI1Type()));
- Value retainCondsMemref = rewriter.create<memref::AllocOp>(
- op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
- rewriter.getI1Type()));
-
- Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>(
- op->getLoc(),
- MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
- deallocCondsMemref);
- Value castedRetainCondsMemref = rewriter.create<memref::CastOp>(
- op->getLoc(),
- MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
- retainCondsMemref);
-
- rewriter.create<func::CallOp>(
- op.getLoc(), deallocHelperFunc,
- SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
- castedCondsMemref, castedDeallocCondsMemref,
- castedRetainCondsMemref});
-
- for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
- Value idxValue = getConstValue(i);
- Value shouldDealloc = rewriter.create<memref::LoadOp>(
- op.getLoc(), deallocCondsMemref, idxValue);
- rewriter.create<scf::IfOp>(
- op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
- builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
- builder.create<scf::YieldOp>(loc);
- });
- }
-
- SmallVector<Value> replacements;
- for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
- Value idxValue = getConstValue(i);
- Value ownership = rewriter.create<memref::LoadOp>(
- op.getLoc(), retainCondsMemref, idxValue);
- replacements.push_back(ownership);
- }
-
- // Deallocate above allocated memrefs again to avoid memory leaks.
- // Deallocation will not be run on code after this stage.
- rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
- rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
- rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref);
- rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref);
- rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref);
-
- rewriter.replaceOp(op, replacements);
- return success();
- }
-
-public:
- DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
- : OpConversionPattern<bufferization::DeallocOp>(context),
- deallocHelperFunc(deallocHelperFunc) {}
-
- LogicalResult
- matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Lower the trivial case.
- if (adaptor.getMemrefs().empty()) {
- Value falseVal = rewriter.create<arith::ConstantOp>(
- op.getLoc(), rewriter.getBoolAttr(false));
- rewriter.replaceOp(
- op, SmallVector<Value>(adaptor.getRetained().size(), falseVal));
- return success();
- }
-
- if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
- return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
-
- if (adaptor.getMemrefs().size() == 1)
- return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
-
- if (!deallocHelperFunc)
- return op->emitError(
- "library function required for generic lowering, but cannot be "
- "automatically inserted when operating on functions");
-
- return rewriteGeneralCase(op, adaptor, rewriter);
- }
-
- /// Build a helper function per compilation unit that can be called at
- /// bufferization dealloc sites to determine aliasing and ownership.
- ///
- /// The generated function takes two memrefs of indices and three memrefs of
- /// booleans as arguments:
- /// * The first argument A should contain the result of the
- /// extract_aligned_pointer_as_index operation applied to the memrefs to be
- /// deallocated
- /// * The second argument B should contain the result of the
- /// extract_aligned_pointer_as_index operation applied to the memrefs to be
- /// retained
- /// * The third argument C should contain the conditions as passed directly
- /// to the deallocation operation.
- /// * The fourth argument D is used to pass results to the caller. Those
- /// represent the condition under which the memref at the corresponding
- /// position in A should be deallocated.
- /// * The fifth argument E is used to pass results to the caller. It
- /// provides the ownership value corresponding the the memref at the same
- /// position in B
- ///
- /// This helper function is supposed to be called once for each
- /// `bufferization.dealloc` operation to determine the deallocation need and
- /// new ownership indicator for the retained values, but does not perform the
- /// deallocation itself.
- ///
- /// Generated code:
- /// ```
- /// func.func @dealloc_helper(
- /// %dyn_dealloc_base_pointer_list: memref<?xindex>,
- /// %dyn_retain_base_pointer_list: memref<?xindex>,
- /// %dyn_cond_list: memref<?xi1>,
- /// %dyn_dealloc_cond_out: memref<?xi1>,
- /// %dyn_ownership_out: memref<?xi1>) {
- /// %c0 = arith.constant 0 : index
- /// %c1 = arith.constant 1 : index
- /// %true = arith.constant true
- /// %false = arith.constant false
- /// %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0
- /// %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0
- /// // Zero initialize result buffer.
- /// scf.for %i = %c0 to %num_retain_memrefs step %c1 {
- /// memref.store %false, %dyn_ownership_out[%i] : memref<?xi1>
- /// }
- /// scf.for %i = %c0 to %num_dealloc_memrefs step %c1 {
- /// %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i]
- /// %cond = memref.load %dyn_cond_list[%i]
- /// // Check for aliasing with retained memrefs.
- /// %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs
- /// step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) {
- /// %retain_bp = memref.load %dyn_retain_base_pointer_list[%j]
- /// %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index
- /// scf.if %does_alias {
- /// %curr_ownership = memref.load %dyn_ownership_out[%j]
- /// %updated_ownership = arith.ori %curr_ownership, %cond : i1
- /// memref.store %updated_ownership, %dyn_ownership_out[%j]
- /// }
- /// %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index
- /// %updated_aggregate = arith.andi %does_not_alias_aggregated,
- /// %does_not_alias : i1
- /// scf.yield %updated_aggregate : i1
- /// }
- /// // Check for aliasing with dealloc memrefs in the list before the
- /// // current one, i.e.,
- /// // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j],
- /// // %dyn_dealloc_base_pointer[i])`
- /// %does_not_alias_any = scf.for %j = %c0 to %i step %c1
- /// iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) {
- /// %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j]
- /// %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp
- /// %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias
- /// scf.yield %updated_alias_agg : i1
- /// }
- /// %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1
- /// memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref<?xi1>
- /// }
- /// return
- /// }
- /// ```
- static func::FuncOp
- buildDeallocationHelperFunction(OpBuilder &builder, Location loc,
- SymbolTable &symbolTable) {
- Type indexMemrefType =
- MemRefType::get({ShapedType::kDynamic}, builder.getIndexType());
- Type boolMemrefType =
- MemRefType::get({ShapedType::kDynamic}, builder.getI1Type());
- SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType,
- boolMemrefType, boolMemrefType};
- builder.clearInsertionPoint();
-
- // Generate the func operation itself.
- auto helperFuncOp = func::FuncOp::create(
- loc, "dealloc_helper", builder.getFunctionType(argTypes, {}));
- symbolTable.insert(helperFuncOp);
- auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
- block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc));
-
- builder.setInsertionPointToStart(&block);
- Value toDeallocMemref = helperFuncOp.getArguments()[0];
- Value toRetainMemref = helperFuncOp.getArguments()[1];
- Value conditionMemref = helperFuncOp.getArguments()[2];
- Value deallocCondsMemref = helperFuncOp.getArguments()[3];
- Value retainCondsMemref = helperFuncOp.getArguments()[4];
-
- // Insert some prerequisites.
- Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0));
- Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1));
- Value trueValue =
- builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true));
- Value falseValue =
- builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false));
- Value toDeallocSize =
- builder.create<memref::DimOp>(loc, toDeallocMemref, c0);
- Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0);
-
- builder.create<scf::ForOp>(
- loc, c0, toRetainSize, c1, std::nullopt,
- [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
- builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref,
- i);
- builder.create<scf::YieldOp>(loc);
- });
-
- builder.create<scf::ForOp>(
- loc, c0, toDeallocSize, c1, std::nullopt,
- [&](OpBuilder &builder, Location loc, Value outerIter,
- ValueRange iterArgs) {
- Value toDealloc =
- builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
- Value cond =
- builder.create<memref::LoadOp>(loc, conditionMemref, outerIter);
-
- // Build the first for loop that computes aliasing with retained
- // memrefs.
- Value noRetainAlias =
- builder
- .create<scf::ForOp>(
- loc, c0, toRetainSize, c1, trueValue,
- [&](OpBuilder &builder, Location loc, Value i,
- ValueRange iterArgs) {
- Value retainValue = builder.create<memref::LoadOp>(
- loc, toRetainMemref, i);
- Value doesAlias = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, retainValue,
- toDealloc);
- builder.create<scf::IfOp>(
- loc, doesAlias,
- [&](OpBuilder &builder, Location loc) {
- Value retainCondValue =
- builder.create<memref::LoadOp>(
- loc, retainCondsMemref, i);
- Value aggregatedRetainCond =
- builder.create<arith::OrIOp>(
- loc, retainCondValue, cond);
- builder.create<memref::StoreOp>(
- loc, aggregatedRetainCond, retainCondsMemref,
- i);
- builder.create<scf::YieldOp>(loc);
- });
- Value doesntAlias = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, retainValue,
- toDealloc);
- Value yieldValue = builder.create<arith::AndIOp>(
- loc, iterArgs[0], doesntAlias);
- builder.create<scf::YieldOp>(loc, yieldValue);
- })
- .getResult(0);
-
- // Build the second for loop that adds aliasing with previously
- // deallocated memrefs.
- Value noAlias =
- builder
- .create<scf::ForOp>(
- loc, c0, outerIter, c1, noRetainAlias,
- [&](OpBuilder &builder, Location loc, Value i,
- ValueRange iterArgs) {
- Value prevDeallocValue = builder.create<memref::LoadOp>(
- loc, toDeallocMemref, i);
- Value doesntAlias = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, prevDeallocValue,
- toDealloc);
- Value yieldValue = builder.create<arith::AndIOp>(
- loc, iterArgs[0], doesntAlias);
- builder.create<scf::YieldOp>(loc, yieldValue);
- })
- .getResult(0);
-
- Value shouldDealoc =
- builder.create<arith::AndIOp>(loc, noAlias, cond);
- builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
- outerIter);
- builder.create<scf::YieldOp>(loc);
- });
-
- builder.create<func::ReturnOp>(loc);
- return helperFuncOp;
- }
-
-private:
- func::FuncOp deallocHelperFunc;
-};
} // namespace
namespace {
@@ -641,7 +105,7 @@ struct BufferizationToMemRefPass
// Build dealloc helper function if there are deallocs.
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
if (deallocOp.getMemrefs().size() > 1) {
- helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction(
+ helperFuncOp = bufferization::buildDeallocationLibraryFunction(
builder, getOperation()->getLoc(), symbolTable);
return WalkResult::interrupt();
}
@@ -651,7 +115,8 @@ struct BufferizationToMemRefPass
RewritePatternSet patterns(&getContext());
patterns.add<CloneOpConversion>(patterns.getContext());
- patterns.add<DeallocOpConversion>(patterns.getContext(), helperFuncOp);
+ bufferization::populateBufferizationDeallocLoweringPattern(patterns,
+ helperFuncOp);
ConversionTarget target(getContext());
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
index 28d9f3b6e6ac32..fc32d3dcf2cb64 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
+++ b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_conversion_library(MLIRBufferizationToMemRef
LINK_LIBS PUBLIC
MLIRBufferizationDialect
+ MLIRBufferizationTransforms
MLIRSCFDialect
MLIRFuncDialect
MLIRArithDialect
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 4c6731f6aec117..af3cc98274dbb5 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
EmptyTensorElimination.cpp
EmptyTensorToAllocTensor.cpp
FuncBufferizableOpInterfaceImpl.cpp
+ LowerDeallocations.cpp
OneShotAnalysis.cpp
OneShotModuleBufferize.cpp
TensorCopyInsertion.cpp
@@ -22,6 +23,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
MLIRBufferizationEnumsIncGen
LINK_LIBS PUBLIC
+ MLIRArithDialect
MLIRBufferizationDialect
MLIRControlFlowInterfaces
MLIRFuncDialect
@@ -30,8 +32,10 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
MLIRMemRefDialect
MLIRPass
MLIRTensorDialect
+ MLIRSCFDialect
MLIRSideEffectInterfaces
MLIRTransforms
MLIRViewLikeInterface
+ MLIRSupport
)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
new file mode 100644
index 00000000000000..c7052434f2c3be
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
@@ -0,0 +1,544 @@
+//===- LowerDeallocations.cpp - Bufferization Deallocs to MemRef pass -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert `bufferization.dealloc` operations
+// to the MemRef dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_LOWERDEALLOCATIONS
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+/// The DeallocOpConversion transforms all bufferization dealloc operations into
+/// memref dealloc operations potentially guarded by scf if operations.
+/// Additionally, memref extract_aligned_pointer_as_index and arith operations
+/// are inserted to compute the guard conditions. We distinguish multiple cases
+/// to provide an overall more efficient lowering. In the general case, a helper
+/// func is created to avoid quadratic code size explosion (relative to the
+/// number of operands of the dealloc operation). For examples of each case,
+/// refer to the documentation of the member functions of this class.
+class DeallocOpConversion
+ : public OpConversionPattern<bufferization::DeallocOp> {
+
+ /// Lower a simple case without any retained values and a single memref to
+ /// avoiding the helper function. Ideally, static analysis can provide enough
+ /// aliasing information to split the dealloc operations up into this simple
+ /// case as much as possible before running this pass.
+ ///
+ /// Example:
+ /// ```
+ /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+ /// ```
+ /// is lowered to
+ /// ```
+ /// scf.if %arg1 {
+ /// memref.dealloc %arg0 : memref<2xf32>
+ /// }
+ /// ```
+ LogicalResult
+ rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
+ assert(adaptor.getRetained().empty() && "expected no retained memrefs");
+
+ rewriter.replaceOpWithNewOp<scf::IfOp>(
+ op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) {
+ builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
+ builder.create<scf::YieldOp>(loc);
+ });
+ return success();
+ }
+
+ /// A special case lowering for the deallocation operation with exactly one
+ /// memref, but arbitrary number of retained values. This avoids the helper
+ /// function that the general case needs and thus also avoids storing indices
+ /// to specifically allocated memrefs. The size of the code produced by this
+ /// lowering is linear to the number of retained values.
+ ///
+ /// Example:
+ /// ```mlir
+ /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond)
+ // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
+ /// return %0#0, %0#1 : i1, i1
+ /// ```
+ /// ```mlir
+ /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m
+ /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
+ /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer
+ /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
+ /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer
+ /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1
+ /// %should_dealloc = arith.andi %not_retained, %cond : i1
+ /// scf.if %should_dealloc {
+ /// memref.dealloc %m : memref<2xf32>
+ /// }
+ /// %true = arith.constant true
+ /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1
+ /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1
+ /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1
+ /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1
+ /// return %r0_ownership, %r1_ownership : i1, i1
+ /// ```
+ LogicalResult rewriteOneMemrefMultipleRetainCase(
+ bufferization::DeallocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
+
+ // Compute the base pointer indices, compare all retained indices to the
+ // memref index to check if they alias.
+ SmallVector<Value> doesNotAliasList;
+ Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(
+ op->getLoc(), adaptor.getMemrefs()[0]);
+ for (Value retained : adaptor.getRetained()) {
+ Value retainedAsIdx =
+ rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(),
+ retained);
+ Value doesNotAlias = rewriter.create<arith::CmpIOp>(
+ op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
+ doesNotAliasList.push_back(doesNotAlias);
+ }
+
+ // AND-reduce the list of booleans from above.
+ Value prev = doesNotAliasList.front();
+ for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front())
+ prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias);
+
+ // Also consider the condition given by the dealloc operation and perform a
+ // conditional deallocation guarded by that value.
+ Value shouldDealloc = rewriter.create<arith::AndIOp>(
+ op->getLoc(), prev, adaptor.getConditions()[0]);
+
+ rewriter.create<scf::IfOp>(
+ op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
+ builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
+ builder.create<scf::YieldOp>(loc);
+ });
+
+ // Compute the replacement values for the dealloc operation results. This
+ // inserts an already canonicalized form of
+ // `select(does_alias_with_memref(r), memref_cond, false)` for each retained
+ // value r.
+ SmallVector<Value> replacements;
+ Value trueVal = rewriter.create<arith::ConstantOp>(
+ op->getLoc(), rewriter.getBoolAttr(true));
+ for (Value doesNotAlias : doesNotAliasList) {
+ Value aliases =
+ rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal);
+ Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases,
+ adaptor.getConditions()[0]);
+ replacements.push_back(result);
+ }
+
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+ }
+
+ /// Lowering that supports all features the dealloc operation has to offer. It
+ /// computes the base pointer of each memref (as an index), stores it in a
+ /// new memref helper structure and passes it to the helper function generated
+ /// in 'buildDeallocationHelperFunction'. The results are stored in two lists
+ /// (represented as memrefs) of booleans passed as arguments. The first list
+ /// stores whether the corresponding condition should be deallocated, the
+ /// second list stores the ownership of the retained values which can be used
+ /// to replace the result values of the `bufferization.dealloc` operation.
+ ///
+ /// Example:
+ /// ```
+ /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>)
+ /// if (%cond0, %cond1)
+ /// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
+ /// ```
+ /// lowers to (simplified):
+ /// ```
+ /// %c0 = arith.constant 0 : index
+ /// %c1 = arith.constant 1 : index
+ /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex>
+ /// %cond_list = memref.alloc() : memref<2xi1>
+ /// %retain_base_pointer_list = memref.alloc() : memref<2xindex>
+ /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0
+ /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0]
+ /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1
+ /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1]
+ /// memref.store %cond0, %cond_list[%c0]
+ /// memref.store %cond1, %cond_list[%c1]
+ /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
+ /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0]
+ /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
+ /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1]
+ /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list :
+ /// memref<2xindex> to memref<?xindex>
+ /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1>
+ /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list :
+ /// memref<2xindex> to memref<?xindex>
+ /// %dealloc_cond_out = memref.alloc() : memref<2xi1>
+ /// %ownership_out = memref.alloc() : memref<2xi1>
+ /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out :
+ /// memref<2xi1> to memref<?xi1>
+ /// %dyn_ownership_out = memref.cast %ownership_out :
+ /// memref<2xi1> to memref<?xi1>
+ /// call @dealloc_helper(%dyn_dealloc_base_pointer_list,
+ /// %dyn_retain_base_pointer_list,
+ /// %dyn_cond_list,
+ /// %dyn_dealloc_cond_out,
+ /// %dyn_ownership_out) : (...)
+ /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1>
+ /// scf.if %m0_dealloc_cond {
+ /// memref.dealloc %m0 : memref<2xf32>
+ /// }
+ /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1>
+ /// scf.if %m1_dealloc_cond {
+ /// memref.dealloc %m1 : memref<5xf32>
+ /// }
+ /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1>
+ /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1>
+ /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex>
+ /// memref.dealloc %retain_base_pointer_list : memref<2xindex>
+ /// memref.dealloc %cond_list : memref<2xi1>
+ /// memref.dealloc %dealloc_cond_out : memref<2xi1>
+ /// memref.dealloc %ownership_out : memref<2xi1>
+ /// // replace %0#0 with %r0_ownership
+ /// // replace %0#1 with %r1_ownership
+ /// ```
+ LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Allocate two memrefs holding the base pointer indices of the list of
+ // memrefs to be deallocated and the ones to be retained. These can then be
+ // passed to the helper function and the for-loops can iterate over them.
+ // Without storing them to memrefs, we could not use for-loops but only a
+ // completely unrolled version of it, potentially leading to code-size
+ // blow-up.
+ Value toDeallocMemref = rewriter.create<memref::AllocOp>(
+ op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
+ rewriter.getIndexType()));
+ Value conditionMemref = rewriter.create<memref::AllocOp>(
+ op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()},
+ rewriter.getI1Type()));
+ Value toRetainMemref = rewriter.create<memref::AllocOp>(
+ op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
+ rewriter.getIndexType()));
+
+ auto getConstValue = [&](uint64_t value) -> Value {
+ return rewriter.create<arith::ConstantOp>(op.getLoc(),
+ rewriter.getIndexAttr(value));
+ };
+
+ // Extract the base pointers of the memrefs as indices to check for aliasing
+ // at runtime.
+ for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) {
+ Value memrefAsIdx =
+ rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
+ toDealloc);
+ rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
+ toDeallocMemref, getConstValue(i));
+ }
+
+ for (auto [i, cond] : llvm::enumerate(adaptor.getConditions()))
+ rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref,
+ getConstValue(i));
+
+ for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) {
+ Value memrefAsIdx =
+ rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
+ toRetain);
+ rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref,
+ getConstValue(i));
+ }
+
+ // Cast the allocated memrefs to dynamic shape because we want only one
+ // helper function no matter how many operands the bufferization.dealloc
+ // has.
+ Value castedDeallocMemref = rewriter.create<memref::CastOp>(
+ op->getLoc(),
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
+ toDeallocMemref);
+ Value castedCondsMemref = rewriter.create<memref::CastOp>(
+ op->getLoc(),
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
+ conditionMemref);
+ Value castedRetainMemref = rewriter.create<memref::CastOp>(
+ op->getLoc(),
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
+ toRetainMemref);
+
+ Value deallocCondsMemref = rewriter.create<memref::AllocOp>(
+ op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
+ rewriter.getI1Type()));
+ Value retainCondsMemref = rewriter.create<memref::AllocOp>(
+ op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
+ rewriter.getI1Type()));
+
+ Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>(
+ op->getLoc(),
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
+ deallocCondsMemref);
+ Value castedRetainCondsMemref = rewriter.create<memref::CastOp>(
+ op->getLoc(),
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
+ retainCondsMemref);
+
+ rewriter.create<func::CallOp>(
+ op.getLoc(), deallocHelperFunc,
+ SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
+ castedCondsMemref, castedDeallocCondsMemref,
+ castedRetainCondsMemref});
+
+ for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
+ Value idxValue = getConstValue(i);
+ Value shouldDealloc = rewriter.create<memref::LoadOp>(
+ op.getLoc(), deallocCondsMemref, idxValue);
+ rewriter.create<scf::IfOp>(
+ op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
+ builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
+ builder.create<scf::YieldOp>(loc);
+ });
+ }
+
+ SmallVector<Value> replacements;
+ for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
+ Value idxValue = getConstValue(i);
+ Value ownership = rewriter.create<memref::LoadOp>(
+ op.getLoc(), retainCondsMemref, idxValue);
+ replacements.push_back(ownership);
+ }
+
+ // Deallocate above allocated memrefs again to avoid memory leaks.
+ // Deallocation will not be run on code after this stage.
+ rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
+ rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
+ rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref);
+ rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref);
+ rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref);
+
+ rewriter.replaceOp(op, replacements);
+ return success();
+ }
+
+public:
+ DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
+ : OpConversionPattern<bufferization::DeallocOp>(context),
+ deallocHelperFunc(deallocHelperFunc) {}
+
+ LogicalResult
+ matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Lower the trivial case.
+ if (adaptor.getMemrefs().empty()) {
+ Value falseVal = rewriter.create<arith::ConstantOp>(
+ op.getLoc(), rewriter.getBoolAttr(false));
+ rewriter.replaceOp(
+ op, SmallVector<Value>(adaptor.getRetained().size(), falseVal));
+ return success();
+ }
+
+ if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
+ return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
+
+ if (adaptor.getMemrefs().size() == 1)
+ return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
+
+ if (!deallocHelperFunc)
+ return op->emitError(
+ "library function required for generic lowering, but cannot be "
+ "automatically inserted when operating on functions");
+
+ return rewriteGeneralCase(op, adaptor, rewriter);
+ }
+
+private:
+ func::FuncOp deallocHelperFunc;
+};
+} // namespace
+
+namespace {
+struct LowerDeallocationsPass
+ : public bufferization::impl::LowerDeallocationsBase<
+ LowerDeallocationsPass> {
+ void runOnOperation() override {
+ if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
+ emitError(getOperation()->getLoc(),
+ "root operation must be a builtin.module or a function");
+ signalPassFailure();
+ return;
+ }
+
+ func::FuncOp helperFuncOp;
+ if (auto module = dyn_cast<ModuleOp>(getOperation())) {
+ OpBuilder builder =
+ OpBuilder::atBlockBegin(&module.getBodyRegion().front());
+ SymbolTable symbolTable(module);
+
+ // Build dealloc helper function if there are deallocs.
+ getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
+ if (deallocOp.getMemrefs().size() > 1) {
+ helperFuncOp = bufferization::buildDeallocationLibraryFunction(
+ builder, getOperation()->getLoc(), symbolTable);
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ }
+
+ RewritePatternSet patterns(&getContext());
+ bufferization::populateBufferizationDeallocLoweringPattern(patterns,
+ helperFuncOp);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
+ scf::SCFDialect, func::FuncDialect>();
+ target.addIllegalOp<bufferization::DeallocOp>();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
+ OpBuilder &builder, Location loc, SymbolTable &symbolTable) {
+ Type indexMemrefType =
+ MemRefType::get({ShapedType::kDynamic}, builder.getIndexType());
+ Type boolMemrefType =
+ MemRefType::get({ShapedType::kDynamic}, builder.getI1Type());
+ SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType,
+ boolMemrefType, boolMemrefType};
+ builder.clearInsertionPoint();
+
+ // Generate the func operation itself.
+ auto helperFuncOp = func::FuncOp::create(
+ loc, "dealloc_helper", builder.getFunctionType(argTypes, {}));
+ symbolTable.insert(helperFuncOp);
+ auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
+ block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc));
+
+ builder.setInsertionPointToStart(&block);
+ Value toDeallocMemref = helperFuncOp.getArguments()[0];
+ Value toRetainMemref = helperFuncOp.getArguments()[1];
+ Value conditionMemref = helperFuncOp.getArguments()[2];
+ Value deallocCondsMemref = helperFuncOp.getArguments()[3];
+ Value retainCondsMemref = helperFuncOp.getArguments()[4];
+
+ // Insert some prerequisites.
+ Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0));
+ Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1));
+ Value trueValue =
+ builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true));
+ Value falseValue =
+ builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false));
+ Value toDeallocSize = builder.create<memref::DimOp>(loc, toDeallocMemref, c0);
+ Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0);
+
+ builder.create<scf::ForOp>(
+ loc, c0, toRetainSize, c1, std::nullopt,
+ [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
+ builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i);
+ builder.create<scf::YieldOp>(loc);
+ });
+
+ builder.create<scf::ForOp>(
+ loc, c0, toDeallocSize, c1, std::nullopt,
+ [&](OpBuilder &builder, Location loc, Value outerIter,
+ ValueRange iterArgs) {
+ Value toDealloc =
+ builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
+ Value cond =
+ builder.create<memref::LoadOp>(loc, conditionMemref, outerIter);
+
+ // Build the first for loop that computes aliasing with retained
+ // memrefs.
+ Value noRetainAlias =
+ builder
+ .create<scf::ForOp>(
+ loc, c0, toRetainSize, c1, trueValue,
+ [&](OpBuilder &builder, Location loc, Value i,
+ ValueRange iterArgs) {
+ Value retainValue = builder.create<memref::LoadOp>(
+ loc, toRetainMemref, i);
+ Value doesAlias = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, retainValue,
+ toDealloc);
+ builder.create<scf::IfOp>(
+ loc, doesAlias,
+ [&](OpBuilder &builder, Location loc) {
+ Value retainCondValue =
+ builder.create<memref::LoadOp>(
+ loc, retainCondsMemref, i);
+ Value aggregatedRetainCond =
+ builder.create<arith::OrIOp>(
+ loc, retainCondValue, cond);
+ builder.create<memref::StoreOp>(
+ loc, aggregatedRetainCond, retainCondsMemref,
+ i);
+ builder.create<scf::YieldOp>(loc);
+ });
+ Value doesntAlias = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, retainValue,
+ toDealloc);
+ Value yieldValue = builder.create<arith::AndIOp>(
+ loc, iterArgs[0], doesntAlias);
+ builder.create<scf::YieldOp>(loc, yieldValue);
+ })
+ .getResult(0);
+
+ // Build the second for loop that adds aliasing with previously
+ // deallocated memrefs.
+ Value noAlias =
+ builder
+ .create<scf::ForOp>(
+ loc, c0, outerIter, c1, noRetainAlias,
+ [&](OpBuilder &builder, Location loc, Value i,
+ ValueRange iterArgs) {
+ Value prevDeallocValue = builder.create<memref::LoadOp>(
+ loc, toDeallocMemref, i);
+ Value doesntAlias = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, prevDeallocValue,
+ toDealloc);
+ Value yieldValue = builder.create<arith::AndIOp>(
+ loc, iterArgs[0], doesntAlias);
+ builder.create<scf::YieldOp>(loc, yieldValue);
+ })
+ .getResult(0);
+
+ Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond);
+ builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
+ outerIter);
+ builder.create<scf::YieldOp>(loc);
+ });
+
+ builder.create<func::ReturnOp>(loc);
+ return helperFuncOp;
+}
+
+void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
+ RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
+ patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc);
+}
+
+std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {
+ return std::make_unique<LowerDeallocationsPass>();
+}
diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
index 95deb165741973..1eb387ce0e5b77 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -68,26 +68,7 @@ func.func @conversion_with_invalid_layout_map(%arg0 : memref<?xf32, strided<[10]
}
// -----
-
-// CHECK-LABEL: func @conversion_dealloc_empty
-func.func @conversion_dealloc_empty() {
- // CHECK-NOT: bufferization.dealloc
- bufferization.dealloc
- return
-}
-
-// -----
-
-func.func @conversion_dealloc_empty_but_retains(%arg0: memref<2xi32>, %arg1: memref<2xi32>) -> (i1, i1) {
- %0:2 = bufferization.dealloc retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
- return %0#0, %0#1 : i1, i1
-}
-
-// CHECK-LABEL: func @conversion_dealloc_empty
-// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
-// CHECK-NEXT: return [[FALSE]], [[FALSE]] :
-
-// -----
+// Test: check that the dealloc lowering pattern is registered.
// CHECK-NOT: func @deallocHelper
// CHECK-LABEL: func @conversion_dealloc_simple
@@ -102,124 +83,3 @@ func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) {
// CHECk-NEXT: memref.dealloc [[ARG0]] : memref<2xf32>
// CHECk-NEXT: }
// CHECk-NEXT: return
-
-// -----
-
-func.func @conversion_dealloc_one_memref_and_multiple_retained(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
- %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
- return %0#0, %0#1 : i1, i1
-}
-
-// CHECK-LABEL: func @conversion_dealloc_one_memref_and_multiple_retained
-// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<1xf32>, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xf32>)
-// CHECK-DAG: [[M0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
-// CHECK-DAG: [[R0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]]
-// CHECK-DAG: [[R1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG3]]
-// CHECK-DAG: [[DOES_NOT_ALIAS_R0:%.+]] = arith.cmpi ne, [[M0]], [[R0]] : index
-// CHECK-DAG: [[DOES_NOT_ALIAS_R1:%.+]] = arith.cmpi ne, [[M0]], [[R1]] : index
-// CHECK: [[NOT_RETAINED:%.+]] = arith.andi [[DOES_NOT_ALIAS_R0]], [[DOES_NOT_ALIAS_R1]]
-// CHECK: [[SHOULD_DEALLOC:%.+]] = arith.andi [[NOT_RETAINED]], [[ARG2]]
-// CHECK: scf.if [[SHOULD_DEALLOC]]
-// CHECK: memref.dealloc [[ARG0]]
-// CHECK: }
-// CHECK-DAG: [[ALIASES_R0:%.+]] = arith.xori [[DOES_NOT_ALIAS_R0]], %true
-// CHECK-DAG: [[ALIASES_R1:%.+]] = arith.xori [[DOES_NOT_ALIAS_R1]], %true
-// CHECK-DAG: [[RES0:%.+]] = arith.andi [[ALIASES_R0]], [[ARG2]]
-// CHECK-DAG: [[RES1:%.+]] = arith.andi [[ALIASES_R1]], [[ARG2]]
-// CHECK: return [[RES0]], [[RES1]]
-
-// CHECK-NOT: func @dealloc_helper
-
-// -----
-
-func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
- %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
- return %0#0, %0#1 : i1, i1
-}
-
-// CHECK-LABEL: func @conversion_dealloc_multiple_memrefs_and_retained
-// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<5xf32>,
-// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1,
-// CHECK-SAME: [[ARG5:%.+]]: memref<2xf32>)
-// CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloc() : memref<2xindex>
-// CHECK: [[CONDS:%.+]] = memref.alloc() : memref<2xi1>
-// CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<2xindex>
-// CHECK-DAG: [[V0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
-// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
-// CHECK-DAG: memref.store [[V0]], [[TO_DEALLOC_MR]][[[C0]]]
-// CHECK-DAG: [[V1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]]
-// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
-// CHECK-DAG: memref.store [[V1]], [[TO_DEALLOC_MR]][[[C1]]]
-// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
-// CHECK-DAG: memref.store [[ARG3]], [[CONDS]][[[C0]]]
-// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
-// CHECK-DAG: memref.store [[ARG4]], [[CONDS]][[[C1]]]
-// CHECK-DAG: [[V2:%.+]] = memref.extract_aligned_pointer_as_index [[ARG2]]
-// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
-// CHECK-DAG: memref.store [[V2]], [[TO_RETAIN_MR]][[[C0]]]
-// CHECK-DAG: [[V3:%.+]] = memref.extract_aligned_pointer_as_index [[ARG5]]
-// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
-// CHECK-DAG: memref.store [[V3]], [[TO_RETAIN_MR]][[[C1]]]
-// CHECK-DAG: [[CAST_DEALLOC:%.+]] = memref.cast [[TO_DEALLOC_MR]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: [[CAST_CONDS:%.+]] = memref.cast [[CONDS]] : memref<2xi1> to memref<?xi1>
-// CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<2xindex> to memref<?xindex>
-// CHECK: [[DEALLOC_CONDS:%.+]] = memref.alloc() : memref<2xi1>
-// CHECK: [[RETAIN_CONDS:%.+]] = memref.alloc() : memref<2xi1>
-// CHECK: [[CAST_DEALLOC_CONDS:%.+]] = memref.cast [[DEALLOC_CONDS]] : memref<2xi1> to memref<?xi1>
-// CHECK: [[CAST_RETAIN_CONDS:%.+]] = memref.cast [[RETAIN_CONDS]] : memref<2xi1> to memref<?xi1>
-// CHECK: call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[CAST_CONDS]], [[CAST_DEALLOC_CONDS]], [[CAST_RETAIN_CONDS]])
-// CHECK: [[C0:%.+]] = arith.constant 0 : index
-// CHECK: [[SHOULD_DEALLOC_0:%.+]] = memref.load [[DEALLOC_CONDS]][[[C0]]]
-// CHECK: scf.if [[SHOULD_DEALLOC_0]] {
-// CHECK: memref.dealloc %arg0
-// CHECK: }
-// CHECK: [[C1:%.+]] = arith.constant 1 : index
-// CHECK: [[SHOULD_DEALLOC_1:%.+]] = memref.load [[DEALLOC_CONDS]][[[C1]]]
-// CHECK: scf.if [[SHOULD_DEALLOC_1]]
-// CHECK: memref.dealloc [[ARG1]]
-// CHECK: }
-// CHECK: [[C0:%.+]] = arith.constant 0 : index
-// CHECK: [[OWNERSHIP0:%.+]] = memref.load [[RETAIN_CONDS]][[[C0]]]
-// CHECK: [[C1:%.+]] = arith.constant 1 : index
-// CHECK: [[OWNERSHIP1:%.+]] = memref.load [[RETAIN_CONDS]][[[C1]]]
-// CHECK: memref.dealloc [[TO_DEALLOC_MR]]
-// CHECK: memref.dealloc [[TO_RETAIN_MR]]
-// CHECK: memref.dealloc [[CONDS]]
-// CHECK: memref.dealloc [[DEALLOC_CONDS]]
-// CHECK: memref.dealloc [[RETAIN_CONDS]]
-// CHECK: return [[OWNERSHIP0]], [[OWNERSHIP1]]
-
-// CHECK: func @dealloc_helper
-// CHECK-SAME: ([[TO_DEALLOC_MR:%.+]]: memref<?xindex>, [[TO_RETAIN_MR:%.+]]: memref<?xindex>,
-// CHECK-SAME: [[CONDS:%.+]]: memref<?xi1>, [[DEALLOC_CONDS_OUT:%.+]]: memref<?xi1>,
-// CHECK-SAME: [[RETAIN_CONDS_OUT:%.+]]: memref<?xi1>)
-// CHECK: [[TO_DEALLOC_SIZE:%.+]] = memref.dim [[TO_DEALLOC_MR]], %c0
-// CHECK: [[TO_RETAIN_SIZE:%.+]] = memref.dim [[TO_RETAIN_MR]], %c0
-// CHECK: scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 {
-// CHECK-NEXT: memref.store %false, [[RETAIN_CONDS_OUT]][[[ITER]]]
-// CHECK-NEXT: }
-// CHECK: scf.for [[OUTER_ITER:%.+]] = %c0 to [[TO_DEALLOC_SIZE]] step %c1 {
-// CHECK: [[TO_DEALLOC:%.+]] = memref.load [[TO_DEALLOC_MR]][[[OUTER_ITER]]]
-// CHECK-NEXT: [[COND:%.+]] = memref.load [[CONDS]][[[OUTER_ITER]]]
-// CHECK-NEXT: [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) {
-// CHECK-NEXT: [[RETAIN_VAL:%.+]] = memref.load [[TO_RETAIN_MR]][[[ITER]]] : memref<?xindex>
-// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi eq, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
-// CHECK-NEXT: scf.if [[DOES_ALIAS]]
-// CHECK-NEXT: [[RETAIN_COND:%.+]] = memref.load [[RETAIN_CONDS_OUT]][[[ITER]]]
-// CHECK-NEXT: [[AGG_RETAIN_COND:%.+]] = arith.ori [[RETAIN_COND]], [[COND]] : i1
-// CHECK-NEXT: memref.store [[AGG_RETAIN_COND]], [[RETAIN_CONDS_OUT]][[[ITER]]]
-// CHECK-NEXT: }
-// CHECK-NEXT: [[DOES_NOT_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
-// CHECK-NEXT: [[AGG_DOES_NOT_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_NOT_ALIAS]] : i1
-// CHECK-NEXT: scf.yield [[AGG_DOES_NOT_ALIAS]] : i1
-// CHECK-NEXT: }
-// CHECK-NEXT: [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[OUTER_ITER]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) {
-// CHECK-NEXT: [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref<?xindex>
-// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index
-// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
-// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1
-// CHECK-NEXT: }
-// CHECK-NEXT: [[DEALLOC_COND:%.+]] = arith.andi [[SHOULD_DEALLOC]], [[COND]] : i1
-// CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
-// CHECK-NEXT: }
-// CHECK-NEXT: return
diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations-func.mlir
similarity index 92%
rename from mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir
rename to mlir/test/Dialect/Bufferization/Transforms/lower-deallocations-func.mlir
index a6dc2c76184cdb..03cf10aa0c05bc 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations-func.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -verify-diagnostics --pass-pipeline="builtin.module(func.func(convert-bufferization-to-memref))" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics --pass-pipeline="builtin.module(func.func(bufferization-lower-deallocations))" -split-input-file %s | FileCheck %s
// CHECK-NOT: func @deallocHelper
// CHECK-LABEL: func @conversion_dealloc_simple
diff --git a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
new file mode 100644
index 00000000000000..19d3bbf7089c10
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
@@ -0,0 +1,156 @@
+// RUN: mlir-opt -verify-diagnostics -bufferization-lower-deallocations -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @conversion_dealloc_empty
+func.func @conversion_dealloc_empty() {
+ // CHECK-NOT: bufferization.dealloc
+ bufferization.dealloc
+ return
+}
+
+// -----
+
+func.func @conversion_dealloc_empty_but_retains(%arg0: memref<2xi32>, %arg1: memref<2xi32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
+ return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @conversion_dealloc_empty
+// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
+// CHECK-NEXT: return [[FALSE]], [[FALSE]] :
+
+// -----
+
+// CHECK-NOT: func @deallocHelper
+// CHECK-LABEL: func @conversion_dealloc_simple
+// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>
+// CHECK-SAME: [[ARG1:%.+]]: i1
+func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) {
+ bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+ return
+}
+
+// CHECk: scf.if [[ARG1]] {
+// CHECk-NEXT: memref.dealloc [[ARG0]] : memref<2xf32>
+// CHECk-NEXT: }
+// CHECk-NEXT: return
+
+// -----
+
+func.func @conversion_dealloc_one_memref_and_multiple_retained(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
+ return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @conversion_dealloc_one_memref_and_multiple_retained
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<1xf32>, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xf32>)
+// CHECK-DAG: [[M0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
+// CHECK-DAG: [[R0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]]
+// CHECK-DAG: [[R1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG3]]
+// CHECK-DAG: [[DOES_NOT_ALIAS_R0:%.+]] = arith.cmpi ne, [[M0]], [[R0]] : index
+// CHECK-DAG: [[DOES_NOT_ALIAS_R1:%.+]] = arith.cmpi ne, [[M0]], [[R1]] : index
+// CHECK: [[NOT_RETAINED:%.+]] = arith.andi [[DOES_NOT_ALIAS_R0]], [[DOES_NOT_ALIAS_R1]]
+// CHECK: [[SHOULD_DEALLOC:%.+]] = arith.andi [[NOT_RETAINED]], [[ARG2]]
+// CHECK: scf.if [[SHOULD_DEALLOC]]
+// CHECK: memref.dealloc [[ARG0]]
+// CHECK: }
+// CHECK-DAG: [[ALIASES_R0:%.+]] = arith.xori [[DOES_NOT_ALIAS_R0]], %true
+// CHECK-DAG: [[ALIASES_R1:%.+]] = arith.xori [[DOES_NOT_ALIAS_R1]], %true
+// CHECK-DAG: [[RES0:%.+]] = arith.andi [[ALIASES_R0]], [[ARG2]]
+// CHECK-DAG: [[RES1:%.+]] = arith.andi [[ALIASES_R1]], [[ARG2]]
+// CHECK: return [[RES0]], [[RES1]]
+
+// CHECK-NOT: func @dealloc_helper
+
+// -----
+
+func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+ return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @conversion_dealloc_multiple_memrefs_and_retained
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<5xf32>,
+// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1,
+// CHECK-SAME: [[ARG5:%.+]]: memref<2xf32>)
+// CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloc() : memref<2xindex>
+// CHECK: [[CONDS:%.+]] = memref.alloc() : memref<2xi1>
+// CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<2xindex>
+// CHECK-DAG: [[V0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: memref.store [[V0]], [[TO_DEALLOC_MR]][[[C0]]]
+// CHECK-DAG: [[V1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]]
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK-DAG: memref.store [[V1]], [[TO_DEALLOC_MR]][[[C1]]]
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: memref.store [[ARG3]], [[CONDS]][[[C0]]]
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK-DAG: memref.store [[ARG4]], [[CONDS]][[[C1]]]
+// CHECK-DAG: [[V2:%.+]] = memref.extract_aligned_pointer_as_index [[ARG2]]
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: memref.store [[V2]], [[TO_RETAIN_MR]][[[C0]]]
+// CHECK-DAG: [[V3:%.+]] = memref.extract_aligned_pointer_as_index [[ARG5]]
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK-DAG: memref.store [[V3]], [[TO_RETAIN_MR]][[[C1]]]
+// CHECK-DAG: [[CAST_DEALLOC:%.+]] = memref.cast [[TO_DEALLOC_MR]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG: [[CAST_CONDS:%.+]] = memref.cast [[CONDS]] : memref<2xi1> to memref<?xi1>
+// CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<2xindex> to memref<?xindex>
+// CHECK: [[DEALLOC_CONDS:%.+]] = memref.alloc() : memref<2xi1>
+// CHECK: [[RETAIN_CONDS:%.+]] = memref.alloc() : memref<2xi1>
+// CHECK: [[CAST_DEALLOC_CONDS:%.+]] = memref.cast [[DEALLOC_CONDS]] : memref<2xi1> to memref<?xi1>
+// CHECK: [[CAST_RETAIN_CONDS:%.+]] = memref.cast [[RETAIN_CONDS]] : memref<2xi1> to memref<?xi1>
+// CHECK: call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[CAST_CONDS]], [[CAST_DEALLOC_CONDS]], [[CAST_RETAIN_CONDS]])
+// CHECK: [[C0:%.+]] = arith.constant 0 : index
+// CHECK: [[SHOULD_DEALLOC_0:%.+]] = memref.load [[DEALLOC_CONDS]][[[C0]]]
+// CHECK: scf.if [[SHOULD_DEALLOC_0]] {
+// CHECK: memref.dealloc %arg0
+// CHECK: }
+// CHECK: [[C1:%.+]] = arith.constant 1 : index
+// CHECK: [[SHOULD_DEALLOC_1:%.+]] = memref.load [[DEALLOC_CONDS]][[[C1]]]
+// CHECK: scf.if [[SHOULD_DEALLOC_1]]
+// CHECK: memref.dealloc [[ARG1]]
+// CHECK: }
+// CHECK: [[C0:%.+]] = arith.constant 0 : index
+// CHECK: [[OWNERSHIP0:%.+]] = memref.load [[RETAIN_CONDS]][[[C0]]]
+// CHECK: [[C1:%.+]] = arith.constant 1 : index
+// CHECK: [[OWNERSHIP1:%.+]] = memref.load [[RETAIN_CONDS]][[[C1]]]
+// CHECK: memref.dealloc [[TO_DEALLOC_MR]]
+// CHECK: memref.dealloc [[TO_RETAIN_MR]]
+// CHECK: memref.dealloc [[CONDS]]
+// CHECK: memref.dealloc [[DEALLOC_CONDS]]
+// CHECK: memref.dealloc [[RETAIN_CONDS]]
+// CHECK: return [[OWNERSHIP0]], [[OWNERSHIP1]]
+
+// CHECK: func @dealloc_helper
+// CHECK-SAME: ([[TO_DEALLOC_MR:%.+]]: memref<?xindex>, [[TO_RETAIN_MR:%.+]]: memref<?xindex>,
+// CHECK-SAME: [[CONDS:%.+]]: memref<?xi1>, [[DEALLOC_CONDS_OUT:%.+]]: memref<?xi1>,
+// CHECK-SAME: [[RETAIN_CONDS_OUT:%.+]]: memref<?xi1>)
+// CHECK: [[TO_DEALLOC_SIZE:%.+]] = memref.dim [[TO_DEALLOC_MR]], %c0
+// CHECK: [[TO_RETAIN_SIZE:%.+]] = memref.dim [[TO_RETAIN_MR]], %c0
+// CHECK: scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 {
+// CHECK-NEXT: memref.store %false, [[RETAIN_CONDS_OUT]][[[ITER]]]
+// CHECK-NEXT: }
+// CHECK: scf.for [[OUTER_ITER:%.+]] = %c0 to [[TO_DEALLOC_SIZE]] step %c1 {
+// CHECK: [[TO_DEALLOC:%.+]] = memref.load [[TO_DEALLOC_MR]][[[OUTER_ITER]]]
+// CHECK-NEXT: [[COND:%.+]] = memref.load [[CONDS]][[[OUTER_ITER]]]
+// CHECK-NEXT: [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) {
+// CHECK-NEXT: [[RETAIN_VAL:%.+]] = memref.load [[TO_RETAIN_MR]][[[ITER]]] : memref<?xindex>
+// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi eq, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT: scf.if [[DOES_ALIAS]]
+// CHECK-NEXT: [[RETAIN_COND:%.+]] = memref.load [[RETAIN_CONDS_OUT]][[[ITER]]]
+// CHECK-NEXT: [[AGG_RETAIN_COND:%.+]] = arith.ori [[RETAIN_COND]], [[COND]] : i1
+// CHECK-NEXT: memref.store [[AGG_RETAIN_COND]], [[RETAIN_CONDS_OUT]][[[ITER]]]
+// CHECK-NEXT: }
+// CHECK-NEXT: [[DOES_NOT_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT: [[AGG_DOES_NOT_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_NOT_ALIAS]] : i1
+// CHECK-NEXT: scf.yield [[AGG_DOES_NOT_ALIAS]] : i1
+// CHECK-NEXT: }
+// CHECK-NEXT: [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[OUTER_ITER]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) {
+// CHECK-NEXT: [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref<?xindex>
+// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
+// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1
+// CHECK-NEXT: }
+// CHECK-NEXT: [[DEALLOC_COND:%.+]] = arith.andi [[SHOULD_DEALLOC]], [[COND]] : i1
+// CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index b2f9eaa7f8cf2d..c06fbe82e7362b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -11976,7 +11976,9 @@ cc_library(
":MemRefDialect",
":MemRefUtils",
":Pass",
+ ":SCFDialect",
":SideEffectInterfaces",
+ ":Support",
":TensorDialect",
":Transforms",
":ViewLikeInterface",
@@ -11996,6 +11998,7 @@ cc_library(
deps = [
":ArithDialect",
":BufferizationDialect",
+ ":BufferizationTransforms",
":ConversionPassIncGen",
":FuncDialect",
":IR",
More information about the Mlir-commits
mailing list