[Mlir-commits] [mlir] MLIR: add_buffer_lifetime (PR #186670)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 15 07:00:35 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: AbdallahRashed (AbdallahRashed)

<details>
<summary>Changes</summary>

the goal of the task is get some states for a single block, number of allocation and dealloc
peek live bytes
non-overlapping pairs

---
Full diff: https://github.com/llvm/llvm-project/pull/186670.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+17) 
- (added) mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp (+194) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt (+1) 
- (added) mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir (+102) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index cd28bd6cf73a5..b1656304fb117 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -184,6 +184,23 @@ def OptimizeAllocationLivenessPass
   let dependentDialects = ["mlir::memref::MemRefDialect"];
 }
 
+def PrintBufferLifetimeStatsPass
+    : Pass<"print-buffer-lifetime-stats", "func::FuncOp"> {
+  let summary = "Print buffer lifetime statistics for allocations in a "
+                "function";
+  let description = [{
+    This analysis-only pass walks a function, collects memref.alloc /
+    memref.dealloc pairs via MemoryEffectOpInterface, computes lifetime
+    intervals based on operation ordering within a block, and prints
+    statistics: number of tracked allocations, total allocated bytes, peak
+    live bytes, and the number of non-overlapping allocation pairs that
+    could potentially share memory.
+
+    The pass does not modify the IR.
+  }];
+  let dependentDialects = ["mlir::memref::MemRefDialect"];
+}
+
 def LowerDeallocationsPass : Pass<"bufferization-lower-deallocations"> {
   let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`"
                 "operations";
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp
new file mode 100644
index 0000000000000..ec2959a348abd
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp
@@ -0,0 +1,194 @@
+//===- BufferLifetimeStats.cpp - Buffer lifetime statistics pass ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_PRINTBUFFERLIFETIMESTATSPASS
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Assign a sequential index to each operation in the block.
+static DenseMap<Operation *, unsigned> buildOperationIndex(Block &block) {
+  DenseMap<Operation *, unsigned> opIndex;
+  unsigned idx = 0;
+  for (Operation &op : block)
+    opIndex[&op] = idx++;
+  return opIndex;
+}
+
+/// Find the unique dealloc for `allocResult` in `block`, or nullptr.
+static Operation *findDeallocInSameBlock(Value allocResult, Block *block) {
+  Operation *deallocOp = nullptr;
+  for (Operation *user : allocResult.getUsers()) {
+    auto memEffectOp = dyn_cast<MemoryEffectOpInterface>(user);
+    if (!memEffectOp)
+      continue;
+    SmallVector<MemoryEffects::EffectInstance, 2> effects;
+    memEffectOp.getEffects(effects);
+    for (const auto &effect : effects) {
+      if (isa<MemoryEffects::Free>(effect.getEffect()) &&
+          user->getBlock() == block) {
+        if (deallocOp)
+          return nullptr;
+        deallocOp = user;
+      }
+    }
+  }
+  return deallocOp;
+}
+
+/// Compute the size in bytes for a statically-shaped memref type.
+static int64_t getMemRefSizeInBytes(MemRefType type) {
+  if (!type.hasStaticShape())
+    return 0;
+  int64_t numElements = type.getNumElements();
+  unsigned bitsPerElement = type.getElementTypeBitWidth();
+  return (numElements * bitsPerElement + 7) / 8;
+}
+
+/// A buffer lifetime interval: [allocIndex, deallocIndex).
+struct LifetimeInterval {
+  Value allocResult;
+  unsigned allocIndex;
+  unsigned deallocIndex;
+  int64_t sizeInBytes;
+};
+
+/// Check whether two lifetime intervals are non-overlapping.
+static bool areNonOverlapping(const LifetimeInterval &a,
+                              const LifetimeInterval &b) {
+  return a.deallocIndex <= b.allocIndex || b.deallocIndex <= a.allocIndex;
+}
+
+//===----------------------------------------------------------------------===//
+// Pass implementation
+//===----------------------------------------------------------------------===//
+
+struct PrintBufferLifetimeStats
+    : public bufferization::impl::PrintBufferLifetimeStatsPassBase<
+          PrintBufferLifetimeStats> {
+public:
+  PrintBufferLifetimeStats() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    if (func.isExternal())
+      return;
+
+    // We only handle single-block functions for now.
+    if (!func.getBody().hasOneBlock())
+      return;
+
+    Block &entryBlock = func.getBody().front();
+    DenseMap<Operation *, unsigned> opIndex = buildOperationIndex(entryBlock);
+    SmallVector<LifetimeInterval> intervals;
+
+    entryBlock.walk([&](MemoryEffectOpInterface memEffectOp) {
+      SmallVector<MemoryEffects::EffectInstance, 2> effects;
+      memEffectOp.getEffects(effects);
+
+      for (const MemoryEffects::EffectInstance &effect : effects) {
+        if (!isa<MemoryEffects::Allocate>(effect.getEffect()))
+          continue;
+
+        Value val = effect.getValue();
+        if (!val || val.getDefiningOp() != memEffectOp.getOperation())
+          continue;
+
+        auto memrefType = dyn_cast<MemRefType>(val.getType());
+        if (!memrefType)
+          continue;
+
+        Operation *deallocOp =
+            findDeallocInSameBlock(val, memEffectOp->getBlock());
+        if (!deallocOp)
+          continue;
+
+        auto allocIt = opIndex.find(memEffectOp.getOperation());
+        auto deallocIt = opIndex.find(deallocOp);
+        if (allocIt == opIndex.end() || deallocIt == opIndex.end())
+          continue;
+
+        int64_t sizeBytes = getMemRefSizeInBytes(memrefType);
+        intervals.push_back(
+            {val, allocIt->second, deallocIt->second, sizeBytes});
+      }
+    });
+
+    // Compute statistics.
+    int64_t totalBytes = 0;
+    for (const auto &interval : intervals)
+      totalBytes += interval.sizeInBytes;
+
+    // Compute peak live bytes by sweeping through all time points.
+    int64_t peakLiveBytes = 0;
+    if (!intervals.empty()) {
+      // Collect all unique time points.
+      SmallVector<unsigned> timePoints;
+      for (const auto &interval : intervals) {
+        timePoints.push_back(interval.allocIndex);
+        timePoints.push_back(interval.deallocIndex);
+      }
+      llvm::sort(timePoints);
+      timePoints.erase(llvm::unique(timePoints), timePoints.end());
+
+      for (unsigned t : timePoints) {
+        int64_t liveBytes = 0;
+        for (const auto &interval : intervals) {
+          if (interval.allocIndex <= t && t < interval.deallocIndex)
+            liveBytes += interval.sizeInBytes;
+        }
+        peakLiveBytes = std::max(peakLiveBytes, liveBytes);
+      }
+    }
+
+    // Count non-overlapping pairs (reuse opportunities).
+    unsigned nonOverlappingPairs = 0;
+    for (unsigned i = 0; i < intervals.size(); ++i)
+      for (unsigned j = i + 1; j < intervals.size(); ++j)
+        if (areNonOverlapping(intervals[i], intervals[j]))
+          ++nonOverlappingPairs;
+
+    llvm::outs() << "--- Buffer Lifetime Statistics for '"
+                 << func.getSymName() << "' ---\n";
+    llvm::outs() << "  Tracked allocations     : " << intervals.size()
+                 << "\n";
+    llvm::outs() << "  Total allocated bytes    : " << totalBytes << "\n";
+    llvm::outs() << "  Peak live bytes          : " << peakLiveBytes << "\n";
+    llvm::outs() << "  Non-overlapping pairs    : " << nonOverlappingPairs
+                 << "\n";
+
+    for (const auto &interval : intervals) {
+      llvm::outs() << "  Buffer: " << interval.allocResult.getType()
+                    << " | size=" << interval.sizeInBytes
+                    << " | lifetime=[" << interval.allocIndex << ", "
+                    << interval.deallocIndex << ")\n";
+    }
+    llvm::outs() << "---\n";
+  }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 7c38621be1bb5..27c2bf564f3dd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRBufferizationTransforms
   Bufferize.cpp
   BufferDeallocationSimplification.cpp
+  BufferLifetimeStats.cpp
   BufferOptimizations.cpp
   BufferResultsToOutParams.cpp
   BufferUtils.cpp
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir
new file mode 100644
index 0000000000000..da4d44bcac2b8
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s --print-buffer-lifetime-stats --split-input-file 2>&1 | FileCheck %s
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'sequential_non_overlapping' ---
+// CHECK:   Tracked allocations     : 2
+// CHECK:   Total allocated bytes    : 8192
+// CHECK:   Peak live bytes          : 4096
+// CHECK:   Non-overlapping pairs    : 1
+// CHECK:   Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 3)
+// CHECK:   Buffer: memref<512xf64> | size=4096 | lifetime=[5, 7)
+// CHECK: ---
+
+func.func @sequential_non_overlapping(%arg0: memref<1024xf32>,
+                                       %arg1: memref<512xf64>) {
+  %cst = arith.constant 0.0 : f32
+  %a = memref.alloc() : memref<1024xf32>       // 1024 * 4 = 4096 bytes
+  linalg.fill ins(%cst : f32) outs(%a : memref<1024xf32>)
+  memref.dealloc %a : memref<1024xf32>
+
+  %cst2 = arith.constant 0.0 : f64
+  %b = memref.alloc() : memref<512xf64>        // 512 * 8 = 4096 bytes
+  linalg.fill ins(%cst2 : f64) outs(%b : memref<512xf64>)
+  memref.dealloc %b : memref<512xf64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'overlapping_lifetimes' ---
+// CHECK:   Tracked allocations     : 2
+// CHECK:   Total allocated bytes    : 6144
+// CHECK:   Peak live bytes          : 6144
+// CHECK:   Non-overlapping pairs    : 0
+// CHECK:   Buffer: memref<512xf32> | size=2048 | lifetime=[0, 4)
+// CHECK:   Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 3)
+// CHECK: ---
+
+func.func @overlapping_lifetimes() {
+  %a = memref.alloc() : memref<512xf32>        // 2048 bytes
+  %b = memref.alloc() : memref<1024xf32>       // 4096 bytes
+  %cst = arith.constant 0.0 : f32
+  memref.dealloc %b : memref<1024xf32>
+  memref.dealloc %a : memref<512xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'three_buffers_mixed' ---
+// CHECK:   Tracked allocations     : 3
+// CHECK:   Total allocated bytes    : 10240
+// CHECK:   Peak live bytes          : 8192
+// CHECK:   Non-overlapping pairs    : 1
+// CHECK:   Buffer: memref<512xf32> | size=2048 | lifetime=[0, 2)
+// CHECK:   Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 5)
+// CHECK:   Buffer: memref<1024xf32> | size=4096 | lifetime=[3, 6)
+// CHECK: ---
+
+// %a and %b overlap (a=[0,2), b=[1,5))
+// %a and %c don't overlap (a=[0,2), c=[3,6))
+// %b and %c overlap (b=[1,5), c=[3,6))
+// So 1 non-overlapping pair: (%a, %c)
+func.func @three_buffers_mixed() {
+  %a = memref.alloc() : memref<512xf32>        // 2048 bytes
+  %b = memref.alloc() : memref<1024xf32>       // 4096 bytes
+  memref.dealloc %a : memref<512xf32>
+  %c = memref.alloc() : memref<1024xf32>       // 4096 bytes
+  %cst = arith.constant 0.0 : f32
+  memref.dealloc %b : memref<1024xf32>
+  memref.dealloc %c : memref<1024xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'single_alloc' ---
+// CHECK:   Tracked allocations     : 1
+// CHECK:   Total allocated bytes    : 256
+// CHECK:   Peak live bytes          : 256
+// CHECK:   Non-overlapping pairs    : 0
+// CHECK:   Buffer: memref<64xf32> | size=256 | lifetime=[0, 1)
+// CHECK: ---
+
+func.func @single_alloc() {
+  %a = memref.alloc() : memref<64xf32>         // 64 * 4 = 256 bytes
+  memref.dealloc %a : memref<64xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'no_allocs' ---
+// CHECK:   Tracked allocations     : 0
+// CHECK:   Total allocated bytes    : 0
+// CHECK:   Peak live bytes          : 0
+// CHECK:   Non-overlapping pairs    : 0
+// CHECK: ---
+
+func.func @no_allocs(%arg0: memref<1024xf32>) {
+  %cst = arith.constant 0.0 : f32
+  linalg.fill ins(%cst : f32) outs(%arg0 : memref<1024xf32>)
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/186670


More information about the Mlir-commits mailing list