[Mlir-commits] [mlir] [mlir] [memref] Compile-time memref.alloc Scheduling/Merging optimization (PR #95882)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 21 02:27:57 PDT 2024


https://github.com/Menooker updated https://github.com/llvm/llvm-project/pull/95882

>From 6ffcc271799f205ad5a36c9ade68d854c06acbe7 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Mon, 3 Jun 2024 17:08:58 +0800
Subject: [PATCH 01/12] basic code

---
 .../mlir/Dialect/MemRef/Transforms/Passes.h   |   4 +
 .../mlir/Dialect/MemRef/Transforms/Passes.td  |  35 +
 .../MemRef/Transforms/StaticMemoryPlanning.h  |  71 ++
 .../Dialect/MemRef/Transforms/CMakeLists.txt  |   2 +
 .../Dialect/MemRef/Transforms/MergeAlloc.cpp  | 352 ++++++++++
 .../Transforms/StaticMemoryPlanning.cpp       | 655 ++++++++++++++++++
 .../Dialect/MemRef/buffer-merge-lifetime.mlir | 129 ++++
 .../test/Dialect/MemRef/buffer-merge-mlp.mlir |  30 +
 mlir/unittests/Dialect/MemRef/CMakeLists.txt  |   2 +
 .../Dialect/MemRef/StaticMemoryPlanning.cpp   | 206 ++++++
 10 files changed, 1486 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h
 create mode 100644 mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
 create mode 100644 mlir/lib/Dialect/MemRef/Transforms/StaticMemoryPlanning.cpp
 create mode 100644 mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir
 create mode 100644 mlir/test/Dialect/MemRef/buffer-merge-mlp.mlir
 create mode 100644 mlir/unittests/Dialect/MemRef/StaticMemoryPlanning.cpp

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index d7050156862df..7ffa07bf768af 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -77,6 +77,10 @@ std::unique_ptr<Pass> createExpandStridedMetadataPass();
 /// components.
 std::unique_ptr<Pass> createExpandReallocPass(bool emitDeallocs = true);
 
+/// Creates an operation pass to merge the local memref allocations
+std::unique_ptr<Pass>
+createMergeAllocPass(const memref::MergeAllocOptions &o = {});
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 651ee05ae1f3c..f65774464c713 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -253,5 +253,40 @@ def ExpandRealloc : Pass<"expand-realloc"> {
   ];
 }
 
+def MergeAlloc : Pass<"merge-alloc", "func::FuncOp">  {
+  let summary = "Merge multiple memref.alloc and reuse the buffer";
+  let description = [{
+    The pass merges the "mergeable" memref.alloc allocations into a single
+    memref.alloc in its ancestor "allocation scope", to enhance memory
+    reuse and cache locality. A memref.alloc is "mergeable" if it is owned
+    by the current function and it is statically shaped and has identity layout.
+    An "allocation scope" is the nearest ancestor surrounding operation
+    of memref.alloc, which extends AutomaticAllocationScope trait and is not
+    scf.for. The function top-level block or the body of parallel-loop are
+    examples of "allocation scope". If there are nested AutomaticAllocationScope,
+    each level of the AutomaticAllocationScope is a different "allocation scope".
+    A "mergeable" memref.alloc will be replaced by a memref.view on the "merged"
+    buffer, with an offset. The "merged" buffer will be located at the begining
+    of the block of the "allocation scope".
+    The offset of each merged buffer is decided by this pass, by considering the
+    lifetime of the original memref before merging. This pass schedules the
+    offsets to 1) make sure the offsets and address ranges do not overlap if
+    two "mergeable" allocations have overlapped lifetime, and 2) reuse the
+    address ranges that are considered "hot" in cache for an later allocation. 
+  }];
+  let options = [
+    Option<"optionCheck", "check", "bool",
+       /*default=*/"false",
+       "Skip the mutation of the IR and only mark the lifetime and scope on the"
+       " operations. Useful for debugging and testing.">,
+    Option<"optionNoLocality", "no-consider-locality", "bool",
+       /*default=*/"false",
+       "Don't consider the cache locality when reusing the buffers. "
+       "This option may result in smaller total memory usage.">,
+  ];
+  let dependentDialects = ["memref::MemRefDialect", "arith::ArithDialect"];
+  let constructor = "mlir::memref::createMergeAllocPass()";
+}
+
 #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
 
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h b/mlir/include/mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h
new file mode 100644
index 0000000000000..83f022cb8fa8d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h
@@ -0,0 +1,71 @@
+//===- StaticMemoryPlanning.h - Static memory planning ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_STATICMEMORYPLANNING_H
+#define MLIR_DIALECT_MEMREF_STATICMEMORYPLANNING_H
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include <cstddef>
+#include <stdint.h>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace mlir {
+class Operation;
+
+namespace memoryplan {
+enum class InplaceKind {
+  ZERO_OFFSET, // this requires that the tensor share the same base
+               // pointer of the replaced tensor
+  FREE,        // the tensor can freely choose any offset on this tensor
+};
+
+struct MemoryTrace {
+  // unique id of a buffer
+  uintptr_t bufferId;
+  // if > 0, size of the buffer allocation, if = 0, it is a deallocation trace
+  std::size_t size;
+  MemoryTrace(uintptr_t bufferId = 0, std::size_t size = 0)
+      : bufferId{bufferId}, size{size} {}
+};
+
+using ScopeTraceData =
+    llvm::DenseMap<Operation *, llvm::SmallVector<memoryplan::MemoryTrace, 8>>;
+using Traces = llvm::SmallVector<memoryplan::MemoryTrace, 8>;
+using InplaceInfo = std::pair<uintptr_t, InplaceKind>;
+
+using InplaceInfoMap =
+    llvm::DenseMap<uintptr_t, llvm::SmallVector<InplaceInfo>>;
+
+/**
+ * Given a list of memory buffer alloc and free traces, try to use a large
+ * buffer to hold all allocated memory, and statically allocate each memory
+ * buffer from the large buffer for better memory reuse.
+ * @param traces the list of memory alloc and free traces, sorted by event time.
+ * @param alignment the alignment in number of elements
+ * @param hotFirst use the hot buffer first, instead of using best fit in size
+ * @param inplaceMap the map from the tensor to alloc into the candidate
+ * tensors that can be inplace reused for it.
+ * @param outSchedule the output schedule for each buffer: the location that
+ * the buffer should be in the large buffer (as an offset in number of elements)
+ * @param outInplaceSelection the output buffer id -> inplace buffer it reuses
+ * @return the size of the large buffer, in number of elements
+ * */
+std::size_t scheduleMemoryAllocations(
+    const Traces &traces, std::size_t alignment, bool hotFirst,
+    const InplaceInfoMap &inplaceMap,
+    std::unordered_map<uintptr_t, std::size_t> &outSchedule,
+    std::unordered_map<uintptr_t, std::vector<uintptr_t>>
+        &outInplaceSelection);
+
+} // namespace memoryplan
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index f150ac7ac2d63..456bbafd9af5e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -14,6 +14,8 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   NormalizeMemRefs.cpp
   ResolveShapedTypeResultDims.cpp
   RuntimeOpVerification.cpp
+  MergeAlloc.cpp
+  StaticMemoryPlanning.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
new file mode 100644
index 0000000000000..7eca9452ae802
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
@@ -0,0 +1,352 @@
+//===- MergeAlloc.cpp - Calling convention conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h"
+
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_MERGEALLOC
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+/// Return `true` if the given MemRef type has a static identity layout (i.e.,
+/// no layout).
+static bool hasStaticIdentityLayout(MemRefType type) {
+  return type.hasStaticShape() && type.getLayout().isIdentity();
+}
+
+namespace {
+static constexpr int64_t NO_ACCESS = -1;
+static constexpr int64_t COMPLEX_ACCESS = -2;
+struct Tick {
+  int64_t firstAccess = NO_ACCESS;
+  int64_t lastAccess = NO_ACCESS;
+
+  void access(int64_t tick) {
+    if (tick == COMPLEX_ACCESS) {
+      firstAccess = COMPLEX_ACCESS;
+      lastAccess = COMPLEX_ACCESS;
+    }
+    if (firstAccess == COMPLEX_ACCESS) {
+      return;
+    }
+    if (firstAccess == NO_ACCESS) {
+      firstAccess = tick;
+    } else {
+      firstAccess = std::min(firstAccess, tick);
+    }
+    lastAccess = std::max(lastAccess, tick);
+  }
+};
+
+bool isMergeableAlloc(Operation *op, int64_t tick) {
+  if (tick == COMPLEX_ACCESS) {
+    return false;
+  }
+  if (!hasStaticIdentityLayout(
+          cast<MemRefType>(op->getResultTypes().front()))) {
+    return false;
+  }
+  // currently only support alignment: none, 1, 2, 4, 8, 16, 32, 64
+  auto alignment = cast<memref::AllocOp>(op).getAlignment();
+  if (!alignment) {
+    return true; // ok if no alignment
+  }
+  return alignment > 0 && (64 % alignment.value() == 0);
+}
+
+// find the closest surrounding parent operation with AutomaticAllocationScope
+// trait, and is not scf.for
+Operation *getAllocScope(Operation *op) {
+  auto parent = op;
+  for (;;) {
+    parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
+    if (!parent) {
+      return nullptr;
+    }
+    if (!isa<scf::ForOp>(parent)) {
+      return parent;
+    }
+  }
+}
+
+FailureOr<size_t> getAllocSize(Operation *op) {
+  auto refType = cast<MemRefType>(op->getResultTypes().front());
+  int64_t size = refType.getElementTypeBitWidth() / 8;
+  // treat bool (i1) as 1 byte. It may not be true for all targets, but we at
+  // least have a large enough size for i1
+  size = (size != 0) ? size : 1;
+  for (auto v : refType.getShape()) {
+    size *= v;
+  }
+  if (size > 0) {
+    return static_cast<size_t>(size);
+  }
+  return op->emitError("Expecting static shaped allocation");
+}
+
+// A complex scope object is addition info for a RegionBranchOpInterface or
+// LoopLikeOpInterface. It contains the scope itself, and the referenced alloc
+// ops inside this scope. We use this object to track which buffers this scope
+// accesses. These buffers must have overlapped lifetime
+struct ComplexScope {
+  Operation *scope;
+  int64_t startTick;
+  llvm::SmallPtrSet<Operation *, 8> operations;
+  ComplexScope(Operation *scope, int64_t startTick)
+      : scope{scope}, startTick{startTick} {}
+  // returns true of an allocation either is not defined in the scope, or the
+  // allocation escapes from the scope
+  bool needsResetTick(Operation *scope, Operation *allocation,
+                      const mlir::BufferViewFlowAnalysis &aliasAnaly) const {
+    // if the allocation is not in the scope, conservatively set the ticks
+    if (!scope->isProperAncestor(allocation)) {
+      return true;
+    }
+    // if the allocation and its alias are used outside of the scope
+    for (auto &&alias : aliasAnaly.resolve(allocation->getResult(0))) {
+      for (auto &&userOp : alias.getUsers()) {
+        if (!scope->isProperAncestor(userOp) && !isMemoryEffectFree(userOp)) {
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
+  // called when walk() runs outside of the scope
+  void onPop(int64_t endTick, const mlir::BufferViewFlowAnalysis &aliasAnaly,
+             llvm::DenseMap<Operation *, Tick> &allocTicks) {
+    for (auto op : operations) {
+      if (needsResetTick(scope, op, aliasAnaly)) {
+        // let all referenced buffers have overlapped lifetime
+        auto &tick = allocTicks[op];
+        tick.access(startTick);
+        tick.access(endTick);
+      }
+    }
+  }
+};
+
+struct TickCollecter {
+  const mlir::BufferViewFlowAnalysis &aliasAnaly;
+  int64_t curTick = 0;
+  llvm::DenseMap<Operation *, Tick> allocTicks;
+  llvm::SmallVector<ComplexScope> complexScopeStack;
+  TickCollecter(const mlir::BufferViewFlowAnalysis &aliasAnaly)
+      : aliasAnaly{aliasAnaly} {}
+  void popScopeIfNecessary(Operation *op) {
+    // first check if we have walked outside of the previous ComplexScope
+    while (!complexScopeStack.empty()) {
+      auto &scope = complexScopeStack.back();
+      if (!op || !scope.scope->isProperAncestor(op)) {
+        scope.onPop(curTick, aliasAnaly, allocTicks);
+        complexScopeStack.pop_back();
+      } else {
+        break;
+      }
+    }
+  }
+
+  void forwardTick() { curTick++; }
+
+  void accessValue(Value v, bool complex) {
+    if (auto refv = dyn_cast<TypedValue<MemRefType>>(v)) {
+      for (auto &&base : aliasAnaly.resolveReverse(refv)) {
+        auto defop = base.getDefiningOp();
+        if (isa_and_present<memref::AllocOp>(defop)) {
+          allocTicks[defop].access(complex ? COMPLEX_ACCESS : curTick);
+          if (!complexScopeStack.empty()) {
+            complexScopeStack.back().operations.insert(defop);
+          }
+        }
+      }
+    }
+  }
+
+  void onMemrefViews(ViewLikeOpInterface op) {
+    auto viewSrc = op.getViewSource();
+    // don't need to access the first operand, which is "source".
+    // The "source" operand is not really read or written at this point
+    for (auto val : op.getOperation()->getOperands()) {
+      if (val != viewSrc)
+        accessValue(val, false);
+    }
+  }
+
+  void onReturnOp(Operation *op) {
+    bool isTopLevel = isa<func::FuncOp>(op->getParentOp());
+    for (auto val : op->getOperands()) {
+      accessValue(val, isTopLevel);
+    }
+  }
+
+  void onGeneralOp(Operation *op) {
+    for (auto val : op->getOperands()) {
+      accessValue(val, false);
+    }
+  }
+
+  void pushComplexScope(Operation *op) {
+    complexScopeStack.emplace_back(op, curTick);
+  }
+
+  FailureOr<memoryplan::ScopeTraceData> getTrace() {
+    struct TraceWithTick {
+      int64_t tick;
+      memoryplan::MemoryTrace trace;
+      TraceWithTick(int64_t tick, uintptr_t bufferId, size_t size)
+          : tick{tick}, trace{bufferId, size} {}
+    };
+    llvm::DenseMap<Operation *, llvm::SmallVector<TraceWithTick, 8>> raw;
+    for (auto &[op, tick] : allocTicks) {
+      if (!isMergeableAlloc(op, tick.firstAccess)) {
+        continue;
+      }
+      auto scope = getAllocScope(op);
+      if (!scope) {
+        return op->emitError(
+            "This op should be surrounded by an AutomaticAllocationScope");
+      }
+      auto allocSize = getAllocSize(op);
+      if (failed(allocSize)) {
+        return allocSize;
+      }
+      // tick.firstAccess * 2 and tick.lastAccess * 2 + 1 to make sure "dealloc"
+      // overlaps "alloc"
+      raw[scope].emplace_back(tick.firstAccess * 2,
+                              reinterpret_cast<uintptr_t>(op), *allocSize);
+      raw[scope].emplace_back(tick.lastAccess * 2 + 1,
+                              reinterpret_cast<uintptr_t>(op), 0);
+    }
+    memoryplan::ScopeTraceData ret;
+    for (auto &[scope, trace] : raw) {
+      std::stable_sort(trace.begin(), trace.end(),
+                       [](const TraceWithTick &a, const TraceWithTick &b) {
+                         return a.tick < b.tick;
+                       });
+      auto &retTrace = ret[scope];
+      retTrace.reserve(trace.size());
+      for (auto &tr : trace) {
+        retTrace.emplace_back(tr.trace);
+      }
+    }
+    return ret;
+  }
+};
+
+} // namespace
+
+FailureOr<mlir::memoryplan::ScopeTraceData>
+collectMemoryTrace(Operation *root,
+                   const mlir::BufferViewFlowAnalysis &aliasAnaly,
+                   bool markOnly) {
+  TickCollecter collecter{aliasAnaly};
+  root->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    collecter.popScopeIfNecessary(op);
+    collecter.forwardTick();
+    if (auto viewop = dyn_cast<ViewLikeOpInterface>(op)) {
+      collecter.onMemrefViews(viewop);
+    } else if (op->hasTrait<OpTrait::ReturnLike>()) {
+      collecter.onReturnOp(op);
+    } else if (!isMemoryEffectFree(op)) {
+      // if the op has no memory effects, it don't contribute to liveness
+      collecter.onGeneralOp(op);
+    }
+    // finally, if op is complex scope, push one ComplexScope
+    if (isa<RegionBranchOpInterface>(op) || isa<LoopLikeOpInterface>(op)) {
+      collecter.pushComplexScope(op);
+    }
+  });
+  collecter.popScopeIfNecessary(nullptr);
+  if (markOnly) {
+    for (auto &[alloc, tick] : collecter.allocTicks) {
+      auto allocscope = getAllocScope(alloc);
+      alloc->setAttr(
+          "__mergealloc_lifetime",
+          DenseI64ArrayAttr::get(root->getContext(),
+                                 {reinterpret_cast<int64_t>(allocscope),
+                                  tick.firstAccess, tick.lastAccess}));
+      allocscope->setAttr(
+          "__mergealloc_scope",
+          IntegerAttr::get(mlir::IntegerType::get(root->getContext(), 64),
+                           reinterpret_cast<int64_t>(allocscope)));
+    }
+    return mlir::memoryplan::ScopeTraceData();
+  }
+  return collecter.getTrace();
+}
+
+} // namespace memref
+} // namespace mlir
+
+namespace {
+using namespace mlir;
+struct MergeAllocPass : memref::impl::MergeAllocBase<MergeAllocPass> {
+  using parent = memref::impl::MergeAllocBase<MergeAllocPass>;
+  void runOnOperation() override {
+    auto op = getOperation();
+    BufferViewFlowAnalysis aliasAnaly{op};
+    auto tracesOrFail =
+        memref::collectMemoryTrace(op, aliasAnaly, this->optionCheck);
+    if (failed(tracesOrFail)) {
+      signalPassFailure();
+      return;
+    }
+    if (this->optionCheck) {
+      return;
+    }
+    std::unordered_map<uintptr_t, std::vector<uintptr_t>> dummy;
+    for (auto &[scope, traces] : *tracesOrFail) {
+      std::unordered_map<uintptr_t, std::size_t> outSchedule;
+      if (traces.empty())
+        continue;
+      auto total = memoryplan::scheduleMemoryAllocations(
+          traces, 64, !this->optionNoLocality, memoryplan::InplaceInfoMap(),
+          outSchedule, dummy);
+      auto &block = scope->getRegion(0).getBlocks().front();
+      OpBuilder builder{&block.front()};
+      auto alignment =
+          builder.getIntegerAttr(IntegerType::get(op.getContext(), 64), 64);
+      auto alloc = builder.create<memref::AllocOp>(
+          scope->getLoc(),
+          MemRefType::get({static_cast<int64_t>(total)}, builder.getI8Type()),
+          alignment);
+      for (auto &[key, offset] : outSchedule) {
+        auto origBuf = reinterpret_cast<Operation *>(key);
+        builder.setInsertionPoint(origBuf);
+        auto byteShift = builder.create<arith::ConstantIndexOp>(
+            origBuf->getLoc(), static_cast<int64_t>(offset));
+        auto view = builder.create<memref::ViewOp>(
+            origBuf->getLoc(), origBuf->getResultTypes().front(), alloc,
+            byteShift, ValueRange{});
+        origBuf->replaceAllUsesWith(view->getResults());
+        origBuf->remove();
+      }
+    }
+  }
+
+public:
+  MergeAllocPass(const memref::MergeAllocOptions &o) : parent{o} {}
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass>
+mlir::memref::createMergeAllocPass(const memref::MergeAllocOptions &o) {
+  return std::make_unique<MergeAllocPass>(o);
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/StaticMemoryPlanning.cpp b/mlir/lib/Dialect/MemRef/Transforms/StaticMemoryPlanning.cpp
new file mode 100644
index 0000000000000..11cbdcd989edf
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/StaticMemoryPlanning.cpp
@@ -0,0 +1,655 @@
+//===- StaticMemoryPlanning.cpp - Static memory planning ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <limits>
+#include <map>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <utility>
+
+namespace mlir {
+namespace memoryplan {
+namespace {
+static constexpr size_t divideAndCeil(size_t x, size_t y) {
+  return (x + y - 1) / y;
+}
+// how the buffer was created
+enum class ChunkType {
+  ORIGIN, // the chunk is directly allocated from the large buffer
+  SPLIT,  // the chunk is got by splitting another memory chunk
+  MERGED, // the chunk is got by merging several consecutive memory chunks
+};
+
+struct MemoryState;
+
+struct MemoryChunk {
+  ChunkType type;
+  size_t size;
+  bool isfree = true;
+  size_t lastFreedTick = 0;
+  bool isInplaceSplitRemainder = false;
+  // splits the chunk and get the left hand side with size = size, registers
+  // both the returned chunk and the rest of the chunk to the state
+  void split(MemoryState *state, size_t size, MemoryChunk *&lhs,
+             MemoryChunk *&rhs);
+  // move the buffer, propagate the message up to the parent chunk. It will
+  // not update the siblings.
+  virtual void move(int64_t startDiff) = 0;
+  // extend the buffer, propagate the message up to the parent chunk. It will
+  // not update the siblings.
+  virtual void extend(int64_t sizeDiff) = 0;
+
+  MemoryChunk(ChunkType type, size_t size) : type(type), size(size) {}
+  // there should be no updates to memory chunks after calling
+  // getStartOffset
+  size_t getStartOffset() {
+    if (cached_start_offset == UNINITIALIZED) {
+      cached_start_offset = getStartOffsetImpl();
+    }
+    return cached_start_offset;
+  }
+  virtual ~MemoryChunk() = default;
+
+  virtual size_t getStartOffsetImpl() = 0;
+
+protected:
+  static constexpr size_t UNINITIALIZED = std::numeric_limits<size_t>::max();
+  size_t cached_start_offset = UNINITIALIZED;
+};
+
+// the memory chunk that is directly allocated from the large buffer
+struct OriginChunk : public MemoryChunk {
+  // no parent
+  // MemoryChunk *parent;
+  size_t start;
+  OriginChunk(size_t start, size_t size)
+      : MemoryChunk{ChunkType::ORIGIN, size}, start(start) {}
+  void move(int64_t startDiff) override { start += startDiff; }
+  void extend(int64_t sizeDiff) override { size += sizeDiff; }
+  size_t getStartOffsetImpl() override { return start; }
+};
+
+// the memory chunk that is split from another chunk
+struct split_chunk_t : public MemoryChunk {
+  MemoryChunk *parent;
+  // if the chunk is the left hand side (smaller starting offset)
+  bool is_lhs_;
+  split_chunk_t(size_t size, MemoryChunk *parent, bool is_lhs)
+      : MemoryChunk{ChunkType::SPLIT, size}, parent(parent), is_lhs_(is_lhs) {}
+  void move(int64_t startDiff) override {
+    if (is_lhs_) {
+      parent->move(startDiff);
+    }
+    // no need to pass message to parent for rhs, since lhs has done so
+  }
+  void extend(int64_t sizeDiff) override {
+    size += sizeDiff;
+    parent->extend(sizeDiff);
+    // if is_lhs, we will later call rhs->move(...)
+  }
+  size_t getStartOffsetImpl() override {
+    if (is_lhs_) {
+      return parent->getStartOffset();
+    } else {
+      return parent->getStartOffset() + parent->size - size;
+    }
+  }
+};
+
+static size_t getSizeOfChunks(const std::vector<MemoryChunk *> &c) {
+  size_t v = 0;
+  for (auto chk : c) {
+    v += chk->size;
+  }
+  return v;
+}
+// the memory chunk that is merged from another chunks
+struct MergedChunk : public MemoryChunk {
+  std::vector<MemoryChunk *> parent;
+  MergedChunk(std::vector<MemoryChunk *> &&parent)
+      : MemoryChunk{ChunkType::MERGED, getSizeOfChunks(parent)},
+        parent(std::move(parent)) {}
+  void move(int64_t startDiff) override {
+    for (auto v : parent) {
+      v->move(startDiff);
+    }
+  }
+  void extend(int64_t sizeDiff) override {
+    size += sizeDiff;
+    parent.back()->extend(sizeDiff);
+  }
+  size_t getStartOffsetImpl() override {
+    return parent.front()->getStartOffset();
+  }
+};
+
+struct MemoryState {
+  // buffer_id -> allocated memory chunk, used to collect the final result
+  std::unordered_map<uintptr_t, MemoryChunk *> allocations;
+  // buffer_id -> <current_alive_memory_chunk>, used when inplace
+  // optimization, and when a buffer is inplace reused by another buffer. The
+  // reused buffer will have unchanged MemoryChunk in allocations, because
+  // allocations shows the final result of the buffer. curAllocations
+  // tracks the current mapping of buffer_id to MemoryChunk, which may be
+  // different from allocations
+  std::unordered_map<uintptr_t, MemoryChunk *> curAllocations;
+  // all memory chunks that has been created, takes the ownerships of the
+  // MemoryChunk objects
+  std::vector<std::unique_ptr<MemoryChunk>> chunks;
+  // the current memory chunks, sorted by the starting offset
+  std::vector<MemoryChunk *> curChunks;
+  // free chunks sorted by size
+  std::multimap<size_t, MemoryChunk *> freeChunksBySize;
+  // free chunks sorted by last freed tick
+  std::multimap<size_t, MemoryChunk *> freeChunksByTick;
+  // the current size of the large buffer, in number of elements
+  size_t currentAllocSize = 0;
+  // the alignment in number of elements
+  size_t alignment;
+  // the map from a buffer-id to the buffer-ids that the buffer can inplace
+  // reuse
+  const InplaceInfoMap &inplaceMap;
+  std::unordered_map<uintptr_t, std::vector<uintptr_t>> &outInplaceSelection;
+  int tick = 0;
+  bool hotFirst;
+
+  MemoryState(size_t alignment, bool hotFirst, const InplaceInfoMap &inplaceMap,
+              std::unordered_map<uintptr_t, std::vector<uintptr_t>>
+                  &outInplaceSelection)
+      : alignment(alignment), inplaceMap(inplaceMap),
+        outInplaceSelection(outInplaceSelection), hotFirst(hotFirst) {}
+
+  void removeChunkFromMap(MemoryChunk *target, size_t t,
+                          std::multimap<size_t, MemoryChunk *> &m) {
+    auto mapitr = m.equal_range(t);
+    assert(mapitr.first != mapitr.second);
+    for (auto map_i = mapitr.first; map_i != mapitr.second; ++map_i) {
+      if (map_i->second == target) {
+        m.erase(map_i);
+        break;
+      }
+    }
+  }
+  void removeChunkFromFreeList(MemoryChunk *target) {
+    removeChunkFromMap(target, target->size, freeChunksBySize);
+    removeChunkFromMap(target, target->lastFreedTick, freeChunksByTick);
+    target->isfree = false;
+  }
+
+  void addChunkToFreeList(MemoryChunk *target) {
+    freeChunksBySize.insert(std::make_pair(target->size, target));
+    freeChunksByTick.insert(std::make_pair(target->lastFreedTick, target));
+    target->isfree = true;
+  }
+
+  void extendAlloc(MemoryChunk *target, size_t aligned) {
+    // remove the chunk from free list
+    removeChunkFromFreeList(target);
+    int64_t sizeDiff = aligned - target->size;
+    assert(sizeDiff > 0);
+    currentAllocSize += sizeDiff;
+    // extend the target chunk, also move all buffers at the right of it
+    target->extend(sizeDiff);
+    bool found_target = false;
+    for (auto v : curChunks) {
+      if (v == target) {
+        found_target = true;
+      } else if (found_target) {
+        // v is at the right of the target
+        v->move(sizeDiff);
+      }
+    }
+    assert(found_target);
+    target->isfree = false;
+  }
+
+  MemoryChunk *splitAlloc(MemoryChunk *target, size_t aligned,
+                          MemoryChunk *&rhs_ret) {
+    // found a free chunk that is large enough
+    if (target->size == aligned) {
+      // a perfect match, no need to split
+      auto ret = target;
+      removeChunkFromFreeList(target);
+      return ret;
+    }
+    // split the larger chunk
+    assert(target->size > aligned);
+    auto lhs = std::make_unique<split_chunk_t>(aligned, target, true);
+    auto rhs =
+        std::make_unique<split_chunk_t>(target->size - aligned, target, false);
+    rhs_ret = rhs.get();
+    auto ret = lhs.get();
+
+    auto old_itr_in_cur_chunks =
+        std::find(curChunks.begin(), curChunks.end(), target);
+    assert(old_itr_in_cur_chunks != curChunks.end());
+    // replace old chunk with rhs
+    *old_itr_in_cur_chunks = rhs.get();
+    // insert lhs before rhs
+    curChunks.insert(old_itr_in_cur_chunks, lhs.get());
+    rhs->lastFreedTick = target->lastFreedTick;
+    // add rhs to free list
+    addChunkToFreeList(rhs.get());
+
+    // move ownership
+    chunks.emplace_back(std::move(lhs));
+    chunks.emplace_back(std::move(rhs));
+
+    // remove old chunk in free list
+    removeChunkFromFreeList(target);
+    ret->isfree = false;
+    return ret;
+  }
+
+  float calculateSizeScore(size_t chk_size, size_t alloc_size) const {
+    // size_score = abs(chunk_size-alloc_size)/max(chunk_size, alloc_size)
+    int64_t sizeDiff =
+        static_cast<int64_t>(chk_size) - static_cast<int64_t>(alloc_size);
+    float size_max = static_cast<float>(std::max(alloc_size, chk_size));
+    float size_score = -std::abs(sizeDiff) / size_max;
+    // if we don't need to extend the buffer, add a bounus score for it
+    if (alloc_size <= chk_size) {
+      size_score += 1;
+    }
+    // size_score and tick_score are normalized in [-1,1]. We set a weight
+    // for these two scores: 1:1
+    return size_score;
+  }
+
+  // calculates the score of a free chunk to help select the best chunk we
+  // allocate memory from. It considers 2 factors: 1) the free chunk size and
+  // the size of the current memory allocation request. The closer they are,
+  // the better the chunk is. 2) the heat of the chunk. If the chunk's last
+  // free'd tick is closer to the current tick, the chunk is better.
+  // The better the chunk is, the greater the score is
+  float calculateChunkScore(MemoryChunk *chk, size_t alloc_size,
+                            size_t last_tick) const {
+    // if the buffer is free'd N ticks ago, it will have score max(0, 1 - N
+    // * 0.1)
+    float tick_score = static_cast<float>(tick - last_tick) / 10;
+    tick_score = 1 - std::min(tick_score, 1.0f);
+    // size_score and tick_score are normalized in [-1,1]. We set a weight
+    // for these two scores: 1:1
+    return 1 * calculateSizeScore(chk->size, alloc_size) + 1 * tick_score;
+  }
+
+  MemoryChunk *alloc(uintptr_t bufferid, size_t size) {
+    tick++;
+    auto ret = doAlloc(bufferid, size);
+    allocations[bufferid] = ret;
+    curAllocations[bufferid] = ret;
+    return ret;
+  }
+
+  // check if the buffer is split from a base tensor and check the
+  // InplaceInfo for whether it requires zero offset
+  bool checkBufferOffsetForInplace(MemoryChunk *chunk,
+                                   const InplaceInfo *info) {
+    // if the old memory chunk is splitted from the base tensor
+    bool old_is_split = chunk->isInplaceSplitRemainder;
+    // if the old memory chunk is based on a offset of the base tensor
+    // and we require that we should use zero offset on that tensor, we
+    // cannot reuse it
+    return !(old_is_split && info->second == InplaceKind::ZERO_OFFSET);
+  }
+
+  // find the range of chunks in curChunks that can be merged for inplace
+  // reuse, returns the memory size of the range and the start/end iterators
+  size_t findInplaceMergeRange(
+      MemoryChunk *victim, size_t aligned,
+      const std::unordered_map<MemoryChunk *, const InplaceInfo *> &can_inplace,
+      std::vector<MemoryChunk *>::iterator &to_merge_start,
+      std::vector<MemoryChunk *>::iterator &to_merge_end) {
+    // addChunkToFreeList(chk);
+    auto itr_in_cur_chunks =
+        std::find(curChunks.begin(), curChunks.end(), victim);
+    assert(itr_in_cur_chunks != curChunks.end());
+    // merge right if they are free or can be inplaced
+    to_merge_start = itr_in_cur_chunks;
+    to_merge_end = itr_in_cur_chunks + 1;
+    // remember the memory size we already collected. If
+    // current_collected_size is greater than the memory size to alloc, we
+    // can stop searching
+    size_t current_collected_size = victim->size;
+    // look right to see any one we can merge with
+    for (auto itr = itr_in_cur_chunks + 1;
+         itr != curChunks.end() && current_collected_size < aligned; ++itr) {
+      // if the memory chunk is in use and is in can_inplace map, we may
+      // reuse it now
+      auto inplace_info_itr = can_inplace.find(*itr);
+      if ((*itr)->isfree ||
+          (inplace_info_itr != can_inplace.end() &&
+           inplace_info_itr->second->second == InplaceKind::FREE)) {
+        to_merge_end = itr + 1;
+        current_collected_size += (*itr)->size;
+      } else {
+        break;
+      }
+    }
+    return current_collected_size;
+  }
+
+  // inplace alloc memory on a chunk that is in use, but about to be freed.
+  MemoryChunk *doInplaceAlloc(uintptr_t bufferid, size_t aligned) {
+    if (inplaceMap.empty()) {
+      return nullptr;
+    }
+    auto itr_inplace = inplaceMap.find(bufferid);
+    if (itr_inplace == inplaceMap.end()) {
+      return nullptr;
+    }
+    // if the buffer can inplace reuse some other buffers that is
+    // still in use but about to be freed
+    const auto &buffer_can_inplace = itr_inplace->second;
+    if (buffer_can_inplace.empty()) {
+      return nullptr;
+    }
+
+    // reversed map, chunk --> buffer id for inplace candidates
+    std::unordered_map<MemoryChunk *, const InplaceInfo *> can_inplace;
+    for (auto &v : buffer_can_inplace) {
+      auto itr = curAllocations.find(v.first);
+      if (itr != curAllocations.end()) {
+        can_inplace[itr->second] = &v;
+      }
+    }
+
+    // stage 1, find a victim based on the memory size that can be freed
+    float target_score = -std::numeric_limits<float>::infinity();
+    MemoryChunk *victim = nullptr;
+    std::vector<MemoryChunk *>::iterator to_merge_start;
+    std::vector<MemoryChunk *>::iterator to_merge_end;
+    size_t current_collected_size = 0;
+    for (auto &bufinfo : buffer_can_inplace) {
+      auto buf_id = bufinfo.first;
+      auto old_buf_itr = curAllocations.find(buf_id);
+      // if the buffer has already been reused by other buffers, skip
+      if (old_buf_itr == curAllocations.end()) {
+        continue;
+      }
+      // the old memory chunk
+      auto old_buf = old_buf_itr->second;
+
+      auto &old_inplace_info = can_inplace[old_buf];
+      if (!checkBufferOffsetForInplace(old_buf, old_inplace_info)) {
+        continue;
+      }
+
+      std::vector<MemoryChunk *>::iterator cur_merge_start;
+      std::vector<MemoryChunk *>::iterator cur_merge_end;
+      auto cur_size = findInplaceMergeRange(old_buf, aligned, can_inplace,
+                                            cur_merge_start, cur_merge_end);
+      float score = calculateSizeScore(cur_size, aligned);
+      if (score > target_score) {
+        target_score = score;
+        victim = old_buf;
+        to_merge_start = cur_merge_start;
+        to_merge_end = cur_merge_end;
+        current_collected_size = cur_size;
+      }
+    }
+    if (current_collected_size * 10 < aligned) {
+      // if the memory can be reused is too small (less than 10% of the
+      // target size), inplacing has no benifits, skip
+      return nullptr;
+    }
+    if (!victim) {
+      return nullptr;
+    }
+    assert(!victim->isfree);
+
+    victim->lastFreedTick = tick;
+
+    std::vector<MemoryChunk *> merged_buffers(to_merge_start, to_merge_end);
+    for (auto buf : merged_buffers) {
+      auto itr = can_inplace.find(buf);
+      if (itr != can_inplace.end()) {
+        uintptr_t vic_buffer_id = itr->second->first;
+        if (vic_buffer_id) {
+          outInplaceSelection[bufferid].emplace_back(vic_buffer_id);
+          DEBUG_WITH_TYPE("memplan", llvm::dbgs() << "Buffer " << bufferid
+                                                  << " inplace reuses "
+                                                  << vic_buffer_id << "\n");
+        }
+      }
+    }
+    if (current_collected_size < aligned) {
+      // if the collected memory size is still less than the size to
+      // alloc, need to extend
+      auto target_size =
+          aligned - current_collected_size + merged_buffers.back()->size;
+      if (!merged_buffers.back()->isfree) {
+        // if it is not free, we are inplacing it. Temporarily move to
+        // free list
+        addChunkToFreeList(merged_buffers.back());
+      }
+      extendAlloc(merged_buffers.back(), target_size);
+      // after extension of the last buffer, the collected size is equal
+      // to the size to alloc
+      current_collected_size = aligned;
+    }
+
+    // remove from freelist and buffer_id->chunk map
+    for (auto itr = to_merge_start; itr != to_merge_end; ++itr) {
+      auto chunk = *itr;
+      if (chunk->isfree) {
+        removeChunkFromFreeList(chunk);
+      }
+      auto itr_chunk = can_inplace.find(chunk);
+      if (itr_chunk != can_inplace.end()) {
+        curAllocations.erase(itr_chunk->second->first);
+      }
+    }
+
+    MemoryChunk *merged_chunk;
+    // if we need to merge multiple chunks
+    if (to_merge_end - to_merge_start > 1) {
+      // do merge
+      chunks.emplace_back(std::make_unique<MergedChunk>(
+          std::vector<MemoryChunk *>(merged_buffers)));
+      merged_chunk = chunks.back().get();
+      // remove merged chunks from free list and cur_chunk list
+      // add merged chunk to cur_chunks and free_chunks_by_size
+      *to_merge_start = merged_chunk;
+      merged_chunk->lastFreedTick = tick;
+      merged_chunk->isfree = false;
+      curChunks.erase(to_merge_start + 1, to_merge_end);
+    } else {
+      merged_chunk = victim;
+      merged_chunk->lastFreedTick = tick;
+    }
+
+    // merged_chunk is in curChunks and is removed from freelist and
+    // curAllocations map
+    if (current_collected_size == aligned) {
+      // if is extended, or perfect match, just return the chunk
+      merged_chunk->isfree = false;
+      return merged_chunk;
+    } else {
+      // otherwise, there are some unused memory in the chunk to be
+      // reused. We need to split it. If the RHS of the chunk is from a
+      // inplace reused buffer, need to add a mapping of the buffer id to
+      // the RHS remaining chunk
+      if (!merged_chunk->isfree) {
+        addChunkToFreeList(merged_chunk);
+      }
+      MemoryChunk *rhs = nullptr;
+      auto ret = splitAlloc(merged_chunk, aligned, rhs);
+      auto itr_chunk = can_inplace.find(merged_buffers.back());
+      if (itr_chunk != can_inplace.end()) {
+        // if the last chunk is from inplace map, the RHS chunk is not
+        // really freed, need to remove from free list and mark it not
+        // freed.
+        removeChunkFromFreeList(rhs);
+        rhs->isInplaceSplitRemainder = true;
+        // update the buffer id -> chunk map, so that when freeing the
+        // inplaced buffer, we can find the correct remaining buffer
+        curAllocations[itr_chunk->second->first] = rhs;
+      }
+      return ret;
+    }
+  }
+
+  MemoryChunk *doAlloc(uintptr_t bufferid, size_t size) {
+    auto aligned = divideAndCeil(size, alignment) * alignment;
+    // try inplace
+    if (auto inp_ret = doInplaceAlloc(bufferid, size)) {
+      return inp_ret;
+    }
+    if (freeChunksBySize.empty()) {
+      chunks.emplace_back(
+          std::make_unique<OriginChunk>(currentAllocSize, aligned));
+      currentAllocSize += aligned;
+      auto ret = chunks.back().get();
+      curChunks.emplace_back(ret);
+      ret->isfree = false;
+      return ret;
+    }
+    if (hotFirst) {
+      MemoryChunk *target = freeChunksByTick.rbegin()->second;
+      float target_score = calculateChunkScore(
+          target, aligned, freeChunksByTick.rbegin()->first);
+      for (auto &kv : freeChunksByTick) {
+        float score = calculateChunkScore(kv.second, aligned, kv.first);
+        if (score > target_score) {
+          target = kv.second;
+          target_score = score;
+        }
+      }
+      if (target->size < aligned) {
+        extendAlloc(target, aligned);
+        return target;
+      } else {
+        MemoryChunk *rhs;
+        return splitAlloc(target, aligned, rhs);
+      }
+    } else {
+      // find a free chunk that best fits the current size
+      // itr will be the smallest chunk whose size >= aligned
+      auto itr = freeChunksBySize.lower_bound(aligned);
+      if (itr == freeChunksBySize.end()) {
+        MemoryChunk *target;
+        // itr points to the last element
+        --itr;
+        // if not found, this means that all free chunk is smaller than
+        // aligned size, switch to the largest chunk
+        target = itr->second;
+        extendAlloc(target, aligned);
+        return target;
+      } else {
+        MemoryChunk *rhs;
+        return splitAlloc(itr->second, aligned, rhs);
+      }
+    }
+  }
+
+  void dealloc(MemoryChunk *chk) {
+    tick++;
+    chk->lastFreedTick = tick;
+    addChunkToFreeList(chk);
+    auto itr_in_cur_chunks = std::find(curChunks.begin(), curChunks.end(), chk);
+    assert(itr_in_cur_chunks != curChunks.end());
+    // merge left and right if they are free
+    std::vector<MemoryChunk *>::iterator to_merge_start = itr_in_cur_chunks;
+    std::vector<MemoryChunk *>::iterator to_merge_end = itr_in_cur_chunks + 1;
+    // look left to see any one we can merge with
+    for (auto itr = itr_in_cur_chunks;; --itr) {
+      if ((*itr)->isfree) {
+        to_merge_start = itr;
+      } else {
+        break;
+      }
+      if (itr == curChunks.begin()) {
+        break;
+      }
+    }
+    // look right to see any one we can merge with
+    for (auto itr = itr_in_cur_chunks + 1; itr != curChunks.end(); ++itr) {
+      if ((*itr)->isfree) {
+        to_merge_end = itr + 1;
+      } else {
+        break;
+      }
+    }
+    if (to_merge_end - to_merge_start > 1) {
+      // do merge
+      chunks.emplace_back(std::make_unique<MergedChunk>(
+          std::vector<MemoryChunk *>(to_merge_start, to_merge_end)));
+
+      // remove merged chunks from free list and cur_chunk list
+      for (auto itr = to_merge_start; itr != to_merge_end; ++itr) {
+        auto chunk = *itr;
+        removeChunkFromFreeList(chunk);
+      }
+      // add merged chunk to cur_chunks and free_chunks_by_size
+      *to_merge_start = chunks.back().get();
+      chunks.back()->lastFreedTick = tick;
+      addChunkToFreeList(chunks.back().get());
+      curChunks.erase(to_merge_start + 1, to_merge_end);
+    }
+    // else, no chunks are merged, do nothing
+  }
+
+  void dealloc(uintptr_t bufferid) {
+    auto alocitr = allocations.find(bufferid);
+    assert(alocitr != allocations.end() &&
+           "Cannot find buffer id in allocations");
+    auto itr = curAllocations.find(bufferid);
+    if (itr != curAllocations.end()) {
+      itr->second->isInplaceSplitRemainder = false;
+      dealloc(itr->second);
+      curAllocations.erase(itr);
+    }
+  }
+
+  std::string toString() const {
+    std::stringstream ss;
+    ss << "total size " << currentAllocSize << " ";
+    size_t cur_offset = 0;
+    for (auto buf : curChunks) {
+      ss << "| " << cur_offset << ',' << buf->size << ',' << buf->isfree << " ";
+      cur_offset += buf->size;
+    }
+    return ss.str();
+  }
+};
+} // namespace
+
+size_t scheduleMemoryAllocations(
+    const Traces &traces, std::size_t alignment, bool hotFirst,
+    const InplaceInfoMap &inplaceMap,
+    std::unordered_map<uintptr_t, std::size_t> &outSchedule,
+    std::unordered_map<uintptr_t, std::vector<uintptr_t>>
+        &outInplaceSelection) {
+  MemoryState planner{alignment, hotFirst, inplaceMap, outInplaceSelection};
+  for (auto &trace : traces) {
+    if (trace.size > 0) {
+      planner.alloc(trace.bufferId, trace.size);
+      DEBUG_WITH_TYPE("memplan", llvm::dbgs() << "Alloc " << trace.bufferId
+                                              << ", sz=" << trace.size << "\n"
+                                              << planner.toString() << "\n");
+    } else {
+      planner.dealloc(trace.bufferId);
+      DEBUG_WITH_TYPE("memplan", llvm::dbgs()
+                                     << "Dealloc " << trace.bufferId << "\n"
+                                     << planner.toString() << "\n");
+    }
+  }
+  for (auto &kv : planner.allocations) {
+    outSchedule[kv.first] = kv.second->getStartOffset();
+  }
+  return planner.currentAllocSize;
+}
+
+} // namespace memoryplan
+} // namespace mlir
diff --git a/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir b/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir
new file mode 100644
index 0000000000000..96cf6b79e5242
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir
@@ -0,0 +1,129 @@
+// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(func.func(merge-alloc{check}))'  %s | FileCheck %s
+
+// CHECK-DAG: func.func @basic() -> memref<8x64xf32>  attributes {__mergealloc_scope = [[TOPSCOPE:[0-9]+]]
+func.func @basic() -> memref<8x64xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c5 = arith.constant 5 : index
+  %ctrue = arith.constant 1 : i1
+  // b is used in return, complex lifetime
+  // CHECK-DAG: %[[B:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE]], -2, -2>}
+  %b = memref.alloc() : memref<8x64xf32>
+  "test.source"(%b)  : (memref<8x64xf32>) -> ()
+  // c and d has overlapping lifetime
+  // CHECK-DAG: %[[C:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE]], 11, 14>}
+  %c = memref.alloc() : memref<8x64xf32>
+  "test.source"(%c)  : (memref<8x64xf32>) -> ()
+  // CHECK-DAG: %[[D:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE]], 13, 13>}
+  %d = memref.alloc() : memref<8x64xf32>
+  "test.source"(%d)  : (memref<8x64xf32>) -> ()
+  "test.source"(%c)  : (memref<8x64xf32>) -> ()
+  // e and f have overlapping lifetime due to the loop
+  // CHECK-DAG: %[[E:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE]], 17, 22>}
+  // CHECK-DAG: %[[F:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE]], 17, 22>}
+  %e = memref.alloc() : memref<8x64xf32>
+  %f = memref.alloc() : memref<8x64xf32>
+  // CHECK: scf.for
+  scf.for %i = %c0 to %c5 step %c1 {
+    "test.source"(%e)  : (memref<8x64xf32>) -> ()
+    "test.source"(%f)  : (memref<8x64xf32>) -> ()
+    // CHECK-DAG: %[[G:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE]], 21, 21>}
+    %g = memref.alloc() : memref<8x64xf32>
+    "test.source"(%g)  : (memref<8x64xf32>) -> ()
+  }
+  // CHECK-DAG: %[[H:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE]], 24, 39>}
+  %h = memref.alloc() : memref<8x64xf32>
+  // CHECK: scf.forall
+  scf.forall (%iv) in (%c5) {
+    // check that the alloc in the forall should switch to another scope id
+    // CHECK-NOT: array<i64: [[TOPSCOPE]]
+    // CHECK-DAG: %[[L:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[FORSCOPE:[0-9]+]], 27, 27>}
+    %l = memref.alloc() : memref<8x64xf32>
+    "test.source"(%h)  : (memref<8x64xf32>) -> ()
+    "test.source"(%l)  : (memref<8x64xf32>) -> ()
+    scf.for %i = %c0 to %c5 step %c1 {
+      // CHECK-DAG: %[[G:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[FORSCOPE]], 30, 30>}
+      %g = memref.alloc() : memref<8x64xf32>
+      "test.source"(%g)  : (memref<8x64xf32>) -> ()
+    }
+    // CHECK-DAG: %[[K:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[FORSCOPE]], 33, 38>}
+    %k = memref.alloc() : memref<8x64xf32>
+    scf.if %ctrue {
+      // CHECK-DAG: %[[J:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[FORSCOPE]], 35, 35>}
+      %j = memref.alloc() : memref<8x64xf32>
+      "test.source"(%j)  : (memref<8x64xf32>) -> ()
+    } else {
+      "test.source"(%k)  : (memref<8x64xf32>) -> ()
+    }
+    // CHECK-DAG: {__mergealloc_scope = [[FORSCOPE]] : i64}
+  }
+  return %b : memref<8x64xf32>
+}
+
+// CHECK-DAG: func.func @basic2() attributes {__mergealloc_scope = [[TOPSCOPE2:[0-9]+]]
+func.func @basic2() {
+  // CHECK-DAG: %[[B:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE2]], 4, 6>}
+  %b = memref.alloc() : memref<8x64xi8>
+  %cur = memref.subview %b[1,0][1,64][1,1] : memref<8x64xi8> to memref<1x64xi8, strided<[64, 1], offset: 64>>
+  "test.source"(%cur)  : (memref<1x64xi8, strided<[64, 1], offset: 64>>) -> ()
+  %cur2 = memref.subview %cur[0,0][1,16][1,1] : memref<1x64xi8, strided<[64, 1], offset: 64>> to memref<1x16xi8, strided<[64, 1], offset: 64>>
+  "test.source"(%cur2)  : (memref<1x16xi8, strided<[64, 1], offset: 64>>) -> ()
+  // CHECK-DAG: %[[C:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE2]], 8, 8>}
+  %c = memref.alloc() : memref<8x64xi8>
+  "test.source"(%c)  : (memref<8x64xi8>) -> ()
+  return
+}
+
+// check that the operations without memory effects do not contribute to the lifetime of the buffer
+// CHECK-DAG: func.func @no_mem_effect() attributes {__mergealloc_scope = [[TOPSCOPE3:[0-9]+]]
+func.func @no_mem_effect() {
+  // CHECK: %[[B:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE3]], 4, 4>}
+  %b = memref.alloc() : memref<8x64xi8>
+  %0 = memref.extract_aligned_pointer_as_index %b : memref<8x64xi8> -> index
+  "test.source"(%b)  : (memref<8x64xi8>) -> ()
+  return
+}
+
+// check that Alias Buffers' lifetimes work well
+// CHECK-DAG: func.func @alias_ref(%[[ARG0:.*]]: i1) attributes {__mergealloc_scope = [[TOPSCOPE4:[0-9]+]]
+func.func @alias_ref(%pred : i1) {
+  // CHECK: %[[A:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE4]], 5, 5>}
+  %a = memref.alloc() : memref<8x64xi8>
+  // CHECK: %[[B:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE4]], 5, 6>}
+  %b = memref.alloc() : memref<8x64xi8>
+  %c = arith.select %pred, %a, %b : i1, memref<8x64xi8>
+  "test.source"(%c)  : (memref<8x64xi8>) -> ()
+  "test.source"(%b)  : (memref<8x64xi8>) -> ()
+  return
+}
+
+// CHECK-DAG: func.func @escape_from_if()  attributes {__mergealloc_scope = [[TOPSCOPE5:[0-9]+]]
+func.func @escape_from_if() {
+  %ctrue = arith.constant 1 : i1
+  // check that f lives at the whole range of the following scf.if 
+  // CHECK-DAG: %[[F:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE5]], 4, 13>}
+  %f = memref.alloc() : memref<8x64xf32>
+  // tick of the scf.if starts from 4 and ends at 14
+  // CHECK: scf.if
+  %c = scf.if %ctrue -> memref<8x64xf32> {
+    "test.source"(%f)  : (memref<8x64xf32>) -> ()
+    // CHECK-DAG: %[[G:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE5]], 4, 14>}
+    %g = memref.alloc() : memref<8x64xf32>
+    "test.source"(%g)  : (memref<8x64xf32>) -> ()
+    scf.yield %g : memref<8x64xf32>
+  } else {
+    // h fully overlaps with g
+    // CHECK-DAG: %[[H:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE5]], 4, 14>}
+    %h = memref.alloc() : memref<8x64xf32>
+    "test.source"(%h)  : (memref<8x64xf32>) -> ()
+    // J only used in the scf.if, don't need conservative lifetime
+    // CHECK-DAG: %[[J:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE5]], 12, 12>}
+    %j = memref.alloc() : memref<8x64xf32>
+    "test.source"(%j)  : (memref<8x64xf32>) -> ()
+    scf.yield %h : memref<8x64xf32>
+  }
+  "test.source"(%c)  : (memref<8x64xf32>) -> ()
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/MemRef/buffer-merge-mlp.mlir b/mlir/test/Dialect/MemRef/buffer-merge-mlp.mlir
new file mode 100644
index 0000000000000..ae002c5b0d34c
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/buffer-merge-mlp.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -one-shot-bufferize="unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" --merge-alloc %s | FileCheck %s
+
+func.func @mlp(%x: tensor<128x128xf32>, %y: tensor<128x128xf32>) -> tensor<128x128xf32> {
+   // CHECK-DAG:  %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<131072xi8>
+   // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+   // CHECK-DAG:  %[[VIEW_A:.*]] = memref.view %[[ALLOC]][%[[C0]]][] : memref<131072xi8> to memref<128x128xf32>
+   %a0 = tensor.empty() : tensor<128x128xf32>
+   // CHECK:      linalg.matmul ins
+   // CHECK-SAME: outs(%[[VIEW_A]] : memref<128x128xf32>)
+   %a = linalg.matmul ins(%x, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%a0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   // CHECK-DAG:  %[[C65536:.*]] = arith.constant 65536 : index
+   // CHECK-DAG:  %[[VIEW_B:.*]] = memref.view %[[ALLOC]][%[[C65536]]][] : memref<131072xi8> to memref<128x128xf32>
+   %b0 = tensor.empty() : tensor<128x128xf32>
+   // CHECK:      linalg.matmul ins(%[[VIEW_A]],
+   // CHECK-SAME: outs(%[[VIEW_B]] : memref<128x128xf32>)
+   %b = linalg.matmul ins(%a, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%b0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   // CHECK-DAG:  %[[C0_2:.*]] = arith.constant 0 : index
+   // CHECK-DAG:  %[[VIEW_C:.*]] = memref.view %[[ALLOC]][%[[C0_2]]][] : memref<131072xi8> to memref<128x128xf32>
+   %c0 = tensor.empty() : tensor<128x128xf32>
+   // CHECK:      linalg.matmul ins(%[[VIEW_B]],
+   // CHECK-SAME: outs(%[[VIEW_C]] : memref<128x128xf32>)
+   %c = linalg.matmul ins(%b, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%c0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   // CHECK-DAG:  %[[D:.*]] = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32>
+   // CHECK:      linalg.matmul ins(%[[VIEW_C]],
+   // CHECK-SAME: outs(%[[D]] : memref<128x128xf32>)
+   %d0 = tensor.empty() : tensor<128x128xf32>
+   %d = linalg.matmul ins(%c, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%d0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   // CHECK:      return %[[D]]
+   return %d : tensor<128x128xf32>
+}
\ No newline at end of file
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index c3f349ad8ec55..1c4cf64f6a238 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,7 +1,9 @@
 add_mlir_unittest(MLIRMemRefTests
   InferShapeTest.cpp
+  StaticMemoryPlanning.cpp
 )
 target_link_libraries(MLIRMemRefTests
   PRIVATE
   MLIRMemRefDialect
+  MLIRMemRefTransforms
   )
diff --git a/mlir/unittests/Dialect/MemRef/StaticMemoryPlanning.cpp b/mlir/unittests/Dialect/MemRef/StaticMemoryPlanning.cpp
new file mode 100644
index 0000000000000..60852d8740b0b
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/StaticMemoryPlanning.cpp
@@ -0,0 +1,206 @@
+//===- StaticMemoryPlanning.cpp - Tests for StaticMemoryPlanning-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h"
+
+#include "gtest/gtest.h"
+#include <iostream>
+
+using namespace mlir;
+TEST(static_memory_planner, TestStaticMemoryPlanning) {
+  /*
+  {0}                   {160}               {280}                 {450}
+  |           0         |           1        |          2          |
+  |    3    |   4/5     |                                          |
+  |    7    |                                     6                |
+            {100}
+  */
+  memoryplan::Traces traces = {{0, 100}, {1, 120}, {2, 100}, {0, 0},
+                               {3, 50},  {4, 60},  {2, 0},   {4, 0},
+                               {8, 100}, {8, 0},   {5, 60},  {5, 0},
+                               {1, 0},   {6, 350}, {3, 0},   {7, 100}};
+  std::unordered_map<uintptr_t, size_t> out;
+  std::unordered_map<uintptr_t, std::vector<uintptr_t>> inplace_selection;
+  size_t total = memoryplan::scheduleMemoryAllocations(
+      traces, 1, false, memoryplan::InplaceInfoMap(), out, inplace_selection);
+  std::unordered_map<uintptr_t, size_t> expected_out = {
+      {0, 0},   {1, 160}, {2, 280}, {3, 0},  {4, 100},
+      {5, 100}, {6, 100}, {7, 0},   {8, 280}};
+  EXPECT_EQ(total, 450UL);
+  EXPECT_EQ(out, expected_out);
+
+  total = memoryplan::scheduleMemoryAllocations(
+      traces, 1, true, memoryplan::InplaceInfoMap(), out, inplace_selection);
+  expected_out = {{0, 0},   {1, 160}, {2, 280}, {3, 0},  {4, 100},
+                  {5, 100}, {6, 100}, {7, 0},   {8, 280}};
+  EXPECT_EQ(total, 450UL);
+  EXPECT_EQ(out, expected_out);
+}
+
+TEST(static_memory_planner, TestStaticMemoryPlanningInplace) {
+  using namespace memoryplan;
+  using inplace_outdata = std::unordered_map<uintptr_t, std::vector<uintptr_t>>;
+  using inplace_data = InplaceInfoMap;
+  // simple inplace (need merge + split)
+  {
+    memoryplan::Traces traces = {{1, 120}, {2, 100}, {3, 200}, {1, 0},
+                                 {2, 0},   {3, 0},   {4, 220}, {4, 0}};
+    std::unordered_map<uintptr_t, size_t> out;
+    inplace_outdata inplace_selection;
+    inplace_data inplace_hint = {
+        {3, {{1, InplaceKind::FREE}, {2, InplaceKind::FREE}}}};
+    size_t total = memoryplan::scheduleMemoryAllocations(
+        traces, 1, false, inplace_hint, out, inplace_selection);
+    std::unordered_map<uintptr_t, size_t> expected_out = {
+        {1, 0}, {2, 120}, {3, 0}, {4, 0}};
+    EXPECT_EQ(total, 220UL);
+    EXPECT_EQ(out, expected_out);
+
+    inplace_outdata expected_inplace = {{3, {1, 2}}};
+    EXPECT_EQ(inplace_selection, expected_inplace);
+  }
+
+  // inplace extend
+  {
+    memoryplan::Traces traces = {{1, 120}, {2, 100}, {4, 250}, {3, 250},
+                                 {1, 0},   {2, 0},   {3, 0},   {4, 0}};
+    std::unordered_map<uintptr_t, size_t> out;
+    inplace_outdata inplace_selection;
+    inplace_data inplace_hint = {
+        {3, {{1, InplaceKind::FREE}, {2, InplaceKind::FREE}}}};
+    size_t total = memoryplan::scheduleMemoryAllocations(
+        traces, 1, false, inplace_hint, out, inplace_selection);
+    std::unordered_map<uintptr_t, size_t> expected_out = {
+        {1, 0}, {2, 120}, {3, 0}, {4, 250}};
+    EXPECT_EQ(total, 500UL);
+    EXPECT_EQ(out, expected_out);
+
+    inplace_outdata expected_inplace = {{3, {1, 2}}};
+    EXPECT_EQ(inplace_selection, expected_inplace);
+  }
+
+  // inplace 2 buffers into one
+  {
+    memoryplan::Traces traces = {{1, 120}, {2, 100}, {3, 150}, {4, 50}, {5, 10},
+                                 {1, 0},   {2, 0},   {3, 0},   {4, 0},  {5, 0}};
+    std::unordered_map<uintptr_t, size_t> out;
+    inplace_outdata inplace_selection;
+    inplace_data inplace_hint = {
+        {3, {{1, InplaceKind::FREE}, {2, InplaceKind::FREE}}},
+        {4, {{1, InplaceKind::FREE}, {2, InplaceKind::FREE}}}};
+    size_t total = memoryplan::scheduleMemoryAllocations(
+        traces, 1, false, inplace_hint, out, inplace_selection);
+    std::unordered_map<uintptr_t, size_t> expected_out = {
+        {1, 0}, {2, 120}, {3, 0}, {4, 150}, {5, 220}};
+    EXPECT_EQ(total, 230UL);
+    EXPECT_EQ(out, expected_out);
+
+    inplace_outdata expected_inplace = {{3, {1, 2}}, {4, {2}}};
+    EXPECT_EQ(inplace_selection, expected_inplace);
+  }
+
+  // inplace 2 buffers into one, but require zero offset
+  {
+    memoryplan::Traces traces = {{1, 120}, {2, 100}, {3, 150}, {4, 50}, {5, 10},
+                                 {1, 0},   {2, 0},   {3, 0},   {4, 0},  {5, 0}};
+    std::unordered_map<uintptr_t, size_t> out;
+    inplace_outdata inplace_selection;
+    inplace_data inplace_hint = {
+        {3, {{1, InplaceKind::FREE}, {2, InplaceKind::ZERO_OFFSET}}},
+        {4, {{1, InplaceKind::FREE}, {2, InplaceKind::FREE}}}};
+    size_t total = memoryplan::scheduleMemoryAllocations(
+        traces, 1, false, inplace_hint, out, inplace_selection);
+    std::unordered_map<uintptr_t, size_t> expected_out = {
+        {1, 0}, {2, 150}, {3, 0}, {4, 150}, {5, 250}};
+    EXPECT_EQ(total, 260UL);
+    EXPECT_EQ(out, expected_out);
+
+    inplace_outdata expected_inplace = {{3, {1}}, {4, {2}}};
+    EXPECT_EQ(inplace_selection, expected_inplace);
+  }
+
+  // inplace 2 buffers into one, but require zero offset for split buffer
+  // buffer4 cannot reuse buffer 2 because it requires zero offset
+  {
+    memoryplan::Traces traces = {{1, 120}, {2, 100}, {3, 150}, {4, 50}, {5, 10},
+                                 {1, 0},   {2, 0},   {3, 0},   {4, 0},  {5, 0}};
+    std::unordered_map<uintptr_t, size_t> out;
+    inplace_outdata inplace_selection;
+    inplace_data inplace_hint = {
+        {3, {{1, InplaceKind::FREE}, {2, InplaceKind::FREE}}},
+        {4, {{1, InplaceKind::FREE}, {2, InplaceKind::ZERO_OFFSET}}}};
+    size_t total = memoryplan::scheduleMemoryAllocations(
+        traces, 1, false, inplace_hint, out, inplace_selection);
+    std::unordered_map<uintptr_t, size_t> expected_out = {
+        {1, 0}, {2, 120}, {3, 0}, {4, 220}, {5, 270}};
+    EXPECT_EQ(total, 280UL);
+    EXPECT_EQ(out, expected_out);
+
+    inplace_outdata expected_inplace = {{3, {1, 2}}};
+    EXPECT_EQ(inplace_selection, expected_inplace);
+  }
+
+  // merge free to the right
+  {
+    memoryplan::Traces traces = {{1, 120}, {2, 100}, {3, 150}, {2, 0}, {4, 150},
+                                 {5, 10},  {1, 0},   {3, 0},   {4, 0}, {5, 0}};
+    std::unordered_map<uintptr_t, size_t> out;
+    inplace_outdata inplace_selection;
+    inplace_data inplace_hint = {{4, {{1, InplaceKind::FREE}}}};
+    size_t total = memoryplan::scheduleMemoryAllocations(
+        traces, 1, false, inplace_hint, out, inplace_selection);
+    std::unordered_map<uintptr_t, size_t> expected_out = {
+        {1, 0}, {2, 120}, {3, 220}, {4, 0}, {5, 150}};
+    EXPECT_EQ(total, 370UL);
+    EXPECT_EQ(out, expected_out);
+
+    inplace_outdata expected_inplace = {{4, {1}}};
+    EXPECT_EQ(inplace_selection, expected_inplace);
+  }
+
+  // perfect matches
+  {
+    memoryplan::Traces traces = {{1, 120}, {2, 100}, {3, 100}, {4, 120},
+                                 {1, 0},   {2, 0},   {3, 0},   {4, 0},
+                                 {5, 200}, {5, 0}};
+    std::unordered_map<uintptr_t, size_t> out;
+    inplace_outdata inplace_selection;
+    inplace_data inplace_hint = {
+        {3, {{1, InplaceKind::FREE}, {2, InplaceKind::FREE}}},
+        {4, {{1, InplaceKind::FREE}, {2, InplaceKind::FREE}}}};
+    size_t total = memoryplan::scheduleMemoryAllocations(
+        traces, 1, false, inplace_hint, out, inplace_selection);
+    std::unordered_map<uintptr_t, size_t> expected_out = {
+        {1, 0}, {2, 120}, {3, 120}, {4, 0}, {5, 0}};
+    EXPECT_EQ(total, 220UL);
+    EXPECT_EQ(out, expected_out);
+
+    inplace_outdata expected_inplace = {{3, {2}}, {4, {1}}};
+    EXPECT_EQ(inplace_selection, expected_inplace);
+  }
+
+  // selected inputs
+  {
+    memoryplan::Traces traces = {{1, 120}, {2, 100}, {3, 100}, {4, 120},
+                                 {1, 0},   {2, 0},   {3, 0},   {4, 0},
+                                 {5, 200}, {5, 0}};
+    std::unordered_map<uintptr_t, size_t> out;
+    inplace_outdata inplace_selection;
+    inplace_data inplace_hint = {{3, {{1, InplaceKind::FREE}}},
+                                 {4, {{2, InplaceKind::FREE}}}};
+    size_t total = memoryplan::scheduleMemoryAllocations(
+        traces, 1, false, inplace_hint, out, inplace_selection);
+    std::unordered_map<uintptr_t, size_t> expected_out = {
+        {1, 0}, {2, 120}, {3, 0}, {4, 120}, {5, 0}};
+    EXPECT_EQ(total, 240UL);
+    EXPECT_EQ(out, expected_out);
+
+    inplace_outdata expected_inplace = {{3, {1}}, {4, {2}}};
+    EXPECT_EQ(inplace_selection, expected_inplace);
+  }
+}

>From 44d6a812b4aab7d6de54ab528b1316f73308a319 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Wed, 5 Jun 2024 15:29:43 +0800
Subject: [PATCH 02/12] enhance and add test

---
 .../Dialect/MemRef/Transforms/MergeAlloc.h    |  74 ++++
 .../Dialect/MemRef/Transforms/CMakeLists.txt  |   1 +
 .../Dialect/MemRef/Transforms/MergeAlloc.cpp  | 321 ++-------------
 .../MemRef/Transforms/MergeAllocTickBased.cpp | 378 ++++++++++++++++++
 .../Dialect/MemRef/buffer-merge-invalid.mlir  |  11 +
 .../test/Dialect/MemRef/buffer-merge-mlp.mlir |  30 --
 mlir/test/Dialect/MemRef/buffer-merge.mlir    |  96 +++++
 7 files changed, 583 insertions(+), 328 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h
 create mode 100644 mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
 create mode 100644 mlir/test/Dialect/MemRef/buffer-merge-invalid.mlir
 delete mode 100644 mlir/test/Dialect/MemRef/buffer-merge-mlp.mlir
 create mode 100644 mlir/test/Dialect/MemRef/buffer-merge.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h
new file mode 100644
index 0000000000000..45b584e99e044
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h
@@ -0,0 +1,74 @@
+//===- MergeAlloc.h - The interfaces for merge alloc pass -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_MERGEALLOC_H
+#define MLIR_DIALECT_MEMREF_MERGEALLOC_H
+
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+
+namespace mlir {
+class BufferViewFlowAnalysis;
+namespace memref {
+struct MergeAllocOptions;
+// abstract base class for lifetime of different buffers. It should hold the
+// lifetime informantion of buffers that are to be merged in the same allocation
+// in an "allocation scope". TraceCollectorFunc decides which buffers are put
+// into which "allocation scope".
+class LifetimeTrace {
+public:
+  enum TraceKind { TK_TICK };
+  virtual ~LifetimeTrace() = default;
+  LifetimeTrace(TraceKind kind) : kind{kind} {}
+  TraceKind getKind() const { return kind; }
+
+private:
+  TraceKind kind;
+};
+
+// top level memory trace info for multiple scopes. Each key-value is the
+// traces and location for buffers in the same "allocation scope"
+struct MemoryTraces {
+  llvm::DenseMap<Operation *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
+  MemoryTraces() = default;
+};
+
+// the memory scheduling result for allocations in the same merged buffer.
+// allocation => offset map. All Operation* in the map should be memref::AllocOp
+// which are in the same LifetimeTrace.
+struct MemorySchedule {
+  size_t totalSize;
+  llvm::DenseMap<Operation *, int64_t> allocToOffset;
+  MemorySchedule() : totalSize{0} {}
+};
+
+using TraceCollectorFunc = llvm::function_ref<FailureOr<MemoryTraces>(
+    Operation *, const BufferViewFlowAnalysis &, const MergeAllocOptions &)>;
+using MemoryPlannerFunc = llvm::function_ref<FailureOr<MemorySchedule>(
+    Operation *, const LifetimeTrace &, const MergeAllocOptions &)>;
+using MemoryMergeMutatorFunc = llvm::function_ref<LogicalResult(
+    Operation *toplevel, Operation *scope, const MemorySchedule &,
+    const MergeAllocOptions &)>;
+
+FailureOr<MemoryTraces>
+tickBasedCollectMemoryTrace(Operation *root,
+                            const mlir::BufferViewFlowAnalysis &aliasAnaly,
+                            const MergeAllocOptions &option);
+
+FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
+                                              const LifetimeTrace &tr,
+                                              const MergeAllocOptions &o);
+LogicalResult tickBasedMutateAllocations(Operation *op, Operation *scope,
+                                         const MemorySchedule &schedule,
+                                         const MergeAllocOptions &o);
+} // namespace memref
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 456bbafd9af5e..357a1beb56a94 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   ResolveShapedTypeResultDims.cpp
   RuntimeOpVerification.cpp
   MergeAlloc.cpp
+  MergeAllocTickBased.cpp
   StaticMemoryPlanning.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
index 7eca9452ae802..100e6bfbcb4db 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/MemRef/Transforms/MergeAlloc.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h"
 
@@ -24,327 +25,51 @@ namespace memref {
 #define GEN_PASS_DEF_MERGEALLOC
 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
 
-/// Return `true` if the given MemRef type has a static identity layout (i.e.,
-/// no layout).
-static bool hasStaticIdentityLayout(MemRefType type) {
-  return type.hasStaticShape() && type.getLayout().isIdentity();
-}
-
 namespace {
-static constexpr int64_t NO_ACCESS = -1;
-static constexpr int64_t COMPLEX_ACCESS = -2;
-struct Tick {
-  int64_t firstAccess = NO_ACCESS;
-  int64_t lastAccess = NO_ACCESS;
-
-  void access(int64_t tick) {
-    if (tick == COMPLEX_ACCESS) {
-      firstAccess = COMPLEX_ACCESS;
-      lastAccess = COMPLEX_ACCESS;
-    }
-    if (firstAccess == COMPLEX_ACCESS) {
-      return;
-    }
-    if (firstAccess == NO_ACCESS) {
-      firstAccess = tick;
-    } else {
-      firstAccess = std::min(firstAccess, tick);
-    }
-    lastAccess = std::max(lastAccess, tick);
-  }
-};
 
-bool isMergeableAlloc(Operation *op, int64_t tick) {
-  if (tick == COMPLEX_ACCESS) {
-    return false;
+LogicalResult passDriver(Operation *op, const memref::MergeAllocOptions &o,
+                         TraceCollectorFunc tracer, MemoryPlannerFunc planner,
+                         MemoryMergeMutatorFunc mutator) {
+  BufferViewFlowAnalysis aliasAnaly{op};
+  auto tracesOrFail = tracer(op, aliasAnaly, o);
+  if (failed(tracesOrFail)) {
+    return failure();
   }
-  if (!hasStaticIdentityLayout(
-          cast<MemRefType>(op->getResultTypes().front()))) {
-    return false;
+  if (o.optionCheck) {
+    return success();
   }
-  // currently only support alignment: none, 1, 2, 4, 8, 16, 32, 64
-  auto alignment = cast<memref::AllocOp>(op).getAlignment();
-  if (!alignment) {
-    return true; // ok if no alignment
-  }
-  return alignment > 0 && (64 % alignment.value() == 0);
-}
-
-// find the closest surrounding parent operation with AutomaticAllocationScope
-// trait, and is not scf.for
-Operation *getAllocScope(Operation *op) {
-  auto parent = op;
-  for (;;) {
-    parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
-    if (!parent) {
-      return nullptr;
+  for (auto &[scope, traces] : (*tracesOrFail).scopeToTraces) {
+    auto schedule = planner(op, *traces, o);
+    if (failed(schedule)) {
+      return failure();
     }
-    if (!isa<scf::ForOp>(parent)) {
-      return parent;
+    if (failed(mutator(op, scope, *schedule, o))) {
+      return failure();
     }
   }
+  return success();
 }
 
-FailureOr<size_t> getAllocSize(Operation *op) {
-  auto refType = cast<MemRefType>(op->getResultTypes().front());
-  int64_t size = refType.getElementTypeBitWidth() / 8;
-  // treat bool (i1) as 1 byte. It may not be true for all targets, but we at
-  // least have a large enough size for i1
-  size = (size != 0) ? size : 1;
-  for (auto v : refType.getShape()) {
-    size *= v;
-  }
-  if (size > 0) {
-    return static_cast<size_t>(size);
-  }
-  return op->emitError("Expecting static shaped allocation");
-}
-
-// A complex scope object is addition info for a RegionBranchOpInterface or
-// LoopLikeOpInterface. It contains the scope itself, and the referenced alloc
-// ops inside this scope. We use this object to track which buffers this scope
-// accesses. These buffers must have overlapped lifetime
-struct ComplexScope {
-  Operation *scope;
-  int64_t startTick;
-  llvm::SmallPtrSet<Operation *, 8> operations;
-  ComplexScope(Operation *scope, int64_t startTick)
-      : scope{scope}, startTick{startTick} {}
-  // returns true of an allocation either is not defined in the scope, or the
-  // allocation escapes from the scope
-  bool needsResetTick(Operation *scope, Operation *allocation,
-                      const mlir::BufferViewFlowAnalysis &aliasAnaly) const {
-    // if the allocation is not in the scope, conservatively set the ticks
-    if (!scope->isProperAncestor(allocation)) {
-      return true;
-    }
-    // if the allocation and its alias are used outside of the scope
-    for (auto &&alias : aliasAnaly.resolve(allocation->getResult(0))) {
-      for (auto &&userOp : alias.getUsers()) {
-        if (!scope->isProperAncestor(userOp) && !isMemoryEffectFree(userOp)) {
-          return true;
-        }
-      }
-    }
-    return false;
-  }
-
-  // called when walk() runs outside of the scope
-  void onPop(int64_t endTick, const mlir::BufferViewFlowAnalysis &aliasAnaly,
-             llvm::DenseMap<Operation *, Tick> &allocTicks) {
-    for (auto op : operations) {
-      if (needsResetTick(scope, op, aliasAnaly)) {
-        // let all referenced buffers have overlapped lifetime
-        auto &tick = allocTicks[op];
-        tick.access(startTick);
-        tick.access(endTick);
-      }
-    }
-  }
-};
-
-struct TickCollecter {
-  const mlir::BufferViewFlowAnalysis &aliasAnaly;
-  int64_t curTick = 0;
-  llvm::DenseMap<Operation *, Tick> allocTicks;
-  llvm::SmallVector<ComplexScope> complexScopeStack;
-  TickCollecter(const mlir::BufferViewFlowAnalysis &aliasAnaly)
-      : aliasAnaly{aliasAnaly} {}
-  void popScopeIfNecessary(Operation *op) {
-    // first check if we have walked outside of the previous ComplexScope
-    while (!complexScopeStack.empty()) {
-      auto &scope = complexScopeStack.back();
-      if (!op || !scope.scope->isProperAncestor(op)) {
-        scope.onPop(curTick, aliasAnaly, allocTicks);
-        complexScopeStack.pop_back();
-      } else {
-        break;
-      }
-    }
-  }
-
-  void forwardTick() { curTick++; }
-
-  void accessValue(Value v, bool complex) {
-    if (auto refv = dyn_cast<TypedValue<MemRefType>>(v)) {
-      for (auto &&base : aliasAnaly.resolveReverse(refv)) {
-        auto defop = base.getDefiningOp();
-        if (isa_and_present<memref::AllocOp>(defop)) {
-          allocTicks[defop].access(complex ? COMPLEX_ACCESS : curTick);
-          if (!complexScopeStack.empty()) {
-            complexScopeStack.back().operations.insert(defop);
-          }
-        }
-      }
-    }
-  }
-
-  void onMemrefViews(ViewLikeOpInterface op) {
-    auto viewSrc = op.getViewSource();
-    // don't need to access the first operand, which is "source".
-    // The "source" operand is not really read or written at this point
-    for (auto val : op.getOperation()->getOperands()) {
-      if (val != viewSrc)
-        accessValue(val, false);
-    }
-  }
-
-  void onReturnOp(Operation *op) {
-    bool isTopLevel = isa<func::FuncOp>(op->getParentOp());
-    for (auto val : op->getOperands()) {
-      accessValue(val, isTopLevel);
-    }
-  }
-
-  void onGeneralOp(Operation *op) {
-    for (auto val : op->getOperands()) {
-      accessValue(val, false);
-    }
-  }
-
-  void pushComplexScope(Operation *op) {
-    complexScopeStack.emplace_back(op, curTick);
-  }
-
-  FailureOr<memoryplan::ScopeTraceData> getTrace() {
-    struct TraceWithTick {
-      int64_t tick;
-      memoryplan::MemoryTrace trace;
-      TraceWithTick(int64_t tick, uintptr_t bufferId, size_t size)
-          : tick{tick}, trace{bufferId, size} {}
-    };
-    llvm::DenseMap<Operation *, llvm::SmallVector<TraceWithTick, 8>> raw;
-    for (auto &[op, tick] : allocTicks) {
-      if (!isMergeableAlloc(op, tick.firstAccess)) {
-        continue;
-      }
-      auto scope = getAllocScope(op);
-      if (!scope) {
-        return op->emitError(
-            "This op should be surrounded by an AutomaticAllocationScope");
-      }
-      auto allocSize = getAllocSize(op);
-      if (failed(allocSize)) {
-        return allocSize;
-      }
-      // tick.firstAccess * 2 and tick.lastAccess * 2 + 1 to make sure "dealloc"
-      // overlaps "alloc"
-      raw[scope].emplace_back(tick.firstAccess * 2,
-                              reinterpret_cast<uintptr_t>(op), *allocSize);
-      raw[scope].emplace_back(tick.lastAccess * 2 + 1,
-                              reinterpret_cast<uintptr_t>(op), 0);
-    }
-    memoryplan::ScopeTraceData ret;
-    for (auto &[scope, trace] : raw) {
-      std::stable_sort(trace.begin(), trace.end(),
-                       [](const TraceWithTick &a, const TraceWithTick &b) {
-                         return a.tick < b.tick;
-                       });
-      auto &retTrace = ret[scope];
-      retTrace.reserve(trace.size());
-      for (auto &tr : trace) {
-        retTrace.emplace_back(tr.trace);
-      }
-    }
-    return ret;
-  }
-};
-
 } // namespace
-
-FailureOr<mlir::memoryplan::ScopeTraceData>
-collectMemoryTrace(Operation *root,
-                   const mlir::BufferViewFlowAnalysis &aliasAnaly,
-                   bool markOnly) {
-  TickCollecter collecter{aliasAnaly};
-  root->walk<WalkOrder::PreOrder>([&](Operation *op) {
-    collecter.popScopeIfNecessary(op);
-    collecter.forwardTick();
-    if (auto viewop = dyn_cast<ViewLikeOpInterface>(op)) {
-      collecter.onMemrefViews(viewop);
-    } else if (op->hasTrait<OpTrait::ReturnLike>()) {
-      collecter.onReturnOp(op);
-    } else if (!isMemoryEffectFree(op)) {
-      // if the op has no memory effects, it don't contribute to liveness
-      collecter.onGeneralOp(op);
-    }
-    // finally, if op is complex scope, push one ComplexScope
-    if (isa<RegionBranchOpInterface>(op) || isa<LoopLikeOpInterface>(op)) {
-      collecter.pushComplexScope(op);
-    }
-  });
-  collecter.popScopeIfNecessary(nullptr);
-  if (markOnly) {
-    for (auto &[alloc, tick] : collecter.allocTicks) {
-      auto allocscope = getAllocScope(alloc);
-      alloc->setAttr(
-          "__mergealloc_lifetime",
-          DenseI64ArrayAttr::get(root->getContext(),
-                                 {reinterpret_cast<int64_t>(allocscope),
-                                  tick.firstAccess, tick.lastAccess}));
-      allocscope->setAttr(
-          "__mergealloc_scope",
-          IntegerAttr::get(mlir::IntegerType::get(root->getContext(), 64),
-                           reinterpret_cast<int64_t>(allocscope)));
-    }
-    return mlir::memoryplan::ScopeTraceData();
-  }
-  return collecter.getTrace();
-}
-
 } // namespace memref
-} // namespace mlir
 
-namespace {
 using namespace mlir;
 struct MergeAllocPass : memref::impl::MergeAllocBase<MergeAllocPass> {
   using parent = memref::impl::MergeAllocBase<MergeAllocPass>;
   void runOnOperation() override {
     auto op = getOperation();
-    BufferViewFlowAnalysis aliasAnaly{op};
-    auto tracesOrFail =
-        memref::collectMemoryTrace(op, aliasAnaly, this->optionCheck);
-    if (failed(tracesOrFail)) {
-      signalPassFailure();
-      return;
-    }
-    if (this->optionCheck) {
-      return;
-    }
-    std::unordered_map<uintptr_t, std::vector<uintptr_t>> dummy;
-    for (auto &[scope, traces] : *tracesOrFail) {
-      std::unordered_map<uintptr_t, std::size_t> outSchedule;
-      if (traces.empty())
-        continue;
-      auto total = memoryplan::scheduleMemoryAllocations(
-          traces, 64, !this->optionNoLocality, memoryplan::InplaceInfoMap(),
-          outSchedule, dummy);
-      auto &block = scope->getRegion(0).getBlocks().front();
-      OpBuilder builder{&block.front()};
-      auto alignment =
-          builder.getIntegerAttr(IntegerType::get(op.getContext(), 64), 64);
-      auto alloc = builder.create<memref::AllocOp>(
-          scope->getLoc(),
-          MemRefType::get({static_cast<int64_t>(total)}, builder.getI8Type()),
-          alignment);
-      for (auto &[key, offset] : outSchedule) {
-        auto origBuf = reinterpret_cast<Operation *>(key);
-        builder.setInsertionPoint(origBuf);
-        auto byteShift = builder.create<arith::ConstantIndexOp>(
-            origBuf->getLoc(), static_cast<int64_t>(offset));
-        auto view = builder.create<memref::ViewOp>(
-            origBuf->getLoc(), origBuf->getResultTypes().front(), alloc,
-            byteShift, ValueRange{});
-        origBuf->replaceAllUsesWith(view->getResults());
-        origBuf->remove();
-      }
+    if (failed(memref::passDriver(
+            op, memref::MergeAllocOptions{optionCheck, optionNoLocality},
+            memref::tickBasedCollectMemoryTrace, memref::tickBasedPlanMemory,
+            memref::tickBasedMutateAllocations))) {
+        signalPassFailure();
     }
   }
 
 public:
   MergeAllocPass(const memref::MergeAllocOptions &o) : parent{o} {}
 };
-} // namespace
+} // namespace mlir
 
 std::unique_ptr<mlir::Pass>
 mlir::memref::createMergeAllocPass(const memref::MergeAllocOptions &o) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
new file mode 100644
index 0000000000000..85df88d1a5fc4
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
@@ -0,0 +1,378 @@
+//===- MergeAllocTickBased.cpp - Ticked based merge alloc implementation---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/MemRef/Transforms/MergeAlloc.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h"
+
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+namespace memref {
+
+/// Return `true` if the given MemRef type has a static identity layout (i.e.,
+/// no layout).
+static bool hasStaticIdentityLayout(MemRefType type) {
+  return type.hasStaticShape() && type.getLayout().isIdentity();
+}
+
+namespace {
+static constexpr int64_t NO_ACCESS = -1;
+static constexpr int64_t COMPLEX_ACCESS = -2;
+struct Tick {
+  int64_t firstAccess = NO_ACCESS;
+  int64_t lastAccess = NO_ACCESS;
+
+  void access(int64_t tick) {
+    if (tick == COMPLEX_ACCESS) {
+      firstAccess = COMPLEX_ACCESS;
+      lastAccess = COMPLEX_ACCESS;
+    }
+    if (firstAccess == COMPLEX_ACCESS) {
+      return;
+    }
+    if (firstAccess == NO_ACCESS) {
+      firstAccess = tick;
+    } else {
+      firstAccess = std::min(firstAccess, tick);
+    }
+    lastAccess = std::max(lastAccess, tick);
+  }
+};
+
+bool isMergeableAlloc(Operation *op, int64_t tick) {
+  if (tick == COMPLEX_ACCESS) {
+    return false;
+  }
+  if (!hasStaticIdentityLayout(
+          cast<MemRefType>(op->getResultTypes().front()))) {
+    return false;
+  }
+  // currently only support alignment: none, 1, 2, 4, 8, 16, 32, 64
+  auto alignment = cast<memref::AllocOp>(op).getAlignment();
+  if (!alignment) {
+    return true; // ok if no alignment
+  }
+  return alignment > 0 && (64 % alignment.value() == 0);
+}
+
+// find the closest surrounding parent operation with AutomaticAllocationScope
+// trait, and is not scf.for
+Operation *getAllocScope(Operation *op) {
+  auto parent = op;
+  for (;;) {
+    parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
+    if (!parent) {
+      return nullptr;
+    }
+    if (!isa<scf::ForOp>(parent)) {
+      return parent;
+    }
+  }
+}
+
+FailureOr<size_t> getAllocSize(Operation *op) {
+  auto refType = cast<MemRefType>(op->getResultTypes().front());
+  int64_t size = refType.getElementTypeBitWidth() / 8;
+  // treat bool (i1) as 1 byte. It may not be true for all targets, but we at
+  // least have a large enough size for i1
+  size = (size != 0) ? size : 1;
+  for (auto v : refType.getShape()) {
+    size *= v;
+  }
+  if (size > 0) {
+    return static_cast<size_t>(size);
+  }
+  return op->emitError("Expecting static shaped allocation");
+}
+
+// A complex scope object is addition info for a RegionBranchOpInterface or
+// LoopLikeOpInterface. It contains the scope itself, and the referenced alloc
+// ops inside this scope. We use this object to track which buffers this scope
+// accesses. These buffers must have overlapped lifetime
+struct ComplexScope {
+  Operation *scope;
+  int64_t startTick;
+  llvm::SmallPtrSet<Operation *, 8> operations;
+  ComplexScope(Operation *scope, int64_t startTick)
+      : scope{scope}, startTick{startTick} {}
+  // returns true of an allocation either is not defined in the scope, or the
+  // allocation escapes from the scope
+  bool needsResetTick(Operation *scope, Operation *allocation,
+                      const mlir::BufferViewFlowAnalysis &aliasAnaly) const {
+    // if the allocation is not in the scope, conservatively set the ticks
+    if (!scope->isProperAncestor(allocation)) {
+      return true;
+    }
+    // if the allocation and its alias are used outside of the scope
+    for (auto &&alias : aliasAnaly.resolve(allocation->getResult(0))) {
+      for (auto &&userOp : alias.getUsers()) {
+        if (!scope->isProperAncestor(userOp) && !isMemoryEffectFree(userOp)) {
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
+  // called when walk() runs outside of the scope
+  LogicalResult onPop(int64_t endTick,
+                      const mlir::BufferViewFlowAnalysis &aliasAnaly,
+                      llvm::DenseMap<Operation *, Tick> &allocTicks) {
+    // if the complex scope is not recognized by us, and if it accesses memory,
+    // raise an error
+    if (!isa<RegionBranchOpInterface>(scope) &&
+        !isa<LoopLikeOpInterface>(scope) && !operations.empty()) {
+      return scope->emitOpError("expecting RegionBranchOpInterface or "
+                                "LoopLikeOpInterface for merge-alloc");
+    }
+    for (auto op : operations) {
+      if (needsResetTick(scope, op, aliasAnaly)) {
+        // let all referenced buffers have overlapped lifetime
+        auto &tick = allocTicks[op];
+        tick.access(startTick);
+        tick.access(endTick);
+      }
+    }
+    return success();
+  }
+};
+
+struct TickTraceResult : public LifetimeTrace {
+  memoryplan::Traces traces;
+  TickTraceResult() : LifetimeTrace{TK_TICK} {}
+  static bool classof(const LifetimeTrace *S) {
+    return S->getKind() == TK_TICK;
+  }
+};
+
+struct TickCollecter {
+  const mlir::BufferViewFlowAnalysis &aliasAnaly;
+  int64_t curTick = 0;
+  llvm::DenseMap<Operation *, Tick> allocTicks;
+  llvm::SmallVector<ComplexScope> complexScopeStack;
+  TickCollecter(const mlir::BufferViewFlowAnalysis &aliasAnaly)
+      : aliasAnaly{aliasAnaly} {}
+  LogicalResult popScopeIfNecessary(Operation *op) {
+    // first check if we have walked outside of the previous ComplexScope
+    while (!complexScopeStack.empty()) {
+      auto &scope = complexScopeStack.back();
+      if (!op || !scope.scope->isProperAncestor(op)) {
+        if (failed(scope.onPop(curTick, aliasAnaly, allocTicks))) {
+          return failure();
+        }
+        complexScopeStack.pop_back();
+      } else {
+        break;
+      }
+    }
+    return success();
+  }
+
+  void forwardTick() { curTick++; }
+
+  void accessValue(Value v, bool complex) {
+    if (auto refv = dyn_cast<TypedValue<MemRefType>>(v)) {
+      for (auto &&base : aliasAnaly.resolveReverse(refv)) {
+        auto defop = base.getDefiningOp();
+        if (isa_and_present<memref::AllocOp>(defop)) {
+          allocTicks[defop].access(complex ? COMPLEX_ACCESS : curTick);
+          if (!complexScopeStack.empty()) {
+            complexScopeStack.back().operations.insert(defop);
+          }
+        }
+      }
+    }
+  }
+
+  void onMemrefViews(ViewLikeOpInterface op) {
+    auto viewSrc = op.getViewSource();
+    // don't need to access the first operand, which is "source".
+    // The "source" operand is not really read or written at this point
+    for (auto val : op.getOperation()->getOperands()) {
+      if (val != viewSrc)
+        accessValue(val, false);
+    }
+  }
+
+  void onReturnOp(Operation *op) {
+    bool isTopLevel = isa<func::FuncOp>(op->getParentOp());
+    for (auto val : op->getOperands()) {
+      accessValue(val, isTopLevel);
+    }
+  }
+
+  void onGeneralOp(Operation *op) {
+    for (auto val : op->getOperands()) {
+      accessValue(val, false);
+    }
+  }
+
+  void pushComplexScope(Operation *op) {
+    complexScopeStack.emplace_back(op, curTick);
+  }
+
+  FailureOr<MemoryTraces> getTrace() {
+    struct TraceWithTick {
+      Operation* op;
+      int64_t tick;
+      memoryplan::MemoryTrace trace;
+      TraceWithTick(int64_t tick, uintptr_t bufferId, size_t size)
+          : tick{tick}, trace{bufferId, size} {}
+    };
+    llvm::DenseMap<Operation *, llvm::SmallVector<TraceWithTick, 8>> raw;
+    for (auto &[op, tick] : allocTicks) {
+      if (!isMergeableAlloc(op, tick.firstAccess)) {
+        continue;
+      }
+      auto scope = getAllocScope(op);
+      if (!scope) {
+        return op->emitError(
+            "This op should be surrounded by an AutomaticAllocationScope");
+      }
+      auto allocSize = getAllocSize(op);
+      if (failed(allocSize)) {
+        return failure();
+      }
+      // tick.firstAccess * 2 and tick.lastAccess * 2 + 1 to make sure "dealloc"
+      // overlaps "alloc"
+      raw[scope].emplace_back(tick.firstAccess * 2,
+                              reinterpret_cast<uintptr_t>(op), *allocSize);
+      raw[scope].emplace_back(tick.lastAccess * 2 + 1,
+                              reinterpret_cast<uintptr_t>(op), 0);
+    }
+    MemoryTraces ret;
+    for (auto &[scope, trace] : raw) {
+      std::stable_sort(trace.begin(), trace.end(),
+                       [](const TraceWithTick &a, const TraceWithTick &b) {
+                         return a.tick < b.tick;
+                       });
+      auto retTrace = std::make_unique<TickTraceResult>();
+      retTrace->traces.reserve(trace.size());
+      for (auto &tr : trace) {
+        retTrace->traces.emplace_back(tr.trace);
+      }
+      ret.scopeToTraces[scope] = std::move(retTrace);
+    }
+    return ret;
+  }
+};
+} // namespace
+
+FailureOr<MemoryTraces>
+tickBasedCollectMemoryTrace(Operation *root,
+                            const mlir::BufferViewFlowAnalysis &aliasAnaly,
+                            const MergeAllocOptions &option) {
+  TickCollecter collecter{aliasAnaly};
+  LogicalResult result = success();
+  root->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (failed(collecter.popScopeIfNecessary(op))) {
+      result = failure();
+    }
+    collecter.forwardTick();
+    if (auto viewop = dyn_cast<ViewLikeOpInterface>(op)) {
+      collecter.onMemrefViews(viewop);
+    } else if (op->hasTrait<OpTrait::ReturnLike>()) {
+      collecter.onReturnOp(op);
+    } else if (!isMemoryEffectFree(op)) {
+      // if the op has no memory effects, it don't contribute to liveness
+      collecter.onGeneralOp(op);
+    }
+    if (op->getNumRegions() > 0 && !isa<func::FuncOp>(op)) {
+      // finally, if op is complex scope, push one ComplexScope
+      collecter.pushComplexScope(op);
+    }
+  });
+  if (failed(result)) {
+    return result;
+  }
+  if (failed(collecter.popScopeIfNecessary(nullptr))) {
+    return failure();
+  }
+  if (option.optionCheck) {
+    for (auto &[alloc, tick] : collecter.allocTicks) {
+      auto allocscope = getAllocScope(alloc);
+      alloc->setAttr(
+          "__mergealloc_lifetime",
+          DenseI64ArrayAttr::get(root->getContext(),
+                                 {reinterpret_cast<int64_t>(allocscope),
+                                  tick.firstAccess, tick.lastAccess}));
+      allocscope->setAttr(
+          "__mergealloc_scope",
+          IntegerAttr::get(mlir::IntegerType::get(root->getContext(), 64),
+                           reinterpret_cast<int64_t>(allocscope)));
+    }
+    return MemoryTraces();
+  }
+  return collecter.getTrace();
+}
+
+FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
+                                              const LifetimeTrace &tr,
+                                              const MergeAllocOptions &o) {
+  auto traceObj = dyn_cast<TickTraceResult>(&tr);
+  if (!traceObj) {
+    return failure();
+  }
+  auto &traces = traceObj->traces;
+  if (traces.empty()) {
+    return MemorySchedule{};
+  }
+  std::unordered_map<uintptr_t, std::size_t> outSchedule;
+  std::unordered_map<uintptr_t, std::vector<uintptr_t>> dummy;
+  auto total = memoryplan::scheduleMemoryAllocations(
+      traces, 64, !o.optionNoLocality, memoryplan::InplaceInfoMap(),
+      outSchedule, dummy);
+  MemorySchedule ret;
+  ret.totalSize = total;
+  for (auto [k, offset] : outSchedule) {
+    ret.allocToOffset[reinterpret_cast<Operation *>(k)] =
+        static_cast<int64_t>(offset);
+  }
+  return std::move(ret);
+}
+
+LogicalResult tickBasedMutateAllocations(Operation *op, Operation *scope,
+                                         const MemorySchedule &schedule,
+                                         const MergeAllocOptions &o) {
+  if (schedule.allocToOffset.empty()) {
+    return success();
+  }
+  auto &block = scope->getRegion(0).getBlocks().front();
+  OpBuilder builder{&block.front()};
+  auto alignment =
+      builder.getIntegerAttr(IntegerType::get(op->getContext(), 64), 64);
+  auto alloc = builder.create<memref::AllocOp>(
+      scope->getLoc(),
+      MemRefType::get({static_cast<int64_t>(schedule.totalSize)},
+                      builder.getI8Type()),
+      alignment);
+  for (auto &[origBuf, offset] : schedule.allocToOffset) {
+    builder.setInsertionPoint(origBuf);
+    auto byteShift = builder.create<arith::ConstantIndexOp>(
+        origBuf->getLoc(), static_cast<int64_t>(offset));
+    auto view = builder.create<memref::ViewOp>(
+        origBuf->getLoc(), origBuf->getResultTypes().front(), alloc, byteShift,
+        ValueRange{});
+    origBuf->replaceAllUsesWith(view->getResults());
+    origBuf->remove();
+  }
+  return success();
+}
+
+} // namespace memref
+} // namespace mlir
diff --git a/mlir/test/Dialect/MemRef/buffer-merge-invalid.mlir b/mlir/test/Dialect/MemRef/buffer-merge-invalid.mlir
new file mode 100644
index 0000000000000..6609cb9d2f6fb
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/buffer-merge-invalid.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics --merge-alloc %s
+
+func.func @block() {
+  %mref = memref.alloc() : memref<8 x f32>
+  %mref2 = memref.alloc() : memref<8 x f32>
+  // expected-error at +1 {{expecting RegionBranchOpInterface or LoopLikeOpInterface for merge-alloc}}
+  "some.block"() ({
+   ^bb0:
+    "some.use"(%mref) : (memref<8 x f32>) -> ()
+   }) : () -> ()
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/MemRef/buffer-merge-mlp.mlir b/mlir/test/Dialect/MemRef/buffer-merge-mlp.mlir
deleted file mode 100644
index ae002c5b0d34c..0000000000000
--- a/mlir/test/Dialect/MemRef/buffer-merge-mlp.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: mlir-opt -one-shot-bufferize="unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" --merge-alloc %s | FileCheck %s
-
-func.func @mlp(%x: tensor<128x128xf32>, %y: tensor<128x128xf32>) -> tensor<128x128xf32> {
-   // CHECK-DAG:  %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<131072xi8>
-   // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-   // CHECK-DAG:  %[[VIEW_A:.*]] = memref.view %[[ALLOC]][%[[C0]]][] : memref<131072xi8> to memref<128x128xf32>
-   %a0 = tensor.empty() : tensor<128x128xf32>
-   // CHECK:      linalg.matmul ins
-   // CHECK-SAME: outs(%[[VIEW_A]] : memref<128x128xf32>)
-   %a = linalg.matmul ins(%x, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%a0: tensor<128x128xf32>) -> tensor<128x128xf32>
-   // CHECK-DAG:  %[[C65536:.*]] = arith.constant 65536 : index
-   // CHECK-DAG:  %[[VIEW_B:.*]] = memref.view %[[ALLOC]][%[[C65536]]][] : memref<131072xi8> to memref<128x128xf32>
-   %b0 = tensor.empty() : tensor<128x128xf32>
-   // CHECK:      linalg.matmul ins(%[[VIEW_A]],
-   // CHECK-SAME: outs(%[[VIEW_B]] : memref<128x128xf32>)
-   %b = linalg.matmul ins(%a, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%b0: tensor<128x128xf32>) -> tensor<128x128xf32>
-   // CHECK-DAG:  %[[C0_2:.*]] = arith.constant 0 : index
-   // CHECK-DAG:  %[[VIEW_C:.*]] = memref.view %[[ALLOC]][%[[C0_2]]][] : memref<131072xi8> to memref<128x128xf32>
-   %c0 = tensor.empty() : tensor<128x128xf32>
-   // CHECK:      linalg.matmul ins(%[[VIEW_B]],
-   // CHECK-SAME: outs(%[[VIEW_C]] : memref<128x128xf32>)
-   %c = linalg.matmul ins(%b, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%c0: tensor<128x128xf32>) -> tensor<128x128xf32>
-   // CHECK-DAG:  %[[D:.*]] = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32>
-   // CHECK:      linalg.matmul ins(%[[VIEW_C]],
-   // CHECK-SAME: outs(%[[D]] : memref<128x128xf32>)
-   %d0 = tensor.empty() : tensor<128x128xf32>
-   %d = linalg.matmul ins(%c, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%d0: tensor<128x128xf32>) -> tensor<128x128xf32>
-   // CHECK:      return %[[D]]
-   return %d : tensor<128x128xf32>
-}
\ No newline at end of file
diff --git a/mlir/test/Dialect/MemRef/buffer-merge.mlir b/mlir/test/Dialect/MemRef/buffer-merge.mlir
new file mode 100644
index 0000000000000..f7e7f0f3d6b62
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/buffer-merge.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt -allow-unregistered-dialect -one-shot-bufferize="unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" --merge-alloc %s | FileCheck %s
+
+func.func @mlp(%x: tensor<128x128xf32>, %y: tensor<128x128xf32>) -> tensor<128x128xf32> {
+   // CHECK-DAG:  %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<131072xi8>
+   // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+   // CHECK-DAG:  %[[VIEW_A:.*]] = memref.view %[[ALLOC]][%[[C0]]][] : memref<131072xi8> to memref<128x128xf32>
+   %a0 = tensor.empty() : tensor<128x128xf32>
+   // CHECK:      linalg.matmul ins
+   // CHECK-SAME: outs(%[[VIEW_A]] : memref<128x128xf32>)
+   %a = linalg.matmul ins(%x, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%a0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   // CHECK-DAG:  %[[C65536:.*]] = arith.constant 65536 : index
+   // CHECK-DAG:  %[[VIEW_B:.*]] = memref.view %[[ALLOC]][%[[C65536]]][] : memref<131072xi8> to memref<128x128xf32>
+   %b0 = tensor.empty() : tensor<128x128xf32>
+   // CHECK:      linalg.matmul ins(%[[VIEW_A]],
+   // CHECK-SAME: outs(%[[VIEW_B]] : memref<128x128xf32>)
+   %b = linalg.matmul ins(%a, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%b0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   // CHECK-DAG:  %[[C0_2:.*]] = arith.constant 0 : index
+   // CHECK-DAG:  %[[VIEW_C:.*]] = memref.view %[[ALLOC]][%[[C0_2]]][] : memref<131072xi8> to memref<128x128xf32>
+   %c0 = tensor.empty() : tensor<128x128xf32>
+   // CHECK:      linalg.matmul ins(%[[VIEW_B]],
+   // CHECK-SAME: outs(%[[VIEW_C]] : memref<128x128xf32>)
+   %c = linalg.matmul ins(%b, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%c0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   // CHECK-DAG:  %[[D:.*]] = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32>
+   // CHECK:      linalg.matmul ins(%[[VIEW_C]],
+   // CHECK-SAME: outs(%[[D]] : memref<128x128xf32>)
+   %d0 = tensor.empty() : tensor<128x128xf32>
+   %d = linalg.matmul ins(%c, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%d0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   // CHECK:      return %[[D]]
+   return %d : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: @basic
+func.func @basic() -> memref<8x64xf32> {
+  // CHECK-DAG: %[[BASE:.*]] = memref.alloc() {alignment = 64 : i64} : memref<4096xi8>
+  // b is used in return, complex lifetime
+  // CHECK-DAG: %[[B:.*]] = memref.alloc()
+  %b = memref.alloc() : memref<8x64xf32>
+  // CHECK:     "test.source"(%[[B]])
+  "test.source"(%b)  : (memref<8x64xf32>) -> ()
+  // c and d has overlapping lifetime
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C:.*]] = memref.view %[[BASE]][%[[C0]]][] : memref<4096xi8> to memref<8x64xf32>
+  %c = memref.alloc() : memref<8x64xf32>
+  // CHECK:     "test.source"(%[[C]])
+  "test.source"(%c)  : (memref<8x64xf32>) -> ()
+  // CHECK-DAG: %[[C2048:.*]] = arith.constant 2048 : index
+  // CHECK-DAG: %[[D:.*]] = memref.view %[[BASE]][%[[C2048]]][] : memref<4096xi8> to memref<8x64xf32>
+  %d = memref.alloc() : memref<8x64xf32>
+  // CHECK:     "test.source"(%[[D]])
+  "test.source"(%d)  : (memref<8x64xf32>) -> ()
+  // CHECK:     "test.source"(%[[C]])
+  "test.source"(%c)  : (memref<8x64xf32>) -> ()
+  // e can reuse the above memory
+  // CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[E:.*]] = memref.view %[[BASE]][%[[C0_2]]][] : memref<4096xi8> to memref<8x64xf32>
+  %e = memref.alloc() : memref<8x64xf32>
+  // CHECK:     "test.source"(%[[E]])
+  "test.source"(%e)  : (memref<8x64xf32>) -> ()
+  // CHECK:     return %[[B]]
+  return %b : memref<8x64xf32>
+}
+
+// CHECK-LABEL: @withloop
+func.func @withloop() {
+  // CHECK-DAG: %[[BASE2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<6144xi8>
+  // CHECK-DAG: %[[C2048:.*]] = arith.constant 2048 : index
+  // CHECK-DAG: %[[F:.*]] = memref.view %[[BASE2]][%[[C2048]]][] : memref<6144xi8> to memref<8x64xf32>
+  %f = memref.alloc() : memref<8x64xf32>
+  // CHECK-DAG: %[[C033:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[G:.*]] = memref.view %[[BASE2]][%[[C033]]][] : memref<6144xi8> to memref<8x64xf32>
+  %g = memref.alloc() : memref<8x64xf32>
+
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c5 = arith.constant 5 : index
+  // CHECK: scf.for
+  scf.for %i = %c0 to %c5 step %c1 {
+      // CHECK:     "test.source"(%[[F]])
+      "test.source"(%f)  : (memref<8x64xf32>) -> ()
+      // CHECK:     "test.source"(%[[G]])
+      "test.source"(%g)  : (memref<8x64xf32>) -> ()
+      // CHECK-DAG: %[[C4096:.*]] = arith.constant 4096 : index
+      // CHECK-DAG: %[[H:.*]] = memref.view %[[BASE2]][%[[C4096]]][] : memref<6144xi8> to memref<8x64xf32>
+      %h = memref.alloc() : memref<8x64xf32>
+      // CHECK:     "test.source"(%[[H]])
+      "test.source"(%h)  : (memref<8x64xf32>) -> ()
+      // CHECK-DAG: %[[C4096_3:.*]] = arith.constant 4096 : index
+      // CHECK-DAG: %[[J:.*]] = memref.view %[[BASE2]][%[[C4096_3]]][] : memref<6144xi8> to memref<8x64xf32>
+      %j = memref.alloc() : memref<8x64xf32>
+      // CHECK:     "test.source"(%[[J]])
+      "test.source"(%j)  : (memref<8x64xf32>) -> ()
+  }
+  return
+}
\ No newline at end of file

>From 2cc74fe67238837f7dcbe003021e4900e1474b51 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Wed, 5 Jun 2024 17:21:26 +0800
Subject: [PATCH 03/12] make options

---
 .../Dialect/MemRef/Transforms/MergeAlloc.h    | 49 ++---------------
 .../mlir/Dialect/MemRef/Transforms/Passes.h   | 52 ++++++++++++++++++-
 .../Dialect/MemRef/Transforms/MergeAlloc.cpp  | 46 ++++++++++------
 .../MemRef/Transforms/MergeAllocTickBased.cpp | 20 +++----
 4 files changed, 97 insertions(+), 70 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h
index 45b584e99e044..e1d7f0876ad4c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_MEMREF_MERGEALLOC_H
 #define MLIR_DIALECT_MEMREF_MERGEALLOC_H
 
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
@@ -17,57 +18,17 @@
 namespace mlir {
 class BufferViewFlowAnalysis;
 namespace memref {
-struct MergeAllocOptions;
-// abstract base class for lifetime of different buffers. It should hold the
-// lifetime informantion of buffers that are to be merged in the same allocation
-// in an "allocation scope". TraceCollectorFunc decides which buffers are put
-// into which "allocation scope".
-class LifetimeTrace {
-public:
-  enum TraceKind { TK_TICK };
-  virtual ~LifetimeTrace() = default;
-  LifetimeTrace(TraceKind kind) : kind{kind} {}
-  TraceKind getKind() const { return kind; }
-
-private:
-  TraceKind kind;
-};
-
-// top level memory trace info for multiple scopes. Each key-value is the
-// traces and location for buffers in the same "allocation scope"
-struct MemoryTraces {
-  llvm::DenseMap<Operation *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
-  MemoryTraces() = default;
-};
-
-// the memory scheduling result for allocations in the same merged buffer.
-// allocation => offset map. All Operation* in the map should be memref::AllocOp
-// which are in the same LifetimeTrace.
-struct MemorySchedule {
-  size_t totalSize;
-  llvm::DenseMap<Operation *, int64_t> allocToOffset;
-  MemorySchedule() : totalSize{0} {}
-};
-
-using TraceCollectorFunc = llvm::function_ref<FailureOr<MemoryTraces>(
-    Operation *, const BufferViewFlowAnalysis &, const MergeAllocOptions &)>;
-using MemoryPlannerFunc = llvm::function_ref<FailureOr<MemorySchedule>(
-    Operation *, const LifetimeTrace &, const MergeAllocOptions &)>;
-using MemoryMergeMutatorFunc = llvm::function_ref<LogicalResult(
-    Operation *toplevel, Operation *scope, const MemorySchedule &,
-    const MergeAllocOptions &)>;
-
-FailureOr<MemoryTraces>
+FailureOr<MemoryTraceScopes>
 tickBasedCollectMemoryTrace(Operation *root,
                             const mlir::BufferViewFlowAnalysis &aliasAnaly,
-                            const MergeAllocOptions &option);
+                            const MergeAllocationOptions &option);
 
 FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
                                               const LifetimeTrace &tr,
-                                              const MergeAllocOptions &o);
+                                              const MergeAllocationOptions &o);
 LogicalResult tickBasedMutateAllocations(Operation *op, Operation *scope,
                                          const MemorySchedule &schedule,
-                                         const MergeAllocOptions &o);
+                                         const MergeAllocationOptions &o);
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 7ffa07bf768af..bf4bf55e43c3b 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -36,6 +36,7 @@ namespace vector {
 class VectorDialect;
 } // namespace vector
 
+class BufferViewFlowAnalysis;
 namespace memref {
 //===----------------------------------------------------------------------===//
 // Passes
@@ -77,9 +78,58 @@ std::unique_ptr<Pass> createExpandStridedMetadataPass();
 /// components.
 std::unique_ptr<Pass> createExpandReallocPass(bool emitDeallocs = true);
 
+// abstract base class for lifetime of different buffers. It should hold the
+// lifetime informantion of buffers that are to be merged in the same allocation
+// in an "allocation scope". TraceCollectorFunc decides which buffers are put
+// into which "allocation scope".
+class LifetimeTrace {
+public:
+  enum TraceKind { TK_TICK };
+  virtual ~LifetimeTrace() = default;
+  LifetimeTrace(TraceKind kind) : kind{kind} {}
+  TraceKind getKind() const { return kind; }
+
+private:
+  TraceKind kind;
+};
+
+// top level memory trace info for multiple scopes. Each key-value is the
+// traces and location for buffers in the same "allocation scope"
+struct MemoryTraceScopes {
+  llvm::DenseMap<Operation *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
+  MemoryTraceScopes() = default;
+};
+
+// the memory scheduling result for allocations in the same merged buffer.
+// allocation => offset map. All Operation* in the map should be memref::AllocOp
+// which are in the same LifetimeTrace.
+struct MemorySchedule {
+  size_t totalSize;
+  llvm::DenseMap<Operation *, int64_t> allocToOffset;
+  MemorySchedule() : totalSize{0} {}
+};
+
+struct MergeAllocationOptions;
+using TraceCollectorFunc = std::function<FailureOr<MemoryTraceScopes>(
+    Operation *, const BufferViewFlowAnalysis &,
+    const MergeAllocationOptions &)>;
+using MemoryPlannerFunc = std::function<FailureOr<MemorySchedule>(
+    Operation *, const LifetimeTrace &, const MergeAllocationOptions &)>;
+using MemoryMergeMutatorFunc = std::function<LogicalResult(
+    Operation *toplevel, Operation *scope, const MemorySchedule &,
+    const MergeAllocationOptions &)>;
+
+struct MergeAllocationOptions {
+  bool checkOnly;
+  bool noLocalityFirst;
+  TraceCollectorFunc tracer;
+  MemoryPlannerFunc planner;
+  MemoryMergeMutatorFunc mutator;
+};
+
 /// Creates an operation pass to merge the local memref allocations
 std::unique_ptr<Pass>
-createMergeAllocPass(const memref::MergeAllocOptions &o = {});
+createMergeAllocPass(const MergeAllocationOptions &o = {});
 
 //===----------------------------------------------------------------------===//
 // Registration
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
index 100e6bfbcb4db..0c3a00ba1bede 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
@@ -27,23 +27,22 @@ namespace memref {
 
 namespace {
 
-LogicalResult passDriver(Operation *op, const memref::MergeAllocOptions &o,
-                         TraceCollectorFunc tracer, MemoryPlannerFunc planner,
-                         MemoryMergeMutatorFunc mutator) {
+LogicalResult passDriver(Operation *op,
+                         const memref::MergeAllocationOptions &o) {
   BufferViewFlowAnalysis aliasAnaly{op};
-  auto tracesOrFail = tracer(op, aliasAnaly, o);
+  auto tracesOrFail = o.tracer(op, aliasAnaly, o);
   if (failed(tracesOrFail)) {
     return failure();
   }
-  if (o.optionCheck) {
+  if (o.checkOnly) {
     return success();
   }
   for (auto &[scope, traces] : (*tracesOrFail).scopeToTraces) {
-    auto schedule = planner(op, *traces, o);
+    auto schedule = o.planner(op, *traces, o);
     if (failed(schedule)) {
       return failure();
     }
-    if (failed(mutator(op, scope, *schedule, o))) {
+    if (failed(o.mutator(op, scope, *schedule, o))) {
       return failure();
     }
   }
@@ -54,24 +53,41 @@ LogicalResult passDriver(Operation *op, const memref::MergeAllocOptions &o,
 } // namespace memref
 
 using namespace mlir;
-struct MergeAllocPass : memref::impl::MergeAllocBase<MergeAllocPass> {
+class MergeAllocPass : public memref::impl::MergeAllocBase<MergeAllocPass> {
   using parent = memref::impl::MergeAllocBase<MergeAllocPass>;
   void runOnOperation() override {
+    memref::MergeAllocationOptions opt;
+    if (!options) {
+      opt.checkOnly = optionCheck;
+      opt.noLocalityFirst = optionNoLocality;
+      opt.tracer = memref::tickBasedCollectMemoryTrace;
+      opt.planner = memref::tickBasedPlanMemory;
+      opt.mutator = memref::tickBasedMutateAllocations;
+    } else {
+      opt = options.value();
+      if (!opt.tracer)
+        opt.tracer = memref::tickBasedCollectMemoryTrace;
+      if (!opt.planner)
+        opt.planner = memref::tickBasedPlanMemory;
+      if (!opt.mutator)
+        opt.mutator = memref::tickBasedMutateAllocations;
+    }
     auto op = getOperation();
-    if (failed(memref::passDriver(
-            op, memref::MergeAllocOptions{optionCheck, optionNoLocality},
-            memref::tickBasedCollectMemoryTrace, memref::tickBasedPlanMemory,
-            memref::tickBasedMutateAllocations))) {
-        signalPassFailure();
+    if (failed(memref::passDriver(op, opt))) {
+      signalPassFailure();
     }
   }
 
+  std::optional<memref::MergeAllocationOptions> options;
+
 public:
-  MergeAllocPass(const memref::MergeAllocOptions &o) : parent{o} {}
+  MergeAllocPass() = default;
+  explicit MergeAllocPass(const memref::MergeAllocationOptions &o)
+      : options{std::move(o)} {}
 };
 } // namespace mlir
 
 std::unique_ptr<mlir::Pass>
-mlir::memref::createMergeAllocPass(const memref::MergeAllocOptions &o) {
+mlir::memref::createMergeAllocPass(const memref::MergeAllocationOptions &o) {
   return std::make_unique<MergeAllocPass>(o);
 }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
index 85df88d1a5fc4..98d6779e78f9d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
@@ -226,9 +226,9 @@ struct TickCollecter {
     complexScopeStack.emplace_back(op, curTick);
   }
 
-  FailureOr<MemoryTraces> getTrace() {
+  FailureOr<MemoryTraceScopes> getTrace() {
     struct TraceWithTick {
-      Operation* op;
+      Operation *op;
       int64_t tick;
       memoryplan::MemoryTrace trace;
       TraceWithTick(int64_t tick, uintptr_t bufferId, size_t size)
@@ -255,7 +255,7 @@ struct TickCollecter {
       raw[scope].emplace_back(tick.lastAccess * 2 + 1,
                               reinterpret_cast<uintptr_t>(op), 0);
     }
-    MemoryTraces ret;
+    MemoryTraceScopes ret;
     for (auto &[scope, trace] : raw) {
       std::stable_sort(trace.begin(), trace.end(),
                        [](const TraceWithTick &a, const TraceWithTick &b) {
@@ -273,10 +273,10 @@ struct TickCollecter {
 };
 } // namespace
 
-FailureOr<MemoryTraces>
+FailureOr<MemoryTraceScopes>
 tickBasedCollectMemoryTrace(Operation *root,
                             const mlir::BufferViewFlowAnalysis &aliasAnaly,
-                            const MergeAllocOptions &option) {
+                            const MergeAllocationOptions &option) {
   TickCollecter collecter{aliasAnaly};
   LogicalResult result = success();
   root->walk<WalkOrder::PreOrder>([&](Operation *op) {
@@ -303,7 +303,7 @@ tickBasedCollectMemoryTrace(Operation *root,
   if (failed(collecter.popScopeIfNecessary(nullptr))) {
     return failure();
   }
-  if (option.optionCheck) {
+  if (option.checkOnly) {
     for (auto &[alloc, tick] : collecter.allocTicks) {
       auto allocscope = getAllocScope(alloc);
       alloc->setAttr(
@@ -316,14 +316,14 @@ tickBasedCollectMemoryTrace(Operation *root,
           IntegerAttr::get(mlir::IntegerType::get(root->getContext(), 64),
                            reinterpret_cast<int64_t>(allocscope)));
     }
-    return MemoryTraces();
+    return MemoryTraceScopes();
   }
   return collecter.getTrace();
 }
 
 FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
                                               const LifetimeTrace &tr,
-                                              const MergeAllocOptions &o) {
+                                              const MergeAllocationOptions &o) {
   auto traceObj = dyn_cast<TickTraceResult>(&tr);
   if (!traceObj) {
     return failure();
@@ -335,7 +335,7 @@ FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
   std::unordered_map<uintptr_t, std::size_t> outSchedule;
   std::unordered_map<uintptr_t, std::vector<uintptr_t>> dummy;
   auto total = memoryplan::scheduleMemoryAllocations(
-      traces, 64, !o.optionNoLocality, memoryplan::InplaceInfoMap(),
+      traces, 64, !o.noLocalityFirst, memoryplan::InplaceInfoMap(),
       outSchedule, dummy);
   MemorySchedule ret;
   ret.totalSize = total;
@@ -348,7 +348,7 @@ FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
 
 LogicalResult tickBasedMutateAllocations(Operation *op, Operation *scope,
                                          const MemorySchedule &schedule,
-                                         const MergeAllocOptions &o) {
+                                         const MergeAllocationOptions &o) {
   if (schedule.allocToOffset.empty()) {
     return success();
   }

>From c9265a4677e853e99d0c4de1b39966829b0606cf Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Fri, 7 Jun 2024 13:46:18 +0800
Subject: [PATCH 04/12] fix test

---
 mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h | 4 ++--
 mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp    | 4 ++++
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index bf4bf55e43c3b..5cafc74713210 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -128,8 +128,8 @@ struct MergeAllocationOptions {
 };
 
 /// Creates an operation pass to merge the local memref allocations
-std::unique_ptr<Pass>
-createMergeAllocPass(const MergeAllocationOptions &o = {});
+std::unique_ptr<Pass> createMergeAllocPass(const MergeAllocationOptions &o);
+std::unique_ptr<Pass> createMergeAllocPass();
 
 //===----------------------------------------------------------------------===//
 // Registration
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
index 0c3a00ba1bede..39cde76f37166 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
@@ -91,3 +91,7 @@ std::unique_ptr<mlir::Pass>
 mlir::memref::createMergeAllocPass(const memref::MergeAllocationOptions &o) {
   return std::make_unique<MergeAllocPass>(o);
 }
+
+std::unique_ptr<mlir::Pass> mlir::memref::createMergeAllocPass() {
+  return std::make_unique<MergeAllocPass>();
+}

>From cce1816dea39a5b5330dad6477fe09cbc6cad875 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Fri, 7 Jun 2024 17:03:08 +0800
Subject: [PATCH 05/12] make interfaces

---
 .../Dialect/MemRef/Transforms/MergeAlloc.h    |  35 --
 .../MemRef/Transforms/MergeAllocTickBased.h   | 156 ++++++
 .../mlir/Dialect/MemRef/Transforms/Passes.h   |   5 +-
 .../mlir/Dialect/MemRef/Transforms/Passes.td  |   3 +
 .../Dialect/MemRef/Transforms/MergeAlloc.cpp  |  14 +-
 .../MemRef/Transforms/MergeAllocTickBased.cpp | 443 +++++++++---------
 mlir/test/Dialect/MemRef/buffer-merge.mlir    |   8 +-
 7 files changed, 395 insertions(+), 269 deletions(-)
 delete mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h
 create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h
deleted file mode 100644
index e1d7f0876ad4c..0000000000000
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAlloc.h
+++ /dev/null
@@ -1,35 +0,0 @@
-//===- MergeAlloc.h - The interfaces for merge alloc pass -------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_MEMREF_MERGEALLOC_H
-#define MLIR_DIALECT_MEMREF_MERGEALLOC_H
-
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
-#include "mlir/IR/Operation.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SmallVector.h"
-#include <memory>
-
-namespace mlir {
-class BufferViewFlowAnalysis;
-namespace memref {
-FailureOr<MemoryTraceScopes>
-tickBasedCollectMemoryTrace(Operation *root,
-                            const mlir::BufferViewFlowAnalysis &aliasAnaly,
-                            const MergeAllocationOptions &option);
-
-FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
-                                              const LifetimeTrace &tr,
-                                              const MergeAllocationOptions &o);
-LogicalResult tickBasedMutateAllocations(Operation *op, Operation *scope,
-                                         const MemorySchedule &schedule,
-                                         const MergeAllocationOptions &o);
-} // namespace memref
-} // namespace mlir
-
-#endif
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
new file mode 100644
index 0000000000000..0683426ee0312
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
@@ -0,0 +1,156 @@
+//===- MergeAllocTickBased.h - Tick-based merge alloc interfaces *- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_MERGEALLOCTICKBASED_H
+#define MLIR_DIALECT_MEMREF_MERGEALLOCTICKBASED_H
+
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+
+namespace mlir {
+class BufferViewFlowAnalysis;
+class ViewLikeOpInterface;
+namespace memref {
+
+// Usually ticks should be non-negative numbers. There are two special ticks
+// defined here.
+namespace special_ticks {
+// the memref is not accessed
+static constexpr int64_t NO_ACCESS = -1;
+// complex access happens on this memref
+static constexpr int64_t COMPLEX_ACCESS = -2;
+} // namespace special_ticks
+
+// the collected tick [first, last] for a memref
+struct Tick {
+  // The tick when the buffer is allocated. allocTick is only used to stablize
+  // the sorting results of the buffers when ticks of different buffers are the
+  // same
+  int64_t allocTick = special_ticks::NO_ACCESS;
+  int64_t firstAccess = special_ticks::NO_ACCESS;
+  int64_t lastAccess = special_ticks::NO_ACCESS;
+
+  // access the memref at the tick, will update firstAccess and lastAccess based
+  // on the tick
+  void access(int64_t tick);
+};
+
+// A complex scope object is addition info for a RegionBranchOpInterface or
+// LoopLikeOpInterface. It contains the scope itself, and the referenced alloc
+// ops inside this scope. We use this object to track which buffers this scope
+// accesses. These buffers must have overlapped lifetime
+struct ComplexScope {
+  Operation *scope;
+  int64_t startTick;
+  llvm::SmallPtrSet<Operation *, 8> operations;
+  ComplexScope(Operation *scope, int64_t startTick)
+      : scope{scope}, startTick{startTick} {}
+};
+
+// the top-level collected lifetime trace for merge-alloc pass
+struct TickTraceResult : public LifetimeTrace {
+  memoryplan::Traces traces;
+  TickTraceResult() : LifetimeTrace{TK_TICK} {}
+  static bool classof(const LifetimeTrace *S) {
+    return S->getKind() == TK_TICK;
+  }
+};
+
+// the internal states for TickCollecter analysis for a function
+struct TickCollecterStates {
+  // the alias analysis result for the function
+  const mlir::BufferViewFlowAnalysis &aliasAnaly;
+  const MergeAllocationOptions &opt;
+  // the current tick for the current callback of walk(). It will be by default
+  // incremented by 1 for each visited op
+  int64_t curTick = 0;
+  // the currently collected AllocOp -> [start, end] map
+  llvm::DenseMap<Operation *, Tick> allocTicks;
+  // the stack of ComplexScopes for the current visited position in the IR
+  llvm::SmallVector<ComplexScope> complexScopeStack;
+  TickCollecterStates(const mlir::BufferViewFlowAnalysis &aliasAnaly,
+                      const MergeAllocationOptions &opt)
+      : aliasAnaly{aliasAnaly}, opt{opt} {}
+};
+
+struct TickCollecter {
+  TickCollecter() = default;
+  virtual LogicalResult popScopeIfNecessary(TickCollecterStates *s,
+                                            Operation *op) const;
+
+  virtual void forwardTick(TickCollecterStates *s) const;
+
+  virtual void accessValue(TickCollecterStates *s, Value v, bool complex) const;
+
+  virtual void onMemrefViews(TickCollecterStates *s,
+                             ViewLikeOpInterface op) const;
+
+  virtual void onReturnOp(TickCollecterStates *s, Operation *op) const;
+
+  virtual void onAllocOp(TickCollecterStates *s, Operation *op) const;
+
+  virtual void onGeneralOp(TickCollecterStates *s, Operation *op) const;
+
+  virtual void pushComplexScope(TickCollecterStates *s, Operation *op) const;
+
+  // called when walk() runs outside of the scope
+  LogicalResult onPopComplexScope(TickCollecterStates *s,
+                                  int64_t endTick) const;
+
+  // returns true of an allocation either is not defined in the scope, or the
+  // allocation escapes from the scope
+  bool needsResetTick(TickCollecterStates *s, Operation *scope,
+                      Operation *allocation) const;
+
+  virtual bool isMergeableAlloc(TickCollecterStates *s, Operation *op,
+                                int64_t tick) const;
+
+  // find the closest surrounding parent operation with AutomaticAllocationScope
+  // trait, and is not scf.for
+  virtual Operation *getAllocScope(TickCollecterStates *s, Operation *op) const;
+
+  virtual FailureOr<size_t> getAllocSize(TickCollecterStates *s,
+                                         Operation *op) const;
+
+  virtual FailureOr<MemoryTraceScopes> getTrace(TickCollecterStates *s) const;
+
+  virtual FailureOr<MemoryTraceScopes>
+  operator()(Operation *root, const mlir::BufferViewFlowAnalysis &aliasAnaly,
+             const MergeAllocationOptions &option) const;
+
+  virtual ~TickCollecter() = default;
+};
+
+struct MergeAllocDefaultMutator {
+  virtual Value buildAlloc(OpBuilder &build, Operation *scope, int64_t size,
+                           int64_t alignment) const;
+  virtual Value buildView(OpBuilder &build, Operation *scope,
+                          Operation *origAllocOp, Value mergedAlloc,
+                          int64_t byteOffset) const;
+  virtual LogicalResult operator()(Operation *op, Operation *scope,
+                                   const MemorySchedule &schedule,
+                                   const MergeAllocationOptions &o) const;
+  MergeAllocDefaultMutator() = default;
+  virtual ~MergeAllocDefaultMutator() = default;
+};
+
+FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
+                                              const LifetimeTrace &tr,
+                                              const MergeAllocationOptions &o);
+
+} // namespace memref
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 5cafc74713210..69f401502fc4e 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -120,8 +120,9 @@ using MemoryMergeMutatorFunc = std::function<LogicalResult(
     const MergeAllocationOptions &)>;
 
 struct MergeAllocationOptions {
-  bool checkOnly;
-  bool noLocalityFirst;
+  bool checkOnly = false;
+  bool noLocalityFirst = false;
+  int64_t alignment = 64;
   TraceCollectorFunc tracer;
   MemoryPlannerFunc planner;
   MemoryMergeMutatorFunc mutator;
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index f65774464c713..4562dc5c8548b 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -283,6 +283,9 @@ def MergeAlloc : Pass<"merge-alloc", "func::FuncOp">  {
        /*default=*/"false",
        "Don't consider the cache locality when reusing the buffers. "
        "This option may result in smaller total memory usage.">,
+    Option<"optionAlignment", "alignment", "int64_t",
+       /*default=*/"64",
+       "The alignment of the merged allocations">,
   ];
   let dependentDialects = ["memref::MemRefDialect", "arith::ArithDialect"];
   let constructor = "mlir::memref::createMergeAllocPass()";
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
index 39cde76f37166..b8451c641218f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/MemRef/Transforms/MergeAlloc.h"
+#include "mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h"
 
@@ -60,17 +60,21 @@ class MergeAllocPass : public memref::impl::MergeAllocBase<MergeAllocPass> {
     if (!options) {
       opt.checkOnly = optionCheck;
       opt.noLocalityFirst = optionNoLocality;
-      opt.tracer = memref::tickBasedCollectMemoryTrace;
+      opt.alignment = optionAlignment;
+      opt.tracer = memref::TickCollecter();
       opt.planner = memref::tickBasedPlanMemory;
-      opt.mutator = memref::tickBasedMutateAllocations;
+      opt.mutator = memref::MergeAllocDefaultMutator();
     } else {
       opt = options.value();
       if (!opt.tracer)
-        opt.tracer = memref::tickBasedCollectMemoryTrace;
+        opt.tracer = memref::TickCollecter();
       if (!opt.planner)
         opt.planner = memref::tickBasedPlanMemory;
       if (!opt.mutator)
-        opt.mutator = memref::tickBasedMutateAllocations;
+        opt.mutator = memref::MergeAllocDefaultMutator();
+    }
+    if (opt.alignment <= 0) {
+      signalPassFailure();
     }
     auto op = getOperation();
     if (failed(memref::passDriver(op, opt))) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
index 98d6779e78f9d..dac3c986ee3c3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
-#include "mlir/Dialect/MemRef/Transforms/MergeAlloc.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/StaticMemoryPlanning.h"
 
@@ -24,37 +24,137 @@
 namespace mlir {
 namespace memref {
 
+using namespace special_ticks;
+
 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
 /// no layout).
 static bool hasStaticIdentityLayout(MemRefType type) {
   return type.hasStaticShape() && type.getLayout().isIdentity();
 }
 
-namespace {
-static constexpr int64_t NO_ACCESS = -1;
-static constexpr int64_t COMPLEX_ACCESS = -2;
-struct Tick {
-  int64_t firstAccess = NO_ACCESS;
-  int64_t lastAccess = NO_ACCESS;
+void Tick::access(int64_t tick) {
+  if (tick == COMPLEX_ACCESS) {
+    firstAccess = COMPLEX_ACCESS;
+    lastAccess = COMPLEX_ACCESS;
+  }
+  if (firstAccess == COMPLEX_ACCESS) {
+    return;
+  }
+  if (firstAccess == NO_ACCESS) {
+    firstAccess = tick;
+  } else {
+    firstAccess = std::min(firstAccess, tick);
+  }
+  lastAccess = std::max(lastAccess, tick);
+}
 
-  void access(int64_t tick) {
-    if (tick == COMPLEX_ACCESS) {
-      firstAccess = COMPLEX_ACCESS;
-      lastAccess = COMPLEX_ACCESS;
+bool TickCollecter::needsResetTick(TickCollecterStates *s, Operation *scope,
+                                   Operation *allocation) const {
+  // if the allocation is not in the scope, conservatively set the ticks
+  if (!scope->isProperAncestor(allocation)) {
+    return true;
+  }
+  // if the allocation and its alias are used outside of the scope
+  for (auto &&alias : s->aliasAnaly.resolve(allocation->getResult(0))) {
+    for (auto &&userOp : alias.getUsers()) {
+      if (!scope->isProperAncestor(userOp) && !isMemoryEffectFree(userOp)) {
+        return true;
+      }
     }
-    if (firstAccess == COMPLEX_ACCESS) {
-      return;
+  }
+  return false;
+}
+
+LogicalResult TickCollecter::onPopComplexScope(TickCollecterStates *s,
+                                               int64_t endTick) const {
+  const auto &scope = s->complexScopeStack.back();
+  // if the complex scope is not recognized by us, and if it accesses memory,
+  // raise an error
+  if (!isa<RegionBranchOpInterface>(scope.scope) &&
+      !isa<LoopLikeOpInterface>(scope.scope) && !scope.operations.empty()) {
+    return scope.scope->emitOpError("expecting RegionBranchOpInterface or "
+                                    "LoopLikeOpInterface for merge-alloc");
+  }
+  for (auto op : scope.operations) {
+    if (needsResetTick(s, scope.scope, op)) {
+      // let all referenced buffers have overlapped lifetime
+      auto &tick = s->allocTicks[op];
+      tick.access(scope.startTick);
+      tick.access(endTick);
     }
-    if (firstAccess == NO_ACCESS) {
-      firstAccess = tick;
+  }
+  return success();
+}
+
+LogicalResult TickCollecter::popScopeIfNecessary(TickCollecterStates *s,
+                                                 Operation *op) const {
+  // first check if we have walked outside of the previous ComplexScope
+  while (!s->complexScopeStack.empty()) {
+    auto &scope = s->complexScopeStack.back();
+    if (!op || !scope.scope->isProperAncestor(op)) {
+      if (failed(onPopComplexScope(s, s->curTick))) {
+        return failure();
+      }
+      s->complexScopeStack.pop_back();
     } else {
-      firstAccess = std::min(firstAccess, tick);
+      break;
     }
-    lastAccess = std::max(lastAccess, tick);
   }
-};
+  return success();
+}
+
+void TickCollecter::forwardTick(TickCollecterStates *s) const { s->curTick++; }
 
-bool isMergeableAlloc(Operation *op, int64_t tick) {
+void TickCollecter::accessValue(TickCollecterStates *s, Value v,
+                                bool complex) const {
+  if (auto refv = dyn_cast<TypedValue<MemRefType>>(v)) {
+    for (auto &&base : s->aliasAnaly.resolveReverse(refv)) {
+      auto defop = base.getDefiningOp();
+      if (isa_and_present<memref::AllocOp>(defop)) {
+        s->allocTicks[defop].access(complex ? COMPLEX_ACCESS : s->curTick);
+        if (!s->complexScopeStack.empty()) {
+          s->complexScopeStack.back().operations.insert(defop);
+        }
+      }
+    }
+  }
+}
+
+void TickCollecter::onMemrefViews(TickCollecterStates *s,
+                                  ViewLikeOpInterface op) const {
+  auto viewSrc = op.getViewSource();
+  // don't need to access the first operand, which is "source".
+  // The "source" operand is not really read or written at this point
+  for (auto val : op.getOperation()->getOperands()) {
+    if (val != viewSrc)
+      accessValue(s, val, false);
+  }
+}
+
+void TickCollecter::onReturnOp(TickCollecterStates *s, Operation *op) const {
+  bool isTopLevel = isa<func::FuncOp>(op->getParentOp());
+  for (auto val : op->getOperands()) {
+    accessValue(s, val, isTopLevel);
+  }
+}
+
+void TickCollecter::onAllocOp(TickCollecterStates *s, Operation *op) const {
+  s->allocTicks[op].allocTick = s->curTick;
+}
+
+void TickCollecter::onGeneralOp(TickCollecterStates *s, Operation *op) const {
+  for (auto val : op->getOperands()) {
+    accessValue(s, val, false);
+  }
+}
+
+void TickCollecter::pushComplexScope(TickCollecterStates *s,
+                                     Operation *op) const {
+  s->complexScopeStack.emplace_back(op, s->curTick);
+}
+
+bool TickCollecter::isMergeableAlloc(TickCollecterStates *s, Operation *op,
+                                     int64_t tick) const {
   if (tick == COMPLEX_ACCESS) {
     return false;
   }
@@ -62,17 +162,17 @@ bool isMergeableAlloc(Operation *op, int64_t tick) {
           cast<MemRefType>(op->getResultTypes().front()))) {
     return false;
   }
-  // currently only support alignment: none, 1, 2, 4, 8, 16, 32, 64
   auto alignment = cast<memref::AllocOp>(op).getAlignment();
   if (!alignment) {
     return true; // ok if no alignment
   }
-  return alignment > 0 && (64 % alignment.value() == 0);
+  return alignment > 0 && (s->opt.alignment % alignment.value() == 0);
 }
 
 // find the closest surrounding parent operation with AutomaticAllocationScope
 // trait, and is not scf.for
-Operation *getAllocScope(Operation *op) {
+Operation *TickCollecter::getAllocScope(TickCollecterStates *s,
+                                        Operation *op) const {
   auto parent = op;
   for (;;) {
     parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
@@ -85,7 +185,8 @@ Operation *getAllocScope(Operation *op) {
   }
 }
 
-FailureOr<size_t> getAllocSize(Operation *op) {
+FailureOr<size_t> TickCollecter::getAllocSize(TickCollecterStates *s,
+                                              Operation *op) const {
   auto refType = cast<MemRefType>(op->getResultTypes().front());
   int64_t size = refType.getElementTypeBitWidth() / 8;
   // treat bool (i1) as 1 byte. It may not be true for all targets, but we at
@@ -100,212 +201,93 @@ FailureOr<size_t> getAllocSize(Operation *op) {
   return op->emitError("Expecting static shaped allocation");
 }
 
-// A complex scope object is addition info for a RegionBranchOpInterface or
-// LoopLikeOpInterface. It contains the scope itself, and the referenced alloc
-// ops inside this scope. We use this object to track which buffers this scope
-// accesses. These buffers must have overlapped lifetime
-struct ComplexScope {
-  Operation *scope;
-  int64_t startTick;
-  llvm::SmallPtrSet<Operation *, 8> operations;
-  ComplexScope(Operation *scope, int64_t startTick)
-      : scope{scope}, startTick{startTick} {}
-  // returns true of an allocation either is not defined in the scope, or the
-  // allocation escapes from the scope
-  bool needsResetTick(Operation *scope, Operation *allocation,
-                      const mlir::BufferViewFlowAnalysis &aliasAnaly) const {
-    // if the allocation is not in the scope, conservatively set the ticks
-    if (!scope->isProperAncestor(allocation)) {
-      return true;
-    }
-    // if the allocation and its alias are used outside of the scope
-    for (auto &&alias : aliasAnaly.resolve(allocation->getResult(0))) {
-      for (auto &&userOp : alias.getUsers()) {
-        if (!scope->isProperAncestor(userOp) && !isMemoryEffectFree(userOp)) {
-          return true;
-        }
-      }
-    }
-    return false;
-  }
-
-  // called when walk() runs outside of the scope
-  LogicalResult onPop(int64_t endTick,
-                      const mlir::BufferViewFlowAnalysis &aliasAnaly,
-                      llvm::DenseMap<Operation *, Tick> &allocTicks) {
-    // if the complex scope is not recognized by us, and if it accesses memory,
-    // raise an error
-    if (!isa<RegionBranchOpInterface>(scope) &&
-        !isa<LoopLikeOpInterface>(scope) && !operations.empty()) {
-      return scope->emitOpError("expecting RegionBranchOpInterface or "
-                                "LoopLikeOpInterface for merge-alloc");
-    }
-    for (auto op : operations) {
-      if (needsResetTick(scope, op, aliasAnaly)) {
-        // let all referenced buffers have overlapped lifetime
-        auto &tick = allocTicks[op];
-        tick.access(startTick);
-        tick.access(endTick);
-      }
-    }
-    return success();
-  }
-};
-
-struct TickTraceResult : public LifetimeTrace {
-  memoryplan::Traces traces;
-  TickTraceResult() : LifetimeTrace{TK_TICK} {}
-  static bool classof(const LifetimeTrace *S) {
-    return S->getKind() == TK_TICK;
-  }
-};
-
-struct TickCollecter {
-  const mlir::BufferViewFlowAnalysis &aliasAnaly;
-  int64_t curTick = 0;
-  llvm::DenseMap<Operation *, Tick> allocTicks;
-  llvm::SmallVector<ComplexScope> complexScopeStack;
-  TickCollecter(const mlir::BufferViewFlowAnalysis &aliasAnaly)
-      : aliasAnaly{aliasAnaly} {}
-  LogicalResult popScopeIfNecessary(Operation *op) {
-    // first check if we have walked outside of the previous ComplexScope
-    while (!complexScopeStack.empty()) {
-      auto &scope = complexScopeStack.back();
-      if (!op || !scope.scope->isProperAncestor(op)) {
-        if (failed(scope.onPop(curTick, aliasAnaly, allocTicks))) {
-          return failure();
-        }
-        complexScopeStack.pop_back();
-      } else {
-        break;
-      }
-    }
-    return success();
-  }
-
-  void forwardTick() { curTick++; }
-
-  void accessValue(Value v, bool complex) {
-    if (auto refv = dyn_cast<TypedValue<MemRefType>>(v)) {
-      for (auto &&base : aliasAnaly.resolveReverse(refv)) {
-        auto defop = base.getDefiningOp();
-        if (isa_and_present<memref::AllocOp>(defop)) {
-          allocTicks[defop].access(complex ? COMPLEX_ACCESS : curTick);
-          if (!complexScopeStack.empty()) {
-            complexScopeStack.back().operations.insert(defop);
-          }
-        }
-      }
-    }
-  }
-
-  void onMemrefViews(ViewLikeOpInterface op) {
-    auto viewSrc = op.getViewSource();
-    // don't need to access the first operand, which is "source".
-    // The "source" operand is not really read or written at this point
-    for (auto val : op.getOperation()->getOperands()) {
-      if (val != viewSrc)
-        accessValue(val, false);
+FailureOr<MemoryTraceScopes>
+TickCollecter::getTrace(TickCollecterStates *s) const {
+  struct TraceWithTick {
+    // just a tie-breaker when 2 tick are the same
+    int64_t allocTick;
+    int64_t tick;
+    memoryplan::MemoryTrace trace;
+    TraceWithTick(int64_t allocTick, int64_t tick, uintptr_t bufferId,
+                  size_t size)
+        : allocTick{allocTick}, tick{tick}, trace{bufferId, size} {}
+  };
+  llvm::DenseMap<Operation *, llvm::SmallVector<TraceWithTick, 8>> raw;
+  for (auto &[op, tick] : s->allocTicks) {
+    if (!isMergeableAlloc(s, op, tick.firstAccess)) {
+      continue;
     }
-  }
-
-  void onReturnOp(Operation *op) {
-    bool isTopLevel = isa<func::FuncOp>(op->getParentOp());
-    for (auto val : op->getOperands()) {
-      accessValue(val, isTopLevel);
+    auto scope = getAllocScope(s, op);
+    if (!scope) {
+      return op->emitError(
+          "This op should be surrounded by an AutomaticAllocationScope");
     }
-  }
-
-  void onGeneralOp(Operation *op) {
-    for (auto val : op->getOperands()) {
-      accessValue(val, false);
+    auto allocSize = getAllocSize(s, op);
+    if (failed(allocSize)) {
+      return failure();
     }
+    // tick.firstAccess * 2 and tick.lastAccess * 2 + 1 to make sure "dealloc"
+    // overlaps "alloc"
+    raw[scope].emplace_back(tick.allocTick, tick.firstAccess * 2,
+                            reinterpret_cast<uintptr_t>(op), *allocSize);
+    raw[scope].emplace_back(tick.allocTick, tick.lastAccess * 2 + 1,
+                            reinterpret_cast<uintptr_t>(op), 0);
   }
-
-  void pushComplexScope(Operation *op) {
-    complexScopeStack.emplace_back(op, curTick);
-  }
-
-  FailureOr<MemoryTraceScopes> getTrace() {
-    struct TraceWithTick {
-      Operation *op;
-      int64_t tick;
-      memoryplan::MemoryTrace trace;
-      TraceWithTick(int64_t tick, uintptr_t bufferId, size_t size)
-          : tick{tick}, trace{bufferId, size} {}
-    };
-    llvm::DenseMap<Operation *, llvm::SmallVector<TraceWithTick, 8>> raw;
-    for (auto &[op, tick] : allocTicks) {
-      if (!isMergeableAlloc(op, tick.firstAccess)) {
-        continue;
-      }
-      auto scope = getAllocScope(op);
-      if (!scope) {
-        return op->emitError(
-            "This op should be surrounded by an AutomaticAllocationScope");
-      }
-      auto allocSize = getAllocSize(op);
-      if (failed(allocSize)) {
-        return failure();
-      }
-      // tick.firstAccess * 2 and tick.lastAccess * 2 + 1 to make sure "dealloc"
-      // overlaps "alloc"
-      raw[scope].emplace_back(tick.firstAccess * 2,
-                              reinterpret_cast<uintptr_t>(op), *allocSize);
-      raw[scope].emplace_back(tick.lastAccess * 2 + 1,
-                              reinterpret_cast<uintptr_t>(op), 0);
-    }
-    MemoryTraceScopes ret;
-    for (auto &[scope, trace] : raw) {
-      std::stable_sort(trace.begin(), trace.end(),
-                       [](const TraceWithTick &a, const TraceWithTick &b) {
-                         return a.tick < b.tick;
-                       });
-      auto retTrace = std::make_unique<TickTraceResult>();
-      retTrace->traces.reserve(trace.size());
-      for (auto &tr : trace) {
-        retTrace->traces.emplace_back(tr.trace);
-      }
-      ret.scopeToTraces[scope] = std::move(retTrace);
+  MemoryTraceScopes ret;
+  for (auto &[scope, trace] : raw) {
+    std::stable_sort(trace.begin(), trace.end(),
+                     [](const TraceWithTick &a, const TraceWithTick &b) {
+                       if (a.tick == b.tick) {
+                         return a.allocTick < b.allocTick;
+                       }
+                       return a.tick < b.tick;
+                     });
+    auto retTrace = std::make_unique<TickTraceResult>();
+    retTrace->traces.reserve(trace.size());
+    for (auto &tr : trace) {
+      retTrace->traces.emplace_back(tr.trace);
     }
-    return ret;
+    ret.scopeToTraces[scope] = std::move(retTrace);
   }
-};
-} // namespace
+  return ret;
+}
 
 FailureOr<MemoryTraceScopes>
-tickBasedCollectMemoryTrace(Operation *root,
-                            const mlir::BufferViewFlowAnalysis &aliasAnaly,
-                            const MergeAllocationOptions &option) {
-  TickCollecter collecter{aliasAnaly};
+TickCollecter::operator()(Operation *root,
+                          const mlir::BufferViewFlowAnalysis &aliasAnaly,
+                          const MergeAllocationOptions &option) const {
+  TickCollecterStates s{aliasAnaly, option};
+  TickCollecter collecter;
   LogicalResult result = success();
   root->walk<WalkOrder::PreOrder>([&](Operation *op) {
-    if (failed(collecter.popScopeIfNecessary(op))) {
+    if (failed(collecter.popScopeIfNecessary(&s, op))) {
       result = failure();
     }
-    collecter.forwardTick();
+    collecter.forwardTick(&s);
     if (auto viewop = dyn_cast<ViewLikeOpInterface>(op)) {
-      collecter.onMemrefViews(viewop);
+      collecter.onMemrefViews(&s, viewop);
     } else if (op->hasTrait<OpTrait::ReturnLike>()) {
-      collecter.onReturnOp(op);
+      collecter.onReturnOp(&s, op);
+    } else if (isa<AllocOp>(op)) {
+      collecter.onAllocOp(&s, op);
     } else if (!isMemoryEffectFree(op)) {
       // if the op has no memory effects, it don't contribute to liveness
-      collecter.onGeneralOp(op);
+      collecter.onGeneralOp(&s, op);
     }
     if (op->getNumRegions() > 0 && !isa<func::FuncOp>(op)) {
       // finally, if op is complex scope, push one ComplexScope
-      collecter.pushComplexScope(op);
+      collecter.pushComplexScope(&s, op);
     }
   });
   if (failed(result)) {
     return result;
   }
-  if (failed(collecter.popScopeIfNecessary(nullptr))) {
+  if (failed(collecter.popScopeIfNecessary(&s, nullptr))) {
     return failure();
   }
   if (option.checkOnly) {
-    for (auto &[alloc, tick] : collecter.allocTicks) {
-      auto allocscope = getAllocScope(alloc);
+    for (auto &[alloc, tick] : s.allocTicks) {
+      auto allocscope = getAllocScope(&s, alloc);
       alloc->setAttr(
           "__mergealloc_lifetime",
           DenseI64ArrayAttr::get(root->getContext(),
@@ -318,7 +300,7 @@ tickBasedCollectMemoryTrace(Operation *root,
     }
     return MemoryTraceScopes();
   }
-  return collecter.getTrace();
+  return collecter.getTrace(&s);
 }
 
 FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
@@ -335,7 +317,7 @@ FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
   std::unordered_map<uintptr_t, std::size_t> outSchedule;
   std::unordered_map<uintptr_t, std::vector<uintptr_t>> dummy;
   auto total = memoryplan::scheduleMemoryAllocations(
-      traces, 64, !o.noLocalityFirst, memoryplan::InplaceInfoMap(),
+      traces, o.alignment, !o.noLocalityFirst, memoryplan::InplaceInfoMap(),
       outSchedule, dummy);
   MemorySchedule ret;
   ret.totalSize = total;
@@ -346,29 +328,44 @@ FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
   return std::move(ret);
 }
 
-LogicalResult tickBasedMutateAllocations(Operation *op, Operation *scope,
-                                         const MemorySchedule &schedule,
-                                         const MergeAllocationOptions &o) {
+Value MergeAllocDefaultMutator::buildAlloc(OpBuilder &builder, Operation *scope,
+                                           int64_t size,
+                                           int64_t alignmentInt) const {
+  auto &block = scope->getRegion(0).getBlocks().front();
+  builder.setInsertionPointToStart(&block);
+  auto alignment = builder.getIntegerAttr(
+      IntegerType::get(builder.getContext(), 64), alignmentInt);
+  auto alloc = builder.create<memref::AllocOp>(
+      scope->getLoc(), MemRefType::get({size}, builder.getI8Type()), alignment);
+  return alloc;
+}
+Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Operation *scope,
+                                          Operation *origAllocOp,
+                                          Value mergedAlloc,
+                                          int64_t byteOffset) const {
+  builder.setInsertionPoint(origAllocOp);
+  auto byteShift =
+      builder.create<arith::ConstantIndexOp>(origAllocOp->getLoc(), byteOffset);
+  return builder.create<memref::ViewOp>(origAllocOp->getLoc(),
+                                        origAllocOp->getResultTypes().front(),
+                                        mergedAlloc, byteShift, ValueRange{});
+}
+
+LogicalResult
+MergeAllocDefaultMutator::operator()(Operation *op, Operation *scope,
+                                     const MemorySchedule &schedule,
+                                     const MergeAllocationOptions &o) const {
   if (schedule.allocToOffset.empty()) {
     return success();
   }
-  auto &block = scope->getRegion(0).getBlocks().front();
-  OpBuilder builder{&block.front()};
-  auto alignment =
-      builder.getIntegerAttr(IntegerType::get(op->getContext(), 64), 64);
-  auto alloc = builder.create<memref::AllocOp>(
-      scope->getLoc(),
-      MemRefType::get({static_cast<int64_t>(schedule.totalSize)},
-                      builder.getI8Type()),
-      alignment);
+  OpBuilder builder{op->getContext()};
+  auto alloc = buildAlloc(
+      builder, scope, static_cast<int64_t>(schedule.totalSize), o.alignment);
   for (auto &[origBuf, offset] : schedule.allocToOffset) {
-    builder.setInsertionPoint(origBuf);
-    auto byteShift = builder.create<arith::ConstantIndexOp>(
-        origBuf->getLoc(), static_cast<int64_t>(offset));
-    auto view = builder.create<memref::ViewOp>(
-        origBuf->getLoc(), origBuf->getResultTypes().front(), alloc, byteShift,
-        ValueRange{});
-    origBuf->replaceAllUsesWith(view->getResults());
+    origBuf->replaceAllUsesWith(
+        buildView(builder, scope, origBuf, alloc, static_cast<int64_t>(offset))
+            .getDefiningOp()
+            ->getResults());
     origBuf->remove();
   }
   return success();
diff --git a/mlir/test/Dialect/MemRef/buffer-merge.mlir b/mlir/test/Dialect/MemRef/buffer-merge.mlir
index f7e7f0f3d6b62..e491e2479a157 100644
--- a/mlir/test/Dialect/MemRef/buffer-merge.mlir
+++ b/mlir/test/Dialect/MemRef/buffer-merge.mlir
@@ -63,11 +63,11 @@ func.func @basic() -> memref<8x64xf32> {
 // CHECK-LABEL: @withloop
 func.func @withloop() {
   // CHECK-DAG: %[[BASE2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<6144xi8>
-  // CHECK-DAG: %[[C2048:.*]] = arith.constant 2048 : index
-  // CHECK-DAG: %[[F:.*]] = memref.view %[[BASE2]][%[[C2048]]][] : memref<6144xi8> to memref<8x64xf32>
-  %f = memref.alloc() : memref<8x64xf32>
   // CHECK-DAG: %[[C033:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[G:.*]] = memref.view %[[BASE2]][%[[C033]]][] : memref<6144xi8> to memref<8x64xf32>
+  // CHECK-DAG: %[[F:.*]] = memref.view %[[BASE2]][%[[C033]]][] : memref<6144xi8> to memref<8x64xf32>
+  %f = memref.alloc() : memref<8x64xf32>
+  // CHECK-DAG: %[[C2048:.*]] = arith.constant 2048 : index
+  // CHECK-DAG: %[[G:.*]] = memref.view %[[BASE2]][%[[C2048]]][] : memref<6144xi8> to memref<8x64xf32>
   %g = memref.alloc() : memref<8x64xf32>
 
   %c0 = arith.constant 0 : index

>From eeabd530ee885d882d49fe993ad158ed4888a838 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Wed, 12 Jun 2024 10:32:33 +0800
Subject: [PATCH 06/12] add doc(WIP)

---
 mlir/docs/MemrefMergeAlloc.md                 | 222 ++++++++++++++++++
 .../MemRef/Transforms/MergeAllocTickBased.h   |  85 ++++---
 .../mlir/Dialect/MemRef/Transforms/Passes.h   |  18 +-
 3 files changed, 285 insertions(+), 40 deletions(-)
 create mode 100644 mlir/docs/MemrefMergeAlloc.md

diff --git a/mlir/docs/MemrefMergeAlloc.md b/mlir/docs/MemrefMergeAlloc.md
new file mode 100644
index 0000000000000..835c556c4fbab
--- /dev/null
+++ b/mlir/docs/MemrefMergeAlloc.md
@@ -0,0 +1,222 @@
+# Compile-time memref.alloc Scheduling and Merging
+
+This document describes a compile-time optimization on `memref.alloc` to reduce memory usage and improve memory locality.
+
+## Current status of bufferization and memref pass pipeline
+Bufferization is a process in the current MLIR of converting ops with tensor semantics to ops with memref semantics.
+One-Shot Bufferize is a new tensor bufferization pass designed for IR in destination-passing style, and with aggressive in-place bufferization. The older/partial bufferization was built around multiple dialects. The community is trying to gradually deprecate the older bufferization and replace them with one-shot bufferization.
+The goal of bufferization is to use as little memory as possible and copy as little memory as possible, as a result, the exsiting focus is to determine in-place or out-of-place among the OpOperand and OpResult of individual ops, while not considering much about the overall memory reuse across Operators within a sub-graph (or partition).
+
+The current implementation of Bufferization and memref pass pipeline focuses on copy-avoidance and in-place reusing of the memory. Consider a computation graph of 4 layers of matmul sharing the same weight:
+```mlir
+func.func @mlp(%x: tensor<128x128xf32>, %y: tensor<128x128xf32>) -> tensor<128x128xf32> {
+   %a0 = tensor.empty() : tensor<128x128xf32>
+   %a = linalg.matmul ins(%x, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%a0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   %b0 = tensor.empty() : tensor<128x128xf32>
+   %b = linalg.matmul ins(%a, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%b0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   %c0 = tensor.empty() : tensor<128x128xf32>
+   %c = linalg.matmul ins(%b, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%c0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   %d0 = tensor.empty() : tensor<128x128xf32>
+   %d = linalg.matmul ins(%c, %y: tensor<128x128xf32>, tensor<128x128xf32>) outs(%d0: tensor<128x128xf32>) -> tensor<128x128xf32>
+   return %d : tensor<128x128xf32>
+}
+```
+
+The bufferization pass will create an `memref.alloc` for each of the tensor `a0`, `b0` and `c0`. The bufferization result should be like:
+
+```mlir
+func.func @mlp(%x: memref<128x128xf32>, %y: memref<128x128xf32>) -> memref<128x128xf32> {
+   %a0 = memref.alloc() : memref<128x128xf32>
+   linalg.matmul ins(%x, %y: memref<128x128xf32>, memref<128x128xf32>) outs(%a0: memref<128x128xf32>)
+   %b0 = memref.alloc() : memref<128x128xf32>
+   linalg.matmul ins(%a0, %y: memref<128x128xf32>, memref<128x128xf32>) outs(%b0: memref<128x128xf32>)
+   %c0 = memref.alloc() : memref<128x128xf32>
+   linalg.matmul ins(%b0, %y: memref<128x128xf32>, memref<128x128xf32>) outs(%c0: memref<128x128xf32>)
+   %d0 = memref.alloc() : memref<128x128xf32>
+   linalg.matmul ins(%c0, %y: memref<128x128xf32>, memref<128x128xf32>) outs(%d0: memref<128x128xf32>)
+   return %d0 : memref<128x128xf32>
+}
+```
+
+Without further optimizations, 3 temp buffers will be allocated at the runtime for these tensors. However, as we can see in the IR, the buffer `a0` is no longer used when buffer `c0` is allocated. So buffer `c0` can reuse the memory buffer of buffer `a0`, to reduce the memory size footprint and improve the locality.
+
+An observation of the current bufferization and memref passes is that they do not consider the memory buffer planning - to reuse the buffer/memref for less total size and better locality.
+
+## Merge-alloc pass
+An optimization pass has been introduced to consolidate multiple allocations (`memref.alloc` ops) into a single `memref.alloc` op and each static-shaped `memref.alloc` op will be transformed into a "slice" from the `single allocated buffer` with `memref.view` and some compile-time decided `offsets`. This optimization works on `memref` instead of `tensor` ops, so it should be executed after bufferization pass, and before adding buffer deallocation ops.
+
+While merging the memory allocations, the transform should consider the lifetime of each allocated `memref`s. By lifetime, we mean the range of time when an memref allocated from `memref.alloc` is actively used. The references on `view`s of a "base" `memref` should contribute to the lifetime of the "base". A later `memref.alloc` should consider to reuse the memory of a previously allocated memref, if the lifetime of these two does not overlap. The transform will perform the "reusing" of memory by setting the `offset` of the later `memref.view` to a position within the memory range of a previous allocation's `memref.view` on the `single allocated buffer`.
+
+Below is the expected transformation result of the example IR in the above section:
+
+```mlir
+func.func @mlp(%x: memref<128x128xf32>, %y: memref<128x128xf32>) -> memref<128x128xf32> {
+   %single_buffer = memref.alloc() : memref<131072xi8> // 128*128*sizeof(f32)*2
+   %a0 = memref.view %single_buffer[0][] : memref<131072xi8> to memref<128x128xf32> // a0 takes the memory from byte offset 0
+   linalg.matmul ins(%x, %y: memref<128x128xf32>, memref<128x128xf32>) outs(%a0: memref<128x128xf32>)
+   %b0 = memref.view %single_buffer[65536][] : memref<131072xi8> to memref<128x128xf32> // b0 takes the memory from byte offset 128*128*sizeof(f32)
+   linalg.matmul ins(%a0, %y: memref<128x128xf32>, memref<128x128xf32>) outs(%b0: memref<128x128xf32>) 
+   %c0 = memref.view %single_buffer[0][] : memref<131072xi8> to memref<128x128xf32> // c0 takes the memory from byte offset 0
+   linalg.matmul ins(%b0, %y: memref<128x128xf32>, memref<128x128xf32>) outs(%c0: memref<128x128xf32>)
+   %d0 = memref.alloc() : memref<128x128xf32> // d0 is returned, do not merge
+   linalg.matmul ins(%c0, %y: memref<128x128xf32>, memref<128x128xf32>) outs(%d0: memref<128x128xf32>)
+   return %d0 : memref<128x128xf32>
+}
+```
+
+There is one single allocation `single_buffer` for all temp buffers and `alloc` ops for `a0`, `b0` and `c0` are removed. The returned memref `d0` is untouched. The memrefs `a0`, `b0` and `c0` are replaced by `memref.view` on `single_buffer`. Since `a0` and `b0`'s lifetime overlaps, the transformation will "allocate" different memory ranges on the `single_buffer` - note that `a0` and `b0` has different offsets `%single_buffer[0]` and `%single_buffer[65536]` and the memory ranges does not overlap. The memref `c0` does not overlap with `a0` in their lifetime, so that `c0` can reuse the memory range of `a0` by setting of offset to `%single_buffer[0]`, which is the same of `a0`. The final allocation size of temp memory buffer will be `128*128*sizeof(f32)*2` instead of three `memref<128x128xf32>` buffers in the original IR.
+
+
+## Other solutions besides merge-alloc
+
+Another (not yet existing) approach to resolve the memory reusing issue is to insert `memref.dealloc` as soon as the buffer is no longer used. For example, in the above "matmul" example, `memref.dealloc` can be inserted after the last use of `a0` at `linalg.matmul ins(%a0, %y...)`. So even without memref merging transformation, a common runtime memory allocator will try to reuse the memory free'd by `memref.dealloc(%a0)` when allocating buffer for `c0`. However, there are some disadvantages of this approach comparing to the compile-time memref merging transformation of this proposal:
+1. it depends on the implementation of the runtime memory allocator.
+2. the runtime memory allocator does not have full picture of the future allocation/deallocation patterns of the program. For example, if we change the above example to make buffer size `c0` greater than size of `a0`, the runtime memory allocator will not likely to be able to reuse the memory of `a0` for `c0`, becuase the free memory chunk size of `a0` does not fit allocation of `c0`. In contrast, the proposed optimization of this document has the knowledge of the allocation patterns. Thus, it can put the memory chunk for `a0` in a right place of the `single allocation buffer`, so that the allocation of `c0` can fit into it.
+3. calling runtime memory allocator for each buffer introduces more run time overhead than a single merged allocation after allocation merging.
+
+However, utilizing runtime memory allocator can be viewed as a supplementary approach of the allocation merging at compile-time, for example, to handle memref with dynamic shapes. These two memory optimization approaches should coexist and cowork in the pass pipeline.
+
+## General framework for implementation of merge-alloc
+
+To make merge-alloc pass capable of handling different hardware architectures and runtime requirements, the pass is implemented as a general pipeline of the following stages:
+
+1. Collect the memory alias via `BufferViewFlowAnalysis`
+2. Collect the memory lifetime traces
+3. Schedule the buffers by an allocation algorithm to compute the offsets of each allocations
+4. Rewrite the IR to replace allocations with views of merged buffers
+
+The steps 2, 3 and 4 can be implemented by the developers to customize the pass for their own use cases. A tick-based pipeline of the pass is provided as the default implementation, which will be discussed in the next section. 
+
+The following concepts should be defined by the implementation of the pass:
+ * Mergeable alloction: the memref.alloc operations that should be merged by the pass. Other memref.alloc operations that are not "mergeable" should be untouched by the pass
+ * Allocation scope: for each mergeable memref.alloc operation, there should be one ancestor surrounding operation called "allocation scope". The memory allocation after merge-alloc for that memref.alloc operation should be hoisted and merged to the block of that "allocation scope". A "allocation scope" should contain a single merged allocation for the mergeable allocation in it.
+ * Lifetime trace: for each mergeable memref.alloc operation, the "lifetime trace" should be collected, indicating the "allocation scope" and the liveness of the buffer allocated. The contents of a "lifetime trace" is implementation-defined
+
+
+There are some more details on each step of the pipeline above.
+
+### Collect the memory lifetime traces
+
+This is the first stage that a developer can customize in merge-alloc. It should collect the lifetime traces for each of the mergable memref.alloc operation. An implementation of the lifetime trace collector should define which allocations are mergeable and find the allocation scopes of them. It should also implement a data structure to hold the detailed liveness of each buffers.
+
+This step is abstracted in a `TraceCollectorFunc` function. The merge-alloc framework defines the abstract interfaces for lifetime trace collector and the collected traces as below:
+
+```c++
+/// abstract base class for lifetime of buffers in the same "allocation scope".
+/// It should hold the lifetime informantion of buffers that are to be merged in
+/// the same allocation in an "allocation scope". TraceCollectorFunc decides
+/// which buffers are put into which "allocation scope".
+class LifetimeTrace {
+public:
+  virtual ~LifetimeTrace() = default;
+};
+
+/// top level memory trace info for multiple scopes. Each key-value is the
+///  "allocation scope" and the LifetimeTrace
+struct MemoryTraceScopes {
+  llvm::DenseMap<Operation *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
+  MemoryTraceScopes() = default;
+};
+
+using TraceCollectorFunc = std::function<FailureOr<MemoryTraceScopes>(
+    Operation *, const BufferViewFlowAnalysis &,
+    const MergeAllocationOptions &)>;
+```
+
+### Memory planning and scheduling
+
+This step is abstracted in a `MemoryPlannerFunc` function. It accepts the `MemoryTraceScopes` collected by the previous step. For each allocation scope in `MemoryTraceScopes`, it decides the total merged allocation size and the offsets for each mergeable allocation inside of the allocation scope. The abstract interfaces are shown below:
+
+```c++
+/// the memory scheduling result for allocations in the same allocation scope.
+/// allocation => offset map. All Operation* in the map should be
+/// memref::AllocOp which are in the same LifetimeTrace.
+struct MemorySchedule {
+  size_t totalSize;
+  llvm::DenseMap<Operation *, int64_t> allocToOffset;
+  MemorySchedule() : totalSize{0} {}
+};
+
+using MemoryPlannerFunc = std::function<FailureOr<MemorySchedule>(
+    Operation *, const LifetimeTrace &, const MergeAllocationOptions &)>;
+```
+
+### Rewriting the IR
+
+Given the `MemorySchedule` of the previous step, this step rewrites the IR to create the merged allocation in each of the allocation scopes, to replace the mergable memref.alloc with views on the merged allocations with the offsets calculated in the `MemorySchedule`. This step is abstracted in a `MemoryMergeMutatorFunc` function.
+
+```c++
+using MemoryMergeMutatorFunc = std::function<LogicalResult(
+    Operation *toplevel, Operation *scope, const MemorySchedule &,
+    const MergeAllocationOptions &)>;
+```
+
+
+## Tick-based Implementation for merge-alloc
+
+A tick-based implementation of merge-alloc in provided by default. The basic idea of the tick-based allocation merging is that
+
+1. Each of the operations in a function is assigned a "tick". An operation with a smaller tick is expected to be executed before one with a larger tick
+2. Collect the first referenced tick and the last referenced tick for each mergeable allocation. If a buffer is referenced in loops and branches, special handling is needed.
+3. For each allocation scope, linearize the first referenced tick and the last referenced tick of mergeable allocations inside of it into a single linear timeline
+4. Use a "static-memory-planner" to handle the linear timeline
+
+### Basic concepts
+
+In the context of tick-based merge-alloc, mergeable alloction and allocation scope are defined as follows
+
+#### Mergeable alloction
+
+The pass should only consider to merge a `memref.alloc` only if
+ * the ownership of the memref does not escape from the function. That is, the current function is responsible to alloc and dealloc this memref
+ * and, the allocated memref is contiguous and has static shape and identical layout.
+
+In tick-based merge-alloc, we call these `memref.alloc` **mergeable** allocations.
+
+The memrefs passed by function arguments, or returned by the function will be untouched by this optimization.
+
+#### Allocation scopes
+
+The transformation first needs to identify the allocation scopes, which are mlir operaions containing non-zero regions, and
+ * implementing `AutomaticAllocationScope`
+ * and is not `scf.for` (allocations in an `scf.for` can be hoisted to parent `AutomaticAllocationScope`)
+
+For example, below is an example IR of a function with nested `scf.forall` ops.
+
+```mlir
+func.func @mlp(...) { // <---- alloc scope 1
+   scf.for(...) { // <---- NOT an alloc scope!
+      // allocation inside will be merge to alloc scope 1 above
+   }
+   ...
+   scf.forall(...) { // <---- alloc scope 2
+      ...
+      // allocation here will be merge to alloc scope 2
+      %buf = memref.alloc() : ...
+      scf.forall(...) { // <---- alloc scope 3
+      }
+   }
+}
+```
+
+There will be three allocation scopes as marked in the comments above. An allocation scope marks the position to insert the `single allocation buffer` after allocation merging. After the transformation, all "mergeable" `memref.alloc` will be merged to the `single allocation buffer` of the nearest ancestor `alloc scope`.
+
+### Tick-based trace collection
+
+walk()
+Alias
+Branch
+Sort-malloc-free
+
+### Static Memory planner
+
+
+The transformantion is consist of an analysis sub-pass and a mutation sub-pass. For each `alloc scope`, the analysis sub-pass finds the lifetime of each mergeable `memref.alloc` belonging to the `alloc scope`. And given the lifetime of each allocation, a memory planning algorithm will be run to find the `single allocation buffer` size of each `alloc scope` and the `offset` for each mergeable allocation within its `single allocation buffer`. Based on the memory planning result, the mutation sub-pass transforms the IR to
+1. insert `memref.alloc` at the front of `alloc scope` body for its `single allocation buffer`
+2. replace mergeable `memref.alloc` with `memref.view` on its `alloc scope`'s `single allocation buffer`
+
+Ticks are assigned on each operation in the `func.func` by a increasing counter with pre-order recursive walking of the IR, as the "execution tick" for each operation. The lifetime analysis pass will assign two integers for each mergeable allocations as the analysis result: `begin_tick` and `end_tick`, to indicate the first and last tick of the use of the allocated memref in the IR. There should be special handling for loop and branch ops (`RegionBranchOpInterface` or `LoopLikeOpInterface`) which references memrefs allocated in parent scopes, to avoid wrong reuse of buffers used in the loop.
+
+The analysis result for each mergeable allocations will be an integer range `[begin_tick,end_tick]`, where `begin_tick <= end_tick`.
+
+The collected ticks for each buffer will be processed by the memory planning algorithm. It should output the total size of the `single allocation buffers` for each `alloc scopes`, and the `offsets` for each individual mergeable buffers. The algorithm should also consider the locality of the buffer to use, when multiple buffer localtion candidates are available.
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
index 0683426ee0312..75209e2df38d0 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
@@ -24,33 +24,34 @@ class BufferViewFlowAnalysis;
 class ViewLikeOpInterface;
 namespace memref {
 
-// Usually ticks should be non-negative numbers. There are two special ticks
-// defined here.
+/// Usually ticks should be non-negative numbers. There are two special ticks
+/// defined here.
 namespace special_ticks {
-// the memref is not accessed
+/// the memref is not accessed
 static constexpr int64_t NO_ACCESS = -1;
-// complex access happens on this memref
+/// complex access happens on this memref, like func.return
 static constexpr int64_t COMPLEX_ACCESS = -2;
 } // namespace special_ticks
 
-// the collected tick [first, last] for a memref
+/// the collected tick [first, last] for a memref allocation
 struct Tick {
-  // The tick when the buffer is allocated. allocTick is only used to stablize
-  // the sorting results of the buffers when ticks of different buffers are the
-  // same
+  /// The tick when the buffer is allocated. allocTick is only used to stablize
+  /// the sorting results of the buffers when ticks of different buffers are the
+  /// same
   int64_t allocTick = special_ticks::NO_ACCESS;
   int64_t firstAccess = special_ticks::NO_ACCESS;
   int64_t lastAccess = special_ticks::NO_ACCESS;
 
-  // access the memref at the tick, will update firstAccess and lastAccess based
-  // on the tick
+  /// access the memref at the tick, will update firstAccess and lastAccess
+  /// based on the tick
   void access(int64_t tick);
 };
 
-// A complex scope object is addition info for a RegionBranchOpInterface or
-// LoopLikeOpInterface. It contains the scope itself, and the referenced alloc
-// ops inside this scope. We use this object to track which buffers this scope
-// accesses. These buffers must have overlapped lifetime
+/// A complex scope object is addition info for a RegionBranchOpInterface or
+/// LoopLikeOpInterface. It contains the scope itself, and the referenced alloc
+/// ops inside this scope. It is used in TickCollecter and TickCollecterStates
+/// internally to track which buffers this scope accesses. These buffers must
+/// have overlapped lifetime
 struct ComplexScope {
   Operation *scope;
   int64_t startTick;
@@ -59,7 +60,7 @@ struct ComplexScope {
       : scope{scope}, startTick{startTick} {}
 };
 
-// the top-level collected lifetime trace for merge-alloc pass
+/// the top-level collected lifetime trace for merge-alloc pass
 struct TickTraceResult : public LifetimeTrace {
   memoryplan::Traces traces;
   TickTraceResult() : LifetimeTrace{TK_TICK} {}
@@ -68,25 +69,33 @@ struct TickTraceResult : public LifetimeTrace {
   }
 };
 
-// the internal states for TickCollecter analysis for a function
+/// the internal states for TickCollecter analysis for a function
 struct TickCollecterStates {
-  // the alias analysis result for the function
+  /// the alias analysis result for the function
   const mlir::BufferViewFlowAnalysis &aliasAnaly;
   const MergeAllocationOptions &opt;
-  // the current tick for the current callback of walk(). It will be by default
-  // incremented by 1 for each visited op
+  /// the current tick for the current callback of walk(). It will be by default
+  /// incremented by 1 for each visited op
   int64_t curTick = 0;
-  // the currently collected AllocOp -> [start, end] map
+  /// the currently collected AllocOp -> [start, end] map
   llvm::DenseMap<Operation *, Tick> allocTicks;
-  // the stack of ComplexScopes for the current visited position in the IR
+  /// the stack of ComplexScopes for the current visited position in the IR
   llvm::SmallVector<ComplexScope> complexScopeStack;
   TickCollecterStates(const mlir::BufferViewFlowAnalysis &aliasAnaly,
                       const MergeAllocationOptions &opt)
       : aliasAnaly{aliasAnaly}, opt{opt} {}
 };
 
+/// the tick-based memory lifetime collector. This class overrides operator() so
+/// that a TickCollecter object can be passed to a TraceCollectorFunc. This
+/// collector assigns a monotonically increasing "tick" for each operation, by
+/// Operaion::walk<WalkOrder::PreOrder>() order. It collects the first reference
+/// tick and the last reference tick for each `memref.alloc` operation as the
+/// lifetime trace stored in `TickTraceResult`
 struct TickCollecter {
   TickCollecter() = default;
+  /// called on each operation before calling onXXXX() below. If may call
+  /// onPopComplexScope internally when walk() runs out of a ComplexScope
   virtual LogicalResult popScopeIfNecessary(TickCollecterStates *s,
                                             Operation *op) const;
 
@@ -105,27 +114,35 @@ struct TickCollecter {
 
   virtual void pushComplexScope(TickCollecterStates *s, Operation *op) const;
 
-  // called when walk() runs outside of the scope
+  /// called when walk() runs outside of the scope
   LogicalResult onPopComplexScope(TickCollecterStates *s,
                                   int64_t endTick) const;
 
-  // returns true of an allocation either is not defined in the scope, or the
-  // allocation escapes from the scope
-  bool needsResetTick(TickCollecterStates *s, Operation *scope,
-                      Operation *allocation) const;
-
+  /// returns true of an allocation either is not defined in the scope, or the
+  /// allocation escapes from the scope
+  virtual bool needsResetTick(TickCollecterStates *s, Operation *scope,
+                              Operation *allocation) const;
+  /// returns true if the memref.alloc op is "merge-able". If false is returned,
+  /// this memref.alloc will be untouched by merge-alloc
   virtual bool isMergeableAlloc(TickCollecterStates *s, Operation *op,
                                 int64_t tick) const;
 
-  // find the closest surrounding parent operation with AutomaticAllocationScope
-  // trait, and is not scf.for
+  /// gets the "AllocScope" op of a memref.alloc op. The default implementation
+  /// finds the closest surrounding parent operation with
+  /// AutomaticAllocationScope trait, and is not scf.for
   virtual Operation *getAllocScope(TickCollecterStates *s, Operation *op) const;
-
+  /// returns the allocation size in bytes of a memref.alloc op. May fail when
+  /// the shape is not static
   virtual FailureOr<size_t> getAllocSize(TickCollecterStates *s,
                                          Operation *op) const;
-
+  // consolidate the traces for each buffer in each scope, recorded in
+  // TickCollecterStates. It is called after walk() in operator()
   virtual FailureOr<MemoryTraceScopes> getTrace(TickCollecterStates *s) const;
 
+  /// the top-level entry for the collector. It creates a TickCollecterStates
+  /// and applies Operaion::walk<WalkOrder::PreOrder>() on the root function Op.
+  /// And it finally calls getTrace() to collect the `TickTraceResult` for each
+  /// scope in MemoryTraceScopes.
   virtual FailureOr<MemoryTraceScopes>
   operator()(Operation *root, const mlir::BufferViewFlowAnalysis &aliasAnaly,
              const MergeAllocationOptions &option) const;
@@ -133,9 +150,15 @@ struct TickCollecter {
   virtual ~TickCollecter() = default;
 };
 
+// the default merge-alloc IR mutator. It can be passes to a
+// MemoryMergeMutatorFunc as argument of merge-alloc
 struct MergeAllocDefaultMutator {
+  /// builds an memory alloc op at the scope with given size and alignment in
+  /// bytes
   virtual Value buildAlloc(OpBuilder &build, Operation *scope, int64_t size,
                            int64_t alignment) const;
+  /// builds an memory view op for original memref.alloc op (origAllocOp) and
+  /// the merged single allocation (mergedAlloc)
   virtual Value buildView(OpBuilder &build, Operation *scope,
                           Operation *origAllocOp, Value mergedAlloc,
                           int64_t byteOffset) const;
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 69f401502fc4e..8030bb7046bfb 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -78,10 +78,10 @@ std::unique_ptr<Pass> createExpandStridedMetadataPass();
 /// components.
 std::unique_ptr<Pass> createExpandReallocPass(bool emitDeallocs = true);
 
-// abstract base class for lifetime of different buffers. It should hold the
-// lifetime informantion of buffers that are to be merged in the same allocation
-// in an "allocation scope". TraceCollectorFunc decides which buffers are put
-// into which "allocation scope".
+/// abstract base class for lifetime of buffers in the same "allocation scope".
+/// It should hold the lifetime informantion of buffers that are to be merged in
+/// the same allocation in an "allocation scope". TraceCollectorFunc decides
+/// which buffers are put into which "allocation scope".
 class LifetimeTrace {
 public:
   enum TraceKind { TK_TICK };
@@ -93,16 +93,16 @@ class LifetimeTrace {
   TraceKind kind;
 };
 
-// top level memory trace info for multiple scopes. Each key-value is the
-// traces and location for buffers in the same "allocation scope"
+/// top level memory trace info for multiple scopes. Each key-value is the
+///  "allocation scope" and the LifetimeTrace
 struct MemoryTraceScopes {
   llvm::DenseMap<Operation *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
   MemoryTraceScopes() = default;
 };
 
-// the memory scheduling result for allocations in the same merged buffer.
-// allocation => offset map. All Operation* in the map should be memref::AllocOp
-// which are in the same LifetimeTrace.
+/// the memory scheduling result for allocations in the same allocation scope.
+/// allocation => offset map. All Operation* in the map should be
+/// memref::AllocOp which are in the same LifetimeTrace.
 struct MemorySchedule {
   size_t totalSize;
   llvm::DenseMap<Operation *, int64_t> allocToOffset;

>From 670265f889c7c24c982fc25f95a00cbe803cfae4 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Wed, 12 Jun 2024 15:26:39 +0800
Subject: [PATCH 07/12] doc

---
 mlir/docs/MemrefMergeAlloc.md                 | 351 +++++++++++++++---
 .../MemRef/Transforms/MergeAllocTickBased.cpp |  11 +-
 2 files changed, 304 insertions(+), 58 deletions(-)

diff --git a/mlir/docs/MemrefMergeAlloc.md b/mlir/docs/MemrefMergeAlloc.md
index 835c556c4fbab..e1c69a8d1f5e6 100644
--- a/mlir/docs/MemrefMergeAlloc.md
+++ b/mlir/docs/MemrefMergeAlloc.md
@@ -1,13 +1,24 @@
 # Compile-time memref.alloc Scheduling and Merging
 
-This document describes a compile-time optimization on `memref.alloc` to reduce memory usage and improve memory locality.
+This document describes a compile-time optimization on `memref.alloc` to reduce
+memory usage and improve memory locality.
 
 ## Current status of bufferization and memref pass pipeline
-Bufferization is a process in the current MLIR of converting ops with tensor semantics to ops with memref semantics.
-One-Shot Bufferize is a new tensor bufferization pass designed for IR in destination-passing style, and with aggressive in-place bufferization. The older/partial bufferization was built around multiple dialects. The community is trying to gradually deprecate the older bufferization and replace them with one-shot bufferization.
-The goal of bufferization is to use as little memory as possible and copy as little memory as possible, as a result, the exsiting focus is to determine in-place or out-of-place among the OpOperand and OpResult of individual ops, while not considering much about the overall memory reuse across Operators within a sub-graph (or partition).
-
-The current implementation of Bufferization and memref pass pipeline focuses on copy-avoidance and in-place reusing of the memory. Consider a computation graph of 4 layers of matmul sharing the same weight:
+Bufferization is a process in the current MLIR of converting ops with tensor
+semantics to ops with memref semantics. One-Shot Bufferize is a new tensor
+bufferization pass designed for IR in destination-passing style, and with
+aggressive in-place bufferization. The older/partial bufferization was built
+around multiple dialects. The community is trying to gradually deprecate the
+older bufferization and replace them with one-shot bufferization. The goal of
+bufferization is to use as little memory as possible and copy as little memory
+as possible, as a result, the exsiting focus is to determine in-place or
+out-of-place among the OpOperand and OpResult of individual ops, while not
+considering much about the overall memory reuse across Operators within a
+sub-graph (or partition).
+
+The current implementation of Bufferization and memref pass pipeline focuses on
+copy-avoidance and in-place reusing of the memory. Consider a computation graph
+of 4 layers of matmul sharing the same weight:
 ```mlir
 func.func @mlp(%x: tensor<128x128xf32>, %y: tensor<128x128xf32>) -> tensor<128x128xf32> {
    %a0 = tensor.empty() : tensor<128x128xf32>
@@ -22,7 +33,8 @@ func.func @mlp(%x: tensor<128x128xf32>, %y: tensor<128x128xf32>) -> tensor<128x1
 }
 ```
 
-The bufferization pass will create an `memref.alloc` for each of the tensor `a0`, `b0` and `c0`. The bufferization result should be like:
+The bufferization pass will create an `memref.alloc` for each of the tensor
+`a0`, `b0` and `c0`. The bufferization result should be like:
 
 ```mlir
 func.func @mlp(%x: memref<128x128xf32>, %y: memref<128x128xf32>) -> memref<128x128xf32> {
@@ -38,16 +50,36 @@ func.func @mlp(%x: memref<128x128xf32>, %y: memref<128x128xf32>) -> memref<128x1
 }
 ```
 
-Without further optimizations, 3 temp buffers will be allocated at the runtime for these tensors. However, as we can see in the IR, the buffer `a0` is no longer used when buffer `c0` is allocated. So buffer `c0` can reuse the memory buffer of buffer `a0`, to reduce the memory size footprint and improve the locality.
+Without further optimizations, 3 temp buffers will be allocated at the runtime
+for these tensors. However, as we can see in the IR, the buffer `a0` is no
+longer used when buffer `c0` is allocated. So buffer `c0` can reuse the memory
+buffer of buffer `a0`, to reduce the memory size footprint and improve the
+locality.
 
-An observation of the current bufferization and memref passes is that they do not consider the memory buffer planning - to reuse the buffer/memref for less total size and better locality.
+An observation of the current bufferization and memref passes is that they do
+not consider the memory buffer planning - to reuse the buffer/memref for less
+total size and better locality.
 
 ## Merge-alloc pass
-An optimization pass has been introduced to consolidate multiple allocations (`memref.alloc` ops) into a single `memref.alloc` op and each static-shaped `memref.alloc` op will be transformed into a "slice" from the `single allocated buffer` with `memref.view` and some compile-time decided `offsets`. This optimization works on `memref` instead of `tensor` ops, so it should be executed after bufferization pass, and before adding buffer deallocation ops.
-
-While merging the memory allocations, the transform should consider the lifetime of each allocated `memref`s. By lifetime, we mean the range of time when an memref allocated from `memref.alloc` is actively used. The references on `view`s of a "base" `memref` should contribute to the lifetime of the "base". A later `memref.alloc` should consider to reuse the memory of a previously allocated memref, if the lifetime of these two does not overlap. The transform will perform the "reusing" of memory by setting the `offset` of the later `memref.view` to a position within the memory range of a previous allocation's `memref.view` on the `single allocated buffer`.
-
-Below is the expected transformation result of the example IR in the above section:
+An optimization pass has been introduced to consolidate multiple allocations
+(`memref.alloc` ops) into a single `memref.alloc` op and each static-shaped
+`memref.alloc` op will be transformed into a "slice" from the `single allocated
+buffer` with `memref.view` and some compile-time decided `offsets`. This
+optimization works on `memref` instead of `tensor` ops, so it should be executed
+after bufferization pass, and before adding buffer deallocation ops.
+
+While merging the memory allocations, the transform should consider the lifetime
+of each allocated `memref`s. By lifetime, we mean the range of time when an
+memref allocated from `memref.alloc` is actively used. The references on `view`s
+of a "base" `memref` should contribute to the lifetime of the "base". A later
+`memref.alloc` should consider to reuse the memory of a previously allocated
+memref, if the lifetime of these two does not overlap. The transform will
+perform the "reusing" of memory by setting the `offset` of the later
+`memref.view` to a position within the memory range of a previous allocation's
+`memref.view` on the `single allocated buffer`.
+
+Below is the expected transformation result of the example IR in the above
+section:
 
 ```mlir
 func.func @mlp(%x: memref<128x128xf32>, %y: memref<128x128xf32>) -> memref<128x128xf32> {
@@ -64,42 +96,92 @@ func.func @mlp(%x: memref<128x128xf32>, %y: memref<128x128xf32>) -> memref<128x1
 }
 ```
 
-There is one single allocation `single_buffer` for all temp buffers and `alloc` ops for `a0`, `b0` and `c0` are removed. The returned memref `d0` is untouched. The memrefs `a0`, `b0` and `c0` are replaced by `memref.view` on `single_buffer`. Since `a0` and `b0`'s lifetime overlaps, the transformation will "allocate" different memory ranges on the `single_buffer` - note that `a0` and `b0` has different offsets `%single_buffer[0]` and `%single_buffer[65536]` and the memory ranges does not overlap. The memref `c0` does not overlap with `a0` in their lifetime, so that `c0` can reuse the memory range of `a0` by setting of offset to `%single_buffer[0]`, which is the same of `a0`. The final allocation size of temp memory buffer will be `128*128*sizeof(f32)*2` instead of three `memref<128x128xf32>` buffers in the original IR.
+There is one single allocation `single_buffer` for all temp buffers and `alloc`
+ops for `a0`, `b0` and `c0` are removed. The returned memref `d0` is untouched.
+The memrefs `a0`, `b0` and `c0` are replaced by `memref.view` on
+`single_buffer`. Since `a0` and `b0`'s lifetime overlaps, the transformation
+will "allocate" different memory ranges on the `single_buffer` - note that `a0`
+and `b0` has different offsets `%single_buffer[0]` and `%single_buffer[65536]`
+and the memory ranges does not overlap. The memref `c0` does not overlap with
+`a0` in their lifetime, so that `c0` can reuse the memory range of `a0` by
+setting of offset to `%single_buffer[0]`, which is the same of `a0`. The final
+allocation size of temp memory buffer will be `128*128*sizeof(f32)*2` instead of
+three `memref<128x128xf32>` buffers in the original IR.
 
 
 ## Other solutions besides merge-alloc
 
-Another (not yet existing) approach to resolve the memory reusing issue is to insert `memref.dealloc` as soon as the buffer is no longer used. For example, in the above "matmul" example, `memref.dealloc` can be inserted after the last use of `a0` at `linalg.matmul ins(%a0, %y...)`. So even without memref merging transformation, a common runtime memory allocator will try to reuse the memory free'd by `memref.dealloc(%a0)` when allocating buffer for `c0`. However, there are some disadvantages of this approach comparing to the compile-time memref merging transformation of this proposal:
+Another (not yet existing) approach to resolve the memory reusing issue is to
+insert `memref.dealloc` as soon as the buffer is no longer used. For example, in
+the above "matmul" example, `memref.dealloc` can be inserted after the last use
+of `a0` at `linalg.matmul ins(%a0, %y...)`. So even without memref merging
+transformation, a common runtime memory allocator will try to reuse the memory
+free'd by `memref.dealloc(%a0)` when allocating buffer for `c0`. However, there
+are some disadvantages of this approach comparing to the compile-time memref
+merging transformation of this proposal:
 1. it depends on the implementation of the runtime memory allocator.
-2. the runtime memory allocator does not have full picture of the future allocation/deallocation patterns of the program. For example, if we change the above example to make buffer size `c0` greater than size of `a0`, the runtime memory allocator will not likely to be able to reuse the memory of `a0` for `c0`, becuase the free memory chunk size of `a0` does not fit allocation of `c0`. In contrast, the proposed optimization of this document has the knowledge of the allocation patterns. Thus, it can put the memory chunk for `a0` in a right place of the `single allocation buffer`, so that the allocation of `c0` can fit into it.
-3. calling runtime memory allocator for each buffer introduces more run time overhead than a single merged allocation after allocation merging.
-
-However, utilizing runtime memory allocator can be viewed as a supplementary approach of the allocation merging at compile-time, for example, to handle memref with dynamic shapes. These two memory optimization approaches should coexist and cowork in the pass pipeline.
+2. the runtime memory allocator does not have full picture of the future
+   allocation/deallocation patterns of the program. For example, if we change
+   the above example to make buffer size `c0` greater than size of `a0`, the
+   runtime memory allocator will not likely to be able to reuse the memory of
+   `a0` for `c0`, becuase the free memory chunk size of `a0` does not fit
+   allocation of `c0`. In contrast, the proposed optimization of this document
+   has the knowledge of the allocation patterns. Thus, it can put the memory
+   chunk for `a0` in a right place of the `single allocation buffer`, so that
+   the allocation of `c0` can fit into it.
+3. calling runtime memory allocator for each buffer introduces more run time
+   overhead than a single merged allocation after allocation merging.
+
+However, utilizing runtime memory allocator can be viewed as a supplementary
+approach of the allocation merging at compile-time, for example, to handle
+memref with dynamic shapes. These two memory optimization approaches should
+coexist and cowork in the pass pipeline.
 
 ## General framework for implementation of merge-alloc
 
-To make merge-alloc pass capable of handling different hardware architectures and runtime requirements, the pass is implemented as a general pipeline of the following stages:
+To make merge-alloc pass capable of handling different hardware architectures
+and runtime requirements, the pass is implemented as a general pipeline of the
+following stages:
 
 1. Collect the memory alias via `BufferViewFlowAnalysis`
 2. Collect the memory lifetime traces
-3. Schedule the buffers by an allocation algorithm to compute the offsets of each allocations
+3. Schedule the buffers by an allocation algorithm to compute the offsets of
+   each allocations
 4. Rewrite the IR to replace allocations with views of merged buffers
 
-The steps 2, 3 and 4 can be implemented by the developers to customize the pass for their own use cases. A tick-based pipeline of the pass is provided as the default implementation, which will be discussed in the next section. 
+The steps 2, 3 and 4 can be implemented by the developers to customize the pass
+for their own use cases. A tick-based pipeline of the pass is provided as the
+default implementation, which will be discussed in the next section. 
 
 The following concepts should be defined by the implementation of the pass:
- * Mergeable alloction: the memref.alloc operations that should be merged by the pass. Other memref.alloc operations that are not "mergeable" should be untouched by the pass
- * Allocation scope: for each mergeable memref.alloc operation, there should be one ancestor surrounding operation called "allocation scope". The memory allocation after merge-alloc for that memref.alloc operation should be hoisted and merged to the block of that "allocation scope". A "allocation scope" should contain a single merged allocation for the mergeable allocation in it.
- * Lifetime trace: for each mergeable memref.alloc operation, the "lifetime trace" should be collected, indicating the "allocation scope" and the liveness of the buffer allocated. The contents of a "lifetime trace" is implementation-defined
+ * Mergeable alloction: the memref.alloc operations that should be merged by the
+   pass. Other memref.alloc operations that are not "mergeable" should be
+   untouched by the pass
+ * Allocation scope: for each mergeable memref.alloc operation, there should be
+   one ancestor surrounding operation called "allocation scope". The memory
+   allocation after merge-alloc for that memref.alloc operation should be
+   hoisted and merged to the block of that "allocation scope". A "allocation
+   scope" should contain a single merged allocation for the mergeable allocation
+   in it.
+ * Lifetime trace: for each mergeable memref.alloc operation, the "lifetime
+   trace" should be collected, indicating the "allocation scope" and the
+   liveness of the buffer allocated. The contents of a "lifetime trace" is
+   implementation-defined
 
 
 There are some more details on each step of the pipeline above.
 
 ### Collect the memory lifetime traces
 
-This is the first stage that a developer can customize in merge-alloc. It should collect the lifetime traces for each of the mergable memref.alloc operation. An implementation of the lifetime trace collector should define which allocations are mergeable and find the allocation scopes of them. It should also implement a data structure to hold the detailed liveness of each buffers.
+This is the first stage that a developer can customize in merge-alloc. It should
+collect the lifetime traces for each of the mergable memref.alloc operation. An
+implementation of the lifetime trace collector should define which allocations
+are mergeable and find the allocation scopes of them. It should also implement a
+data structure to hold the detailed liveness of each buffers.
 
-This step is abstracted in a `TraceCollectorFunc` function. The merge-alloc framework defines the abstract interfaces for lifetime trace collector and the collected traces as below:
+This step is abstracted in a `TraceCollectorFunc` function. The merge-alloc
+framework defines the abstract interfaces for lifetime trace collector and the
+collected traces as below:
 
 ```c++
 /// abstract base class for lifetime of buffers in the same "allocation scope".
@@ -125,7 +207,11 @@ using TraceCollectorFunc = std::function<FailureOr<MemoryTraceScopes>(
 
 ### Memory planning and scheduling
 
-This step is abstracted in a `MemoryPlannerFunc` function. It accepts the `MemoryTraceScopes` collected by the previous step. For each allocation scope in `MemoryTraceScopes`, it decides the total merged allocation size and the offsets for each mergeable allocation inside of the allocation scope. The abstract interfaces are shown below:
+This step is abstracted in a `MemoryPlannerFunc` function. It accepts the
+`MemoryTraceScopes` collected by the previous step. For each allocation scope in
+`MemoryTraceScopes`, it decides the total merged allocation size and the offsets
+for each mergeable allocation inside of the allocation scope. The abstract
+interfaces are shown below:
 
 ```c++
 /// the memory scheduling result for allocations in the same allocation scope.
@@ -143,7 +229,11 @@ using MemoryPlannerFunc = std::function<FailureOr<MemorySchedule>(
 
 ### Rewriting the IR
 
-Given the `MemorySchedule` of the previous step, this step rewrites the IR to create the merged allocation in each of the allocation scopes, to replace the mergable memref.alloc with views on the merged allocations with the offsets calculated in the `MemorySchedule`. This step is abstracted in a `MemoryMergeMutatorFunc` function.
+Given the `MemorySchedule` of the previous step, this step rewrites the IR to
+create the merged allocation in each of the allocation scopes, to replace the
+mergable memref.alloc with views on the merged allocations with the offsets
+calculated in the `MemorySchedule`. This step is abstracted in a
+`MemoryMergeMutatorFunc` function.
 
 ```c++
 using MemoryMergeMutatorFunc = std::function<LogicalResult(
@@ -154,32 +244,55 @@ using MemoryMergeMutatorFunc = std::function<LogicalResult(
 
 ## Tick-based Implementation for merge-alloc
 
-A tick-based implementation of merge-alloc in provided by default. The basic idea of the tick-based allocation merging is that
-
-1. Each of the operations in a function is assigned a "tick". An operation with a smaller tick is expected to be executed before one with a larger tick
-2. Collect the first referenced tick and the last referenced tick for each mergeable allocation. If a buffer is referenced in loops and branches, special handling is needed.
-3. For each allocation scope, linearize the first referenced tick and the last referenced tick of mergeable allocations inside of it into a single linear timeline
+A tick-based implementation of merge-alloc in provided by default. The basic
+idea of the tick-based allocation merging is that
+
+1. Each of the operations in a function is assigned a "tick". An operation with
+   a smaller tick is expected to be executed before one with a larger tick
+2. Collect the first referenced tick and the last referenced tick for each
+   mergeable allocation. If a buffer is referenced in loops and branches,
+   special handling is needed.
+3. For each allocation scope, linearize the first referenced tick and the last
+   referenced tick of mergeable allocations inside of it into a single linear
+   timeline
 4. Use a "static-memory-planner" to handle the linear timeline
 
+Limitations of Tick-based merge-alloc:
+ * only contiguous, static shaped and identical layout memrefs are considered.
+   Others are disregarded
+ * only `RegionBranchOpInterface` or `LoopLikeOpInterface` operations are
+   allowed to access memref inside the operations' children regions. Other
+   operaions containing regions should not access memref inside. Otherwise, a
+   pass error could occur.
+
 ### Basic concepts
 
-In the context of tick-based merge-alloc, mergeable alloction and allocation scope are defined as follows
+In the context of tick-based merge-alloc, mergeable alloction and allocation
+scope are defined as follows
 
 #### Mergeable alloction
 
 The pass should only consider to merge a `memref.alloc` only if
- * the ownership of the memref does not escape from the function. That is, the current function is responsible to alloc and dealloc this memref
- * and, the allocated memref is contiguous and has static shape and identical layout.
+ * the ownership of the memref does not escape from the function. That is, the
+   current function is responsible to alloc and dealloc this memref
+ * and the allocated memref is contiguous and has static shape and identical
+   layout.
+ * and memref is in the default memory space (this restriction may be removed in
+   the future)
 
-In tick-based merge-alloc, we call these `memref.alloc` **mergeable** allocations.
+In tick-based merge-alloc, we call these `memref.alloc` **mergeable**
+allocations.
 
-The memrefs passed by function arguments, or returned by the function will be untouched by this optimization.
+The memrefs passed by function arguments, or returned by the function will be
+untouched by this optimization.
 
 #### Allocation scopes
 
-The transformation first needs to identify the allocation scopes, which are mlir operaions containing non-zero regions, and
+The transformation first needs to identify the allocation scopes, which are mlir
+operaions containing non-zero regions, and
  * implementing `AutomaticAllocationScope`
- * and is not `scf.for` (allocations in an `scf.for` can be hoisted to parent `AutomaticAllocationScope`)
+ * and is not `scf.for` (allocations in an `scf.for` can be hoisted to parent
+   `AutomaticAllocationScope`)
 
 For example, below is an example IR of a function with nested `scf.forall` ops.
 
@@ -199,24 +312,154 @@ func.func @mlp(...) { // <---- alloc scope 1
 }
 ```
 
-There will be three allocation scopes as marked in the comments above. An allocation scope marks the position to insert the `single allocation buffer` after allocation merging. After the transformation, all "mergeable" `memref.alloc` will be merged to the `single allocation buffer` of the nearest ancestor `alloc scope`.
+There will be three allocation scopes as marked in the comments above. An
+allocation scope marks the position to insert the `single allocation buffer`
+after allocation merging. After the transformation, all "mergeable"
+`memref.alloc` will be merged to the `single allocation buffer` of the nearest
+ancestor `alloc scope`.
 
 ### Tick-based trace collection
 
-walk()
-Alias
-Branch
-Sort-malloc-free
+This section discusses how ticks are collected and how the pass consolidates the
+tick to get the lifetime traces for each mergeable allocations.
+
+Ticks are assigned on each operation in the `func.func` by a increasing counter
+with pre-order recursive `walk()` of the IR, as the "execution tick" for each
+operation. After walking into the IR, the pass assigns two integers for each
+mergeable allocations as the analysis result: `begin_tick` and `end_tick`, to
+indicate the first and last tick of the use of the allocated memref in the IR.
+Note that aliasing of memref buffers is also consider during tick collection.
+When an operation which is not memory-effect-free accesses memrefs via its
+operands, the ticks for the referenced memrefs and the aliasing memrefs of them
+should be updated. The alias analysis is performed by `BufferViewFlowAnalysis`.
+
+The collected result for each mergeable allocations will be an integer range
+`[begin_tick,end_tick]` (both boundaries are inclusive), where `begin_tick <=
+end_tick`. If two tick ranges of two mergeable allocations in the same
+allocation scope do not overlap, this implies that these two buffer can share
+the same memory address.
+
+There should be special handling for loop and branch ops
+(`RegionBranchOpInterface` or `LoopLikeOpInterface`) which references memrefs
+allocated in parent scopes, to avoid wrong reuse of buffers used in the loops or
+branches.
+
+For example, consider the code like:
+
+```mlir
+func.func @basic() {
+  ...
+  %e = memref.alloc() : memref<8x64xf32> // tick = 0
+  %f = memref.alloc() : memref<8x64xf32> // tick = 1
+  scf.for %i = %c0 to %c3 step %c1 {     // tick = 2
+      "test.use"(%e)  : (memref<8x64xf32>) -> () // tick = 3
+      "test.use"(%f)  : (memref<8x64xf32>) -> () // tick = 4
+  }
+}
+```
 
-### Static Memory planner
+A manual observation of the IR will see that buffers `e` and `f` have
+overlapping lifetime, because the access pattern in the loop is `e f e f e f`.
+Thus, buffers `e` and `f` should not share the same memory address. However, the
+collected ticks for the two buffers shows that they are only accessed in tick 3
+and tick 4, respectively.
 
+To produce the correct lifetime analysis result, the tick collector will
+conservatively extend the lifetime of the accessed memrefs in loop and branch
+ops (`RegionBranchOpInterface` or `LoopLikeOpInterface`), to make them span at
+least the begin tick and end tick of the loop or branch op.
 
-The transformantion is consist of an analysis sub-pass and a mutation sub-pass. For each `alloc scope`, the analysis sub-pass finds the lifetime of each mergeable `memref.alloc` belonging to the `alloc scope`. And given the lifetime of each allocation, a memory planning algorithm will be run to find the `single allocation buffer` size of each `alloc scope` and the `offset` for each mergeable allocation within its `single allocation buffer`. Based on the memory planning result, the mutation sub-pass transforms the IR to
-1. insert `memref.alloc` at the front of `alloc scope` body for its `single allocation buffer`
-2. replace mergeable `memref.alloc` with `memref.view` on its `alloc scope`'s `single allocation buffer`
+In the above example, both of the lifetime of buffers `e` and `f` will be
+extended to the tick range of the parent `scf.for` op, as `[2, 4]`.
 
-Ticks are assigned on each operation in the `func.func` by a increasing counter with pre-order recursive walking of the IR, as the "execution tick" for each operation. The lifetime analysis pass will assign two integers for each mergeable allocations as the analysis result: `begin_tick` and `end_tick`, to indicate the first and last tick of the use of the allocated memref in the IR. There should be special handling for loop and branch ops (`RegionBranchOpInterface` or `LoopLikeOpInterface`) which references memrefs allocated in parent scopes, to avoid wrong reuse of buffers used in the loop.
+In some special cases, when the `memref.alloc` is in the block of a loop or
+branch, and the buffer is not used outside of the loop or branch, the tick
+collector does not need to conservatively extend the ticks of the allocations.
+For example:
 
-The analysis result for each mergeable allocations will be an integer range `[begin_tick,end_tick]`, where `begin_tick <= end_tick`.
+```mlir
+func.func @basic() {
+  ...
+  scf.for %i = %c0 to %c3 step %c1 {     // tick = 0
+      %g = memref.alloc() : memref<8x64xf32> // tick = 1
+      %h = memref.alloc() : memref<8x64xf32> // tick = 2
+      "test.use"(%g)  : (memref<8x64xf32>) -> () // tick = 3
+      "test.use"(%h)  : (memref<8x64xf32>) -> () // tick = 4
+  }
+}
+```
+
+The buffer `g` has lifetime tick range `[3,3]` and `h` has `[4,4]`, because they
+are allocated within the loop. Thus `g` and `h` has non-overlapping lifetime.
+
+The remaining part of this section will discuss how the tick collector
+consolidates the lifetime trace results.
+
+After calling `walk()` into the function operation, there will be a map of
+`AllocOp => [begin_tick,end_tick]` collected for each allocation scopes. In the
+view of an allocation scope, it has a timeline of first-access and last-access
+events of the mergeable allocations of the scope, sorted by the ticks in
+incresing order. The tick traces for buffers inside an allocation scope are then
+linearized to a stream of first-access and last-access events as the lifetime
+traces. For example, an allocation scope has the allocations with the ticks
+
+```
+buffer=A, tick=[1,4], size=16
+buffer=B, tick=[2,3], size=64
+buffer=C, tick=[5,6], size=16
+```
+
+The linearized lifetime trace will be
+
+```
+alloc(A,16)
+alloc(B,64)
+free(B)
+free(A)
+alloc(C,16)
+free(C)
+```
 
-The collected ticks for each buffer will be processed by the memory planning algorithm. It should output the total size of the `single allocation buffers` for each `alloc scopes`, and the `offsets` for each individual mergeable buffers. The algorithm should also consider the locality of the buffer to use, when multiple buffer localtion candidates are available.
+The static memory planner discussed in the next section will take the linearized
+lifetime trace as input.
+
+### Static Memory Planner
+
+The static memory planner is a compile-time memory allocator that plans the
+memory for a list of allocations. It operates on a contiguous "base-buffer"
+logically. For each "alloc" event in the linearized lifetime trace (see above
+section), the memory planner logically "allocates" a contiguous range in the
+contiguous "buffer". The start-offset of the "allocated" contiguous range will
+be the returned as the memory planning result of the mergeable allocation, for
+future IR rewriting.
+
+The implementation of static memory planner is very similar to naive chunk-based
+general runtime memory allocators like `malloc`. The logical memory are managed
+by memory chunks, which represents a contiguous range of the "base-buffer". A
+memory-chunk may be either marked "free" or "in-use" based on its allocation
+state. The memory planner reads the linearized alloc/free events in their order
+in the collected traces. On an "alloc" event, the memory planner finds an
+appropriate free chunk, and split the chunk into two chunks - one for the memory
+range for the allocation, and another the remaining free memory range. On an
+"free" event, the memory planner marks the memory chunk as "free". If the
+neighbouring memory chunks are also "free", the planner will further merge the
+neighbouring free chunks into a larger free chunk.
+
+A improvement of static memory planner over runtime memory allocators is that,
+if a "free" memory chunk has smaller size than an allocation size, the memory
+planner is allowed to "extend" the size of the "free" memory chunk to match the
+allocation. It is not possible for runtime memory allocators, because extending
+a memory chunk involves moving the memory addresses of the previously allocated
+memory. However, in our compile-time memory planner, all the allocations are
+logical, and the offsets of the allocated memory ranges can always be adjusted
+by a later allocation. This improvement helps to reduce the issue of memory
+fragmentation.
+
+On an "alloc" event, the memory planner needs to choose one candidate from all
+"free" memory chunks. A memory chunk that is recently free'd is considered "hot"
+in cache. In the default configuration (when `no-consider-locality` option is
+not specified to the merge-alloc pass), static memory planner considers both
+cache-locality and the degree of matching of allocation size and the chunk size
+for each free memory chunks, with a simple cost-model. With
+`no-consider-locality` option is specified, static memory planner will choose
+the best matched free memory chunk in the chunk size.
\ No newline at end of file
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
index dac3c986ee3c3..d5ac190f2a1b9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
@@ -27,8 +27,12 @@ namespace memref {
 using namespace special_ticks;
 
 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
-/// no layout).
-static bool hasStaticIdentityLayout(MemRefType type) {
+/// no layout) and default memory space.
+static bool isMemRefTypeOk(MemRefType type) {
+  IntegerAttr intMemorySpace =
+      llvm::dyn_cast_or_null<IntegerAttr>(type.getMemorySpace());
+  if (intMemorySpace && intMemorySpace.getValue() != 0)
+    return false;
   return type.hasStaticShape() && type.getLayout().isIdentity();
 }
 
@@ -158,8 +162,7 @@ bool TickCollecter::isMergeableAlloc(TickCollecterStates *s, Operation *op,
   if (tick == COMPLEX_ACCESS) {
     return false;
   }
-  if (!hasStaticIdentityLayout(
-          cast<MemRefType>(op->getResultTypes().front()))) {
+  if (!isMemRefTypeOk(cast<MemRefType>(op->getResultTypes().front()))) {
     return false;
   }
   auto alignment = cast<memref::AllocOp>(op).getAlignment();

>From a63cf8d5cb4f35249ae0107e467e6ac025cdd227 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Fri, 14 Jun 2024 16:11:08 +0800
Subject: [PATCH 08/12] fix comments

---
 .../MemRef/Transforms/MergeAllocTickBased.h   |  8 +++---
 .../mlir/Dialect/MemRef/Transforms/Passes.h   |  2 +-
 .../mlir/Dialect/MemRef/Transforms/Passes.td  | 14 +++++-----
 .../Dialect/MemRef/Transforms/MergeAlloc.cpp  |  6 ++---
 .../MemRef/Transforms/MergeAllocTickBased.cpp | 27 +++++++++++--------
 .../Dialect/MemRef/buffer-merge-lifetime.mlir |  2 +-
 6 files changed, 33 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
index 75209e2df38d0..fe107ae756c57 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
@@ -27,10 +27,10 @@ namespace memref {
 /// Usually ticks should be non-negative numbers. There are two special ticks
 /// defined here.
 namespace special_ticks {
-/// the memref is not accessed
+/// the memref is not yet accessed
 static constexpr int64_t NO_ACCESS = -1;
-/// complex access happens on this memref, like func.return
-static constexpr int64_t COMPLEX_ACCESS = -2;
+/// untraceable access happens on this memref, like func.return
+static constexpr int64_t UNTRACEABLE_ACCESS = -2;
 } // namespace special_ticks
 
 /// the collected tick [first, last] for a memref allocation
@@ -44,7 +44,7 @@ struct Tick {
 
   /// access the memref at the tick, will update firstAccess and lastAccess
   /// based on the tick
-  void access(int64_t tick);
+  void update(int64_t tick);
 };
 
 /// A complex scope object is addition info for a RegionBranchOpInterface or
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 8030bb7046bfb..6c812fbf5a7b6 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -121,7 +121,7 @@ using MemoryMergeMutatorFunc = std::function<LogicalResult(
 
 struct MergeAllocationOptions {
   bool checkOnly = false;
-  bool noLocalityFirst = false;
+  std::string plannerOptions;
   int64_t alignment = 64;
   TraceCollectorFunc tracer;
   MemoryPlannerFunc planner;
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 4562dc5c8548b..b9b03154b359c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -275,14 +275,16 @@ def MergeAlloc : Pass<"merge-alloc", "func::FuncOp">  {
     address ranges that are considered "hot" in cache for an later allocation. 
   }];
   let options = [
-    Option<"optionCheck", "check", "bool",
+    Option<"optionAnalysisOnly", "analysis-only", "bool",
        /*default=*/"false",
        "Skip the mutation of the IR and only mark the lifetime and scope on the"
-       " operations. Useful for debugging and testing.">,
-    Option<"optionNoLocality", "no-consider-locality", "bool",
-       /*default=*/"false",
-       "Don't consider the cache locality when reusing the buffers. "
-       "This option may result in smaller total memory usage.">,
+       " attr of operations. Useful for debugging and testing.">,
+    Option<"plannerOptions", "planner-options", "std::string",
+       /*default=*/"\"\"",
+       "The options for the memory-planner. `cost-model` for using a cost-"
+       "model considering both cache locality and memory size. `size-first`"
+       " may generate allocations with smaller total size without considering"
+       " cache locality. By default `cost-model` is used.">,
     Option<"optionAlignment", "alignment", "int64_t",
        /*default=*/"64",
        "The alignment of the merged allocations">,
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
index b8451c641218f..ade68752fa33f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
@@ -1,4 +1,4 @@
-//===- MergeAlloc.cpp - Calling convention conversion ---------------------===//
+//===- MergeAlloc.cpp - General framework for merge-allocation ------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -58,8 +58,8 @@ class MergeAllocPass : public memref::impl::MergeAllocBase<MergeAllocPass> {
   void runOnOperation() override {
     memref::MergeAllocationOptions opt;
     if (!options) {
-      opt.checkOnly = optionCheck;
-      opt.noLocalityFirst = optionNoLocality;
+      opt.checkOnly = optionAnalysisOnly;
+      opt.plannerOptions = plannerOptions;
       opt.alignment = optionAlignment;
       opt.tracer = memref::TickCollecter();
       opt.planner = memref::tickBasedPlanMemory;
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
index d5ac190f2a1b9..ba40705f48b24 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
@@ -36,12 +36,12 @@ static bool isMemRefTypeOk(MemRefType type) {
   return type.hasStaticShape() && type.getLayout().isIdentity();
 }
 
-void Tick::access(int64_t tick) {
-  if (tick == COMPLEX_ACCESS) {
-    firstAccess = COMPLEX_ACCESS;
-    lastAccess = COMPLEX_ACCESS;
+void Tick::update(int64_t tick) {
+  if (tick == UNTRACEABLE_ACCESS) {
+    firstAccess = UNTRACEABLE_ACCESS;
+    lastAccess = UNTRACEABLE_ACCESS;
   }
-  if (firstAccess == COMPLEX_ACCESS) {
+  if (firstAccess == UNTRACEABLE_ACCESS) {
     return;
   }
   if (firstAccess == NO_ACCESS) {
@@ -83,8 +83,8 @@ LogicalResult TickCollecter::onPopComplexScope(TickCollecterStates *s,
     if (needsResetTick(s, scope.scope, op)) {
       // let all referenced buffers have overlapped lifetime
       auto &tick = s->allocTicks[op];
-      tick.access(scope.startTick);
-      tick.access(endTick);
+      tick.update(scope.startTick);
+      tick.update(endTick);
     }
   }
   return success();
@@ -115,7 +115,7 @@ void TickCollecter::accessValue(TickCollecterStates *s, Value v,
     for (auto &&base : s->aliasAnaly.resolveReverse(refv)) {
       auto defop = base.getDefiningOp();
       if (isa_and_present<memref::AllocOp>(defop)) {
-        s->allocTicks[defop].access(complex ? COMPLEX_ACCESS : s->curTick);
+        s->allocTicks[defop].update(complex ? UNTRACEABLE_ACCESS : s->curTick);
         if (!s->complexScopeStack.empty()) {
           s->complexScopeStack.back().operations.insert(defop);
         }
@@ -159,7 +159,7 @@ void TickCollecter::pushComplexScope(TickCollecterStates *s,
 
 bool TickCollecter::isMergeableAlloc(TickCollecterStates *s, Operation *op,
                                      int64_t tick) const {
-  if (tick == COMPLEX_ACCESS) {
+  if (tick == UNTRACEABLE_ACCESS) {
     return false;
   }
   if (!isMemRefTypeOk(cast<MemRefType>(op->getResultTypes().front()))) {
@@ -311,16 +311,21 @@ FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
                                               const MergeAllocationOptions &o) {
   auto traceObj = dyn_cast<TickTraceResult>(&tr);
   if (!traceObj) {
-    return failure();
+    return op->emitOpError("Unrecognized trace result.");
   }
   auto &traces = traceObj->traces;
   if (traces.empty()) {
     return MemorySchedule{};
   }
+  bool useCostModel =
+      o.plannerOptions.empty() || o.plannerOptions == "cost-model";
+  if (!useCostModel && o.plannerOptions != "size-first") {
+    return op->emitOpError("Unrecognized planner option");
+  }
   std::unordered_map<uintptr_t, std::size_t> outSchedule;
   std::unordered_map<uintptr_t, std::vector<uintptr_t>> dummy;
   auto total = memoryplan::scheduleMemoryAllocations(
-      traces, o.alignment, !o.noLocalityFirst, memoryplan::InplaceInfoMap(),
+      traces, o.alignment, useCostModel, memoryplan::InplaceInfoMap(),
       outSchedule, dummy);
   MemorySchedule ret;
   ret.totalSize = total;
diff --git a/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir b/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir
index 96cf6b79e5242..1a64b23b7b9c3 100644
--- a/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir
+++ b/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(func.func(merge-alloc{check}))'  %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(func.func(merge-alloc{analysis-only}))'  %s | FileCheck %s
 
 // CHECK-DAG: func.func @basic() -> memref<8x64xf32>  attributes {__mergealloc_scope = [[TOPSCOPE:[0-9]+]]
 func.func @basic() -> memref<8x64xf32> {

>From 342bbc67ae0e91a304a347c43dc6515d7d87535f Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Tue, 18 Jun 2024 11:45:02 +0800
Subject: [PATCH 09/12] update comments

---
 mlir/docs/MemrefMergeAlloc.md | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/docs/MemrefMergeAlloc.md b/mlir/docs/MemrefMergeAlloc.md
index e1c69a8d1f5e6..1a8f5a81bb084 100644
--- a/mlir/docs/MemrefMergeAlloc.md
+++ b/mlir/docs/MemrefMergeAlloc.md
@@ -69,20 +69,20 @@ optimization works on `memref` instead of `tensor` ops, so it should be executed
 after bufferization pass, and before adding buffer deallocation ops.
 
 While merging the memory allocations, the transform should consider the lifetime
-of each allocated `memref`s. By lifetime, we mean the range of time when an
-memref allocated from `memref.alloc` is actively used. The references on `view`s
+of each allocated `memref`s. By lifetime, we mean the range of time when the
+memory allocated from `memref.alloc` is actively used. The references on `view`s
 of a "base" `memref` should contribute to the lifetime of the "base". A later
 `memref.alloc` should consider to reuse the memory of a previously allocated
 memref, if the lifetime of these two does not overlap. The transform will
 perform the "reusing" of memory by setting the `offset` of the later
 `memref.view` to a position within the memory range of a previous allocation's
-`memref.view` on the `single allocated buffer`.
+`memref.alloc` from the `single allocated buffer`.
 
 Below is the expected transformation result of the example IR in the above
 section:
 
 ```mlir
-func.func @mlp(%x: memref<128x128xf32>, %y: memref<128x128xf32>) -> memref<128x128xf32> {
+func.func @mlp(%x: memref<256x128xf32>, %y: memref<128x128xf32>) -> memref<128x128xf32> {
    %single_buffer = memref.alloc() : memref<131072xi8> // 128*128*sizeof(f32)*2
    %a0 = memref.view %single_buffer[0][] : memref<131072xi8> to memref<128x128xf32> // a0 takes the memory from byte offset 0
    linalg.matmul ins(%x, %y: memref<128x128xf32>, memref<128x128xf32>) outs(%a0: memref<128x128xf32>)
@@ -457,9 +457,9 @@ fragmentation.
 
 On an "alloc" event, the memory planner needs to choose one candidate from all
 "free" memory chunks. A memory chunk that is recently free'd is considered "hot"
-in cache. In the default configuration (when `no-consider-locality` option is
-not specified to the merge-alloc pass), static memory planner considers both
+in cache. In the default configuration (when `planner-options=size-first` option
+is not specified to the merge-alloc pass), static memory planner considers both
 cache-locality and the degree of matching of allocation size and the chunk size
 for each free memory chunks, with a simple cost-model. With
-`no-consider-locality` option is specified, static memory planner will choose
-the best matched free memory chunk in the chunk size.
\ No newline at end of file
+`planner-options=size-first` option is specified, static memory planner will
+choose the best matched free memory chunk in the chunk size.
\ No newline at end of file

>From 74fe5a32124705bb26db84f617b5383bb7f00a00 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Tue, 18 Jun 2024 11:56:07 +0800
Subject: [PATCH 10/12] use interrupt

---
 .../Dialect/MemRef/Transforms/MergeAllocTickBased.cpp  | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
index ba40705f48b24..e890b653df9b8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
@@ -261,10 +261,9 @@ TickCollecter::operator()(Operation *root,
                           const MergeAllocationOptions &option) const {
   TickCollecterStates s{aliasAnaly, option};
   TickCollecter collecter;
-  LogicalResult result = success();
-  root->walk<WalkOrder::PreOrder>([&](Operation *op) {
+  auto result = root->walk<WalkOrder::PreOrder>([&](Operation *op) {
     if (failed(collecter.popScopeIfNecessary(&s, op))) {
-      result = failure();
+      return WalkResult::interrupt();
     }
     collecter.forwardTick(&s);
     if (auto viewop = dyn_cast<ViewLikeOpInterface>(op)) {
@@ -281,9 +280,10 @@ TickCollecter::operator()(Operation *root,
       // finally, if op is complex scope, push one ComplexScope
       collecter.pushComplexScope(&s, op);
     }
+    return WalkResult::advance();
   });
-  if (failed(result)) {
-    return result;
+  if (result.wasInterrupted()) {
+    return failure();
   }
   if (failed(collecter.popScopeIfNecessary(&s, nullptr))) {
     return failure();

>From dd50c18dcfbf243de409ceefc895f0c921d45d3d Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Fri, 21 Jun 2024 16:06:11 +0800
Subject: [PATCH 11/12] fix comments

---
 mlir/docs/MemrefMergeAlloc.md                 | 58 +++++++++----------
 .../MemRef/Transforms/MergeAllocTickBased.h   |  6 +-
 .../mlir/Dialect/MemRef/Transforms/Passes.h   |  4 +-
 .../MemRef/Transforms/MergeAllocTickBased.cpp | 42 ++++++++------
 .../Dialect/MemRef/buffer-merge-invalid.mlir  |  4 +-
 .../Dialect/MemRef/buffer-merge-lifetime.mlir | 30 +++++++++-
 mlir/test/Dialect/MemRef/buffer-merge.mlir    | 36 +++++++++++-
 7 files changed, 124 insertions(+), 56 deletions(-)

diff --git a/mlir/docs/MemrefMergeAlloc.md b/mlir/docs/MemrefMergeAlloc.md
index 1a8f5a81bb084..944984c4bc0a9 100644
--- a/mlir/docs/MemrefMergeAlloc.md
+++ b/mlir/docs/MemrefMergeAlloc.md
@@ -7,11 +7,9 @@ memory usage and improve memory locality.
 Bufferization is a process in the current MLIR of converting ops with tensor
 semantics to ops with memref semantics. One-Shot Bufferize is a new tensor
 bufferization pass designed for IR in destination-passing style, and with
-aggressive in-place bufferization. The older/partial bufferization was built
-around multiple dialects. The community is trying to gradually deprecate the
-older bufferization and replace them with one-shot bufferization. The goal of
+aggressive in-place bufferization. The goal of
 bufferization is to use as little memory as possible and copy as little memory
-as possible, as a result, the exsiting focus is to determine in-place or
+as possible, as a result, the existing focus is to determine in-place or
 out-of-place among the OpOperand and OpResult of individual ops, while not
 considering much about the overall memory reuse across Operators within a
 sub-graph (or partition).
@@ -34,7 +32,7 @@ func.func @mlp(%x: tensor<128x128xf32>, %y: tensor<128x128xf32>) -> tensor<128x1
 ```
 
 The bufferization pass will create an `memref.alloc` for each of the tensor
-`a0`, `b0` and `c0`. The bufferization result should be like:
+`a0`, `b0` and `c0`. The bufferization result is like:
 
 ```mlir
 func.func @mlp(%x: memref<128x128xf32>, %y: memref<128x128xf32>) -> memref<128x128xf32> {
@@ -62,16 +60,16 @@ total size and better locality.
 
 ## Merge-alloc pass
 An optimization pass has been introduced to consolidate multiple allocations
-(`memref.alloc` ops) into a single `memref.alloc` op and each static-shaped
-`memref.alloc` op will be transformed into a "slice" from the `single allocated
-buffer` with `memref.view` and some compile-time decided `offsets`. This
+(`memref.alloc` ops) into a single `memref.alloc` op and each "mergeable"
+`memref.alloc` op will be transformed into a "slice" from the "single allocated
+buffer" with `memref.view` and some compile-time decided `offsets`. This
 optimization works on `memref` instead of `tensor` ops, so it should be executed
 after bufferization pass, and before adding buffer deallocation ops.
 
 While merging the memory allocations, the transform should consider the lifetime
 of each allocated `memref`s. By lifetime, we mean the range of time when the
-memory allocated from `memref.alloc` is actively used. The references on `view`s
-of a "base" `memref` should contribute to the lifetime of the "base". A later
+memory allocated from `memref.alloc` is actively used. Views (aliases) into a
+"base" memref should contribute to the lifetime of the "base". A later
 `memref.alloc` should consider to reuse the memory of a previously allocated
 memref, if the lifetime of these two does not overlap. The transform will
 perform the "reusing" of memory by setting the `offset` of the later
@@ -154,15 +152,14 @@ for their own use cases. A tick-based pipeline of the pass is provided as the
 default implementation, which will be discussed in the next section. 
 
 The following concepts should be defined by the implementation of the pass:
- * Mergeable alloction: the memref.alloc operations that should be merged by the
-   pass. Other memref.alloc operations that are not "mergeable" should be
+ * Mergeable allocation: the memref.alloc operations that should be merged by
+   the pass. Other memref.alloc operations that are not "mergeable" should be
    untouched by the pass
  * Allocation scope: for each mergeable memref.alloc operation, there should be
-   one ancestor surrounding operation called "allocation scope". The memory
+   one ancestor surrounding basic blocking called "allocation scope". The memory
    allocation after merge-alloc for that memref.alloc operation should be
-   hoisted and merged to the block of that "allocation scope". A "allocation
-   scope" should contain a single merged allocation for the mergeable allocation
-   in it.
+   hoisted and merged to that basic blocking. A "allocation scope" should
+   contain a single merged allocation for the mergeable allocation in it.
  * Lifetime trace: for each mergeable memref.alloc operation, the "lifetime
    trace" should be collected, indicating the "allocation scope" and the
    liveness of the buffer allocated. The contents of a "lifetime trace" is
@@ -196,7 +193,7 @@ public:
 /// top level memory trace info for multiple scopes. Each key-value is the
 ///  "allocation scope" and the LifetimeTrace
 struct MemoryTraceScopes {
-  llvm::DenseMap<Operation *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
+  llvm::DenseMap<Block *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
   MemoryTraceScopes() = default;
 };
 
@@ -237,7 +234,7 @@ calculated in the `MemorySchedule`. This step is abstracted in a
 
 ```c++
 using MemoryMergeMutatorFunc = std::function<LogicalResult(
-    Operation *toplevel, Operation *scope, const MemorySchedule &,
+    Operation *toplevel, Block *scope, const MemorySchedule &,
     const MergeAllocationOptions &)>;
 ```
 
@@ -260,23 +257,24 @@ idea of the tick-based allocation merging is that
 Limitations of Tick-based merge-alloc:
  * only contiguous, static shaped and identical layout memrefs are considered.
    Others are disregarded
- * only `RegionBranchOpInterface` or `LoopLikeOpInterface` operations are
+ * only `RegionBranchOpInterface` operations are
    allowed to access memref inside the operations' children regions. Other
    operaions containing regions should not access memref inside. Otherwise, a
    pass error could occur.
 
 ### Basic concepts
 
-In the context of tick-based merge-alloc, mergeable alloction and allocation
+In the context of tick-based merge-alloc, mergeable allocation and allocation
 scope are defined as follows
 
-#### Mergeable alloction
+#### Mergeable allocation
 
 The pass should only consider to merge a `memref.alloc` only if
- * the ownership of the memref does not escape from the function. That is, the
-   current function is responsible to alloc and dealloc this memref
- * and the allocated memref is contiguous and has static shape and identical
-   layout.
+ * the ownership of the memref does not escape from the function or the body of
+   the loop. That is, the memref and its alias should not be returned or
+   yielded by a function or a loop.
+ * and the memref is "dense" in its strides (points to a contiguous range of
+   memory) and it has static shape
  * and memref is in the default memory space (this restriction may be removed in
    the future)
 
@@ -288,11 +286,11 @@ untouched by this optimization.
 
 #### Allocation scopes
 
-The transformation first needs to identify the allocation scopes, which are mlir
-operaions containing non-zero regions, and
- * implementing `AutomaticAllocationScope`
- * and is not `scf.for` (allocations in an `scf.for` can be hoisted to parent
-   `AutomaticAllocationScope`)
+The transformation first needs to identify the allocation scopes, which are
+single basic blocks of parent operaions which
+ * implement `AutomaticAllocationScope`
+ * and are not `scf.for` (allocations in an `scf.for` can be hoisted to
+ parent `AutomaticAllocationScope`)
 
 For example, below is an example IR of a function with nested `scf.forall` ops.
 
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
index fe107ae756c57..9dfbadf1cd456 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
@@ -155,14 +155,14 @@ struct TickCollecter {
 struct MergeAllocDefaultMutator {
   /// builds an memory alloc op at the scope with given size and alignment in
   /// bytes
-  virtual Value buildAlloc(OpBuilder &build, Operation *scope, int64_t size,
+  virtual Value buildAlloc(OpBuilder &build, Block *scope, int64_t size,
                            int64_t alignment) const;
   /// builds an memory view op for original memref.alloc op (origAllocOp) and
   /// the merged single allocation (mergedAlloc)
-  virtual Value buildView(OpBuilder &build, Operation *scope,
+  virtual Value buildView(OpBuilder &build, Block *scope,
                           Operation *origAllocOp, Value mergedAlloc,
                           int64_t byteOffset) const;
-  virtual LogicalResult operator()(Operation *op, Operation *scope,
+  virtual LogicalResult operator()(Operation *op, Block *scope,
                                    const MemorySchedule &schedule,
                                    const MergeAllocationOptions &o) const;
   MergeAllocDefaultMutator() = default;
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 6c812fbf5a7b6..2fbc57ffa62f3 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -96,7 +96,7 @@ class LifetimeTrace {
 /// top level memory trace info for multiple scopes. Each key-value is the
 ///  "allocation scope" and the LifetimeTrace
 struct MemoryTraceScopes {
-  llvm::DenseMap<Operation *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
+  llvm::DenseMap<Block *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
   MemoryTraceScopes() = default;
 };
 
@@ -116,7 +116,7 @@ using TraceCollectorFunc = std::function<FailureOr<MemoryTraceScopes>(
 using MemoryPlannerFunc = std::function<FailureOr<MemorySchedule>(
     Operation *, const LifetimeTrace &, const MergeAllocationOptions &)>;
 using MemoryMergeMutatorFunc = std::function<LogicalResult(
-    Operation *toplevel, Operation *scope, const MemorySchedule &,
+    Operation *toplevel, Block *scope, const MemorySchedule &,
     const MergeAllocationOptions &)>;
 
 struct MergeAllocationOptions {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
index e890b653df9b8..437120fd766be 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
@@ -26,8 +26,8 @@ namespace memref {
 
 using namespace special_ticks;
 
-/// Return `true` if the given MemRef type has a static identity layout (i.e.,
-/// no layout) and default memory space.
+/// Return `true` if the given MemRef type has static shapes
+/// and default memory space.
 static bool isMemRefTypeOk(MemRefType type) {
   IntegerAttr intMemorySpace =
       llvm::dyn_cast_or_null<IntegerAttr>(type.getMemorySpace());
@@ -74,10 +74,9 @@ LogicalResult TickCollecter::onPopComplexScope(TickCollecterStates *s,
   const auto &scope = s->complexScopeStack.back();
   // if the complex scope is not recognized by us, and if it accesses memory,
   // raise an error
-  if (!isa<RegionBranchOpInterface>(scope.scope) &&
-      !isa<LoopLikeOpInterface>(scope.scope) && !scope.operations.empty()) {
-    return scope.scope->emitOpError("expecting RegionBranchOpInterface or "
-                                    "LoopLikeOpInterface for merge-alloc");
+  if (!isa<RegionBranchOpInterface>(scope.scope) && !scope.operations.empty()) {
+    return scope.scope->emitOpError(
+        "expecting RegionBranchOpInterface for merge-alloc");
   }
   for (auto op : scope.operations) {
     if (needsResetTick(s, scope.scope, op)) {
@@ -136,9 +135,12 @@ void TickCollecter::onMemrefViews(TickCollecterStates *s,
 }
 
 void TickCollecter::onReturnOp(TickCollecterStates *s, Operation *op) const {
-  bool isTopLevel = isa<func::FuncOp>(op->getParentOp());
+  // if a memref escapes from a function or a loop, we need to mark it
+  // unmergeable
+  bool isEscape = isa<func::FuncOp>(op->getParentOp()) ||
+                  isa<LoopLikeOpInterface>(op->getParentOp());
   for (auto val : op->getOperands()) {
-    accessValue(s, val, isTopLevel);
+    accessValue(s, val, isEscape);
   }
 }
 
@@ -215,7 +217,7 @@ TickCollecter::getTrace(TickCollecterStates *s) const {
                   size_t size)
         : allocTick{allocTick}, tick{tick}, trace{bufferId, size} {}
   };
-  llvm::DenseMap<Operation *, llvm::SmallVector<TraceWithTick, 8>> raw;
+  llvm::DenseMap<Block *, llvm::SmallVector<TraceWithTick, 8>> raw;
   for (auto &[op, tick] : s->allocTicks) {
     if (!isMergeableAlloc(s, op, tick.firstAccess)) {
       continue;
@@ -225,15 +227,21 @@ TickCollecter::getTrace(TickCollecterStates *s) const {
       return op->emitError(
           "This op should be surrounded by an AutomaticAllocationScope");
     }
+    if (scope->getNumRegions() != 1 ||
+        scope->getRegion(0).getBlocks().size() != 1) {
+      return op->emitError("This op should be surrounded by an "
+                           "AutomaticAllocationScope of single block");
+    }
+    auto block = &*scope->getRegion(0).getBlocks().begin();
     auto allocSize = getAllocSize(s, op);
     if (failed(allocSize)) {
       return failure();
     }
     // tick.firstAccess * 2 and tick.lastAccess * 2 + 1 to make sure "dealloc"
     // overlaps "alloc"
-    raw[scope].emplace_back(tick.allocTick, tick.firstAccess * 2,
+    raw[block].emplace_back(tick.allocTick, tick.firstAccess * 2,
                             reinterpret_cast<uintptr_t>(op), *allocSize);
-    raw[scope].emplace_back(tick.allocTick, tick.lastAccess * 2 + 1,
+    raw[block].emplace_back(tick.allocTick, tick.lastAccess * 2 + 1,
                             reinterpret_cast<uintptr_t>(op), 0);
   }
   MemoryTraceScopes ret;
@@ -336,18 +344,18 @@ FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
   return std::move(ret);
 }
 
-Value MergeAllocDefaultMutator::buildAlloc(OpBuilder &builder, Operation *scope,
+Value MergeAllocDefaultMutator::buildAlloc(OpBuilder &builder, Block *block,
                                            int64_t size,
                                            int64_t alignmentInt) const {
-  auto &block = scope->getRegion(0).getBlocks().front();
-  builder.setInsertionPointToStart(&block);
+  builder.setInsertionPointToStart(block);
   auto alignment = builder.getIntegerAttr(
       IntegerType::get(builder.getContext(), 64), alignmentInt);
   auto alloc = builder.create<memref::AllocOp>(
-      scope->getLoc(), MemRefType::get({size}, builder.getI8Type()), alignment);
+      block->getParentOp()->getLoc(),
+      MemRefType::get({size}, builder.getI8Type()), alignment);
   return alloc;
 }
-Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Operation *scope,
+Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Block *scope,
                                           Operation *origAllocOp,
                                           Value mergedAlloc,
                                           int64_t byteOffset) const {
@@ -360,7 +368,7 @@ Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Operation *scope,
 }
 
 LogicalResult
-MergeAllocDefaultMutator::operator()(Operation *op, Operation *scope,
+MergeAllocDefaultMutator::operator()(Operation *op, Block *scope,
                                      const MemorySchedule &schedule,
                                      const MergeAllocationOptions &o) const {
   if (schedule.allocToOffset.empty()) {
diff --git a/mlir/test/Dialect/MemRef/buffer-merge-invalid.mlir b/mlir/test/Dialect/MemRef/buffer-merge-invalid.mlir
index 6609cb9d2f6fb..b98c2dcd93511 100644
--- a/mlir/test/Dialect/MemRef/buffer-merge-invalid.mlir
+++ b/mlir/test/Dialect/MemRef/buffer-merge-invalid.mlir
@@ -3,9 +3,9 @@
 func.func @block() {
   %mref = memref.alloc() : memref<8 x f32>
   %mref2 = memref.alloc() : memref<8 x f32>
-  // expected-error at +1 {{expecting RegionBranchOpInterface or LoopLikeOpInterface for merge-alloc}}
+  // expected-error at +1 {{expecting RegionBranchOpInterface for merge-alloc}}
   "some.block"() ({
    ^bb0:
     "some.use"(%mref) : (memref<8 x f32>) -> ()
    }) : () -> ()
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir b/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir
index 1a64b23b7b9c3..6f2de79bfc6c3 100644
--- a/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir
+++ b/mlir/test/Dialect/MemRef/buffer-merge-lifetime.mlir
@@ -126,4 +126,32 @@ func.func @escape_from_if() {
   }
   "test.source"(%c)  : (memref<8x64xf32>) -> ()
   return
-}
\ No newline at end of file
+}
+
+// CHECK-DAG: func.func @escape_from_for()  attributes {__mergealloc_scope = [[TOPSCOPE6:[0-9]+]]
+func.func @escape_from_for() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c5 = arith.constant 5 : index
+  // check that f has untraceable lifetime, due to being yielded by for loop
+  // CHECK-DAG: %[[F:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE6]], -2, -2>}
+  %f = memref.alloc() : memref<8x64xf32>
+  %out = scf.for %i = %c0 to %c5 step %c1 iter_args(%buf = %f) -> (memref<8x64xf32>) {
+    "test.source"(%buf)  : (memref<8x64xf32>) -> ()
+    // check that f has untraceable lifetime, due to being yielded by for loop
+    // CHECK-DAG: %[[G:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE6]], -2, -2>}
+    %g = memref.alloc() : memref<8x64xf32>
+    "test.source"(%g)  : (memref<8x64xf32>) -> ()
+    %ctrue = "test.source"()  : () -> i1
+    %c = scf.if %ctrue -> memref<8x64xf32> {
+      scf.yield %g : memref<8x64xf32>
+    } else {
+      scf.yield %buf : memref<8x64xf32>
+    }
+    scf.yield %c : memref<8x64xf32>
+  }
+  "test.source"(%out)  : (memref<8x64xf32>) -> ()
+  return
+}
diff --git a/mlir/test/Dialect/MemRef/buffer-merge.mlir b/mlir/test/Dialect/MemRef/buffer-merge.mlir
index e491e2479a157..a30d63974017b 100644
--- a/mlir/test/Dialect/MemRef/buffer-merge.mlir
+++ b/mlir/test/Dialect/MemRef/buffer-merge.mlir
@@ -60,6 +60,40 @@ func.func @basic() -> memref<8x64xf32> {
   return %b : memref<8x64xf32>
 }
 
+// CHECK-LABEL: @different_size_and_dtype
+func.func @different_size_and_dtype() {
+  // CHECK-DAG: %[[BASE_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3072xi8>
+  // c and d has overlapping lifetime
+  // CHECK-DAG: %[[C0_3:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C_3:.*]] = memref.view %[[BASE_3]][%[[C0_3]]][] : memref<3072xi8> to memref<8x64xf32>
+  %c = memref.alloc() : memref<8x64xf32>
+  // CHECK:     "test.source"(%[[C_3]])
+  "test.source"(%c)  : (memref<8x64xf32>) -> ()
+  // CHECK-DAG: %[[C2048_3:.*]] = arith.constant 2048 : index
+  // CHECK-DAG: %[[D_3:.*]] = memref.view %[[BASE_3]][%[[C2048_3]]][] : memref<3072xi8> to memref<64xf32>
+  %d = memref.alloc() : memref<64xf32>
+  // CHECK:     "test.source"(%[[D_3]])
+  "test.source"(%d)  : (memref<64xf32>) -> ()
+  // last use of d
+  // e can reuse the d's memory
+  // CHECK-DAG: %[[C2048_3_2:.*]] = arith.constant 2048 : index
+  // CHECK-DAG: %[[E_3:.*]] = memref.view %[[BASE_3]][%[[C2048_3_2]]][] : memref<3072xi8> to memref<2x2x32xf64>
+  %e = memref.alloc() : memref<2x2x32xf64>
+  // CHECK:     "test.source"(%[[E_3]])
+  "test.source"(%e)  : (memref<2x2x32xf64>) -> ()
+  // CHECK:     "test.source"(%[[C_3]])
+  "test.source"(%c)  : (memref<8x64xf32>) -> ()
+
+  // e and c are free'd. f can reuse the memory across e and c
+  // CHECK-DAG: %[[C0_3_2:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[F_3:.*]] = memref.view %[[BASE_3]][%[[C0_3_2]]][] : memref<3072xi8> to memref<2x4x32xf64>
+  %f = memref.alloc() : memref<2x4x32xf64>
+  // CHECK:     "test.source"(%[[F_3]])
+  "test.source"(%f)  : (memref<2x4x32xf64>) -> ()
+  // CHECK:     return
+  return
+}
+
 // CHECK-LABEL: @withloop
 func.func @withloop() {
   // CHECK-DAG: %[[BASE2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<6144xi8>
@@ -93,4 +127,4 @@ func.func @withloop() {
       "test.source"(%j)  : (memref<8x64xf32>) -> ()
   }
   return
-}
\ No newline at end of file
+}

>From b40e5122e1ae9642d2dbbe206f246bf02eadccd0 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Fri, 21 Jun 2024 17:27:05 +0800
Subject: [PATCH 12/12] add addr space

---
 mlir/docs/MemrefMergeAlloc.md                 | 11 ++--
 .../MemRef/Transforms/MergeAllocTickBased.h   |  9 ++-
 .../mlir/Dialect/MemRef/Transforms/Passes.h   | 10 +++-
 .../Dialect/MemRef/Transforms/MergeAlloc.cpp  |  4 +-
 .../MemRef/Transforms/MergeAllocTickBased.cpp | 60 ++++++++++++-------
 mlir/test/Dialect/MemRef/buffer-merge.mlir    | 49 +++++++++++++++
 6 files changed, 112 insertions(+), 31 deletions(-)

diff --git a/mlir/docs/MemrefMergeAlloc.md b/mlir/docs/MemrefMergeAlloc.md
index 944984c4bc0a9..b9ae7ff808106 100644
--- a/mlir/docs/MemrefMergeAlloc.md
+++ b/mlir/docs/MemrefMergeAlloc.md
@@ -187,13 +187,15 @@ collected traces as below:
 /// which buffers are put into which "allocation scope".
 class LifetimeTrace {
 public:
-  virtual ~LifetimeTrace() = default;
+  virtual Block *getAllocScope() const = 0;
+  virtual Attribute getMemorySpace() const = 0;
 };
 
-/// top level memory trace info for multiple scopes. Each key-value is the
-///  "allocation scope" and the LifetimeTrace
+/// top level memory trace info for multiple scopes. Each element of scopeTraces
+/// should contain an "allocation scope" and the implementation-defined lifetime
+/// data
 struct MemoryTraceScopes {
-  llvm::DenseMap<Block *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
+  llvm::SmallVector<std::unique_ptr<LifetimeTrace>> scopeTraces;
   MemoryTraceScopes() = default;
 };
 
@@ -216,6 +218,7 @@ interfaces are shown below:
 /// memref::AllocOp which are in the same LifetimeTrace.
 struct MemorySchedule {
   size_t totalSize;
+  Attribute memorySpace;
   llvm::DenseMap<Operation *, int64_t> allocToOffset;
   MemorySchedule() : totalSize{0} {}
 };
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
index 9dfbadf1cd456..8b70ff60cb0ee 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/MergeAllocTickBased.h
@@ -63,7 +63,12 @@ struct ComplexScope {
 /// the top-level collected lifetime trace for merge-alloc pass
 struct TickTraceResult : public LifetimeTrace {
   memoryplan::Traces traces;
-  TickTraceResult() : LifetimeTrace{TK_TICK} {}
+  Block *block;
+  Attribute memorySpace;
+  TickTraceResult(Block *block, Attribute memorySpace)
+      : LifetimeTrace{TK_TICK}, block{block}, memorySpace{memorySpace} {}
+  Block *getAllocScope() const override { return block; }
+  Attribute getMemorySpace() const override { return memorySpace; }
   static bool classof(const LifetimeTrace *S) {
     return S->getKind() == TK_TICK;
   }
@@ -156,7 +161,7 @@ struct MergeAllocDefaultMutator {
   /// builds an memory alloc op at the scope with given size and alignment in
   /// bytes
   virtual Value buildAlloc(OpBuilder &build, Block *scope, int64_t size,
-                           int64_t alignment) const;
+                           int64_t alignment, Attribute memorySpace) const;
   /// builds an memory view op for original memref.alloc op (origAllocOp) and
   /// the merged single allocation (mergedAlloc)
   virtual Value buildView(OpBuilder &build, Block *scope,
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 2fbc57ffa62f3..54f2e78372c41 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -88,15 +88,18 @@ class LifetimeTrace {
   virtual ~LifetimeTrace() = default;
   LifetimeTrace(TraceKind kind) : kind{kind} {}
   TraceKind getKind() const { return kind; }
+  virtual Block *getAllocScope() const = 0;
+  virtual Attribute getMemorySpace() const = 0;
 
 private:
   TraceKind kind;
 };
 
-/// top level memory trace info for multiple scopes. Each key-value is the
-///  "allocation scope" and the LifetimeTrace
+/// top level memory trace info for multiple scopes. Each element of scopeTraces
+/// should contain an "allocation scope" and the implementation-defined lifetime
+/// data
 struct MemoryTraceScopes {
-  llvm::DenseMap<Block *, std::unique_ptr<LifetimeTrace>> scopeToTraces;
+  llvm::SmallVector<std::unique_ptr<LifetimeTrace>> scopeTraces;
   MemoryTraceScopes() = default;
 };
 
@@ -105,6 +108,7 @@ struct MemoryTraceScopes {
 /// memref::AllocOp which are in the same LifetimeTrace.
 struct MemorySchedule {
   size_t totalSize;
+  Attribute memorySpace;
   llvm::DenseMap<Operation *, int64_t> allocToOffset;
   MemorySchedule() : totalSize{0} {}
 };
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
index ade68752fa33f..9cfae357596a9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAlloc.cpp
@@ -37,12 +37,12 @@ LogicalResult passDriver(Operation *op,
   if (o.checkOnly) {
     return success();
   }
-  for (auto &[scope, traces] : (*tracesOrFail).scopeToTraces) {
+  for (auto &traces : (*tracesOrFail).scopeTraces) {
     auto schedule = o.planner(op, *traces, o);
     if (failed(schedule)) {
       return failure();
     }
-    if (failed(o.mutator(op, scope, *schedule, o))) {
+    if (failed(o.mutator(op, traces->getAllocScope(), *schedule, o))) {
       return failure();
     }
   }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
index 437120fd766be..a828e44b05457 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MergeAllocTickBased.cpp
@@ -28,13 +28,7 @@ using namespace special_ticks;
 
 /// Return `true` if the given MemRef type has static shapes
 /// and default memory space.
-static bool isMemRefTypeOk(MemRefType type) {
-  IntegerAttr intMemorySpace =
-      llvm::dyn_cast_or_null<IntegerAttr>(type.getMemorySpace());
-  if (intMemorySpace && intMemorySpace.getValue() != 0)
-    return false;
-  return type.hasStaticShape() && type.getLayout().isIdentity();
-}
+static bool isMemRefTypeOk(MemRefType type) { return type.hasStaticShape(); }
 
 void Tick::update(int64_t tick) {
   if (tick == UNTRACEABLE_ACCESS) {
@@ -217,7 +211,9 @@ TickCollecter::getTrace(TickCollecterStates *s) const {
                   size_t size)
         : allocTick{allocTick}, tick{tick}, trace{bufferId, size} {}
   };
-  llvm::DenseMap<Block *, llvm::SmallVector<TraceWithTick, 8>> raw;
+  llvm::DenseMap<std::pair<Block *, Attribute>,
+                 llvm::SmallVector<TraceWithTick, 8>>
+      raw;
   for (auto &[op, tick] : s->allocTicks) {
     if (!isMergeableAlloc(s, op, tick.firstAccess)) {
       continue;
@@ -237,15 +233,18 @@ TickCollecter::getTrace(TickCollecterStates *s) const {
     if (failed(allocSize)) {
       return failure();
     }
+    auto key = std::make_pair(
+        block, cast<MemRefType>(op->getResultTypes().front()).getMemorySpace());
     // tick.firstAccess * 2 and tick.lastAccess * 2 + 1 to make sure "dealloc"
     // overlaps "alloc"
-    raw[block].emplace_back(tick.allocTick, tick.firstAccess * 2,
-                            reinterpret_cast<uintptr_t>(op), *allocSize);
-    raw[block].emplace_back(tick.allocTick, tick.lastAccess * 2 + 1,
-                            reinterpret_cast<uintptr_t>(op), 0);
+    raw[key].emplace_back(tick.allocTick, tick.firstAccess * 2,
+                          reinterpret_cast<uintptr_t>(op), *allocSize);
+    raw[key].emplace_back(tick.allocTick, tick.lastAccess * 2 + 1,
+                          reinterpret_cast<uintptr_t>(op), 0);
   }
   MemoryTraceScopes ret;
-  for (auto &[scope, trace] : raw) {
+  for (auto &[scopeAndSpace, trace] : raw) {
+    const auto &[scope, memSpace] = scopeAndSpace;
     std::stable_sort(trace.begin(), trace.end(),
                      [](const TraceWithTick &a, const TraceWithTick &b) {
                        if (a.tick == b.tick) {
@@ -253,13 +252,29 @@ TickCollecter::getTrace(TickCollecterStates *s) const {
                        }
                        return a.tick < b.tick;
                      });
-    auto retTrace = std::make_unique<TickTraceResult>();
+    auto retTrace = std::make_unique<TickTraceResult>(scope, memSpace);
     retTrace->traces.reserve(trace.size());
     for (auto &tr : trace) {
       retTrace->traces.emplace_back(tr.trace);
     }
-    ret.scopeToTraces[scope] = std::move(retTrace);
+    ret.scopeTraces.emplace_back(std::move(retTrace));
   }
+  // stablize the order of scopes for testing
+  std::stable_sort(
+      ret.scopeTraces.begin(), ret.scopeTraces.end(),
+      [](const std::unique_ptr<LifetimeTrace> &a,
+         const std::unique_ptr<LifetimeTrace> &b) {
+        int64_t aFirstSize = -1, bFirstSize = -1;
+        if (auto &traces = static_cast<TickTraceResult *>(a.get())->traces;
+            !traces.empty()) {
+          aFirstSize = traces.front().size;
+        }
+        if (auto &traces = static_cast<TickTraceResult *>(b.get())->traces;
+            !traces.empty()) {
+          bFirstSize = traces.front().size;
+        }
+        return aFirstSize < bFirstSize;
+      });
   return ret;
 }
 
@@ -337,6 +352,7 @@ FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
       outSchedule, dummy);
   MemorySchedule ret;
   ret.totalSize = total;
+  ret.memorySpace = tr.getMemorySpace();
   for (auto [k, offset] : outSchedule) {
     ret.allocToOffset[reinterpret_cast<Operation *>(k)] =
         static_cast<int64_t>(offset);
@@ -345,14 +361,17 @@ FailureOr<MemorySchedule> tickBasedPlanMemory(Operation *op,
 }
 
 Value MergeAllocDefaultMutator::buildAlloc(OpBuilder &builder, Block *block,
-                                           int64_t size,
-                                           int64_t alignmentInt) const {
+                                           int64_t size, int64_t alignmentInt,
+                                           Attribute memorySpace) const {
   builder.setInsertionPointToStart(block);
   auto alignment = builder.getIntegerAttr(
       IntegerType::get(builder.getContext(), 64), alignmentInt);
   auto alloc = builder.create<memref::AllocOp>(
       block->getParentOp()->getLoc(),
-      MemRefType::get({size}, builder.getI8Type()), alignment);
+      MemRefType::get({size}, builder.getI8Type(),
+                      /*layout*/ MemRefLayoutAttrInterface(), memorySpace),
+      alignment);
+
   return alloc;
 }
 Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Block *scope,
@@ -375,8 +394,9 @@ MergeAllocDefaultMutator::operator()(Operation *op, Block *scope,
     return success();
   }
   OpBuilder builder{op->getContext()};
-  auto alloc = buildAlloc(
-      builder, scope, static_cast<int64_t>(schedule.totalSize), o.alignment);
+  auto alloc =
+      buildAlloc(builder, scope, static_cast<int64_t>(schedule.totalSize),
+                 o.alignment, schedule.memorySpace);
   for (auto &[origBuf, offset] : schedule.allocToOffset) {
     origBuf->replaceAllUsesWith(
         buildView(builder, scope, origBuf, alloc, static_cast<int64_t>(offset))
diff --git a/mlir/test/Dialect/MemRef/buffer-merge.mlir b/mlir/test/Dialect/MemRef/buffer-merge.mlir
index a30d63974017b..eb612d03cdc44 100644
--- a/mlir/test/Dialect/MemRef/buffer-merge.mlir
+++ b/mlir/test/Dialect/MemRef/buffer-merge.mlir
@@ -128,3 +128,52 @@ func.func @withloop() {
   }
   return
 }
+
+
+// CHECK-LABEL: @different_mem_space
+func.func @different_mem_space() {
+  // CHECK-DAG: %[[BASE_DEFAULT:.*]] = memref.alloc() {alignment = 64 : i64} : memref<4096xi8>
+  // CHECK-DAG: %[[BASE_WG:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3584xi8, #gpu.address_space<workgroup>>
+  // CHECK-DAG: %[[BASE_PRIVATE:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1280xi8, #gpu.address_space<private>>
+  // CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[defaultv:.*]] = memref.view %[[BASE_DEFAULT]][%[[C0_4]]][] : memref<4096xi8> to memref<8x64xf32>
+  %defaultv = memref.alloc() : memref<8x64xf32>
+  // CHECK:     "test.source"(%[[defaultv]])
+  "test.source"(%defaultv)  : (memref<8x64xf32>) -> ()
+  // CHECK-DAG: %[[C2048_4:.*]] = arith.constant 2048 : index
+  // CHECK-DAG: %[[defaultv2:.*]] = memref.view %[[BASE_DEFAULT]][%[[C2048_4]]][] : memref<4096xi8> to memref<8x64xf32>
+  %defaultv2 = memref.alloc() : memref<8x64xf32>
+  // CHECK-DAG: %[[C2048_4_2:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[priv:.*]] = memref.view %[[BASE_PRIVATE]][%[[C2048_4_2]]][] : memref<1280xi8, #gpu.address_space<private>> to memref<4x64xf32, #gpu.address_space<private>>
+  %priv = memref.alloc() : memref<4x64xf32, #gpu.address_space<private>>
+  // CHECK:     "test.source"(%[[priv]])
+  "test.source"(%priv)  : (memref<4x64xf32, #gpu.address_space<private>>) -> ()
+  // CHECK-DAG: %[[C0_4_2:.*]] = arith.constant 1024 : index
+  // CHECK-DAG: %[[priv2:.*]] = memref.view %[[BASE_PRIVATE]][%[[C0_4_2]]][] : memref<1280xi8, #gpu.address_space<private>> to memref<64xf32, #gpu.address_space<private>>
+  %priv2 = memref.alloc() : memref<64xf32,  #gpu.address_space<private>>
+  // CHECK:     "test.source"(%[[priv2]])
+  "test.source"(%priv2)  : (memref<64xf32, #gpu.address_space<private>>) -> ()
+  // CHECK:     "test.source"(%[[priv]])
+  "test.source"(%priv)  : (memref<4x64xf32, #gpu.address_space<private>>) -> ()
+
+
+  // CHECK-DAG: %[[C2048_4_3:.*]] = arith.constant 1536 : index
+  // CHECK-DAG: %[[sharedv:.*]] = memref.view %[[BASE_WG]][%[[C2048_4_3]]][] : memref<3584xi8, #gpu.address_space<workgroup>> to memref<8x64xf32, #gpu.address_space<workgroup>>
+  %sharedv = memref.alloc() : memref<8x64xf32, #gpu.address_space<workgroup>>
+  // CHECK-DAG: %[[C0_4_3:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[sharedv2:.*]] = memref.view %[[BASE_WG]][%[[C0_4_3]]][] : memref<3584xi8, #gpu.address_space<workgroup>> to memref<6x64xf32, #gpu.address_space<workgroup>>
+  %sharedv2 = memref.alloc() : memref<6x64xf32, #gpu.address_space<workgroup>>
+  // CHECK:     "test.source"(%[[sharedv2]])
+  // CHECK:     "test.source"(%[[sharedv]])
+  // CHECK:     "test.source"(%[[sharedv2]])
+  "test.source"(%sharedv2)  : (memref<6x64xf32, #gpu.address_space<workgroup>>) -> ()
+  "test.source"(%sharedv)  : (memref<8x64xf32, #gpu.address_space<workgroup>>) -> ()
+  "test.source"(%sharedv2)  : (memref<6x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+
+  // CHECK:     "test.source"(%[[defaultv2]])
+  // CHECK:     "test.source"(%[[defaultv]])
+  "test.source"(%defaultv2)  : (memref<8x64xf32>) -> ()
+  "test.source"(%defaultv)  : (memref<8x64xf32>) -> ()
+  return
+}



More information about the Mlir-commits mailing list