[Mlir-commits] [mlir] [mlir][bufferization] Ownership dealloc: support `IsolatedFromAbove` (PR #97669)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 3 20:40:21 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

Author: Nikhil Kalra (nikalra)

<details>
<summary>Changes</summary>

Handle `IsolatedFromAbove` operations in `ownership-based-buffer-deallocation` by using the same contract as function boundaries. Specifically, IsolatedFromAbove ops cannot take ownership of their arguments, and rely on the caller to deallocate them.

---
Full diff: https://github.com/llvm/llvm-project/pull/97669.diff


5 Files Affected:

- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (+5) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp (+3-2) 
- (added) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir (+125) 
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+12) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+18-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 72f47b8b468ea..d9525cb640e1c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -174,6 +174,11 @@ void BufferViewFlowAnalysis::build(Operation *op) {
         }
       }
 
+      if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+        // Mark the entry block arguments and results as terminal.
+        populateTerminalValues(op);
+      }
+
       return WalkResult::advance();
     }
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index ca5d0688b5b59..a52906174fb07 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -640,8 +640,9 @@ LogicalResult BufferDeallocation::deallocate(Block *block) {
       continue;
 
     // Adhere to function boundary ABI: no ownership of function argument
-    // MemRefs is taken.
-    if (isa<FunctionOpInterface>(block->getParentOp()) &&
+    // MemRefs is taken. Likewise for ops marked IsolatedFromAbove.
+    if ((isa<FunctionOpInterface>(block->getParentOp()) ||
+         block->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) &&
         block->isEntryBlock()) {
       Value newArg = buildBoolValue(builder, arg.getLoc(), false);
       state.updateOwnership(arg, newArg);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir
new file mode 100644
index 0000000000000..0c3ceda5237cc
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir
@@ -0,0 +1,125 @@
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
+// RUN:   -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+
+func.func @function_call() {
+  %alloc = memref.alloc() : memref<f64>
+  %alloc2 = memref.alloc() : memref<f64>
+  %ret = test.isolated_one_region_with_recursive_memory_effects %alloc {
+  ^bb0(%arg1: memref<f64>):
+    test.region_yield %arg1 : memref<f64>
+  } : (memref<f64>) -> memref<f64>
+  test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+  return
+}
+
+// CHECK-LABEL: func @function_call()
+//       CHECK: [[ALLOC0:%.+]] = memref.alloc(
+//  CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloc(
+//  CHECK-NEXT: [[RET:%.+]]:2 = test.isolated_one_region_with_recursive_memory_effects [[ALLOC0]]
+//  CHECK-NEXT: ^bb0([[ARG:%.+]]: memref<f64>)
+//       CHECK:   test.region_yield [[ARG]]
+//   CHECK-NOT:   bufferization.dealloc
+//       CHECK: }
+//  CHECK-NEXT: test.copy
+//  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
+//  CHECK-NEXT: bufferization.dealloc ([[ALLOC0]], [[ALLOC1]], [[BASE]] :{{.*}}) if (%true, %true, [[RET]]#1)
+
+// -----
+
+func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
+  %alloc = memref.alloc() : memref<f64>
+  %alloc2 = memref.alloca() : memref<f64>
+  %0 = arith.select %arg0, %alloc, %alloc2 : memref<f64>
+  %ret = test.isolated_one_region_with_recursive_memory_effects %0 {
+  ^bb0(%arg1: memref<f64>):
+    test.region_yield %arg1 : memref<f64>
+  } : (memref<f64>) -> (memref<f64>)
+  test.copy(%ret, %alloc) : (memref<f64>, memref<f64>)
+  return
+}
+
+// CHECK-LABEL: func @function_call_requries_merged_ownership_mid_block
+//       CHECK:   [[ALLOC0:%.+]] = memref.alloc(
+//  CHECK-NEXT:   [[ALLOC1:%.+]] = memref.alloca(
+//  CHECK-NEXT:   [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
+//  CHECK-NEXT:   [[RET:%.+]]:2 = test.isolated_one_region_with_recursive_memory_effects [[SELECT]]
+//  CHECK-NEXT:   ^bb0([[ARG:%.+]]: memref<f64>)
+//       CHECK:     test.region_yield [[ARG]]
+//   CHECK-NOT:     bufferization.dealloc
+//       CHECK:   }
+//  CHECK-NEXT:   test.copy
+//  CHECK-NEXT:   [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
+//  CHECK-NEXT:   bufferization.dealloc ([[ALLOC0]], [[BASE]] :
+//  CHECK-SAME:     if (%true, [[RET]]#1)
+//   CHECK-NOT:     retain
+//  CHECK-NEXT:   return
+
+// -----
+
+func.func @g(%arg0: memref<f32>) -> memref<f32> {
+  %0 = test.isolated_one_region_with_recursive_memory_effects %arg0 {
+  ^bb0(%arg1: memref<f32>):
+    test.region_yield %arg1 : memref<f32>
+  } : (memref<f32>) -> (memref<f32>)
+  return %0 : memref<f32>
+}
+
+// CHECK-LABEL:   func.func @g(
+// CHECK-SAME:                 %[[VAL_0:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           %[[BLOCK:.*]]:2 = test.isolated_one_region_with_recursive_memory_effects %[[VAL_0]] {
+// CHECK:           ^bb0(%[[ARG:.*]]: memref<f32>):
+// CHECK:             test.region_yield %[[ARG]], %false : memref<f32>, i1
+// CHECK:           } : (memref<f32>) -> (memref<f32>, i1)
+// CHECK:           %[[VAL_4:.*]] = scf.if %[[BLOCK]]#1 -> (memref<f32>) {
+// CHECK:             scf.yield %[[BLOCK]]#0 : memref<f32>
+// CHECK:           } else {
+// CHECK:             %[[VAL_6:.*]] = bufferization.clone %[[BLOCK]]#0 : memref<f32> to memref<f32>
+// CHECK:             scf.yield %[[VAL_6]] : memref<f32>
+// CHECK:           }
+// CHECK:           %[[BUF:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[BLOCK]]#0 : memref<f32> -> memref<f32>, index
+// CHECK:           %[[VAL_11:.*]] = bufferization.dealloc (%[[BUF]] : memref<f32>) if (%[[BLOCK]]#1) retain (%[[VAL_4]] : memref<f32>)
+// CHECK:           return %[[VAL_4]] : memref<f32>
+// CHECK:         }
+
+// -----
+
+func.func @alloc_yielded_from_block() {
+  %alloc = memref.alloc() : memref<f64>
+  %alloc2 = memref.alloc() : memref<f64>
+  %ret = test.isolated_one_region_with_recursive_memory_effects %alloc {
+  ^bb0(%arg1: memref<f64>):
+    %0 = memref.load %arg1[] : memref<f64>
+    %c1 = arith.constant 1.0 : f64
+    %r0 = arith.cmpf oeq, %0, %c1 : f64
+    %1 = scf.if %r0 -> memref<f64> {
+      %alloc3 = memref.alloc() : memref<f64>
+      scf.yield %alloc3 : memref<f64>
+    } else {
+      scf.yield %arg1 : memref<f64>
+    }
+    test.region_yield %1 : memref<f64>
+  } : (memref<f64>) -> memref<f64>
+  test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+  return
+}
+
+// CHECK-LABEL:   func.func @alloc_yielded_from_block() {
+// CHECK:           %true = arith.constant true
+// CHECK:           %[[ALLOC:.*]] = memref.alloc() : memref<f64>
+// CHECK:           %[[BLOCK:.*]]:2 = test.isolated_one_region_with_recursive_memory_effects %[[ALLOC]] {
+// CHECK:           ^bb0(%[[ARG:.*]]: memref<f64>):
+// CHECK:             %[[VAL_9:.*]] = arith.cmpf oeq
+// CHECK:             %[[VAL_10:.*]]:2 = scf.if %[[VAL_9]] -> (memref<f64>, i1) {
+// CHECK:               %[[BLOCK_ALLOC:.*]] = memref.alloc() : memref<f64>
+// CHECK:               scf.yield %[[BLOCK_ALLOC]], %true_{{[0-9]*}} : memref<f64>, i1
+// CHECK:             } else {
+// CHECK:               scf.yield %[[ARG]], %false : memref<f64>, i1
+// CHECK:             }
+// CHECK:             test.region_yield %[[VAL_10]]#0, %[[VAL_10]]#1 : memref<f64>, i1
+// CHECK:           } : (memref<f64>) -> (memref<f64>, i1)
+// CHECK:           test.copy
+// CHECK:           %[[BUF:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[BLOCK]]#0 : memref<f64> -> memref<f64>, index
+// CHECK:           bufferization.dealloc (%[[ALLOC]], %{{.*}}, %[[BUF]] : memref<f64>, memref<f64>, memref<f64>) if (%true, %true, %[[BLOCK]]#1)
+// CHECK:           return
+// CHECK:         }
+
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index fbaa102d3e33c..6666c9b86db42 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -110,6 +110,18 @@ void IsolatedRegionOp::print(OpAsmPrinter &p) {
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
 }
 
+//===----------------------------------------------------------------------===//
+// IsolatedOneRegionWithRecursiveMemoryEffectsOp
+//===----------------------------------------------------------------------===//
+
+void IsolatedOneRegionWithRecursiveMemoryEffectsOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (point.isParent())
+    regions.emplace_back(&getBody());
+  else
+    regions.emplace_back((*this)->getResults());
+}
+
 //===----------------------------------------------------------------------===//
 // SSACFGRegionOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e1ec1428ee6d6..bbe84572868b2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -507,6 +507,23 @@ def OneRegionWithRecursiveMemoryEffectsOp
   let regions = (region SizedRegion<1>:$body);
 }
 
+def IsolatedOneRegionWithRecursiveMemoryEffectsOp
+    : TEST_Op<"isolated_one_region_with_recursive_memory_effects", [
+        RecursiveMemoryEffects,
+        IsolatedFromAbove,
+        SingleBlockImplicitTerminator<"RegionYieldOp">,
+        DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
+  let description = [{
+    IsolatedFromAbove Op that has one region and recursive side effects.
+  }];
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$body);
+  let assemblyFormat = [{
+    attr-dict-with-keyword $operands $body `:` functional-type(operands, results)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NoTerminator Operation
 //===----------------------------------------------------------------------===//
@@ -2147,7 +2164,7 @@ def RegionYieldOp : TEST_Op<"region_yield",
     This operation is used in a region and yields the corresponding type for
     that operation.
   }];
-  let arguments = (ins AnyType:$result);
+  let arguments = (ins Variadic<AnyType>:$result);
   let assemblyFormat = [{
     $result `:` type($result) attr-dict
   }];

``````````

</details>


https://github.com/llvm/llvm-project/pull/97669


More information about the Mlir-commits mailing list