[Mlir-commits] [mlir] MLIR: add_buffer_lifetime (PR #186670)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 15 07:06:07 PDT 2026


https://github.com/AbdallahRashed updated https://github.com/llvm/llvm-project/pull/186670

>From fe9ca8f8e0e52815c9c134fb74b97dc122c8a3ee Mon Sep 17 00:00:00 2001
From: AbdallahRashed <abdallah.mrashed at gmail.com>
Date: Sun, 15 Mar 2026 14:56:45 +0100
Subject: [PATCH] MLIR: add_buffer_lifetime

the goal of the task is get some states for a single
block, number of allocation and dealloc
peek live bytes
non-overlapping pairs
---
 .../Bufferization/Transforms/Passes.td        |  17 ++
 .../Transforms/BufferLifetimeStats.cpp        | 193 ++++++++++++++++++
 .../Bufferization/Transforms/CMakeLists.txt   |   1 +
 .../Transforms/buffer-lifetime-stats.mlir     | 102 +++++++++
 4 files changed, 313 insertions(+)
 create mode 100644 mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index cd28bd6cf73a5..b1656304fb117 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -184,6 +184,23 @@ def OptimizeAllocationLivenessPass
   let dependentDialects = ["mlir::memref::MemRefDialect"];
 }
 
+def PrintBufferLifetimeStatsPass
+    : Pass<"print-buffer-lifetime-stats", "func::FuncOp"> {
+  let summary = "Print buffer lifetime statistics for allocations in a "
+                "function";
+  let description = [{
+    This analysis-only pass walks a function, collects memref.alloc /
+    memref.dealloc pairs via MemoryEffectOpInterface, computes lifetime
+    intervals based on operation ordering within a block, and prints
+    statistics: number of tracked allocations, total allocated bytes, peak
+    live bytes, and the number of non-overlapping allocation pairs that
+    could potentially share memory.
+
+    The pass does not modify the IR.
+  }];
+  let dependentDialects = ["mlir::memref::MemRefDialect"];
+}
+
 def LowerDeallocationsPass : Pass<"bufferization-lower-deallocations"> {
   let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`"
                 "operations";
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp
new file mode 100644
index 0000000000000..86f037634ce43
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferLifetimeStats.cpp
@@ -0,0 +1,193 @@
+//===- BufferLifetimeStats.cpp - Buffer lifetime statistics pass ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_PRINTBUFFERLIFETIMESTATSPASS
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Assign a sequential index to each operation in the block.
+static DenseMap<Operation *, unsigned> buildOperationIndex(Block &block) {
+  DenseMap<Operation *, unsigned> opIndex;
+  unsigned idx = 0;
+  for (Operation &op : block)
+    opIndex[&op] = idx++;
+  return opIndex;
+}
+
+/// Find the unique dealloc for `allocResult` in `block`, or nullptr.
+static Operation *findDeallocInSameBlock(Value allocResult, Block *block) {
+  Operation *deallocOp = nullptr;
+  for (Operation *user : allocResult.getUsers()) {
+    auto memEffectOp = dyn_cast<MemoryEffectOpInterface>(user);
+    if (!memEffectOp)
+      continue;
+    SmallVector<MemoryEffects::EffectInstance, 2> effects;
+    memEffectOp.getEffects(effects);
+    for (const auto &effect : effects) {
+      if (isa<MemoryEffects::Free>(effect.getEffect()) &&
+          user->getBlock() == block) {
+        if (deallocOp)
+          return nullptr;
+        deallocOp = user;
+      }
+    }
+  }
+  return deallocOp;
+}
+
+/// Compute the size in bytes for a statically-shaped memref type.
+static int64_t getMemRefSizeInBytes(MemRefType type) {
+  if (!type.hasStaticShape())
+    return 0;
+  int64_t numElements = type.getNumElements();
+  unsigned bitsPerElement = type.getElementTypeBitWidth();
+  return (numElements * bitsPerElement + 7) / 8;
+}
+
+/// A buffer lifetime interval: [allocIndex, deallocIndex).
+struct LifetimeInterval {
+  Value allocResult;
+  unsigned allocIndex;
+  unsigned deallocIndex;
+  int64_t sizeInBytes;
+};
+
+/// Check whether two lifetime intervals are non-overlapping.
+static bool areNonOverlapping(const LifetimeInterval &a,
+                              const LifetimeInterval &b) {
+  return a.deallocIndex <= b.allocIndex || b.deallocIndex <= a.allocIndex;
+}
+
+//===----------------------------------------------------------------------===//
+// Pass implementation
+//===----------------------------------------------------------------------===//
+
+struct PrintBufferLifetimeStats
+    : public bufferization::impl::PrintBufferLifetimeStatsPassBase<
+          PrintBufferLifetimeStats> {
+public:
+  PrintBufferLifetimeStats() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    if (func.isExternal())
+      return;
+
+    // We only handle single-block functions for now.
+    if (!func.getBody().hasOneBlock())
+      return;
+
+    Block &entryBlock = func.getBody().front();
+    DenseMap<Operation *, unsigned> opIndex = buildOperationIndex(entryBlock);
+    SmallVector<LifetimeInterval> intervals;
+
+    entryBlock.walk([&](MemoryEffectOpInterface memEffectOp) {
+      SmallVector<MemoryEffects::EffectInstance, 2> effects;
+      memEffectOp.getEffects(effects);
+
+      for (const MemoryEffects::EffectInstance &effect : effects) {
+        if (!isa<MemoryEffects::Allocate>(effect.getEffect()))
+          continue;
+
+        Value val = effect.getValue();
+        if (!val || val.getDefiningOp() != memEffectOp.getOperation())
+          continue;
+
+        auto memrefType = dyn_cast<MemRefType>(val.getType());
+        if (!memrefType)
+          continue;
+
+        Operation *deallocOp =
+            findDeallocInSameBlock(val, memEffectOp->getBlock());
+        if (!deallocOp)
+          continue;
+
+        auto allocIt = opIndex.find(memEffectOp.getOperation());
+        auto deallocIt = opIndex.find(deallocOp);
+        if (allocIt == opIndex.end() || deallocIt == opIndex.end())
+          continue;
+
+        int64_t sizeBytes = getMemRefSizeInBytes(memrefType);
+        intervals.push_back(
+            {val, allocIt->second, deallocIt->second, sizeBytes});
+      }
+    });
+
+    // Compute statistics.
+    int64_t totalBytes = 0;
+    for (const auto &interval : intervals)
+      totalBytes += interval.sizeInBytes;
+
+    // Compute peak live bytes by sweeping through all time points.
+    int64_t peakLiveBytes = 0;
+    if (!intervals.empty()) {
+      // Collect all unique time points.
+      SmallVector<unsigned> timePoints;
+      for (const auto &interval : intervals) {
+        timePoints.push_back(interval.allocIndex);
+        timePoints.push_back(interval.deallocIndex);
+      }
+      llvm::sort(timePoints);
+      timePoints.erase(llvm::unique(timePoints), timePoints.end());
+
+      for (unsigned t : timePoints) {
+        int64_t liveBytes = 0;
+        for (const auto &interval : intervals) {
+          if (interval.allocIndex <= t && t < interval.deallocIndex)
+            liveBytes += interval.sizeInBytes;
+        }
+        peakLiveBytes = std::max(peakLiveBytes, liveBytes);
+      }
+    }
+
+    // Count non-overlapping pairs (reuse opportunities).
+    unsigned nonOverlappingPairs = 0;
+    for (unsigned i = 0; i < intervals.size(); ++i)
+      for (unsigned j = i + 1; j < intervals.size(); ++j)
+        if (areNonOverlapping(intervals[i], intervals[j]))
+          ++nonOverlappingPairs;
+
+    llvm::outs() << "--- Buffer Lifetime Statistics for '" << func.getSymName()
+                 << "' ---\n";
+    llvm::outs() << "  Tracked allocations     : " << intervals.size() << "\n";
+    llvm::outs() << "  Total allocated bytes    : " << totalBytes << "\n";
+    llvm::outs() << "  Peak live bytes          : " << peakLiveBytes << "\n";
+    llvm::outs() << "  Non-overlapping pairs    : " << nonOverlappingPairs
+                 << "\n";
+
+    for (const auto &interval : intervals) {
+      llvm::outs() << "  Buffer: " << interval.allocResult.getType()
+                   << " | size=" << interval.sizeInBytes << " | lifetime=["
+                   << interval.allocIndex << ", " << interval.deallocIndex
+                   << ")\n";
+    }
+    llvm::outs() << "---\n";
+  }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 7c38621be1bb5..27c2bf564f3dd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRBufferizationTransforms
   Bufferize.cpp
   BufferDeallocationSimplification.cpp
+  BufferLifetimeStats.cpp
   BufferOptimizations.cpp
   BufferResultsToOutParams.cpp
   BufferUtils.cpp
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir
new file mode 100644
index 0000000000000..da4d44bcac2b8
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-lifetime-stats.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s --print-buffer-lifetime-stats --split-input-file 2>&1 | FileCheck %s
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'sequential_non_overlapping' ---
+// CHECK:   Tracked allocations     : 2
+// CHECK:   Total allocated bytes    : 8192
+// CHECK:   Peak live bytes          : 4096
+// CHECK:   Non-overlapping pairs    : 1
+// CHECK:   Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 3)
+// CHECK:   Buffer: memref<512xf64> | size=4096 | lifetime=[5, 7)
+// CHECK: ---
+
+func.func @sequential_non_overlapping(%arg0: memref<1024xf32>,
+                                       %arg1: memref<512xf64>) {
+  %cst = arith.constant 0.0 : f32
+  %a = memref.alloc() : memref<1024xf32>       // 1024 * 4 = 4096 bytes
+  linalg.fill ins(%cst : f32) outs(%a : memref<1024xf32>)
+  memref.dealloc %a : memref<1024xf32>
+
+  %cst2 = arith.constant 0.0 : f64
+  %b = memref.alloc() : memref<512xf64>        // 512 * 8 = 4096 bytes
+  linalg.fill ins(%cst2 : f64) outs(%b : memref<512xf64>)
+  memref.dealloc %b : memref<512xf64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'overlapping_lifetimes' ---
+// CHECK:   Tracked allocations     : 2
+// CHECK:   Total allocated bytes    : 6144
+// CHECK:   Peak live bytes          : 6144
+// CHECK:   Non-overlapping pairs    : 0
+// CHECK:   Buffer: memref<512xf32> | size=2048 | lifetime=[0, 4)
+// CHECK:   Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 3)
+// CHECK: ---
+
+func.func @overlapping_lifetimes() {
+  %a = memref.alloc() : memref<512xf32>        // 2048 bytes
+  %b = memref.alloc() : memref<1024xf32>       // 4096 bytes
+  %cst = arith.constant 0.0 : f32
+  memref.dealloc %b : memref<1024xf32>
+  memref.dealloc %a : memref<512xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'three_buffers_mixed' ---
+// CHECK:   Tracked allocations     : 3
+// CHECK:   Total allocated bytes    : 10240
+// CHECK:   Peak live bytes          : 8192
+// CHECK:   Non-overlapping pairs    : 1
+// CHECK:   Buffer: memref<512xf32> | size=2048 | lifetime=[0, 2)
+// CHECK:   Buffer: memref<1024xf32> | size=4096 | lifetime=[1, 5)
+// CHECK:   Buffer: memref<1024xf32> | size=4096 | lifetime=[3, 6)
+// CHECK: ---
+
+// %a and %b overlap (a=[0,2), b=[1,5))
+// %a and %c don't overlap (a=[0,2), c=[3,6))
+// %b and %c overlap (b=[1,5), c=[3,6))
+// So 1 non-overlapping pair: (%a, %c)
+func.func @three_buffers_mixed() {
+  %a = memref.alloc() : memref<512xf32>        // 2048 bytes
+  %b = memref.alloc() : memref<1024xf32>       // 4096 bytes
+  memref.dealloc %a : memref<512xf32>
+  %c = memref.alloc() : memref<1024xf32>       // 4096 bytes
+  %cst = arith.constant 0.0 : f32
+  memref.dealloc %b : memref<1024xf32>
+  memref.dealloc %c : memref<1024xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'single_alloc' ---
+// CHECK:   Tracked allocations     : 1
+// CHECK:   Total allocated bytes    : 256
+// CHECK:   Peak live bytes          : 256
+// CHECK:   Non-overlapping pairs    : 0
+// CHECK:   Buffer: memref<64xf32> | size=256 | lifetime=[0, 1)
+// CHECK: ---
+
+func.func @single_alloc() {
+  %a = memref.alloc() : memref<64xf32>         // 64 * 4 = 256 bytes
+  memref.dealloc %a : memref<64xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: --- Buffer Lifetime Statistics for 'no_allocs' ---
+// CHECK:   Tracked allocations     : 0
+// CHECK:   Total allocated bytes    : 0
+// CHECK:   Peak live bytes          : 0
+// CHECK:   Non-overlapping pairs    : 0
+// CHECK: ---
+
+func.func @no_allocs(%arg0: memref<1024xf32>) {
+  %cst = arith.constant 0.0 : f32
+  linalg.fill ins(%cst : f32) outs(%arg0 : memref<1024xf32>)
+  return
+}



More information about the Mlir-commits mailing list