[Mlir-commits] [mlir] [mlir][bufferization] Adding the optimize-allocation-liveness pass (PR #101827)

Dennis Filimonov llvmlistbot at llvm.org
Sat Aug 3 07:41:57 PDT 2024


https://github.com/DennisFily created https://github.com/llvm/llvm-project/pull/101827

Adding a pass that is expected to run after the deallocation pipeline and will move buffer deallocations right after their last user or dependency, thus optimizing the allocation liveness.

>From c693f3b6dd819d33bb02d8f1f1158406f25f1ebf Mon Sep 17 00:00:00 2001
From: "Dennis.Filimonov" <dennis.filimonov at gmail.com>
Date: Fri, 2 Aug 2024 23:31:02 +0300
Subject: [PATCH] Adding the optimize-allocation-liveness pass

This commit will add a pass that is expected to run after the
deallocation pipeline and will move buffer deallocations right after their
last user, thus optimizing the allocation liveness.
---
 .../Dialect/Bufferization/Transforms/Passes.h |   5 +
 .../Bufferization/Transforms/Passes.td        |  16 ++
 .../Bufferization/Transforms/CMakeLists.txt   |   1 +
 .../Transforms/OptimizeAllocationLiveness.cpp | 142 ++++++++++++++
 .../optimize-allocation-liveness.mlir         | 185 ++++++++++++++++++
 5 files changed, 349 insertions(+)
 create mode 100644 mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index c12ed7f5d0180..c7914830b77b7 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -37,6 +37,11 @@ std::unique_ptr<Pass> createBufferDeallocationPass();
 std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
     DeallocationOptions options = DeallocationOptions());
 
+/// Creates a pass that finds all temp allocations, and attempts to move the deallocation after the last user/dependency 
+/// of the allocation. Thus, optimizing the allocation liveness.
+//  The pass is expected to run after the deallocaion pipeline.
+std::unique_ptr<Pass> createOptimizeAllocationlivenessPass();
+
 /// Creates a pass that optimizes `bufferization.dealloc` operations. For
 /// example, it reduces the number of alias checks needed at runtime using
 /// static alias analysis.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1cece818dbbbc..619853704ec50 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -232,6 +232,22 @@ def OwnershipBasedBufferDeallocation : Pass<
   ];
 }
 
+
+def OptimizeAllocationliveness : Pass<
+    "optimize-allocation-liveness", "func::FuncOp"> {
+  let summary = "This pass optimizes the liveness of temp allocations in the input function";
+  let description = [{
+    This pass will go over all the allocations that also have deallocations in the same block i.e. temp buffers.
+    And find the last user/dependency of that allocation , it attempts to move the deallocation right after that last user. 
+    This will optimize liveness of the allocations to the minimum.
+    The pass is expected to run after the deallocating pipeline, which places all deallocation at the end of the function.
+  }];
+  let constructor = "mlir::bufferization::createOptimizeAllocationlivenessPass()";
+  let dependentDialects = [
+    "mlir::memref::MemRefDialect"
+  ];
+}
+
 def BufferDeallocationSimplification :
     Pass<"buffer-deallocation-simplification"> {
   let summary = "Optimizes `bufferization.dealloc` operation for more "
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 8617c17e7a5e5..f27d924416677 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   OneShotModuleBufferize.cpp
   OwnershipBasedBufferDeallocation.cpp
   TensorCopyInsertion.cpp
+  OptimizeAllocationLiveness.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
new file mode 100644
index 0000000000000..d52bc3f8f7f22
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
@@ -0,0 +1,142 @@
+//===- OptimizeAllocationliveness.cpp - impl. for buffer dealloc. ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements an algorithem for optimization of allocation liveness,
+// The algorithm moves the dealloc op to right after the last user of the
+// allocation and on the same block as the allocation.
+//
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "llvm/Support/Debug.h"
+
+#include <optional>
+#include <utility>
+
+#define DEBUG_TYPE "optimize-allocation-liveness"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_OPTIMIZEALLOCATIONLIVENESS
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+/// TODO find proper location for this function, since its copied from the llvm project.
+static bool happensBefore(Operation *a, Operation *b) {
+  do {
+    if (a->isProperAncestor(b))
+      return false;
+    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+      return a->isBeforeInBlock(bAncestor);
+    }
+  } while ((a = a->getParentOp()));
+  return false;
+}
+
+/// This method will find all the users of an op according to given templete
+/// user type.
+/// TODO find proper location for this helper function.
+template <typename T> FailureOr<T> getUserOfType(Value val) {
+  auto isTOp = [](Operation *op) { return isa<T>(op); };
+  auto userItr = llvm::find_if(val.getUsers(), isTOp);
+  if (userItr == val.getUsers().end())
+    return failure();
+  assert(llvm::count_if(val.getUsers(), isTOp) == 1 &&
+         "expecting one user of type T");
+  return cast<T>(*userItr);
+}
+
+struct OptimizeAllocationliveness
+    : public bufferization::impl::OptimizeAllocationlivenessBase<
+          OptimizeAllocationliveness> {
+public:
+  OptimizeAllocationliveness() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    if (func.isExternal())
+      return;
+    if (func.empty() || func.getOps<memref::DeallocOp>().empty())
+      return;
+    
+    BufferViewFlowAnalysis analysis = BufferViewFlowAnalysis(func);
+    func.walk([&](memref::AllocOp allocOp) {
+      LDBG("Checking alloc op: " << allocOp);
+
+      auto deallocOp = getUserOfType<memref::DeallocOp>(allocOp);
+      if (failed(deallocOp)) {
+        return WalkResult::advance();
+      }
+
+      // Find the last user of the alloc op and its aliases.
+      Operation *lastUser = nullptr;
+      const BufferViewFlowAnalysis::ValueSetT& deps = analysis.resolve(allocOp.getMemref());
+      for (auto dep : llvm::make_early_inc_range(deps)) {
+        for (auto user : dep.getUsers()) {
+          // We are looking for a non dealloc op user.
+          if (isa<memref::DeallocOp>(user))
+            continue;
+          // Not expecting a return op to be a user of the alloc op.
+          if (isa<func::ReturnOp>(user))
+            continue;
+
+          // find the ancestor of user that is in the same block as the allocOp.
+          auto topUser = allocOp->getBlock()->findAncestorOpInBlock(*user);
+          if (!lastUser || happensBefore(lastUser, topUser)) {
+            lastUser = topUser;
+          }
+        }
+      }
+      if (lastUser == nullptr) {
+        return WalkResult::advance();
+      }
+      LDBG("Last user found: " << *lastUser);
+      assert(lastUser->getBlock() == allocOp->getBlock());
+      assert(lastUser->getBlock() == (*deallocOp)->getBlock());
+      // Move the dealloc op after the last user.
+      (*deallocOp)->moveAfter(lastUser);
+      LDBG("Moved dealloc op after: " << *lastUser);
+
+      return WalkResult::advance();
+    });
+  }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// OptimizeAllocatinliveness construction
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass>
+mlir::bufferization::createOptimizeAllocationlivenessPass() {
+  return std::make_unique<OptimizeAllocationliveness>();
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir b/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir
new file mode 100644
index 0000000000000..6357c9af44ed3
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir
@@ -0,0 +1,185 @@
+// RUN: mlir-opt %s --optimize-allocation-liveness --split-input-file | FileCheck %s
+
+// CHECK-LABEL:   func.func private @optimize_alloc_location(
+// CHECK-SAME:                             %[[VAL_0:.*]]: memref<45x24x256xf32, 1>,
+// CHECK-SAME:                             %[[VAL_1:.*]]: memref<24x256xf32, 1>,
+// CHECK-SAME:                             %[[VAL_2:.*]]: memref<256xf32, 1>) {
+// CHECK:           %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_4:.*]] = memref.expand_shape %[[VAL_3]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK:           memref.dealloc %[[VAL_3]] : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+// CHECK:           linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_1]], %[[VAL_2]] : memref<24x256xf32, 1>, memref<256xf32, 1>) outs(%[[VAL_5]] : memref<24x256xf32, 1>) {
+// CHECK:           ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+// CHECK:             %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK:             linalg.yield %[[VAL_9]] : f32
+// CHECK:           }
+// CHECK:           memref.dealloc %[[VAL_5]] : memref<24x256xf32, 1>
+// CHECK:           return
+// CHECK:         }
+
+// this will optimize the location of the %alloc deallocation
+func.func private @optimize_alloc_location(%arg0: memref<45x24x256xf32, 1> , %arg1: memref<24x256xf32, 1> , %arg2: memref<256xf32, 1>) -> () {
+
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+  %expand_shape = memref.expand_shape %alloc [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg1, %arg2 : memref<24x256xf32, 1>, memref<256xf32, 1>) outs(%alloc_1 : memref<24x256xf32, 1>) {
+  ^bb0(%in: f32, %in_3: f32, %out: f32):
+    %0 = arith.addf %in, %in_3 : f32
+    linalg.yield %0 : f32
+  }
+  memref.dealloc %alloc : memref<45x6144xf32, 1>
+  memref.dealloc %alloc_1 : memref<24x256xf32, 1>
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @test_multiple_deallocation_moves(
+// CHECK-SAME:                                                        %[[VAL_0:.*]]: memref<45x24x256xf32, 1>,
+// CHECK-SAME:                                                        %[[VAL_1:.*]]: memref<24x256xf32, 1>,
+// CHECK-SAME:                                                        %[[VAL_2:.*]]: memref<256xf32, 1>) {
+// CHECK:           %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_4:.*]] = memref.expand_shape %[[VAL_3]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK:           memref.dealloc %[[VAL_3]] : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+// CHECK:           %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_7:.*]] = memref.expand_shape %[[VAL_6]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK:           memref.dealloc %[[VAL_6]] : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_8:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_9:.*]] = memref.expand_shape %[[VAL_8]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK:           memref.dealloc %[[VAL_8]] : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_11:.*]] = memref.expand_shape %[[VAL_10]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK:           memref.dealloc %[[VAL_10]] : memref<45x6144xf32, 1>
+// CHECK:           linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_1]], %[[VAL_2]] : memref<24x256xf32, 1>, memref<256xf32, 1>) outs(%[[VAL_5]] : memref<24x256xf32, 1>) {
+// CHECK:           ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32):
+// CHECK:             %[[VAL_15:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32
+// CHECK:             linalg.yield %[[VAL_15]] : f32
+// CHECK:           }
+// CHECK:           memref.dealloc %[[VAL_5]] : memref<24x256xf32, 1>
+// CHECK:           return
+// CHECK:         }
+
+// This tests creates multiple deallocation rearrangements. 
+func.func private @test_multiple_deallocation_moves(%arg0: memref<45x24x256xf32, 1> , %arg1: memref<24x256xf32, 1> , %arg2: memref<256xf32, 1>) -> () {
+
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+  %expand_shape = memref.expand_shape %alloc [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+  %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+  %expand_shape2 = memref.expand_shape %alloc_2 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+  %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+  %expand_shape3 = memref.expand_shape %alloc_3 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+  %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+  %expand_shape4 = memref.expand_shape %alloc_4 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg1, %arg2 : memref<24x256xf32, 1>, memref<256xf32, 1>) outs(%alloc_1 : memref<24x256xf32, 1>) {
+  ^bb0(%in: f32, %in_3: f32, %out: f32):
+    %0 = arith.addf %in, %in_3 : f32
+    linalg.yield %0 : f32
+  }
+  memref.dealloc %alloc : memref<45x6144xf32, 1>
+  memref.dealloc %alloc_1 : memref<24x256xf32, 1>
+  memref.dealloc %alloc_2 : memref<45x6144xf32, 1>
+  memref.dealloc %alloc_3 : memref<45x6144xf32, 1>
+  memref.dealloc %alloc_4 : memref<45x6144xf32, 1>
+  return
+}
+
+// -----
+// CHECK-LABEL:   func.func private @test_users_in_different_blocks_linalig_generic(
+// CHECK-SAME:                                                                      %[[VAL_0:.*]]: memref<1x20x20xf32, 1>) -> (memref<8x32x1x4xf32, 1>, memref<1x32x32xf32, 1>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32x32xf32, 1>
+// CHECK:           %[[VAL_4:.*]] = memref.subview %[[VAL_3]][0, 0, 0] [1, 20, 20] [1, 1, 1] : memref<1x32x32xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
+// CHECK:           memref.copy %[[VAL_0]], %[[VAL_4]] : memref<1x20x20xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
+// CHECK:           %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x1xf32, 1>
+// CHECK:           %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x8x32x1x4xf32, 1>
+// CHECK:           linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%[[VAL_6]] : memref<1x8x32x1x4xf32, 1>) {
+// CHECK:           ^bb0(%[[VAL_7:.*]]: f32):
+// CHECK:             %[[VAL_8:.*]] = linalg.index 0 : index
+// CHECK:             %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]], %[[VAL_2]]] : memref<1x32x32x1xf32, 1>
+// CHECK:             linalg.yield %[[VAL_9]] : f32
+// CHECK:           }
+// CHECK:           memref.dealloc %[[VAL_5]] : memref<1x32x32x1xf32, 1>
+// CHECK:           %[[VAL_10:.*]] = memref.collapse_shape %[[VAL_6]] {{\[\[}}0, 1], [2], [3], [4]] : memref<1x8x32x1x4xf32, 1> into memref<8x32x1x4xf32, 1>
+// CHECK:           memref.dealloc %[[VAL_6]] : memref<1x8x32x1x4xf32, 1>
+// CHECK:           return %[[VAL_10]], %[[VAL_3]] : memref<8x32x1x4xf32, 1>, memref<1x32x32xf32, 1>
+// CHECK:         }
+
+
+// This test will optimize the location of the %alloc_0 deallocation, since the last user of this allocation is the last linalg.generic operation
+// it will move the deallocation right after the last linalg.generic operation
+// %alloc_1 will not be moved becuase of the collapse shape op.
+func.func private @test_users_in_different_blocks_linalig_generic(%arg0: memref<1x20x20xf32, 1>) -> (memref<8x32x1x4xf32, 1> , memref<1x32x32xf32, 1> ) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x32x32xf32, 1>
+  %subview = memref.subview %alloc[0, 0, 0] [1, 20, 20] [1, 1, 1] : memref<1x32x32xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
+  memref.copy %arg0, %subview : memref<1x20x20xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
+  %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x1xf32, 1>
+  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x8x32x1x4xf32, 1>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%alloc_1 : memref<1x8x32x1x4xf32, 1>) {
+  ^bb0(%out: f32):
+    %0 = linalg.index 0 : index
+    %8 = memref.load %alloc_0[%0, %0, %0, %c0] : memref<1x32x32x1xf32, 1>
+    linalg.yield %8 : f32
+  }
+  %collapse_shape = memref.collapse_shape %alloc_1 [[0, 1], [2], [3], [4]] : memref<1x8x32x1x4xf32, 1> into memref<8x32x1x4xf32, 1>
+  memref.dealloc %alloc_0 : memref<1x32x32x1xf32, 1>
+  memref.dealloc %alloc_1 : memref<1x8x32x1x4xf32, 1>
+  return %collapse_shape, %alloc : memref<8x32x1x4xf32, 1>, memref<1x32x32xf32, 1>
+}
+
+// -----
+// CHECK-LABEL:   func.func private @test_deallocs_in_different_block_forops(
+// CHECK-SAME:                                                               %[[VAL_0:.*]]: memref<45x24x256xf32, 1>,
+// CHECK-SAME:                                                               %[[VAL_1:.*]]: memref<24x256xf32, 1>,
+// CHECK-SAME:                                                               %[[VAL_2:.*]]: memref<256xf32, 1>) {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 45 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 24 : index
+// CHECK:           %[[VAL_8:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_9:.*]] = memref.expand_shape %[[VAL_8]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK:           %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+// CHECK:           %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK:           %[[VAL_12:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK:           memref.dealloc %[[VAL_11]] : memref<45x6144xf32, 1>
+// CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] {
+// CHECK:             scf.for %[[VAL_14:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_5]] {
+// CHECK:               %[[VAL_15:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_13]], %[[VAL_14]], 0] [1, 8, 256] [1, 1, 1] : memref<45x24x256xf32, 1> to memref<1x8x256xf32, strided<[6144, 256, 1], offset: ?>, 1>
+// CHECK:               %[[VAL_16:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_14]], 0] [8, 256] [1, 1] : memref<24x256xf32, 1> to memref<8x256xf32, strided<[256, 1], offset: ?>, 1>
+// CHECK:             }
+// CHECK:           }
+// CHECK:           memref.dealloc %[[VAL_10]] : memref<24x256xf32, 1>
+// CHECK:           memref.dealloc %[[VAL_8]] : memref<45x6144xf32, 1>
+// CHECK:           return
+// CHECK:         }
+
+// This test will not move the deallocations %alloc and %alloc1 since they are used in the last scf.for operation
+// %alloc_2 will move right after its last user the expand_shape operation
+func.func private @test_deallocs_in_different_block_forops(%arg0: memref<45x24x256xf32, 1>, %arg1: memref<24x256xf32, 1> , %arg2: memref<256xf32, 1> ) -> () {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c8 = arith.constant 8 : index
+  %c45 = arith.constant 45 : index
+  %c24 = arith.constant 24 : index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+  %expand_shape = memref.expand_shape %alloc [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+  %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+  %expand_shape2 = memref.expand_shape %alloc_2 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+  scf.for %arg3 = %c0 to %c45 step %c1 {
+    scf.for %arg4 = %c0 to %c24 step %c8 {
+      %subview = memref.subview %expand_shape[%arg3, %arg4, 0] [1, 8, 256] [1, 1, 1] : memref<45x24x256xf32, 1> to memref<1x8x256xf32, strided<[6144, 256, 1], offset: ?>, 1>
+      %subview_3 = memref.subview %alloc_1[%arg4, 0] [8, 256] [1, 1] : memref<24x256xf32, 1> to memref<8x256xf32, strided<[256, 1], offset: ?>, 1>    
+    }
+  }
+  memref.dealloc %alloc : memref<45x6144xf32, 1>
+  memref.dealloc %alloc_1 : memref<24x256xf32, 1>
+  memref.dealloc %alloc_2 : memref<45x6144xf32, 1>
+  return
+}
\ No newline at end of file



More information about the Mlir-commits mailing list