[Mlir-commits] [mlir] 660fded - [mlir][bufferization] Add specialized lowering for deallocs with one memref but arbitrary retains
Martin Erhart
llvmlistbot at llvm.org
Mon Aug 14 01:59:05 PDT 2023
Author: Martin Erhart
Date: 2023-08-14T08:58:46Z
New Revision: 660fdedec9d6ddcd4dcd311edd60cfc27d74ef55
URL: https://github.com/llvm/llvm-project/commit/660fdedec9d6ddcd4dcd311edd60cfc27d74ef55
DIFF: https://github.com/llvm/llvm-project/commit/660fdedec9d6ddcd4dcd311edd60cfc27d74ef55.diff
LOG: [mlir][bufferization] Add specialized lowering for deallocs with one memref but arbitrary retains
It is often the case that many values in the `memrefs` operand list can be
split off to speparate dealloc operations by the
`--buffer-deallocation-simplification` pass, however, the retain list has to be
preserved initially. Further canonicalization can often trim it down
considerable, but some retains may remain. In those cases, the general lowering
would be chosen, but is very inefficient. This commit adds another lowering for
those cases which avoids allocation of auxillary memrefs and the helper
function while still producing code that is linear in the number of operands of
the dealloc operation.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D157692
Added:
Modified:
mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index f998c8ce172a03..6225e010784789 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -120,6 +120,91 @@ class DeallocOpConversion
return success();
}
+ /// A special case lowering for the deallocation operation with exactly one
+ /// memref, but arbitrary number of retained values. This avoids the helper
+ /// function that the general case needs and thus also avoids storing indices
+ /// to specifically allocated memrefs. The size of the code produced by this
+ /// lowering is linear to the number of retained values.
+ ///
+ /// Example:
+ /// ```mlir
+ /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond)
+ // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
+ /// return %0#0, %0#1 : i1, i1
+ /// ```
+ /// ```mlir
+ /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m
+ /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
+ /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer
+ /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
+ /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer
+ /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1
+ /// %should_dealloc = arith.andi %not_retained, %cond : i1
+ /// scf.if %should_dealloc {
+ /// memref.dealloc %m : memref<2xf32>
+ /// }
+ /// %true = arith.constant true
+ /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1
+ /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1
+ /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1
+ /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1
+ /// return %r0_ownership, %r1_ownership : i1, i1
+ /// ```
+ LogicalResult rewriteOneMemrefMultipleRetainCase(
+ bufferization::DeallocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
+
+ // Compute the base pointer indices, compare all retained indices to the
+ // memref index to check if they alias.
+ SmallVector<Value> doesNotAliasList;
+ Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(
+ op->getLoc(), adaptor.getMemrefs()[0]);
+ for (Value retained : adaptor.getRetained()) {
+ Value retainedAsIdx =
+ rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(),
+ retained);
+ Value doesNotAlias = rewriter.create<arith::CmpIOp>(
+ op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
+ doesNotAliasList.push_back(doesNotAlias);
+ }
+
+ // AND-reduce the list of booleans from above.
+ Value prev = doesNotAliasList.front();
+ for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front())
+ prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias);
+
+ // Also consider the condition given by the dealloc operation and perform a
+ // conditional deallocation guarded by that value.
+ Value shouldDealloc = rewriter.create<arith::AndIOp>(
+ op->getLoc(), prev, adaptor.getConditions()[0]);
+
+ rewriter.create<scf::IfOp>(
+ op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
+ builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
+ builder.create<scf::YieldOp>(loc);
+ });
+
+ // Compute the replacement values for the dealloc operation results. This
+ // inserts an already canonicalized form of
+ // `select(does_alias_with_memref(r), memref_cond, false)` for each retained
+ // value r.
+ SmallVector<Value> replacements;
+ Value trueVal = rewriter.create<arith::ConstantOp>(
+ op->getLoc(), rewriter.getBoolAttr(true));
+ for (Value doesNotAlias : doesNotAliasList) {
+ Value aliases =
+ rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal);
+ Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases,
+ adaptor.getConditions()[0]);
+ replacements.push_back(result);
+ }
+
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+ }
+
/// Lowering that supports all features the dealloc operation has to offer. It
/// computes the base pointer of each memref (as an index), stores it in a
/// new memref helper structure and passes it to the helper function generated
@@ -310,12 +395,20 @@ class DeallocOpConversion
matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Lower the trivial case.
- if (adaptor.getMemrefs().empty())
- return rewriter.eraseOp(op), success();
+ if (adaptor.getMemrefs().empty()) {
+ Value falseVal = rewriter.create<arith::ConstantOp>(
+ op.getLoc(), rewriter.getBoolAttr(false));
+ rewriter.replaceOp(
+ op, SmallVector<Value>(adaptor.getRetained().size(), falseVal));
+ return success();
+ }
if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
+ if (adaptor.getMemrefs().size() == 1)
+ return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
+
return rewriteGeneralCase(op, adaptor, rewriter);
}
@@ -535,8 +628,7 @@ struct BufferizationToMemRefPass
// Build dealloc helper function if there are deallocs.
func::FuncOp helperFuncOp;
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
- if (deallocOp.getMemrefs().size() > 1 ||
- !deallocOp.getRetained().empty()) {
+ if (deallocOp.getMemrefs().size() > 1) {
helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction(
builder, getOperation()->getLoc(), symbolTable);
return WalkResult::interrupt();
diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
index 0f80e9a1140e30..95deb165741973 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -66,17 +66,29 @@ func.func @conversion_with_invalid_layout_map(%arg0 : memref<?xf32, strided<[10]
memref.dealloc %arg0 : memref<?xf32, strided<[10], offset: ?>>
return %1 : memref<?xf32, strided<[10], offset: ?>>
}
+
// -----
// CHECK-LABEL: func @conversion_dealloc_empty
func.func @conversion_dealloc_empty() {
- // CHECK-NEXT: return
+ // CHECK-NOT: bufferization.dealloc
bufferization.dealloc
return
}
// -----
+func.func @conversion_dealloc_empty_but_retains(%arg0: memref<2xi32>, %arg1: memref<2xi32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
+ return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @conversion_dealloc_empty
+// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
+// CHECK-NEXT: return [[FALSE]], [[FALSE]] :
+
+// -----
+
// CHECK-NOT: func @deallocHelper
// CHECK-LABEL: func @conversion_dealloc_simple
// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>
@@ -93,6 +105,33 @@ func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) {
// -----
+func.func @conversion_dealloc_one_memref_and_multiple_retained(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
+ return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @conversion_dealloc_one_memref_and_multiple_retained
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<1xf32>, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xf32>)
+// CHECK-DAG: [[M0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
+// CHECK-DAG: [[R0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]]
+// CHECK-DAG: [[R1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG3]]
+// CHECK-DAG: [[DOES_NOT_ALIAS_R0:%.+]] = arith.cmpi ne, [[M0]], [[R0]] : index
+// CHECK-DAG: [[DOES_NOT_ALIAS_R1:%.+]] = arith.cmpi ne, [[M0]], [[R1]] : index
+// CHECK: [[NOT_RETAINED:%.+]] = arith.andi [[DOES_NOT_ALIAS_R0]], [[DOES_NOT_ALIAS_R1]]
+// CHECK: [[SHOULD_DEALLOC:%.+]] = arith.andi [[NOT_RETAINED]], [[ARG2]]
+// CHECK: scf.if [[SHOULD_DEALLOC]]
+// CHECK: memref.dealloc [[ARG0]]
+// CHECK: }
+// CHECK-DAG: [[ALIASES_R0:%.+]] = arith.xori [[DOES_NOT_ALIAS_R0]], %true
+// CHECK-DAG: [[ALIASES_R1:%.+]] = arith.xori [[DOES_NOT_ALIAS_R1]], %true
+// CHECK-DAG: [[RES0:%.+]] = arith.andi [[ALIASES_R0]], [[ARG2]]
+// CHECK-DAG: [[RES1:%.+]] = arith.andi [[ALIASES_R1]], [[ARG2]]
+// CHECK: return [[RES0]], [[RES1]]
+
+// CHECK-NOT: func @dealloc_helper
+
+// -----
+
func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
%0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
return %0#0, %0#1 : i1, i1
More information about the Mlir-commits
mailing list