[Mlir-commits] [mlir] [mlir][bufferization] Better analysis around allocs and block arguments (PR #67923)

Matthias Springer llvmlistbot at llvm.org
Sun Oct 1 09:04:34 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/67923

Values that are the result of buffer allocation ops are guaranteed to *not* be the same allocation as block arguments of containing blocks. This fact can be used to allow for more aggressive simplification of `bufferization.dealloc` ops.

>From 447aaa3d8c548a3e4f9a491c8bf5587e0e848a2b Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 1 Oct 2023 18:03:11 +0200
Subject: [PATCH] [mlir][bufferization] Better analysis around allocs and block
 arguments

Values that are the result of buffer allocation ops are guaranteed to *not* be the same allocation as block arguments of containing blocks. This fact can be used to allow for more aggressive simplification of `bufferization.dealloc` ops.
---
 .../BufferDeallocationSimplification.cpp      | 50 ++++++++++++++++---
 .../dealloc-region-branchop-interface.mlir    | 13 ++---
 .../buffer-deallocation-simplification.mlir   | 27 ++++++++++
 3 files changed, 76 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 5a6de372e2310dc..a834b7d8f89fb72 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -49,19 +49,53 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
   return success();
 }
 
-/// Checks if 'memref' may or must alias a MemRef in 'memrefList'. It is often a
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+  while (value) {
+    auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>();
+    if (!viewLikeOp)
+      return value;
+    value = viewLikeOp.getViewSource();
+  }
+  return value;
+}
+
+/// Return "true" if the given values are guaranteed to be different (and
+/// non-aliasing) allocations based on the fact that one value is the result
+/// of an allocation and the other value is a block argument of a parent block.
+/// Note: This is a best-effort analysis that will eventually be replaced by a
+/// proper "is same allocation" analysis. This function may return "false" even
+/// though the two values are distinct allocations.
+static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
+  Value v1Base = getViewBase(v1);
+  Value v2Base = getViewBase(v2);
+  auto areDistinct = [](Value v1, Value v2) {
+    if (Operation *op = v1.getDefiningOp())
+      if (hasEffect<MemoryEffects::Allocate>(op, v1))
+        if (auto bbArg = dyn_cast<BlockArgument>(v2))
+          if (bbArg.getOwner()->findAncestorOpInBlock(*op))
+            return true;
+    return false;
+  };
+  return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
+}
+
+/// Checks if `memref` may or must alias a MemRef in `otherList`. It is often a
 /// requirement of optimization patterns that there cannot be any aliasing
-/// memref in order to perform the desired simplification. The 'allowSelfAlias'
-/// argument indicates whether 'memref' may be present in 'memrefList' which
+/// memref in order to perform the desired simplification. The `allowSelfAlias`
+/// argument indicates whether `memref` may be present in `otherList` which
 /// makes this helper function applicable to situations where we already know
-/// that 'memref' is in the list but also when we don't want it in the list.
+/// that `memref` is in the list but also when we don't want it in the list.
 static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
-                                     ValueRange memrefList, Value memref,
+                                     ValueRange otherList, Value memref,
                                      bool allowSelfAlias) {
-  for (auto mr : memrefList) {
-    if (allowSelfAlias && mr == memref)
+  for (auto other : otherList) {
+    if (allowSelfAlias && other == memref)
+      continue;
+    if (distinctAllocAndBlockArgument(other, memref))
       continue;
-    if (!analysis.alias(mr, memref).isNo())
+    if (!analysis.alias(other, memref).isNo())
       return true;
   }
   return false;
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
index dc372749fc074be..ad7c4c783e907f3 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
@@ -270,7 +270,8 @@ func.func @loop_alloc(
 //       CHECK: [[V0:%.+]]:2 = scf.for {{.*}} iter_args([[ARG6:%.+]] = [[ARG3]], [[ARG7:%.+]] = %false
 //       CHECK:   [[ALLOC1:%.+]] = memref.alloc()
 //       CHECK:   [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG6]]
-//       CHECK:   bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG7]]) retain ([[ALLOC1]] :
+//       CHECK:   bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG7]])
+//   CHECK-NOT:       retain
 //       CHECK:   scf.yield [[ALLOC1]], %true
 //       CHECK: test.copy
 //       CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
@@ -563,8 +564,8 @@ func.func @while_two_arg(%arg0: index) {
 //       CHECK: ^bb0([[ARG1:%.+]]: memref<?xf32>, [[ARG2:%.+]]: memref<?xf32>, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1):
 //       CHECK:   [[ALLOC1:%.+]] = memref.alloc(
 //       CHECK:   [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]]
-//       CHECK:   [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG4]]) retain ([[ARG1]], [[ALLOC1]] :
-//       CHECK:   [[OWN_AGG:%.+]] = arith.ori [[OWN]]#0, [[ARG3]]
+//       CHECK:   [[OWN:%.+]] = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG4]]) retain ([[ARG1]] :
+//       CHECK:   [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[ARG3]]
 //       CHECK:   scf.yield [[ARG1]], [[ALLOC1]], [[OWN_AGG]], %true
 //       CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
 //       CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#1
@@ -594,10 +595,10 @@ func.func @while_three_arg(%arg0: index) {
 //       CHECK:   [[ALLOC1:%.+]] = memref.alloc(
 //       CHECK:   [[ALLOC2:%.+]] = memref.alloc(
 //       CHECK:   [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG1]]
-//       CHECK:   [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]]
 //       CHECK:   [[BASE2:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG3]]
-//       CHECK:   [[OWN:%.+]]:3 = bufferization.dealloc ([[BASE0]], [[BASE1]], [[BASE2]], [[ALLOC1]] :{{.*}}) if ([[ARG4]], [[ARG5]], [[ARG6]], %true{{[0-9_]*}}) retain ([[ALLOC2]], [[ALLOC1]], [[ARG2]] :
-//       CHECK:   scf.yield [[ALLOC2]], [[ALLOC1]], [[ARG2]], %true{{[0-9_]*}}, %true{{[0-9_]*}}, [[OWN]]#2 :
+//       CHECK:   [[OWN:%.+]] = bufferization.dealloc ([[BASE0]], [[BASE2]] :{{.*}}) if ([[ARG4]], [[ARG6]]) retain ([[ARG2]] :
+//       CHECK:   [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[ARG5]]
+//       CHECK:   scf.yield [[ALLOC2]], [[ALLOC1]], [[ARG2]], %true{{[0-9_]*}}, %true{{[0-9_]*}}, [[OWN_AGG]] :
 //       CHECK: }
 //       CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
 //       CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#1
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index 98eb038df30a3b9..e192e9870becdbc 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -133,3 +133,30 @@ func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_c
 //       CHECK:   [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]]
 //       CHECK:   bufferization.dealloc ([[BASE]] :{{.*}}) if (%true{{[0-9_]*}})
 //  CHECK-NEXT:   return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} :
+
+// -----
+
+func.func @alloc_and_bbarg(%arg0: memref<5xf32>, %arg1: index, %arg2: index, %arg3: index) -> f32 {
+  %true = arith.constant true
+  %false = arith.constant false
+  %0:2 = scf.for %arg4 = %arg1 to %arg2 step %arg3 iter_args(%arg5 = %arg0, %arg6 = %false) -> (memref<5xf32>, i1) {
+    %alloc = memref.alloc() : memref<5xf32>
+    memref.copy %arg5, %alloc : memref<5xf32> to memref<5xf32>
+    %base_buffer_0, %offset_1, %sizes_2, %strides_3 = memref.extract_strided_metadata %arg5 : memref<5xf32> -> memref<f32>, index, index, index
+    %2 = bufferization.dealloc (%base_buffer_0, %alloc : memref<f32>, memref<5xf32>) if (%arg6, %true) retain (%alloc : memref<5xf32>)
+    scf.yield %alloc, %2 : memref<5xf32>, i1
+  }
+  %1 = memref.load %0#0[%arg1] : memref<5xf32>
+  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %0#0 : memref<5xf32> -> memref<f32>, index, index, index
+  bufferization.dealloc (%base_buffer : memref<f32>) if (%0#1)
+  return %1 : f32
+}
+
+// CHECK-LABEL: func @alloc_and_bbarg
+//       CHECK:   %[[true:.*]] = arith.constant true
+//       CHECK:   scf.for {{.*}} iter_args(%[[iter:.*]] = %{{.*}}, %{{.*}} = %{{.*}})
+//       CHECK:     %[[alloc:.*]] = memref.alloc
+//       CHECK:     %[[view:.*]], %{{.*}}, %{{.*}}, %{{.*}} =  memref.extract_strided_metadata %[[iter]]
+//       CHECK:     bufferization.dealloc (%[[view]] : memref<f32>)
+//   CHECK-NOT:     retain
+//       CHECK:     scf.yield %[[alloc]], %[[true]]



More information about the Mlir-commits mailing list