[Mlir-commits] [mlir] [mlir][bufferization] Ownership dealloc: support `IsolatedFromAbove` (PR #97669)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 3 20:40:21 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization
Author: Nikhil Kalra (nikalra)
<details>
<summary>Changes</summary>
Handle `IsolatedFromAbove` operations in `ownership-based-buffer-deallocation` by using the same contract as function boundaries. Specifically, IsolatedFromAbove ops cannot take ownership of their arguments, and rely on the caller to deallocate them.
---
Full diff: https://github.com/llvm/llvm-project/pull/97669.diff
5 Files Affected:
- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (+5)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp (+3-2)
- (added) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir (+125)
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+12)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+18-1)
``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 72f47b8b468ea..d9525cb640e1c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -174,6 +174,11 @@ void BufferViewFlowAnalysis::build(Operation *op) {
}
}
+ if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ // Mark the entry block arguments and results as terminal.
+ populateTerminalValues(op);
+ }
+
return WalkResult::advance();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index ca5d0688b5b59..a52906174fb07 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -640,8 +640,9 @@ LogicalResult BufferDeallocation::deallocate(Block *block) {
continue;
// Adhere to function boundary ABI: no ownership of function argument
- // MemRefs is taken.
- if (isa<FunctionOpInterface>(block->getParentOp()) &&
+ // MemRefs is taken. Likewise for ops marked IsolatedFromAbove.
+ if ((isa<FunctionOpInterface>(block->getParentOp()) ||
+ block->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) &&
block->isEntryBlock()) {
Value newArg = buildBoolValue(builder, arg.getLoc(), false);
state.updateOwnership(arg, newArg);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir
new file mode 100644
index 0000000000000..0c3ceda5237cc
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir
@@ -0,0 +1,125 @@
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
+// RUN: -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+
+func.func @function_call() {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloc() : memref<f64>
+ %ret = test.isolated_one_region_with_recursive_memory_effects %alloc {
+ ^bb0(%arg1: memref<f64>):
+ test.region_yield %arg1 : memref<f64>
+ } : (memref<f64>) -> memref<f64>
+ test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+ return
+}
+
+// CHECK-LABEL: func @function_call()
+// CHECK: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK-NEXT: [[RET:%.+]]:2 = test.isolated_one_region_with_recursive_memory_effects [[ALLOC0]]
+// CHECK-NEXT: ^bb0([[ARG:%.+]]: memref<f64>)
+// CHECK: test.region_yield [[ARG]]
+// CHECK-NOT: bufferization.dealloc
+// CHECK: }
+// CHECK-NEXT: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC0]], [[ALLOC1]], [[BASE]] :{{.*}}) if (%true, %true, [[RET]]#1)
+
+// -----
+
+func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloca() : memref<f64>
+ %0 = arith.select %arg0, %alloc, %alloc2 : memref<f64>
+ %ret = test.isolated_one_region_with_recursive_memory_effects %0 {
+ ^bb0(%arg1: memref<f64>):
+ test.region_yield %arg1 : memref<f64>
+ } : (memref<f64>) -> (memref<f64>)
+ test.copy(%ret, %alloc) : (memref<f64>, memref<f64>)
+ return
+}
+
+// CHECK-LABEL: func @function_call_requries_merged_ownership_mid_block
+// CHECK: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloca(
+// CHECK-NEXT: [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
+// CHECK-NEXT: [[RET:%.+]]:2 = test.isolated_one_region_with_recursive_memory_effects [[SELECT]]
+// CHECK-NEXT: ^bb0([[ARG:%.+]]: memref<f64>)
+// CHECK: test.region_yield [[ARG]]
+// CHECK-NOT: bufferization.dealloc
+// CHECK: }
+// CHECK-NEXT: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC0]], [[BASE]] :
+// CHECK-SAME: if (%true, [[RET]]#1)
+// CHECK-NOT: retain
+// CHECK-NEXT: return
+
+// -----
+
+func.func @g(%arg0: memref<f32>) -> memref<f32> {
+ %0 = test.isolated_one_region_with_recursive_memory_effects %arg0 {
+ ^bb0(%arg1: memref<f32>):
+ test.region_yield %arg1 : memref<f32>
+ } : (memref<f32>) -> (memref<f32>)
+ return %0 : memref<f32>
+}
+
+// CHECK-LABEL: func.func @g(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<f32>) -> memref<f32> {
+// CHECK: %[[BLOCK:.*]]:2 = test.isolated_one_region_with_recursive_memory_effects %[[VAL_0]] {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<f32>):
+// CHECK: test.region_yield %[[ARG]], %false : memref<f32>, i1
+// CHECK: } : (memref<f32>) -> (memref<f32>, i1)
+// CHECK: %[[VAL_4:.*]] = scf.if %[[BLOCK]]#1 -> (memref<f32>) {
+// CHECK: scf.yield %[[BLOCK]]#0 : memref<f32>
+// CHECK: } else {
+// CHECK: %[[VAL_6:.*]] = bufferization.clone %[[BLOCK]]#0 : memref<f32> to memref<f32>
+// CHECK: scf.yield %[[VAL_6]] : memref<f32>
+// CHECK: }
+// CHECK: %[[BUF:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[BLOCK]]#0 : memref<f32> -> memref<f32>, index
+// CHECK: %[[VAL_11:.*]] = bufferization.dealloc (%[[BUF]] : memref<f32>) if (%[[BLOCK]]#1) retain (%[[VAL_4]] : memref<f32>)
+// CHECK: return %[[VAL_4]] : memref<f32>
+// CHECK: }
+
+// -----
+
+func.func @alloc_yielded_from_block() {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloc() : memref<f64>
+ %ret = test.isolated_one_region_with_recursive_memory_effects %alloc {
+ ^bb0(%arg1: memref<f64>):
+ %0 = memref.load %arg1[] : memref<f64>
+ %c1 = arith.constant 1.0 : f64
+ %r0 = arith.cmpf oeq, %0, %c1 : f64
+ %1 = scf.if %r0 -> memref<f64> {
+ %alloc3 = memref.alloc() : memref<f64>
+ scf.yield %alloc3 : memref<f64>
+ } else {
+ scf.yield %arg1 : memref<f64>
+ }
+ test.region_yield %1 : memref<f64>
+ } : (memref<f64>) -> memref<f64>
+ test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+ return
+}
+
+// CHECK-LABEL: func.func @alloc_yielded_from_block() {
+// CHECK: %true = arith.constant true
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<f64>
+// CHECK: %[[BLOCK:.*]]:2 = test.isolated_one_region_with_recursive_memory_effects %[[ALLOC]] {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<f64>):
+// CHECK: %[[VAL_9:.*]] = arith.cmpf oeq
+// CHECK: %[[VAL_10:.*]]:2 = scf.if %[[VAL_9]] -> (memref<f64>, i1) {
+// CHECK: %[[BLOCK_ALLOC:.*]] = memref.alloc() : memref<f64>
+// CHECK: scf.yield %[[BLOCK_ALLOC]], %true_{{[0-9]*}} : memref<f64>, i1
+// CHECK: } else {
+// CHECK: scf.yield %[[ARG]], %false : memref<f64>, i1
+// CHECK: }
+// CHECK: test.region_yield %[[VAL_10]]#0, %[[VAL_10]]#1 : memref<f64>, i1
+// CHECK: } : (memref<f64>) -> (memref<f64>, i1)
+// CHECK: test.copy
+// CHECK: %[[BUF:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[BLOCK]]#0 : memref<f64> -> memref<f64>, index
+// CHECK: bufferization.dealloc (%[[ALLOC]], %{{.*}}, %[[BUF]] : memref<f64>, memref<f64>, memref<f64>) if (%true, %true, %[[BLOCK]]#1)
+// CHECK: return
+// CHECK: }
+
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index fbaa102d3e33c..6666c9b86db42 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -110,6 +110,18 @@ void IsolatedRegionOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
+//===----------------------------------------------------------------------===//
+// IsolatedOneRegionWithRecursiveMemoryEffectsOp
+//===----------------------------------------------------------------------===//
+
+void IsolatedOneRegionWithRecursiveMemoryEffectsOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent())
+ regions.emplace_back(&getBody());
+ else
+ regions.emplace_back((*this)->getResults());
+}
+
//===----------------------------------------------------------------------===//
// SSACFGRegionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e1ec1428ee6d6..bbe84572868b2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -507,6 +507,23 @@ def OneRegionWithRecursiveMemoryEffectsOp
let regions = (region SizedRegion<1>:$body);
}
+def IsolatedOneRegionWithRecursiveMemoryEffectsOp
+ : TEST_Op<"isolated_one_region_with_recursive_memory_effects", [
+ RecursiveMemoryEffects,
+ IsolatedFromAbove,
+ SingleBlockImplicitTerminator<"RegionYieldOp">,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
+ let description = [{
+ IsolatedFromAbove Op that has one region and recursive side effects.
+ }];
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat = [{
+ attr-dict-with-keyword $operands $body `:` functional-type(operands, results)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NoTerminator Operation
//===----------------------------------------------------------------------===//
@@ -2147,7 +2164,7 @@ def RegionYieldOp : TEST_Op<"region_yield",
This operation is used in a region and yields the corresponding type for
that operation.
}];
- let arguments = (ins AnyType:$result);
+ let arguments = (ins Variadic<AnyType>:$result);
let assemblyFormat = [{
$result `:` type($result) attr-dict
}];
``````````
</details>
https://github.com/llvm/llvm-project/pull/97669
More information about the Mlir-commits
mailing list