[Mlir-commits] [mlir] 4bde084 - [mlir][bufferization] Change semantics of DeallocOp result values
Martin Erhart
llvmlistbot at llvm.org
Fri Aug 4 06:56:45 PDT 2023
Author: Martin Erhart
Date: 2023-08-04T13:54:46Z
New Revision: 4bde084f0c1837234a5ddab66c95c55fe4b81599
URL: https://github.com/llvm/llvm-project/commit/4bde084f0c1837234a5ddab66c95c55fe4b81599
DIFF: https://github.com/llvm/llvm-project/commit/4bde084f0c1837234a5ddab66c95c55fe4b81599.diff
LOG: [mlir][bufferization] Change semantics of DeallocOp result values
This change allows supporting operations for which we don't get precise aliasing information without the need to insert clone operations. E.g., `arith.select`.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D156992
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
mlir/test/Dialect/Bufferization/canonicalize.mlir
mlir/test/Dialect/Bufferization/invalid.mlir
mlir/test/Dialect/Bufferization/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 55182f83832b47..a9e7b7215c1b43 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -485,22 +485,41 @@ def Bufferization_DeallocOp : Bufferization_Op<"dealloc", [
deallocating that memref). If two memrefs alias each other, only one will be
deallocated to avoid double free situations.
- The memrefs to be deallocated must be the originally allocated memrefs,
- however, the memrefs to be retained may be arbitrary memrefs.
-
- Returns a list of conditions corresponding to the list of memrefs which
- indicates the new ownerships, i.e., if the memref was deallocated the
- ownership was dropped (set to 'false') and otherwise will be the same as the
- input condition.
+ The number of variadic `memref` operands (the memrefs to be deallocated)
+ must equal the number of variadic `condition` operands and correspond to
+ each other element-wise.
+
+ The `memref` operands must be the originally allocated memrefs, however, the
+ `retained` memref operands may be arbitrary memrefs.
+
+ This operation returns a variadic number of `updatedConditions` operands,
+ one updated condition per retained memref. An updated condition indicates
+ the ownership of the respective retained memref. It is computed as the
+ disjunction of all `conditions` operands where the corresponding to
+ `memrefs` operand aliases with the retained memref. If the retained memref
+ has no aliases among `memrefs`, the resulting updated condition is 'false'.
+ This is because all memrefs that need to be deallocated within one basic
+ block should be added to the same `bufferization.dealloc` operation at the
+ end of the block; if no aliasing memref is present, then it does not have to
+ be deallocated and thus we don't need to claim ownership. If the memrefs to
+ be deallocated are split over multiple dealloc operations (e.g., to avoid
+ aliasing checks at runtime between the `memref` operands), then the results
+ have to be manually combined using an `arith.ori` operation and all of them
+ still require the same list of `retained` memref operands unless the
+ (potentially empty) set of aliasing memrefs can be determined statically. In
+ that case, the `updatedCondition` operand can be replaced accordingly (e.g.,
+ by a canonicalizer).
Example:
```mlir
- %0:2 = bufferization.dealloc %a0, %a1 if %cond0, %cond1 retain %r0, %r1 :
- memref<2xf32>, memref<4xi32> retain memref<?xf32>, memref<f64>
+ %0:3 = bufferization.dealloc (%a0, %a1 : memref<2xf32>, memref<4xi32>)
+ if (%cond0, %cond1) retain (%r0, %r1, %r2 : memref<?xf32>, memref<f64>,
+ memref<2xi32>)
```
- Deallocation will be called on `%a0` if `%cond0` is 'true' and neither `%r0`
- or `%r1` are aliases of `%a0`. `%a1` will be deallocated when `%cond1` is
- set to 'true' and none of `%r0`, %r1` and `%a0` are aliases.
+ Deallocation will be called on `%a0` if `%cond0` is 'true' and neither
+ `%r0`, `%r1`, or `%r2` are aliases of `%a0`. `%a1` will be deallocated when
+ `%cond1` is set to 'true' and none of `%r0`, %r1`, `%r2`, and `%a0` are
+ aliases.
}];
let arguments = (ins Variadic<AnyRankedOrUnrankedMemRef>:$memrefs,
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 886bcfa2f8530d..f998c8ce172a03 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -91,77 +91,100 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
class DeallocOpConversion
: public OpConversionPattern<bufferization::DeallocOp> {
- /// Lower a simple case avoiding the helper function. Ideally, static analysis
- /// can provide enough aliasing information to split the dealloc operations up
- /// into this simple case as much as possible before running this pass.
+ /// Lower a simple case without any retained values and a single memref to
+ /// avoiding the helper function. Ideally, static analysis can provide enough
+ /// aliasing information to split the dealloc operations up into this simple
+ /// case as much as possible before running this pass.
///
/// Example:
/// ```
- /// %0 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+ /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
/// ```
/// is lowered to
/// ```
/// scf.if %arg1 {
/// memref.dealloc %arg0 : memref<2xf32>
/// }
- /// %0 = arith.constant false
/// ```
LogicalResult
rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- rewriter.create<scf::IfOp>(op.getLoc(), adaptor.getConditions()[0],
- [&](OpBuilder &builder, Location loc) {
- builder.create<memref::DeallocOp>(
- loc, adaptor.getMemrefs()[0]);
- builder.create<scf::YieldOp>(loc);
- });
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(op,
- rewriter.getBoolAttr(false));
+ assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
+ assert(adaptor.getRetained().empty() && "expected no retained memrefs");
+
+ rewriter.replaceOpWithNewOp<scf::IfOp>(
+ op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) {
+ builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
+ builder.create<scf::YieldOp>(loc);
+ });
return success();
}
/// Lowering that supports all features the dealloc operation has to offer. It
- /// computes the base pointer of each memref (as an index), stores them in a
- /// new memref and passes it to the helper function generated in
- /// 'buildDeallocationHelperFunction'. The two return values are used as
- /// condition for the scf if operation containing the memref deallocate and as
- /// replacement for the original bufferization dealloc respectively.
+ /// computes the base pointer of each memref (as an index), stores it in a
+ /// new memref helper structure and passes it to the helper function generated
+ /// in 'buildDeallocationHelperFunction'. The results are stored in two lists
+ /// (represented as memrefs) of booleans passed as arguments. The first list
+ /// stores whether the corresponding condition should be deallocated, the
+ /// second list stores the ownership of the retained values which can be used
+ /// to replace the result values of the `bufferization.dealloc` operation.
///
/// Example:
/// ```
- /// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>)
- /// if (%arg3, %arg4) retain (%arg2 : memref<1xf32>)
+ /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>)
+ /// if (%cond0, %cond1)
+ /// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
/// ```
/// lowers to (simplified):
/// ```
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
- /// %alloc = memref.alloc() : memref<2xindex>
- /// %alloc_0 = memref.alloc() : memref<1xindex>
- /// %intptr = memref.extract_aligned_pointer_as_index %arg0
- /// memref.store %intptr, %alloc[%c0] : memref<2xindex>
- /// %intptr_1 = memref.extract_aligned_pointer_as_index %arg1
- /// memref.store %intptr_1, %alloc[%c1] : memref<2xindex>
- /// %intptr_2 = memref.extract_aligned_pointer_as_index %arg2
- /// memref.store %intptr_2, %alloc_0[%c0] : memref<1xindex>
- /// %cast = memref.cast %alloc : memref<2xindex> to memref<?xindex>
- /// %cast_4 = memref.cast %alloc_0 : memref<1xindex> to memref<?xindex>
- /// %0:2 = call @dealloc_helper(%cast, %cast_4, %c0)
- /// %1 = arith.andi %0#0, %arg3 : i1
- /// %2 = arith.andi %0#1, %arg3 : i1
- /// scf.if %1 {
- /// memref.dealloc %arg0 : memref<2xf32>
+ /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex>
+ /// %cond_list = memref.alloc() : memref<2xi1>
+ /// %retain_base_pointer_list = memref.alloc() : memref<2xindex>
+ /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0
+ /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0]
+ /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1
+ /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1]
+ /// memref.store %cond0, %cond_list[%c0]
+ /// memref.store %cond1, %cond_list[%c1]
+ /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
+ /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0]
+ /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
+ /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1]
+ /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list :
+ /// memref<2xindex> to memref<?xindex>
+ /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1>
+ /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list :
+ /// memref<2xindex> to memref<?xindex>
+ /// %dealloc_cond_out = memref.alloc() : memref<2xi1>
+ /// %ownership_out = memref.alloc() : memref<2xi1>
+ /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out :
+ /// memref<2xi1> to memref<?xi1>
+ /// %dyn_ownership_out = memref.cast %ownership_out :
+ /// memref<2xi1> to memref<?xi1>
+ /// call @dealloc_helper(%dyn_dealloc_base_pointer_list,
+ /// %dyn_retain_base_pointer_list,
+ /// %dyn_cond_list,
+ /// %dyn_dealloc_cond_out,
+ /// %dyn_ownership_out) : (...)
+ /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1>
+ /// scf.if %m0_dealloc_cond {
+ /// memref.dealloc %m0 : memref<2xf32>
/// }
- /// %3:2 = call @dealloc_helper(%cast, %cast_4, %c1)
- /// %4 = arith.andi %3#0, %arg4 : i1
- /// %5 = arith.andi %3#1, %arg4 : i1
- /// scf.if %4 {
- /// memref.dealloc %arg1 : memref<5xf32>
+ /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1>
+ /// scf.if %m1_dealloc_cond {
+ /// memref.dealloc %m1 : memref<5xf32>
/// }
- /// memref.dealloc %alloc : memref<2xindex>
- /// memref.dealloc %alloc_0 : memref<1xindex>
- /// // replace %0#0 with %2
- /// // replace %0#1 with %5
+ /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1>
+ /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1>
+ /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex>
+ /// memref.dealloc %retain_base_pointer_list : memref<2xindex>
+ /// memref.dealloc %cond_list : memref<2xi1>
+ /// memref.dealloc %dealloc_cond_out : memref<2xi1>
+ /// memref.dealloc %ownership_out : memref<2xi1>
+ /// // replace %0#0 with %r0_ownership
+ /// // replace %0#1 with %r1_ownership
/// ```
LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
OpAdaptor adaptor,
@@ -175,6 +198,9 @@ class DeallocOpConversion
Value toDeallocMemref = rewriter.create<memref::AllocOp>(
op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
rewriter.getIndexType()));
+ Value conditionMemref = rewriter.create<memref::AllocOp>(
+ op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()},
+ rewriter.getI1Type()));
Value toRetainMemref = rewriter.create<memref::AllocOp>(
op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
rewriter.getIndexType()));
@@ -193,6 +219,11 @@ class DeallocOpConversion
rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
toDeallocMemref, getConstValue(i));
}
+
+ for (auto [i, cond] : llvm::enumerate(adaptor.getConditions()))
+ rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref,
+ getConstValue(i));
+
for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) {
Value memrefAsIdx =
rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
@@ -208,22 +239,41 @@ class DeallocOpConversion
op->getLoc(),
MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
toDeallocMemref);
+ Value castedCondsMemref = rewriter.create<memref::CastOp>(
+ op->getLoc(),
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
+ conditionMemref);
Value castedRetainMemref = rewriter.create<memref::CastOp>(
op->getLoc(),
MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
toRetainMemref);
- SmallVector<Value> replacements;
+ Value deallocCondsMemref = rewriter.create<memref::AllocOp>(
+ op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
+ rewriter.getI1Type()));
+ Value retainCondsMemref = rewriter.create<memref::AllocOp>(
+ op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
+ rewriter.getI1Type()));
+
+ Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>(
+ op->getLoc(),
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
+ deallocCondsMemref);
+ Value castedRetainCondsMemref = rewriter.create<memref::CastOp>(
+ op->getLoc(),
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
+ retainCondsMemref);
+
+ rewriter.create<func::CallOp>(
+ op.getLoc(), deallocHelperFunc,
+ SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
+ castedCondsMemref, castedDeallocCondsMemref,
+ castedRetainCondsMemref});
+
for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
- auto callOp = rewriter.create<func::CallOp>(
- op.getLoc(), deallocHelperFunc,
- SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
- getConstValue(i)});
- Value shouldDealloc = rewriter.create<arith::AndIOp>(
- op.getLoc(), callOp.getResult(0), adaptor.getConditions()[i]);
- Value ownership = rewriter.create<arith::AndIOp>(
- op.getLoc(), callOp.getResult(1), adaptor.getConditions()[i]);
- replacements.push_back(ownership);
+ Value idxValue = getConstValue(i);
+ Value shouldDealloc = rewriter.create<memref::LoadOp>(
+ op.getLoc(), deallocCondsMemref, idxValue);
rewriter.create<scf::IfOp>(
op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
@@ -231,10 +281,21 @@ class DeallocOpConversion
});
}
+ SmallVector<Value> replacements;
+ for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
+ Value idxValue = getConstValue(i);
+ Value ownership = rewriter.create<memref::LoadOp>(
+ op.getLoc(), retainCondsMemref, idxValue);
+ replacements.push_back(ownership);
+ }
+
// Deallocate above allocated memrefs again to avoid memory leaks.
// Deallocation will not be run on code after this stage.
rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
+ rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref);
+ rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref);
+ rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref);
rewriter.replaceOp(op, replacements);
return success();
@@ -261,70 +322,95 @@ class DeallocOpConversion
/// Build a helper function per compilation unit that can be called at
/// bufferization dealloc sites to determine aliasing and ownership.
///
- /// The generated function takes two memrefs of indices and one index value as
- /// arguments and returns two boolean values:
- /// * The first memref argument A should contain the result of the
+ /// The generated function takes two memrefs of indices and three memrefs of
+ /// booleans as arguments:
+ /// * The first argument A should contain the result of the
/// extract_aligned_pointer_as_index operation applied to the memrefs to be
/// deallocated
- /// * The second memref argument B should contain the result of the
+ /// * The second argument B should contain the result of the
/// extract_aligned_pointer_as_index operation applied to the memrefs to be
/// retained
- /// * The index argument I represents the currently processed index of
- /// memref A and is needed because aliasing with all previously deallocated
- /// memrefs has to be checked to avoid double deallocation
- /// * The first result indicates whether the memref at position I should be
- /// deallocated
- /// * The second result provides the updated ownership value corresponding
- /// the the memref at position I
- ///
- /// This helper function is supposed to be called for each element in the list
- /// of memrefs to be deallocated to determine the deallocation need and new
- /// ownership indicator, but does not perform the deallocation itself.
+ /// * The third argument C should contain the conditions as passed directly
+ /// to the deallocation operation.
+ /// * The fourth argument D is used to pass results to the caller. Those
+ /// represent the condition under which the memref at the corresponding
+ /// position in A should be deallocated.
+ /// * The fifth argument E is used to pass results to the caller. It
+ /// provides the ownership value corresponding the the memref at the same
+ /// position in B
///
- /// The first scf for loop in the body computes whether the memref at index I
- /// aliases with any memref in the list of retained memrefs.
- /// The second loop additionally checks whether one of the previously
- /// deallocated memrefs aliases with the currently processed one.
+ /// This helper function is supposed to be called once for each
+ /// `bufferization.dealloc` operation to determine the deallocation need and
+ /// new ownership indicator for the retained values, but does not perform the
+ /// deallocation itself.
///
/// Generated code:
/// ```
- /// func.func @dealloc_helper(%arg0: memref<?xindex>,
- /// %arg1: memref<?xindex>,
- /// %arg2: index) -> (i1, i1) {
+ /// func.func @dealloc_helper(
+ /// %dyn_dealloc_base_pointer_list: memref<?xindex>,
+ /// %dyn_retain_base_pointer_list: memref<?xindex>,
+ /// %dyn_cond_list: memref<?xi1>,
+ /// %dyn_dealloc_cond_out: memref<?xi1>,
+ /// %dyn_ownership_out: memref<?xi1>) {
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
/// %true = arith.constant true
- /// %dim = memref.dim %arg1, %c0 : memref<?xindex>
- /// %0 = memref.load %arg0[%arg2] : memref<?xindex>
- /// %1 = scf.for %i = %c0 to %dim step %c1 iter_args(%arg4 = %true) -> (i1){
- /// %4 = memref.load %arg1[%i] : memref<?xindex>
- /// %5 = arith.cmpi ne, %4, %0 : index
- /// %6 = arith.andi %arg4, %5 : i1
- /// scf.yield %6 : i1
+ /// %false = arith.constant false
+ /// %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0
+ /// %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0
+ /// // Zero initialize result buffer.
+ /// scf.for %i = %c0 to %num_retain_memrefs step %c1 {
+ /// memref.store %false, %dyn_ownership_out[%i] : memref<?xi1>
/// }
- /// %2 = scf.for %i = %c0 to %arg2 step %c1 iter_args(%arg4 = %1) -> (i1) {
- /// %4 = memref.load %arg0[%i] : memref<?xindex>
- /// %5 = arith.cmpi ne, %4, %0 : index
- /// %6 = arith.andi %arg4, %5 : i1
- /// scf.yield %6 : i1
+ /// scf.for %i = %c0 to %num_dealloc_memrefs step %c1 {
+ /// %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i]
+ /// %cond = memref.load %dyn_cond_list[%i]
+ /// // Check for aliasing with retained memrefs.
+ /// %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs
+ /// step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) {
+ /// %retain_bp = memref.load %dyn_retain_base_pointer_list[%j]
+ /// %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index
+ /// scf.if %does_alias {
+ /// %curr_ownership = memref.load %dyn_ownership_out[%j]
+ /// %updated_ownership = arith.ori %curr_ownership, %cond : i1
+ /// memref.store %updated_ownership, %dyn_ownership_out[%j]
+ /// }
+ /// %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index
+ /// %updated_aggregate = arith.andi %does_not_alias_aggregated,
+ /// %does_not_alias : i1
+ /// scf.yield %updated_aggregate : i1
+ /// }
+ /// // Check for aliasing with dealloc memrefs in the list before the
+ /// // current one, i.e.,
+ /// // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j],
+ /// // %dyn_dealloc_base_pointer[i])`
+ /// %does_not_alias_any = scf.for %j = %c0 to %i step %c1
+ /// iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) {
+ /// %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j]
+ /// %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp
+ /// %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias
+ /// scf.yield %updated_alias_agg : i1
+ /// }
+ /// %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1
+ /// memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref<?xi1>
/// }
- /// %3 = arith.xori %1, %true : i1
- /// return %2, %3 : i1, i1
+ /// return
/// }
/// ```
static func::FuncOp
buildDeallocationHelperFunction(OpBuilder &builder, Location loc,
SymbolTable &symbolTable) {
- Type idxType = builder.getIndexType();
- Type memrefArgType = MemRefType::get({ShapedType::kDynamic}, idxType);
- SmallVector<Type> argTypes{memrefArgType, memrefArgType, idxType};
+ Type indexMemrefType =
+ MemRefType::get({ShapedType::kDynamic}, builder.getIndexType());
+ Type boolMemrefType =
+ MemRefType::get({ShapedType::kDynamic}, builder.getI1Type());
+ SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType,
+ boolMemrefType, boolMemrefType};
builder.clearInsertionPoint();
// Generate the func operation itself.
auto helperFuncOp = func::FuncOp::create(
- loc, "dealloc_helper",
- builder.getFunctionType(argTypes,
- {builder.getI1Type(), builder.getI1Type()}));
+ loc, "dealloc_helper", builder.getFunctionType(argTypes, {}));
symbolTable.insert(helperFuncOp);
auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc));
@@ -332,57 +418,101 @@ class DeallocOpConversion
builder.setInsertionPointToStart(&block);
Value toDeallocMemref = helperFuncOp.getArguments()[0];
Value toRetainMemref = helperFuncOp.getArguments()[1];
- Value idxArg = helperFuncOp.getArguments()[2];
+ Value conditionMemref = helperFuncOp.getArguments()[2];
+ Value deallocCondsMemref = helperFuncOp.getArguments()[3];
+ Value retainCondsMemref = helperFuncOp.getArguments()[4];
// Insert some prerequisites.
Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0));
Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1));
Value trueValue =
builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true));
+ Value falseValue =
+ builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false));
+ Value toDeallocSize =
+ builder.create<memref::DimOp>(loc, toDeallocMemref, c0);
Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0);
- Value toDealloc =
- builder.create<memref::LoadOp>(loc, toDeallocMemref, idxArg);
-
- // Build the first for loop that computes aliasing with retained memrefs.
- Value noRetainAlias =
- builder
- .create<scf::ForOp>(
- loc, c0, toRetainSize, c1, trueValue,
- [&](OpBuilder &builder, Location loc, Value i,
- ValueRange iterArgs) {
- Value retainValue =
- builder.create<memref::LoadOp>(loc, toRetainMemref, i);
- Value doesntAlias = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, retainValue, toDealloc);
- Value yieldValue = builder.create<arith::AndIOp>(
- loc, iterArgs[0], doesntAlias);
- builder.create<scf::YieldOp>(loc, yieldValue);
- })
- .getResult(0);
-
- // Build the second for loop that adds aliasing with previously deallocated
- // memrefs.
- Value noAlias =
- builder
- .create<scf::ForOp>(
- loc, c0, idxArg, c1, noRetainAlias,
- [&](OpBuilder &builder, Location loc, Value i,
- ValueRange iterArgs) {
- Value prevDeallocValue =
- builder.create<memref::LoadOp>(loc, toDeallocMemref, i);
- Value doesntAlias = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, prevDeallocValue,
- toDealloc);
- Value yieldValue = builder.create<arith::AndIOp>(
- loc, iterArgs[0], doesntAlias);
- builder.create<scf::YieldOp>(loc, yieldValue);
- })
- .getResult(0);
-
- Value ownership =
- builder.create<arith::XOrIOp>(loc, noRetainAlias, trueValue);
- builder.create<func::ReturnOp>(loc, SmallVector<Value>{noAlias, ownership});
+ builder.create<scf::ForOp>(
+ loc, c0, toRetainSize, c1, std::nullopt,
+ [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
+ builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref,
+ i);
+ builder.create<scf::YieldOp>(loc);
+ });
+
+ builder.create<scf::ForOp>(
+ loc, c0, toDeallocSize, c1, std::nullopt,
+ [&](OpBuilder &builder, Location loc, Value outerIter,
+ ValueRange iterArgs) {
+ Value toDealloc =
+ builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
+ Value cond =
+ builder.create<memref::LoadOp>(loc, conditionMemref, outerIter);
+
+ // Build the first for loop that computes aliasing with retained
+ // memrefs.
+ Value noRetainAlias =
+ builder
+ .create<scf::ForOp>(
+ loc, c0, toRetainSize, c1, trueValue,
+ [&](OpBuilder &builder, Location loc, Value i,
+ ValueRange iterArgs) {
+ Value retainValue = builder.create<memref::LoadOp>(
+ loc, toRetainMemref, i);
+ Value doesAlias = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, retainValue,
+ toDealloc);
+ builder.create<scf::IfOp>(
+ loc, doesAlias,
+ [&](OpBuilder &builder, Location loc) {
+ Value retainCondValue =
+ builder.create<memref::LoadOp>(
+ loc, retainCondsMemref, i);
+ Value aggregatedRetainCond =
+ builder.create<arith::OrIOp>(
+ loc, retainCondValue, cond);
+ builder.create<memref::StoreOp>(
+ loc, aggregatedRetainCond, retainCondsMemref,
+ i);
+ builder.create<scf::YieldOp>(loc);
+ });
+ Value doesntAlias = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, retainValue,
+ toDealloc);
+ Value yieldValue = builder.create<arith::AndIOp>(
+ loc, iterArgs[0], doesntAlias);
+ builder.create<scf::YieldOp>(loc, yieldValue);
+ })
+ .getResult(0);
+
+ // Build the second for loop that adds aliasing with previously
+ // deallocated memrefs.
+ Value noAlias =
+ builder
+ .create<scf::ForOp>(
+ loc, c0, outerIter, c1, noRetainAlias,
+ [&](OpBuilder &builder, Location loc, Value i,
+ ValueRange iterArgs) {
+ Value prevDeallocValue = builder.create<memref::LoadOp>(
+ loc, toDeallocMemref, i);
+ Value doesntAlias = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, prevDeallocValue,
+ toDealloc);
+ Value yieldValue = builder.create<arith::AndIOp>(
+ loc, iterArgs[0], doesntAlias);
+ builder.create<scf::YieldOp>(loc, yieldValue);
+ })
+ .getResult(0);
+
+ Value shouldDealoc =
+ builder.create<arith::AndIOp>(loc, noAlias, cond);
+ builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
+ outerIter);
+ builder.create<scf::YieldOp>(loc);
+ });
+
+ builder.create<func::ReturnOp>(loc);
return helperFuncOp;
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 3747ab1562ff42..e981b80a60083a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -755,7 +755,8 @@ LogicalResult DeallocOp::inferReturnTypes(
ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
- inferredReturnTypes = SmallVector<Type>(adaptor.getConditions().getTypes());
+ inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
+ IntegerType::get(context, 1));
return success();
}
@@ -766,44 +767,46 @@ LogicalResult DeallocOp::verify() {
return success();
}
+static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
+ ArrayRef<Value> memrefs,
+ ArrayRef<Value> conditions,
+ PatternRewriter &rewriter) {
+ if (deallocOp.getMemrefs() == memrefs)
+ return failure();
+
+ rewriter.updateRootInPlace(deallocOp, [&]() {
+ deallocOp.getMemrefsMutable().assign(memrefs);
+ deallocOp.getConditionsMutable().assign(conditions);
+ });
+ return success();
+}
+
namespace {
-/// Remove duplicate values in the list of retained memrefs as well as the list
-/// of memrefs to be deallocated. For the latter, we need to make sure the
-/// corresponding condition values match as well, or otherwise have to combine
-/// them (by computing the disjunction of them).
+/// Remove duplicate values in the list of memrefs to be deallocated. We need to
+/// make sure the corresponding condition value is updated accordingly since
+/// their two conditions might not cover the same set of cases. In that case, we
+/// have to combine them (by computing the disjunction of them).
/// Example:
/// ```mlir
-/// %0:2 = bufferization.dealloc (%arg0, %arg0 : ...)
-/// if (%arg1, %arg2)
-/// retain (%arg3, %arg3 : ...)
+/// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
/// ```
/// is canonicalized to
/// ```mlir
/// %0 = arith.ori %arg1, %arg2 : i1
-/// %1 = bufferization.dealloc (%arg0 : memref<2xi32>)
-/// if (%0)
-/// retain (%arg3 : memref<2xi32>)
+/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
/// ```
-struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
+struct DeallocRemoveDuplicateDeallocMemrefs
+ : public OpRewritePattern<DeallocOp> {
using OpRewritePattern<DeallocOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
// Unique memrefs to be deallocated.
- DenseSet<Value> retained(deallocOp.getRetained().begin(),
- deallocOp.getRetained().end());
DenseMap<Value, unsigned> memrefToCondition;
- SmallVector<Value> newMemrefs, newConditions, newRetained;
- SmallVector<int32_t> resultIndices(deallocOp.getMemrefs().size(), -1);
+ SmallVector<Value> newMemrefs, newConditions;
for (auto [i, memref, cond] :
llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
- if (retained.contains(memref)) {
- rewriter.replaceAllUsesWith(deallocOp.getResult(i),
- deallocOp.getConditions()[i]);
- continue;
- }
-
if (memrefToCondition.count(memref)) {
// If the dealloc conditions don't match, we need to make sure that the
// dealloc happens on the union of cases.
@@ -816,42 +819,121 @@ struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
newMemrefs.push_back(memref);
newConditions.push_back(cond);
}
- resultIndices[i] = memrefToCondition[memref];
}
+ // Return failure if we don't change anything such that we don't run into an
+ // infinite loop of pattern applications.
+ return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+ rewriter);
+ }
+};
+
+/// Remove duplicate values in the list of retained memrefs. We need to make
+/// sure the corresponding result condition value is replaced properly.
+/// Example:
+/// ```mlir
+/// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
+/// ```
+/// is canonicalized to
+/// ```mlir
+/// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
+/// ```
+struct DeallocRemoveDuplicateRetainedMemrefs
+ : public OpRewritePattern<DeallocOp> {
+ using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DeallocOp deallocOp,
+ PatternRewriter &rewriter) const override {
// Unique retained values
- DenseSet<Value> seen;
+ DenseMap<Value, unsigned> seen;
+ SmallVector<Value> newRetained;
+ SmallVector<unsigned> resultReplacementIdx;
+ unsigned i = 0;
for (auto retained : deallocOp.getRetained()) {
- if (!seen.contains(retained)) {
- seen.insert(retained);
- newRetained.push_back(retained);
+ if (seen.count(retained)) {
+ resultReplacementIdx.push_back(seen[retained]);
+ continue;
}
+
+ seen[retained] = i;
+ newRetained.push_back(retained);
+ resultReplacementIdx.push_back(i++);
}
// Return failure if we don't change anything such that we don't run into an
// infinite loop of pattern applications.
- if (newConditions.size() == deallocOp.getConditions().size() &&
- newRetained.size() == deallocOp.getRetained().size())
+ if (newRetained.size() == deallocOp.getRetained().size())
return failure();
// We need to create a new op because the number of results is always the
// same as the number of condition operands.
- auto newDealloc = rewriter.create<DeallocOp>(deallocOp.getLoc(), newMemrefs,
- newConditions, newRetained);
- for (auto [i, newIdx] : llvm::enumerate(resultIndices))
- if (newIdx != -1)
- rewriter.replaceAllUsesWith(deallocOp.getResult(i),
- newDealloc.getResult(newIdx));
-
- rewriter.eraseOp(deallocOp);
+ auto newDeallocOp =
+ rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
+ deallocOp.getConditions(), newRetained);
+ SmallVector<Value> replacements(
+ llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
+ return newDeallocOp.getUpdatedConditions()[idx];
+ }));
+ rewriter.replaceOp(deallocOp, replacements);
return success();
}
};
+/// Remove memrefs to be deallocated that are also present in the retained list
+/// since they will always alias and thus never actually be deallocated.
+/// Example:
+/// ```mlir
+/// %0 = bufferization.dealloc (%arg0 : ...) if (%arg1) retain (%arg0 : ...)
+/// ```
+/// is canonicalized to
+/// ```mlir
+/// %0 = bufferization.dealloc retain (%arg0 : ...)
+/// ```
+struct DeallocRemoveDeallocMemrefsContainedInRetained
+ : public OpRewritePattern<DeallocOp> {
+ using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DeallocOp deallocOp,
+ PatternRewriter &rewriter) const override {
+ // Unique memrefs to be deallocated.
+ DenseMap<Value, unsigned> retained;
+ for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained()))
+ retained[ret] = i;
+
+ // There must not be any duplicates in the retain list anymore because we
+ // would miss updating one of the result values otherwise.
+ if (retained.size() != deallocOp.getRetained().size())
+ return failure();
+
+ SmallVector<Value> newMemrefs, newConditions;
+ for (auto [memref, cond] :
+ llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+ if (retained.contains(memref)) {
+ rewriter.setInsertionPointAfter(deallocOp);
+ auto orOp = rewriter.create<arith::OrIOp>(
+ deallocOp.getLoc(),
+ deallocOp.getUpdatedConditions()[retained[memref]], cond);
+ rewriter.replaceAllUsesExcept(
+ deallocOp.getUpdatedConditions()[retained[memref]],
+ orOp.getResult(), orOp);
+ continue;
+ }
+
+ newMemrefs.push_back(memref);
+ newConditions.push_back(cond);
+ }
+
+ // Return failure if we don't change anything such that we don't run into an
+ // infinite loop of pattern applications.
+ return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+ rewriter);
+ }
+};
+
/// Erase deallocation operations where the variadic list of memrefs to
-/// deallocate is emtpy. Example:
+/// deallocate is empty. Example:
/// ```mlir
-/// bufferization.dealloc retain (%arg0: memref<2xi32>)
+/// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
/// ```
struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
using OpRewritePattern<DeallocOp>::OpRewritePattern;
@@ -859,7 +941,11 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
if (deallocOp.getMemrefs().empty()) {
- rewriter.eraseOp(deallocOp);
+ Value constFalse = rewriter.create<arith::ConstantOp>(
+ deallocOp.getLoc(), rewriter.getBoolAttr(false));
+ rewriter.replaceOp(
+ deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
+ constFalse));
return success();
}
return failure();
@@ -871,12 +957,12 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
///
/// Example:
/// ```
-/// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
+/// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
/// if (%arg2, %false)
/// ```
/// becomes
/// ```
-/// %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
+/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
/// ```
struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
using OpRewritePattern<DeallocOp>::OpRewritePattern;
@@ -884,32 +970,16 @@ struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
SmallVector<Value> newMemrefs, newConditions;
- SmallVector<Value> replacements;
-
- for (auto [res, memref, cond] :
- llvm::zip(deallocOp.getUpdatedConditions(), deallocOp.getMemrefs(),
- deallocOp.getConditions())) {
- if (matchPattern(cond, m_Zero())) {
- replacements.push_back(cond);
- continue;
+ for (auto [memref, cond] :
+ llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+ if (!matchPattern(cond, m_Zero())) {
+ newMemrefs.push_back(memref);
+ newConditions.push_back(cond);
}
- newMemrefs.push_back(memref);
- newConditions.push_back(cond);
- replacements.push_back({});
}
- if (newMemrefs.size() == deallocOp.getMemrefs().size())
- return failure();
-
- auto newDeallocOp = rewriter.create<DeallocOp>(
- deallocOp.getLoc(), newMemrefs, newConditions, deallocOp.getRetained());
- unsigned i = 0;
- for (auto &repl : replacements)
- if (!repl)
- repl = newDeallocOp.getResult(i++);
-
- rewriter.replaceOp(deallocOp, replacements);
- return success();
+ return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+ rewriter);
}
};
@@ -917,9 +987,10 @@ struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<DeallocRemoveDuplicates, EraseEmptyDealloc, EraseAlwaysFalseDealloc>(
- context);
+ results.add<DeallocRemoveDuplicateDeallocMemrefs,
+ DeallocRemoveDuplicateRetainedMemrefs,
+ DeallocRemoveDeallocMemrefsContainedInRetained, EraseEmptyDealloc,
+ EraseAlwaysFalseDealloc>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
index aed232d29d175c..0f80e9a1140e30 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -81,78 +81,106 @@ func.func @conversion_dealloc_empty() {
// CHECK-LABEL: func @conversion_dealloc_simple
// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>
// CHECK-SAME: [[ARG1:%.+]]: i1
-func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) -> i1 {
- %0 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
- return %0 : i1
+func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) {
+ bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+ return
}
// CHECk: scf.if [[ARG1]] {
// CHECk-NEXT: memref.dealloc [[ARG0]] : memref<2xf32>
// CHECk-NEXT: }
-// CHECk-NEXT: [[FALSE:%.+]] = arith.constant false
-// CHECk-NEXT: return [[FALSE]] : i1
+// CHECk-NEXT: return
// -----
-func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1) -> (i1, i1) {
- %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2 : memref<1xf32>)
+func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
return %0#0, %0#1 : i1, i1
}
// CHECK-LABEL: func @conversion_dealloc_multiple_memrefs_and_retained
-// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>,
-// CHECK-SAME: [[ARG1:%.+]]: memref<5xf32>,
-// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>,
-// CHECK-SAME: [[ARG3:%.+]]: i1,
-// CHECK-SAME: [[ARG4:%.+]]: i1
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<5xf32>,
+// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1,
+// CHECK-SAME: [[ARG5:%.+]]: memref<2xf32>)
// CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloc() : memref<2xindex>
-// CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<1xindex>
+// CHECK: [[CONDS:%.+]] = memref.alloc() : memref<2xi1>
+// CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<2xindex>
// CHECK-DAG: [[V0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
// CHECK-DAG: memref.store [[V0]], [[TO_DEALLOC_MR]][[[C0]]]
// CHECK-DAG: [[V1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]]
// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
// CHECK-DAG: memref.store [[V1]], [[TO_DEALLOC_MR]][[[C1]]]
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: memref.store [[ARG3]], [[CONDS]][[[C0]]]
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK-DAG: memref.store [[ARG4]], [[CONDS]][[[C1]]]
// CHECK-DAG: [[V2:%.+]] = memref.extract_aligned_pointer_as_index [[ARG2]]
// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
// CHECK-DAG: memref.store [[V2]], [[TO_RETAIN_MR]][[[C0]]]
+// CHECK-DAG: [[V3:%.+]] = memref.extract_aligned_pointer_as_index [[ARG5]]
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK-DAG: memref.store [[V3]], [[TO_RETAIN_MR]][[[C1]]]
// CHECK-DAG: [[CAST_DEALLOC:%.+]] = memref.cast [[TO_DEALLOC_MR]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<1xindex> to memref<?xindex>
-// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
-// CHECK: [[RES0:%.+]]:2 = call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[C0]])
-// CHECK: [[SHOULD_DEALLOC_0:%.+]] = arith.andi [[RES0]]#0, [[ARG3]]
-// CHECK: [[OWNERSHIP0:%.+]] = arith.andi [[RES0]]#1, [[ARG3]]
+// CHECK-DAG: [[CAST_CONDS:%.+]] = memref.cast [[CONDS]] : memref<2xi1> to memref<?xi1>
+// CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<2xindex> to memref<?xindex>
+// CHECK: [[DEALLOC_CONDS:%.+]] = memref.alloc() : memref<2xi1>
+// CHECK: [[RETAIN_CONDS:%.+]] = memref.alloc() : memref<2xi1>
+// CHECK: [[CAST_DEALLOC_CONDS:%.+]] = memref.cast [[DEALLOC_CONDS]] : memref<2xi1> to memref<?xi1>
+// CHECK: [[CAST_RETAIN_CONDS:%.+]] = memref.cast [[RETAIN_CONDS]] : memref<2xi1> to memref<?xi1>
+// CHECK: call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[CAST_CONDS]], [[CAST_DEALLOC_CONDS]], [[CAST_RETAIN_CONDS]])
+// CHECK: [[C0:%.+]] = arith.constant 0 : index
+// CHECK: [[SHOULD_DEALLOC_0:%.+]] = memref.load [[DEALLOC_CONDS]][[[C0]]]
// CHECK: scf.if [[SHOULD_DEALLOC_0]] {
// CHECK: memref.dealloc %arg0
// CHECK: }
// CHECK: [[C1:%.+]] = arith.constant 1 : index
-// CHECK: [[RES1:%.+]]:2 = call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[C1]])
-// CHECK: [[SHOULD_DEALLOC_1:%.+]] = arith.andi [[RES1:%.+]]#0, [[ARG4]]
-// CHECK: [[OWNERSHIP1:%.+]] = arith.andi [[RES1:%.+]]#1, [[ARG4]]
+// CHECK: [[SHOULD_DEALLOC_1:%.+]] = memref.load [[DEALLOC_CONDS]][[[C1]]]
// CHECK: scf.if [[SHOULD_DEALLOC_1]]
// CHECK: memref.dealloc [[ARG1]]
// CHECK: }
+// CHECK: [[C0:%.+]] = arith.constant 0 : index
+// CHECK: [[OWNERSHIP0:%.+]] = memref.load [[RETAIN_CONDS]][[[C0]]]
+// CHECK: [[C1:%.+]] = arith.constant 1 : index
+// CHECK: [[OWNERSHIP1:%.+]] = memref.load [[RETAIN_CONDS]][[[C1]]]
// CHECK: memref.dealloc [[TO_DEALLOC_MR]]
// CHECK: memref.dealloc [[TO_RETAIN_MR]]
+// CHECK: memref.dealloc [[CONDS]]
+// CHECK: memref.dealloc [[DEALLOC_CONDS]]
+// CHECK: memref.dealloc [[RETAIN_CONDS]]
// CHECK: return [[OWNERSHIP0]], [[OWNERSHIP1]]
// CHECK: func @dealloc_helper
-// CHECK-SAME: [[ARG0:%.+]]: memref<?xindex>, [[ARG1:%.+]]: memref<?xindex>
-// CHECK-SAME: [[ARG2:%.+]]: index
-// CHECK-SAME: -> (i1, i1)
-// CHECK: [[TO_RETAIN_SIZE:%.+]] = memref.dim [[ARG1]], %c0
-// CHECK: [[TO_DEALLOC:%.+]] = memref.load [[ARG0]][[[ARG2]]] : memref<?xindex>
-// CHECK-NEXT: [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) {
-// CHECK-NEXT: [[RETAIN_VAL:%.+]] = memref.load [[ARG1]][[[ITER]]] : memref<?xindex>
-// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
-// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
-// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1
+// CHECK-SAME: ([[TO_DEALLOC_MR:%.+]]: memref<?xindex>, [[TO_RETAIN_MR:%.+]]: memref<?xindex>,
+// CHECK-SAME: [[CONDS:%.+]]: memref<?xi1>, [[DEALLOC_CONDS_OUT:%.+]]: memref<?xi1>,
+// CHECK-SAME: [[RETAIN_CONDS_OUT:%.+]]: memref<?xi1>)
+// CHECK: [[TO_DEALLOC_SIZE:%.+]] = memref.dim [[TO_DEALLOC_MR]], %c0
+// CHECK: [[TO_RETAIN_SIZE:%.+]] = memref.dim [[TO_RETAIN_MR]], %c0
+// CHECK: scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 {
+// CHECK-NEXT: memref.store %false, [[RETAIN_CONDS_OUT]][[[ITER]]]
// CHECK-NEXT: }
-// CHECK-NEXT: [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[ARG2]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) {
-// CHECK-NEXT: [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref<?xindex>
-// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index
-// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
-// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1
+// CHECK: scf.for [[OUTER_ITER:%.+]] = %c0 to [[TO_DEALLOC_SIZE]] step %c1 {
+// CHECK: [[TO_DEALLOC:%.+]] = memref.load [[TO_DEALLOC_MR]][[[OUTER_ITER]]]
+// CHECK-NEXT: [[COND:%.+]] = memref.load [[CONDS]][[[OUTER_ITER]]]
+// CHECK-NEXT: [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) {
+// CHECK-NEXT: [[RETAIN_VAL:%.+]] = memref.load [[TO_RETAIN_MR]][[[ITER]]] : memref<?xindex>
+// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi eq, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT: scf.if [[DOES_ALIAS]]
+// CHECK-NEXT: [[RETAIN_COND:%.+]] = memref.load [[RETAIN_CONDS_OUT]][[[ITER]]]
+// CHECK-NEXT: [[AGG_RETAIN_COND:%.+]] = arith.ori [[RETAIN_COND]], [[COND]] : i1
+// CHECK-NEXT: memref.store [[AGG_RETAIN_COND]], [[RETAIN_CONDS_OUT]][[[ITER]]]
+// CHECK-NEXT: }
+// CHECK-NEXT: [[DOES_NOT_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT: [[AGG_DOES_NOT_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_NOT_ALIAS]] : i1
+// CHECK-NEXT: scf.yield [[AGG_DOES_NOT_ALIAS]] : i1
+// CHECK-NEXT: }
+// CHECK-NEXT: [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[OUTER_ITER]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) {
+// CHECK-NEXT: [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref<?xindex>
+// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
+// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1
+// CHECK-NEXT: }
+// CHECK-NEXT: [[DEALLOC_COND:%.+]] = arith.andi [[SHOULD_DEALLOC]], [[COND]] : i1
+// CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
// CHECK-NEXT: }
-// CHECK-NEXT: [[OWNERSHIP:%.+]] = arith.xori [[NO_RETAIN_ALIAS]], %true : i1
-// CHECK-NEXT: return [[SHOULD_DEALLOC]], [[OWNERSHIP]] : i1, i1
+// CHECK-NEXT: return
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 96f82f6835dd61..6a4edf1e6335f1 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -282,44 +282,44 @@ func.func @dealloc_canonicalize_clone_removal(%arg0: memref<?xindex>) -> memref<
// -----
-func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1, i1, i1) {
+func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1) {
%0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg5, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>)
- %1:2 = bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2)
- return %0#0, %0#1, %0#2, %1#0, %1#1 : i1, i1, i1, i1, i1
+ bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2)
+ return %0#0, %0#1, %0#2 : i1, i1, i1
}
// CHECK-LABEL: func @dealloc_canonicalize_duplicates
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>, [[ARG5:%.+]]: memref<2xi32>)
// CHECK-NEXT: [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG5]] : memref<2xi32>, memref<2xi32>)
// CHECK-NEXT: [[NEW_COND:%.+]] = arith.ori [[ARG1]], [[ARG2]] : i1
-// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]])
-// CHECK-NEXT: return [[V0]]#0, [[V0]]#1, [[V0]]#1, [[V1]], [[V1]] :
+// CHECK-NEXT: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]])
+// CHECK-NEXT: return [[V0]]#0, [[V0]]#1, [[V0]]#0 :
// -----
-func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1) {
+func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1) {
%0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
- %1:2 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
+ %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
bufferization.dealloc
bufferization.dealloc retain (%arg0 : memref<2xi32>)
- return %0, %1#0, %1#1 : i1, i1, i1
+ return %0, %1 : i1, i1
}
// CHECK-LABEL: func @dealloc_canonicalize_retained_and_deallocated
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
-// CHECK-NEXT: return [[ARG1]], [[ARG1]], [[V0]] :
+// CHECK-NEXT: [[V1:%.+]] = arith.ori [[V0]], [[ARG1]]
+// CHECK-NEXT: return [[ARG1]], [[V1]] :
// -----
-func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) -> (i1, i1) {
+func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) {
%false = arith.constant false
- %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2)
- return %0#0, %0#1 : i1, i1
+ bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2)
+ return
}
// CHECK-LABEL: func @dealloc_always_false_condition
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1)
-// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
-// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]])
-// CHECK-NEXT: return [[FALSE]], [[V0]] :
+// CHECK-NEXT: bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]])
+// CHECK-NEXT: return
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 8dded8b4debe3e..3b4bfee5622e9b 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -106,8 +106,16 @@ func.func @invalid_tensor_copy(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
// -----
-func.func @invalid_dealloc_memref_condition_mismatch(%arg0: memref<2xf32>, %arg1: memref<4xi32>, %arg2: i1) -> i1 {
+func.func @invalid_dealloc_memref_condition_mismatch(%arg0: memref<2xf32>, %arg1: memref<4xi32>, %arg2: i1) {
// expected-error @below{{must have the same number of conditions as memrefs to deallocate}}
- %0 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2)
- return %0 : i1
+ bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2)
+ return
+}
+
+// -----
+
+func.func @invalid_dealloc_wrong_number_of_results(%arg0: memref<2xf32>, %arg1: memref<4xi32>, %arg2: i1) -> i1 {
+ // expected-error @below{{operation defines 1 results but was provided 2 to bind}}
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg2) retain (%arg1 : memref<4xi32>)
+ return %0#0 : i1
}
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index 55af9ef37384a9..773f15c1ffcb89 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -73,8 +73,8 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>,
// CHECK: bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg3) retain (%arg4, %arg5 : memref<?xf32>, memref<*xf64>)
%0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg3) retain (%arg4, %arg5 : memref<?xf32>, memref<*xf64>)
// CHECK: bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2)
- %1 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2)
+ bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2)
// CHECK: bufferization.dealloc
bufferization.dealloc
- return %0, %1 : i1, i1
+ return %0#0, %0#1 : i1, i1
}
More information about the Mlir-commits
mailing list