[Mlir-commits] [mlir] MLIR: add_buffer_lifetime (PR #186670)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 15 07:00:37 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization
Author: AbdallahRashed (AbdallahRashed)
<details>
<summary>Changes</summary>
the goal of the task is get some states for a single block, number of allocation and dealloc
peek live bytes
non-overlapping pairs
---
Full diff: https://github.com/llvm/llvm-project/pull/186670.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+17)
- (added) mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp (+194)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt (+1)
- (added) mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir (+102)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index cd28bd6cf73a5..b1656304fb117 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -184,6 +184,23 @@ def OptimizeAllocationLivenessPass
let dependentDialects = ["mlir::memref::MemRefDialect"];
}
+def PrintBufferLifetimeStatsPass
+ : Pass<"print-buffer-lifetime-stats", "func::FuncOp"> {
+ let summary = "Print buffer lifetime statistics for allocations in a "
+ "function";
+ let description = [{
+ This analysis-only pass walks a function, collects memref.alloc /
+ memref.dealloc pairs via MemoryEffectOpInterface, computes lifetime
+ intervals based on operation ordering within a block, and prints
+ statistics: number of tracked allocations, total allocated bytes, peak
+ live bytes, and the number of non-overlapping allocation pairs that
+ could potentially share memory.
+
+ The pass does not modify the IR.
+ }];
+ let dependentDialects = ["mlir::memref::MemRefDialect"];
+}
+
def LowerDeallocationsPass : Pass<"bufferization-lower-deallocations"> {
let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`"
"operations";
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp
new file mode 100644
index 0000000000000..ec2959a348abd
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp
@@ -0,0 +1,194 @@
+//===- BufferLifetimeStats.cpp - Buffer lifetime statistics pass ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_PRINTBUFFERLIFETIMESTATSPASS
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Assign a sequential index to each operation in the block.
+static DenseMap<Operation *, unsigned> buildOperationIndex(Block &block) {
+ DenseMap<Operation *, unsigned> opIndex;
+ unsigned idx = 0;
+ for (Operation &op : block)
+ opIndex[&op] = idx++;
+ return opIndex;
+}
+
+/// Find the unique dealloc for `allocResult` in `block`, or nullptr.
+static Operation *findDeallocInSameBlock(Value allocResult, Block *block) {
+ Operation *deallocOp = nullptr;
+ for (Operation *user : allocResult.getUsers()) {
+ auto memEffectOp = dyn_cast<MemoryEffectOpInterface>(user);
+ if (!memEffectOp)
+ continue;
+ SmallVector<MemoryEffects::EffectInstance, 2> effects;
+ memEffectOp.getEffects(effects);
+ for (const auto &effect : effects) {
+ if (isa<MemoryEffects::Free>(effect.getEffect()) &&
+ user->getBlock() == block) {
+ if (deallocOp)
+ return nullptr;
+ deallocOp = user;
+ }
+ }
+ }
+ return deallocOp;
+}
+
+/// Compute the size in bytes for a statically-shaped memref type.
+static int64_t getMemRefSizeInBytes(MemRefType type) {
+ if (!type.hasStaticShape())
+ return 0;
+ int64_t numElements = type.getNumElements();
+ unsigned bitsPerElement = type.getElementTypeBitWidth();
+ return (numElements * bitsPerElement + 7) / 8;
+}
+
+/// A buffer lifetime interval: [allocIndex, deallocIndex).
+struct LifetimeInterval {
+ Value allocResult;
+ unsigned allocIndex;
+ unsigned deallocIndex;
+ int64_t sizeInBytes;
+};
+
+/// Check whether two lifetime intervals are non-overlapping.
+static bool areNonOverlapping(const LifetimeInterval &a,
+ const LifetimeInterval &b) {
+ return a.deallocIndex <= b.allocIndex || b.deallocIndex <= a.allocIndex;
+}
+
+//===----------------------------------------------------------------------===//
+// Pass implementation
+//===----------------------------------------------------------------------===//
+
+struct PrintBufferLifetimeStats
+ : public bufferization::impl::PrintBufferLifetimeStatsPassBase<
+ PrintBufferLifetimeStats> {
+public:
+ PrintBufferLifetimeStats() = default;
+
+ void runOnOperation() override {
+ func::FuncOp func = getOperation();
+
+ if (func.isExternal())
+ return;
+
+ // We only handle single-block functions for now.
+ if (!func.getBody().hasOneBlock())
+ return;
+
+ Block &entryBlock = func.getBody().front();
+ DenseMap<Operation *, unsigned> opIndex = buildOperationIndex(entryBlock);
+ SmallVector<LifetimeInterval> intervals;
+
+ entryBlock.walk([&](MemoryEffectOpInterface memEffectOp) {
+ SmallVector<MemoryEffects::EffectInstance, 2> effects;
+ memEffectOp.getEffects(effects);
+
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ if (!isa<MemoryEffects::Allocate>(effect.getEffect()))
+ continue;
+
+ Value val = effect.getValue();
+ if (!val || val.getDefiningOp() != memEffectOp.getOperation())
+ continue;
+
+ auto memrefType = dyn_cast<MemRefType>(val.getType());
+ if (!memrefType)
+ continue;
+
+ Operation *deallocOp =
+ findDeallocInSameBlock(val, memEffectOp->getBlock());
+ if (!deallocOp)
+ continue;
+
+ auto allocIt = opIndex.find(memEffectOp.getOperation());
+ auto deallocIt = opIndex.find(deallocOp);
+ if (allocIt == opIndex.end() || deallocIt == opIndex.end())
+ continue;
+
+ int64_t sizeBytes = getMemRefSizeInBytes(memrefType);
+ intervals.push_back(
+ {val, allocIt->second, deallocIt->second, sizeBytes});
+ }
+ });
+
+ // Compute statistics.
+ int64_t totalBytes = 0;
+ for (const auto &interval : intervals)
+ totalBytes += interval.sizeInBytes;
+
+ // Compute peak live bytes by sweeping through all time points.
+ int64_t peakLiveBytes = 0;
+ if (!intervals.empty()) {
+ // Collect all unique time points.
+ SmallVector<unsigned> timePoints;
+ for (const auto &interval : intervals) {
+ timePoints.push_back(interval.allocIndex);
+ timePoints.push_back(interval.deallocIndex);
+ }
+ llvm::sort(timePoints);
+ timePoints.erase(llvm::unique(timePoints), timePoints.end());
+
+ for (unsigned t : timePoints) {
+ int64_t liveBytes = 0;
+ for (const auto &interval : intervals) {
+ if (interval.allocIndex <= t && t < interval.deallocIndex)
+ liveBytes += interval.sizeInBytes;
+ }
+ peakLiveBytes = std::max(peakLiveBytes, liveBytes);
+ }
+ }
+
+ // Count non-overlapping pairs (reuse opportunities).
+ unsigned nonOverlappingPairs = 0;
+ for (unsigned i = 0; i < intervals.size(); ++i)
+ for (unsigned j = i + 1; j < intervals.size(); ++j)
+ if (areNonOverlapping(intervals[i], intervals[j]))
+ ++nonOverlappingPairs;
+
+ llvm::outs() << "--- Buffer Lifetime Statistics for '"
+ << func.getSymName() << "' ---\n";
+ llvm::outs() << " Tracked allocations : " << intervals.size()
+ << "\n";
+ llvm::outs() << " Total allocated bytes : " << totalBytes << "\n";
+ llvm::outs() << " Peak live bytes : " << peakLiveBytes << "\n";
+ llvm::outs() << " Non-overlapping pairs : " << nonOverlappingPairs
+ << "\n";
+
+ for (const auto &interval : intervals) {
+ llvm::outs() << " Buffer: " << interval.allocResult.getType()
+ << " | size=" << interval.sizeInBytes
+ << " | lifetime=[" << interval.allocIndex << ", "
+ << interval.deallocIndex << ")\n";
+ }
+ llvm::outs() << "---\n";
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 7c38621be1bb5..27c2bf564f3dd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRBufferizationTransforms
Bufferize.cpp
BufferDeallocationSimplification.cpp
+ BufferLifetimeStats.cpp
BufferOptimizations.cpp
BufferResultsToOutParams.cpp
BufferUtils.cpp
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir
new file mode 100644
index 0000000000000..da4d44bcac2b8
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s --print-buffer-lifetime-stats --split-input-file 2>&1 | FileCheck %s
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'sequential_non_overlapping' ---
+// CHECK: Tracked allocations : 2
+// CHECK: Total allocated bytes : 8192
+// CHECK: Peak live bytes : 4096
+// CHECK: Non-overlapping pairs : 1
+// CHECK: Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 3)
+// CHECK: Buffer: memref<512xf64> | size=4096 | lifetime=[5, 7)
+// CHECK: ---
+
+func.func @sequential_non_overlapping(%arg0: memref<1024xf32>,
+ %arg1: memref<512xf64>) {
+ %cst = arith.constant 0.0 : f32
+ %a = memref.alloc() : memref<1024xf32> // 1024 * 4 = 4096 bytes
+ linalg.fill ins(%cst : f32) outs(%a : memref<1024xf32>)
+ memref.dealloc %a : memref<1024xf32>
+
+ %cst2 = arith.constant 0.0 : f64
+ %b = memref.alloc() : memref<512xf64> // 512 * 8 = 4096 bytes
+ linalg.fill ins(%cst2 : f64) outs(%b : memref<512xf64>)
+ memref.dealloc %b : memref<512xf64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'overlapping_lifetimes' ---
+// CHECK: Tracked allocations : 2
+// CHECK: Total allocated bytes : 6144
+// CHECK: Peak live bytes : 6144
+// CHECK: Non-overlapping pairs : 0
+// CHECK: Buffer: memref<512xf32> | size=2048 | lifetime=[0, 4)
+// CHECK: Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 3)
+// CHECK: ---
+
+func.func @overlapping_lifetimes() {
+ %a = memref.alloc() : memref<512xf32> // 2048 bytes
+ %b = memref.alloc() : memref<1024xf32> // 4096 bytes
+ %cst = arith.constant 0.0 : f32
+ memref.dealloc %b : memref<1024xf32>
+ memref.dealloc %a : memref<512xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'three_buffers_mixed' ---
+// CHECK: Tracked allocations : 3
+// CHECK: Total allocated bytes : 10240
+// CHECK: Peak live bytes : 8192
+// CHECK: Non-overlapping pairs : 1
+// CHECK: Buffer: memref<512xf32> | size=2048 | lifetime=[0, 2)
+// CHECK: Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 5)
+// CHECK: Buffer: memref<1024xf32> | size=4096 | lifetime=[3, 6)
+// CHECK: ---
+
+// %a and %b overlap (a=[0,2), b=[1,5))
+// %a and %c don't overlap (a=[0,2), c=[3,6))
+// %b and %c overlap (b=[1,5), c=[3,6))
+// So 1 non-overlapping pair: (%a, %c)
+func.func @three_buffers_mixed() {
+ %a = memref.alloc() : memref<512xf32> // 2048 bytes
+ %b = memref.alloc() : memref<1024xf32> // 4096 bytes
+ memref.dealloc %a : memref<512xf32>
+ %c = memref.alloc() : memref<1024xf32> // 4096 bytes
+ %cst = arith.constant 0.0 : f32
+ memref.dealloc %b : memref<1024xf32>
+ memref.dealloc %c : memref<1024xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'single_alloc' ---
+// CHECK: Tracked allocations : 1
+// CHECK: Total allocated bytes : 256
+// CHECK: Peak live bytes : 256
+// CHECK: Non-overlapping pairs : 0
+// CHECK: Buffer: memref<64xf32> | size=256 | lifetime=[0, 1)
+// CHECK: ---
+
+func.func @single_alloc() {
+ %a = memref.alloc() : memref<64xf32> // 64 * 4 = 256 bytes
+ memref.dealloc %a : memref<64xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'no_allocs' ---
+// CHECK: Tracked allocations : 0
+// CHECK: Total allocated bytes : 0
+// CHECK: Peak live bytes : 0
+// CHECK: Non-overlapping pairs : 0
+// CHECK: ---
+
+func.func @no_allocs(%arg0: memref<1024xf32>) {
+ %cst = arith.constant 0.0 : f32
+ linalg.fill ins(%cst : f32) outs(%arg0 : memref<1024xf32>)
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/186670
More information about the Mlir-commits
mailing list