[Mlir-commits] [mlir] [mlir][bufferization] Adding the optimize-allocation-liveness pass (PR #101827)
Dennis Filimonov
llvmlistbot at llvm.org
Wed Aug 14 00:11:41 PDT 2024
https://github.com/DennisFily updated https://github.com/llvm/llvm-project/pull/101827
>From 0ff47e7e0cee44526761a910ade30fde93b2eb6c Mon Sep 17 00:00:00 2001
From: "Dennis.Filimonov" <dennis.filimonov at gmail.com>
Date: Fri, 2 Aug 2024 23:31:02 +0300
Subject: [PATCH] Adding the optimize-allocation-liveness pass
This commit will add a pass that is expected to run after the
deallocation pipeline and will move buffer deallocations right after their
last user, thus optimizing the allocation liveness.
---
.../Dialect/Bufferization/Transforms/Passes.h | 5 +
.../Bufferization/Transforms/Passes.td | 16 ++
.../Bufferization/Transforms/CMakeLists.txt | 1 +
.../Transforms/OptimizeAllocationLiveness.cpp | 157 +++++++++++++++
.../optimize-allocation-liveness.mlir | 178 ++++++++++++++++++
5 files changed, 357 insertions(+)
create mode 100644 mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
create mode 100644 mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index c12ed7f5d0180b..72abb5b3f1f94e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -37,6 +37,11 @@ std::unique_ptr<Pass> createBufferDeallocationPass();
std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
DeallocationOptions options = DeallocationOptions());
+/// Creates a pass that finds all temporary allocations
+/// and attempts to move the deallocation after the last user/dependency
+/// of the allocation, thereby optimizing allocation liveness.
+std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();
+
/// Creates a pass that optimizes `bufferization.dealloc` operations. For
/// example, it reduces the number of alias checks needed at runtime using
/// static alias analysis.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1cece818dbbbc3..a610ddcc9899ed 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -252,6 +252,22 @@ def BufferDeallocationSimplification :
];
}
+def OptimizeAllocationLiveness
+ : Pass<"optimize-allocation-liveness", "func::FuncOp"> {
+ let summary = "This pass optimizes the liveness of temp allocations in the "
+ "input function";
+ let description =
+ [{This pass will find all operations that have a memory allocation effect.
+ It will search for the corresponding deallocation and move it right after
+ the last user of the allocation.
+ This will optimize the liveness of the allocations.
+
+ The pass is expected to run after the deallocation pipeline.}];
+ let constructor =
+ "mlir::bufferization::createOptimizeAllocationLivenessPass()";
+ let dependentDialects = ["mlir::memref::MemRefDialect"];
+}
+
def LowerDeallocations : Pass<"bufferization-lower-deallocations"> {
let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`"
"operations";
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 8617c17e7a5e5e..f27d924416677a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
OneShotModuleBufferize.cpp
OwnershipBasedBufferDeallocation.cpp
TensorCopyInsertion.cpp
+ OptimizeAllocationLiveness.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
new file mode 100644
index 00000000000000..8ff1a134399644
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
@@ -0,0 +1,157 @@
+//===- OptimizeAllocationLiveness.cpp - impl. optimize allocation liveness pass
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for optimizing allocation liveness.
+// The pass moves the deallocation operation after the last user of the
+// allocated buffer.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "optimize-allocation-liveness"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_OPTIMIZEALLOCATIONLIVENESS
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+ do {
+ if (a->isProperAncestor(b))
+ return false;
+ if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+ return a->isBeforeInBlock(bAncestor);
+ }
+ } while ((a = a->getParentOp()));
+ return false;
+}
+
+/// This method searches for a user of value that is a dealloc operation.
+/// If multiple users with free effect are found, return nullptr.
+Operation *findUserWithFreeSideEffect(Value value) {
+ Operation *freeOpUser = nullptr;
+ for (Operation *user : value.getUsers()) {
+ if (MemoryEffectOpInterface memEffectOp =
+ dyn_cast<MemoryEffectOpInterface>(user)) {
+ SmallVector<MemoryEffects::EffectInstance, 2> effects;
+ memEffectOp.getEffects(effects);
+
+ for (const auto &effect : effects) {
+ if (isa<MemoryEffects::Free>(effect.getEffect())) {
+ if (freeOpUser) {
+ LDBG("Multiple users with free effect found: " << *freeOpUser
+ << " and " << *user);
+ return nullptr;
+ }
+ freeOpUser = user;
+ }
+ }
+ }
+ }
+ return freeOpUser;
+}
+
+/// Checks if the given op allocates memory.
+static bool hasMemoryAllocEffect(MemoryEffectOpInterface memEffectOp) {
+ SmallVector<MemoryEffects::EffectInstance, 2> effects;
+ memEffectOp.getEffects(effects);
+ for (const auto &effect : effects) {
+ if (isa<MemoryEffects::Allocate>(effect.getEffect())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+struct OptimizeAllocationLiveness
+ : public bufferization::impl::OptimizeAllocationLivenessBase<
+ OptimizeAllocationLiveness> {
+public:
+ OptimizeAllocationLiveness() = default;
+
+ void runOnOperation() override {
+ func::FuncOp func = getOperation();
+
+ if (func.isExternal())
+ return;
+
+ BufferViewFlowAnalysis analysis = BufferViewFlowAnalysis(func);
+
+ func.walk([&](MemoryEffectOpInterface memEffectOp) -> WalkResult {
+ if (!hasMemoryAllocEffect(memEffectOp))
+ return WalkResult::advance();
+
+ auto allocOp = memEffectOp;
+ LDBG("Checking alloc op: " << allocOp);
+
+ auto deallocOp = findUserWithFreeSideEffect(allocOp->getResult(0));
+ if (!deallocOp)
+ return WalkResult::advance();
+
+ Operation *lastUser = nullptr;
+ const BufferViewFlowAnalysis::ValueSetT &deps =
+ analysis.resolve(allocOp->getResult(0));
+ for (auto dep : llvm::make_early_inc_range(deps)) {
+ for (auto user : dep.getUsers()) {
+ // We are looking for a non dealloc op user.
+ // check if user is the dealloc op itself.
+ if (user == deallocOp)
+ continue;
+
+ // find the ancestor of user that is in the same block as the allocOp.
+ auto topUser = allocOp->getBlock()->findAncestorOpInBlock(*user);
+ if (!lastUser || happensBefore(lastUser, topUser)) {
+ lastUser = topUser;
+ }
+ }
+ }
+ if (lastUser == nullptr) {
+ return WalkResult::advance();
+ }
+ LDBG("Last user found: " << *lastUser);
+ assert(lastUser->getBlock() == allocOp->getBlock());
+ assert(lastUser->getBlock() == deallocOp->getBlock());
+ // Move the dealloc op after the last user.
+ deallocOp->moveAfter(lastUser);
+ LDBG("Moved dealloc op after: " << *lastUser);
+
+ return WalkResult::advance();
+ });
+ }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// OptimizeAllocatinliveness construction
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass>
+mlir::bufferization::createOptimizeAllocationLivenessPass() {
+ return std::make_unique<OptimizeAllocationLiveness>();
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir b/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir
new file mode 100644
index 00000000000000..137426190eab9f
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir
@@ -0,0 +1,178 @@
+// RUN: mlir-opt %s --optimize-allocation-liveness --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func private @optimize_alloc_location(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<45x24x256xf32, 1>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<24x256xf32, 1>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<256xf32, 1>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_4:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_5:.*]] = memref.expand_shape %[[VAL_4]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK: memref.dealloc %[[VAL_4]] : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+// CHECK: %[[VAL_7:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: memref.store %[[VAL_7]], %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]] : memref<24x256xf32, 1>
+// CHECK: memref.dealloc %[[VAL_6]] : memref<24x256xf32, 1>
+// CHECK: return
+// CHECK: }
+
+
+// This test will optimize the location of the %alloc deallocation
+func.func private @optimize_alloc_location(%arg0: memref<45x24x256xf32, 1> , %arg1: memref<24x256xf32, 1> , %arg2: memref<256xf32, 1>) -> () {
+ %c1 = arith.constant 1 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+ %expand_shape = memref.expand_shape %alloc [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+ %cf1 = arith.constant 1.0 : f32
+ memref.store %cf1, %alloc_1[%c1, %c1] : memref<24x256xf32, 1>
+ memref.dealloc %alloc : memref<45x6144xf32, 1>
+ memref.dealloc %alloc_1 : memref<24x256xf32, 1>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @test_multiple_deallocation_moves(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<45x24x256xf32, 1>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<24x256xf32, 1>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<256xf32, 1>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_4:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_5:.*]] = memref.expand_shape %[[VAL_4]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK: memref.dealloc %[[VAL_4]] : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+// CHECK: %[[VAL_7:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_8:.*]] = memref.expand_shape %[[VAL_7]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK: memref.dealloc %[[VAL_7]] : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_9:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_10:.*]] = memref.expand_shape %[[VAL_9]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK: memref.dealloc %[[VAL_9]] : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_12:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK: memref.dealloc %[[VAL_11]] : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_13:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: memref.store %[[VAL_13]], %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]] : memref<24x256xf32, 1>
+// CHECK: memref.dealloc %[[VAL_6]] : memref<24x256xf32, 1>
+// CHECK: return
+// CHECK: }
+
+
+// This test creates multiple deallocation rearrangements.
+func.func private @test_multiple_deallocation_moves(%arg0: memref<45x24x256xf32, 1> , %arg1: memref<24x256xf32, 1> , %arg2: memref<256xf32, 1>) -> () {
+ %c1 = arith.constant 1 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+ %expand_shape = memref.expand_shape %alloc [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+ %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+ %expand_shape2 = memref.expand_shape %alloc_2 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+ %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+ %expand_shape3 = memref.expand_shape %alloc_3 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+ %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+ %expand_shape4 = memref.expand_shape %alloc_4 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+ %cf1 = arith.constant 1.0 : f32
+ memref.store %cf1, %alloc_1[%c1, %c1] : memref<24x256xf32, 1>
+ memref.dealloc %alloc : memref<45x6144xf32, 1>
+ memref.dealloc %alloc_1 : memref<24x256xf32, 1>
+ memref.dealloc %alloc_2 : memref<45x6144xf32, 1>
+ memref.dealloc %alloc_3 : memref<45x6144xf32, 1>
+ memref.dealloc %alloc_4 : memref<45x6144xf32, 1>
+ return
+}
+
+// -----
+// CHECK-LABEL: func.func private @test_users_in_different_blocks_linalig_generic(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x20x20xf32, 1>) -> memref<1x32x32xf32, 1> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32x32xf32, 1>
+// CHECK: %[[VAL_4:.*]] = memref.subview %[[VAL_3]][0, 0, 0] [1, 20, 20] [1, 1, 1] : memref<1x32x32xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_4]] : memref<1x20x20xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
+// CHECK: %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x1xf32, 1>
+// CHECK: %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x8x32x1x4xf32, 1>
+// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%[[VAL_6]] : memref<1x8x32x1x4xf32, 1>) {
+// CHECK: ^bb0(%[[VAL_7:.*]]: f32):
+// CHECK: %[[VAL_8:.*]] = linalg.index 0 : index
+// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]], %[[VAL_2]]] : memref<1x32x32x1xf32, 1>
+// CHECK: linalg.yield %[[VAL_9]] : f32
+// CHECK: }
+// CHECK: memref.dealloc %[[VAL_5]] : memref<1x32x32x1xf32, 1>
+// CHECK: %[[VAL_10:.*]] = memref.collapse_shape %[[VAL_6]] {{\[\[}}0, 1], [2], [3], [4]] : memref<1x8x32x1x4xf32, 1> into memref<8x32x1x4xf32, 1>
+// CHECK: memref.dealloc %[[VAL_6]] : memref<1x8x32x1x4xf32, 1>
+// CHECK: return %[[VAL_3]] : memref<1x32x32xf32, 1>
+// CHECK: }
+
+
+
+// This test will optimize the location of the %alloc_0 deallocation, since the last user of this allocation is the last linalg.generic operation
+// it will move the deallocation right after the last linalg.generic operation
+// %alloc_1 will not be moved becuase of the collapse shape op.
+func.func private @test_users_in_different_blocks_linalig_generic(%arg0: memref<1x20x20xf32, 1>) -> (memref<1x32x32xf32, 1>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x32x32xf32, 1>
+ %subview = memref.subview %alloc[0, 0, 0] [1, 20, 20] [1, 1, 1] : memref<1x32x32xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
+ memref.copy %arg0, %subview : memref<1x20x20xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
+ %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x1xf32, 1>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x8x32x1x4xf32, 1>
+ linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%alloc_1 : memref<1x8x32x1x4xf32, 1>) {
+ ^bb0(%out: f32):
+ %0 = linalg.index 0 : index
+ %8 = memref.load %alloc_0[%0, %0, %0, %c0] : memref<1x32x32x1xf32, 1>
+ linalg.yield %8 : f32
+ }
+ %collapse_shape = memref.collapse_shape %alloc_1 [[0, 1], [2], [3], [4]] : memref<1x8x32x1x4xf32, 1> into memref<8x32x1x4xf32, 1>
+ memref.dealloc %alloc_0 : memref<1x32x32x1xf32, 1>
+ memref.dealloc %alloc_1 : memref<1x8x32x1x4xf32, 1>
+ return %alloc : memref<1x32x32xf32, 1>
+}
+
+// -----
+// CHECK-LABEL: func.func private @test_deallocs_in_different_block_forops(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<45x24x256xf32, 1>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<24x256xf32, 1>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<256xf32, 1>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 45 : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 24 : index
+// CHECK: %[[VAL_8:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_9:.*]] = memref.expand_shape %[[VAL_8]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+// CHECK: %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+// CHECK: %[[VAL_12:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+// CHECK: memref.dealloc %[[VAL_11]] : memref<45x6144xf32, 1>
+// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] {
+// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_5]] {
+// CHECK: %[[VAL_15:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_13]], %[[VAL_14]], 0] [1, 8, 256] [1, 1, 1] : memref<45x24x256xf32, 1> to memref<1x8x256xf32, strided<[6144, 256, 1], offset: ?>, 1>
+// CHECK: %[[VAL_16:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_14]], 0] [8, 256] [1, 1] : memref<24x256xf32, 1> to memref<8x256xf32, strided<[256, 1], offset: ?>, 1>
+// CHECK: }
+// CHECK: }
+// CHECK: memref.dealloc %[[VAL_10]] : memref<24x256xf32, 1>
+// CHECK: memref.dealloc %[[VAL_8]] : memref<45x6144xf32, 1>
+// CHECK: return
+// CHECK: }
+
+// This test will not move the deallocations %alloc and %alloc1 since they are used in the last scf.for operation
+// %alloc_2 will move right after its last user the expand_shape operation
+func.func private @test_deallocs_in_different_block_forops(%arg0: memref<45x24x256xf32, 1>, %arg1: memref<24x256xf32, 1> , %arg2: memref<256xf32, 1> ) -> () {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c45 = arith.constant 45 : index
+ %c24 = arith.constant 24 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+ %expand_shape = memref.expand_shape %alloc [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
+ %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
+ %expand_shape2 = memref.expand_shape %alloc_2 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
+ scf.for %arg3 = %c0 to %c45 step %c1 {
+ scf.for %arg4 = %c0 to %c24 step %c8 {
+ %subview = memref.subview %expand_shape[%arg3, %arg4, 0] [1, 8, 256] [1, 1, 1] : memref<45x24x256xf32, 1> to memref<1x8x256xf32, strided<[6144, 256, 1], offset: ?>, 1>
+ %subview_3 = memref.subview %alloc_1[%arg4, 0] [8, 256] [1, 1] : memref<24x256xf32, 1> to memref<8x256xf32, strided<[256, 1], offset: ?>, 1>
+ }
+ }
+ memref.dealloc %alloc : memref<45x6144xf32, 1>
+ memref.dealloc %alloc_1 : memref<24x256xf32, 1>
+ memref.dealloc %alloc_2 : memref<45x6144xf32, 1>
+ return
+}
More information about the Mlir-commits
mailing list