[flang-commits] [flang] [flang] handle alloca outside of entry blocks in MemoryAllocation (PR #98457)

via flang-commits flang-commits at lists.llvm.org
Thu Jul 11 03:10:26 PDT 2024


https://github.com/jeanPerier created https://github.com/llvm/llvm-project/pull/98457

This patch  generalizes the MemoryAllocation pass (alloca -> heap) to handle fir.alloca regardless of their postion in the IR. Currently, it only dealt with fir.alloca in function entry blocks. The logic is placed in a utility that can be used to replace alloca in an operation on demand to whatever kind of allocation the utility user wants via callbacks (allocmem, or custom runtime calls to instrument the code...).

To do so, a concept of ownership, that was already implied a bit and used in passes like stack-reclaim, is formalized. Any operation with the LoopLikeInterface, AutomaticAllocationScope, or IsolatedFromAbove owns the alloca directly nested inside its regions, and they must not be used after the operation.

The pass then looks for the exit points of region with such interface, and use that to insert deallocation. If dominance is not proved, the pass fallbacks to storing the new address into a C pointer variable created in the entry of the owning region which allows inserting deallocation as needed, included near the alloca itself to avoid leaks when the alloca is executed multiple times due to block CFGs loops.

This should fix https://github.com/llvm/llvm-project/issues/88344.

In a next step, I will try to refactor lowering a bit to introduce lifetime operation for alloca so that the deallocation points can be inserted as soon as possible.

>From aaecc7bfc79d1375e74fa7d168f086b945a291c0 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 8 Jul 2024 03:45:29 -0700
Subject: [PATCH] [flang] handle alloca outside of entry blocks in
 MemoryAllocation

---
 .../include/flang/Optimizer/Dialect/FIROps.td |  13 ++
 .../flang/Optimizer/Transforms/MemoryUtils.h  |  63 +++++
 flang/lib/Optimizer/Dialect/FIROps.cpp        |  21 ++
 flang/lib/Optimizer/Transforms/CMakeLists.txt |   1 +
 .../Optimizer/Transforms/MemoryAllocation.cpp | 143 ++++--------
 .../lib/Optimizer/Transforms/MemoryUtils.cpp  | 220 ++++++++++++++++++
 flang/test/Fir/memory-allocation-opt-2.fir    | 105 +++++++++
 7 files changed, 466 insertions(+), 100 deletions(-)
 create mode 100644 flang/include/flang/Optimizer/Transforms/MemoryUtils.h
 create mode 100644 flang/lib/Optimizer/Transforms/MemoryUtils.cpp
 create mode 100644 flang/test/Fir/memory-allocation-opt-2.fir

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 5b03806614f9b..89c13fa7cebe6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -124,6 +124,13 @@ def fir_AllocaOp : fir_Op<"alloca", [AttrSizedOperandSegments,
     Indeed, a user would likely expect a good Fortran compiler to perform such
     an optimization.
 
+    Stack allocations have a maximum lifetime concept: their uses must not
+    exceed the lifetime of the closest parent operation with the
+    AutomaticAllocationScope trait, IsIsolatedFromAbove trait, or
+    LoopLikeOpInterface trait. This restriction is meant to ease the
+    insertion of stack save and restore operations, and to ease the conversion
+    of stack allocation into heap allocation.
+
     Until Fortran 2018, procedures defaulted to non-recursive. A legal
     implementation could therefore convert stack allocations to global
     allocations. Such a conversion effectively adds the SAVE attribute to all
@@ -183,11 +190,17 @@ def fir_AllocaOp : fir_Op<"alloca", [AttrSizedOperandSegments,
     mlir::Type getAllocatedType();
     bool hasLenParams() { return !getTypeparams().empty(); }
     bool hasShapeOperands() { return !getShape().empty(); }
+    bool isDynamic() {return hasLenParams() || hasShapeOperands();}
     unsigned numLenParams() { return getTypeparams().size(); }
     operand_range getLenParams() { return getTypeparams(); }
     unsigned numShapeOperands() { return getShape().size(); }
     operand_range getShapeOperands() { return getShape(); }
     static mlir::Type getRefTy(mlir::Type ty);
+    /// Is this an operation that owns the alloca directly made in its region?
+    static bool ownsNestedAlloca(mlir::Operation* op);
+    /// Get the parent region that owns this alloca. Nullptr if none can be
+    /// identified.
+    mlir::Region* getOwnerRegion();
   }];
 }
 
diff --git a/flang/include/flang/Optimizer/Transforms/MemoryUtils.h b/flang/include/flang/Optimizer/Transforms/MemoryUtils.h
new file mode 100644
index 0000000000000..fdc1b2c23b812
--- /dev/null
+++ b/flang/include/flang/Optimizer/Transforms/MemoryUtils.h
@@ -0,0 +1,63 @@
+//===-- Optimizer/Transforms/MemoryUtils.h ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a utility to replace fir.alloca by dynamic allocation and
+// deallocation. The exact kind of dynamic allocation is left to be defined by
+// the utility user via callbacks (could be fir.allocmem or custom runtime
+// calls).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_TRANSFORMS_MEMORYUTILS_H
+#define FORTRAN_OPTIMIZER_TRANSFORMS_MEMORYUTILS_H
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+
+namespace mlir {
+class RewriterBase;
+}
+
+namespace fir {
+
+/// Type of callbacks that indicate if a given fir.alloca must be
+/// rewritten.
+using MustRewriteCallBack = llvm::function_ref<bool(fir::AllocaOp)>;
+
+/// Type of callbacks that produce the replacement for a given fir.alloca.
+/// It is provided extra information about the dominance of the deallocation
+/// points that have been identified, and may refuse to replace the alloca,
+/// even if the MustRewriteCallBack previously returned true, in which case
+/// it should return a null value.
+/// The callback should not delete the alloca, the utility will do it.
+using AllocaRewriterCallBack =
+    llvm::function_ref<mlir::Value(mlir::OpBuilder &, fir::AllocaOp,
+                                   bool /*allocaDominatesDeallocLocations*/)>;
+/// Type of callbacks that must generate deallocation of storage obtained via
+/// AllocaRewriterCallBack calls.
+using DeallocCallBack =
+    llvm::function_ref<void(mlir::Location, mlir::OpBuilder &, mlir::Value)>;
+
+/// Utility to replace fir.alloca by dynamic allocations inside \p parentOp.
+/// \p MustRewriteCallBack let the user control which fir.alloca should be
+/// replaced. \p AllocaRewriterCallBack let the user define how the new memory
+/// should be allocated. \p DeallocCallBack let the user decide how the memory
+/// should be deallocated. The boolean result indicate if the utility succeeded
+/// to replace all fir.alloca as requested by the user. Causes of failures are
+/// the presence of unregistered operations, or OpenMP/ACC recipe operation that
+/// returns memory allocated inside their region.
+bool replaceAllocas(mlir::RewriterBase &rewriter, mlir::Operation *parentOp,
+                    MustRewriteCallBack, AllocaRewriterCallBack,
+                    DeallocCallBack);
+
+} // namespace fir
+
+#endif // FORTRAN_OPTIMIZER_TRANSFORMS_MEMORYUTILS_H
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index a499a6e4f8d04..9e6b88041ba69 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -275,6 +275,27 @@ llvm::LogicalResult fir::AllocaOp::verify() {
   return mlir::success();
 }
 
+bool fir::AllocaOp::ownsNestedAlloca(mlir::Operation *op) {
+  return op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>() ||
+         op->hasTrait<mlir::OpTrait::AutomaticAllocationScope>() ||
+         mlir::isa<mlir::LoopLikeOpInterface>(*op);
+}
+
+mlir::Region *fir::AllocaOp::getOwnerRegion() {
+  mlir::Operation *currentOp = getOperation();
+  while (mlir::Operation *parentOp = currentOp->getParentOp()) {
+    // If the operation was not registered, inquiries about its traits will be
+    // incorrect and it is not possible to reason about the operation. This
+    // should not happen in a normal Fortran compilation flow, but be foolproof.
+    if (!parentOp->isRegistered())
+      return nullptr;
+    if (fir::AllocaOp::ownsNestedAlloca(parentOp))
+      return currentOp->getParentRegion();
+    currentOp = parentOp;
+  }
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // AllocMemOp
 //===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 94d94398d696a..3108304240894 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_flang_library(FIRTransforms
   ControlFlowConverter.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
+  MemoryUtils.cpp
   MemoryAllocation.cpp
   StackArrays.cpp
   MemRefDataFlowOpt.cpp
diff --git a/flang/lib/Optimizer/Transforms/MemoryAllocation.cpp b/flang/lib/Optimizer/Transforms/MemoryAllocation.cpp
index 03b1ae89428af..3f308a8f4b560 100644
--- a/flang/lib/Optimizer/Transforms/MemoryAllocation.cpp
+++ b/flang/lib/Optimizer/Transforms/MemoryAllocation.cpp
@@ -9,6 +9,7 @@
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/MemoryUtils.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/Diagnostics.h"
@@ -27,50 +28,18 @@ namespace fir {
 // Number of elements in an array does not determine where it is allocated.
 static constexpr std::size_t unlimitedArraySize = ~static_cast<std::size_t>(0);
 
-namespace {
-class ReturnAnalysis {
-public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReturnAnalysis)
-
-  ReturnAnalysis(mlir::Operation *op) {
-    if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op))
-      for (mlir::Block &block : func)
-        for (mlir::Operation &i : block)
-          if (mlir::isa<mlir::func::ReturnOp>(i)) {
-            returnMap[op].push_back(&i);
-            break;
-          }
-  }
-
-  llvm::SmallVector<mlir::Operation *> getReturns(mlir::Operation *func) const {
-    auto iter = returnMap.find(func);
-    if (iter != returnMap.end())
-      return iter->second;
-    return {};
-  }
-
-private:
-  llvm::DenseMap<mlir::Operation *, llvm::SmallVector<mlir::Operation *>>
-      returnMap;
-};
-} // namespace
-
 /// Return `true` if this allocation is to remain on the stack (`fir.alloca`).
 /// Otherwise the allocation should be moved to the heap (`fir.allocmem`).
 static inline bool
-keepStackAllocation(fir::AllocaOp alloca, mlir::Block *entry,
+keepStackAllocation(fir::AllocaOp alloca,
                     const fir::MemoryAllocationOptOptions &options) {
-  // Limitation: only arrays allocated on the stack in the entry block are
-  // considered for now.
-  // TODO: Generalize the algorithm and placement of the freemem nodes.
-  if (alloca->getBlock() != entry)
-    return true;
+  // Move all arrays and character with runtime determined size to the heap.
+  if (options.dynamicArrayOnHeap && alloca.isDynamic())
+    return false;
+  // TODO: use data layout to reason in terms of byte size to cover all "big"
+  // entities, which may be scalar derived types.
   if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(alloca.getInType())) {
-    if (fir::hasDynamicSize(seqTy)) {
-      // Move all arrays with runtime determined size to the heap.
-      if (options.dynamicArrayOnHeap)
-        return false;
-    } else {
+    if (!fir::hasDynamicSize(seqTy)) {
       std::int64_t numberOfElements = 1;
       for (std::int64_t i : seqTy.getShape()) {
         numberOfElements *= i;
@@ -82,8 +51,6 @@ keepStackAllocation(fir::AllocaOp alloca, mlir::Block *entry,
       // the heap.
       if (static_cast<std::size_t>(numberOfElements) >
           options.maxStackArraySize) {
-        LLVM_DEBUG(llvm::dbgs()
-                   << "memory allocation opt: found " << alloca << '\n');
         return false;
       }
     }
@@ -91,49 +58,30 @@ keepStackAllocation(fir::AllocaOp alloca, mlir::Block *entry,
   return true;
 }
 
-namespace {
-class AllocaOpConversion : public mlir::OpRewritePattern<fir::AllocaOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  AllocaOpConversion(mlir::MLIRContext *ctx,
-                     llvm::ArrayRef<mlir::Operation *> rets)
-      : OpRewritePattern(ctx), returnOps(rets) {}
-
-  llvm::LogicalResult
-  matchAndRewrite(fir::AllocaOp alloca,
-                  mlir::PatternRewriter &rewriter) const override {
-    auto loc = alloca.getLoc();
-    mlir::Type varTy = alloca.getInType();
-    auto unpackName =
-        [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
-      if (opt)
-        return *opt;
-      return {};
-    };
-    auto uniqName = unpackName(alloca.getUniqName());
-    auto bindcName = unpackName(alloca.getBindcName());
-    auto heap = rewriter.create<fir::AllocMemOp>(
-        loc, varTy, uniqName, bindcName, alloca.getTypeparams(),
-        alloca.getShape());
-    auto insPt = rewriter.saveInsertionPoint();
-    for (mlir::Operation *retOp : returnOps) {
-      rewriter.setInsertionPoint(retOp);
-      [[maybe_unused]] auto free = rewriter.create<fir::FreeMemOp>(loc, heap);
-      LLVM_DEBUG(llvm::dbgs() << "memory allocation opt: add free " << free
-                              << " for " << heap << '\n');
-    }
-    rewriter.restoreInsertionPoint(insPt);
-    rewriter.replaceOpWithNewOp<fir::ConvertOp>(
-        alloca, fir::ReferenceType::get(varTy), heap);
-    LLVM_DEBUG(llvm::dbgs() << "memory allocation opt: replaced " << alloca
-                            << " with " << heap << '\n');
-    return mlir::success();
-  }
+static mlir::Value genAllocmem(mlir::OpBuilder &builder, fir::AllocaOp alloca,
+                               bool deallocPointsDominateAlloc) {
+  mlir::Type varTy = alloca.getInType();
+  auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
+    if (opt)
+      return *opt;
+    return {};
+  };
+  llvm::StringRef uniqName = unpackName(alloca.getUniqName());
+  llvm::StringRef bindcName = unpackName(alloca.getBindcName());
+  auto heap = builder.create<fir::AllocMemOp>(alloca.getLoc(), varTy, uniqName,
+                                              bindcName, alloca.getTypeparams(),
+                                              alloca.getShape());
+  LLVM_DEBUG(llvm::dbgs() << "memory allocation opt: replaced " << alloca
+                          << " with " << heap << '\n');
+  return heap;
+}
 
-private:
-  llvm::ArrayRef<mlir::Operation *> returnOps;
-};
+static void genFreemem(mlir::Location loc, mlir::OpBuilder &builder,
+                       mlir::Value allocmem) {
+  [[maybe_unused]] auto free = builder.create<fir::FreeMemOp>(loc, allocmem);
+  LLVM_DEBUG(llvm::dbgs() << "memory allocation opt: add free " << free
+                          << " for " << allocmem << '\n');
+}
 
 /// This pass can reclassify memory allocations (fir.alloca, fir.allocmem) based
 /// on heuristics and settings. The intention is to allow better performance and
@@ -144,6 +92,7 @@ class AllocaOpConversion : public mlir::OpRewritePattern<fir::AllocaOp> {
 ///      make it a heap allocation.
 ///   2. If a stack allocation is an array with a runtime evaluated size make
 ///      it a heap allocation.
+namespace {
 class MemoryAllocationOpt
     : public fir::impl::MemoryAllocationOptBase<MemoryAllocationOpt> {
 public:
@@ -184,23 +133,17 @@ class MemoryAllocationOpt
     // If func is a declaration, skip it.
     if (func.empty())
       return;
-
-    const auto &analysis = getAnalysis<ReturnAnalysis>();
-
-    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
-                           mlir::func::FuncDialect>();
-    target.addDynamicallyLegalOp<fir::AllocaOp>([&](fir::AllocaOp alloca) {
-      return keepStackAllocation(alloca, &func.front(), options);
-    });
-
-    llvm::SmallVector<mlir::Operation *> returnOps = analysis.getReturns(func);
-    patterns.insert<AllocaOpConversion>(context, returnOps);
-    if (mlir::failed(
-            mlir::applyPartialConversion(func, target, std::move(patterns)))) {
-      mlir::emitError(func.getLoc(),
-                      "error in memory allocation optimization\n");
-      signalPassFailure();
-    }
+    auto tryReplacing = [&](fir::AllocaOp alloca) {
+      bool res = !keepStackAllocation(alloca, options);
+      if (res) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "memory allocation opt: found " << alloca << '\n');
+      }
+      return res;
+    };
+    mlir::IRRewriter rewriter(context);
+    fir::replaceAllocas(rewriter, func.getOperation(), tryReplacing,
+                        genAllocmem, genFreemem);
   }
 
 private:
diff --git a/flang/lib/Optimizer/Transforms/MemoryUtils.cpp b/flang/lib/Optimizer/Transforms/MemoryUtils.cpp
new file mode 100644
index 0000000000000..80228ceba0eaa
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/MemoryUtils.cpp
@@ -0,0 +1,220 @@
+//===- MemoryUtils.cpp ----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Transforms/MemoryUtils.h"
+#include "flang/Optimizer/Builder/Todo.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Dominance.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace {
+class AllocaReplaceImpl {
+public:
+  AllocaReplaceImpl(fir::AllocaRewriterCallBack allocaRewriter,
+                    fir::DeallocCallBack deallocGenerator)
+      : allocaRewriter{allocaRewriter}, deallocGenerator{deallocGenerator} {}
+  bool replace(mlir::RewriterBase &, fir::AllocaOp);
+
+private:
+  mlir::Region *findDeallocationPointsAndOwner(
+      fir::AllocaOp alloca,
+      llvm::SmallVectorImpl<mlir::Operation *> &deallocationPoints);
+  bool
+  allocDominatesDealloc(fir::AllocaOp alloca,
+                        llvm::ArrayRef<mlir::Operation *> deallocationPoints) {
+    return llvm::all_of(deallocationPoints, [&](mlir::Operation *deallocPoint) {
+      return this->dominanceInfo.properlyDominates(alloca.getOperation(),
+                                                   deallocPoint);
+    });
+  }
+  void
+  genIndirectDeallocation(mlir::RewriterBase &, fir::AllocaOp,
+                          llvm::ArrayRef<mlir::Operation *> deallocationPoints,
+                          mlir::Value replacement, mlir::Region &owningRegion);
+
+private:
+  fir::AllocaRewriterCallBack allocaRewriter;
+  fir::DeallocCallBack deallocGenerator;
+  mlir::DominanceInfo dominanceInfo;
+};
+} // namespace
+
+static bool terminatorYieldsMemory(mlir::Operation &terminator) {
+  return llvm::any_of(terminator.getResults(), [](mlir::OpResult res) {
+    return fir::conformsWithPassByRef(res.getType());
+  });
+}
+
+static bool isRegionTerminator(mlir::Operation &terminator) {
+  // Using ReturnLike trait is tempting but it is not set on
+  // all region terminator that matters (like omp::TerminatorOp that
+  // has no results).
+  // May be true for dead code. It is not a correctness issue and dead code can
+  // be eliminated by running region simplification before this utility is
+  // used.
+  // May also be true for unreachable like terminators (e.g., after an abort
+  // call related to Fortran STOP). This is also OK, the inserted deallocation
+  // will simply never be reached. It is easier for the rest of the code here
+  // to assume there is always at least one deallocation point, so keep
+  // unreachable terminators.
+  return !terminator.hasSuccessors();
+}
+
+mlir::Region *AllocaReplaceImpl::findDeallocationPointsAndOwner(
+    fir::AllocaOp alloca,
+    llvm::SmallVectorImpl<mlir::Operation *> &deallocationPoints) {
+  // Step 1: Identify the operation and region owning the alloca.
+  mlir::Region *owningRegion = alloca.getOwnerRegion();
+  if (!owningRegion)
+    return nullptr;
+  mlir::Operation *owningOp = owningRegion->getParentOp();
+  assert(owningOp && "region expected to be owned");
+  // Step 2: Identify the exit points of the owning region, they are the default
+  // deallocation points. TODO: detect and use lifetime markers to get earlier
+  // deallocation points.
+  bool isOpenACCMPRecipe = mlir::isa<mlir::accomp::RecipeInterface>(owningOp);
+  for (mlir::Block &block : owningRegion->getBlocks())
+    if (mlir::Operation *terminator = block.getTerminator();
+        isRegionTerminator(*terminator)) {
+      // FIXME: OpenACC and OpenMP privatization recipe are stand alone
+      // operation meant to be later "inlined", the value they return may
+      // be the address of a local alloca. It would be incorrect to insert
+      // deallocation before the terminator (this would introduce use after
+      // free once the recipe is inlined.
+      // This probably require redesign or special handling on the OpenACC/MP
+      // side.
+      if (isOpenACCMPRecipe && terminatorYieldsMemory(*terminator))
+        return nullptr;
+      deallocationPoints.push_back(terminator);
+    }
+  // If the owningRegion did not adhere to the ReturnLike interface for its
+  // terminators, bail and do not attempt to translate it (we could maybe
+  // fallback to consider terminators with no block successor, but since all
+  // FIR, OpenACC, OpenMP, CUF, SCF operations with IsIsolatedFromAbove,
+  // AutomaticAllocationScope, or LoopLikeOpInterface have such terminators,
+  // avoid any untested complexity for now).
+  if (deallocationPoints.empty())
+    return nullptr;
+
+  // Step 3: detect loops between the alloc and deallocation points.
+  // If such loop exists, the easy solution is to consider the alloc
+  // as a deallocation point of any previous allocation. This works
+  // because the alloc does not properly dominates itself, so the
+  // inserted deallocation will be conditional.
+  // For now, always assume there may always be a loop if any of the
+  // deallocation point does not dominate the alloca. It is
+  // conservative approach. Bringing lifetime markers above will reduce
+  // the false positive for alloca made inside if like constructs or CFG.
+  if (!allocDominatesDealloc(alloca, deallocationPoints))
+    deallocationPoints.push_back(alloca.getOperation());
+  return owningRegion;
+}
+
+static mlir::Value castIfNeeed(mlir::Location loc, mlir::RewriterBase &rewriter,
+                               mlir::Type newType, mlir::Value value) {
+  if (value.getType() != newType)
+    return rewriter.create<fir::ConvertOp>(loc, newType, value);
+  return value;
+}
+
+void AllocaReplaceImpl::genIndirectDeallocation(
+    mlir::RewriterBase &rewriter, fir::AllocaOp alloca,
+    llvm::ArrayRef<mlir::Operation *> deallocationPoints,
+    mlir::Value replacement, mlir::Region &owningRegion) {
+  mlir::Location loc = alloca.getLoc();
+  auto replacementInsertPoint = rewriter.saveInsertionPoint();
+  // Create C pointer variable in the entry block to store the alloc
+  // and access it indirectly in the entry points that do not dominate.
+  rewriter.setInsertionPointToStart(&owningRegion.front());
+  mlir::Type heapType = fir::HeapType::get(alloca.getInType());
+  mlir::Value ptrVar = rewriter.create<fir::AllocaOp>(loc, heapType);
+  mlir::Value nullPtr = rewriter.create<fir::ZeroOp>(loc, heapType);
+  rewriter.create<fir::StoreOp>(loc, nullPtr, ptrVar);
+  // TODO: introducing a pointer compare op in FIR would help
+  // generating less IR here.
+  mlir::Type intPtrTy = rewriter.getI64Type();
+  mlir::Value c0 = rewriter.create<mlir::arith::ConstantOp>(
+      loc, intPtrTy, rewriter.getIntegerAttr(intPtrTy, 0));
+
+  // Store new storage address right after its creation.
+  rewriter.restoreInsertionPoint(replacementInsertPoint);
+  mlir::Value castReplacement =
+      castIfNeeed(loc, rewriter, heapType, replacement);
+  rewriter.create<fir::StoreOp>(loc, castReplacement, ptrVar);
+
+  // Generate conditional deallocation at every deallocation point.
+  auto genConditionalDealloc = [&](mlir::Location loc) {
+    mlir::Value ptrVal = rewriter.create<fir::LoadOp>(loc, ptrVar);
+    mlir::Value ptrToInt =
+        rewriter.create<fir::ConvertOp>(loc, intPtrTy, ptrVal);
+    mlir::Value isAllocated = rewriter.create<mlir::arith::CmpIOp>(
+        loc, mlir::arith::CmpIPredicate::ne, ptrToInt, c0);
+    auto ifOp = rewriter.create<fir::IfOp>(loc, std::nullopt, isAllocated,
+                                           /*withElseRegion=*/false);
+    rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    mlir::Value cast =
+        castIfNeeed(loc, rewriter, replacement.getType(), ptrVal);
+    deallocGenerator(loc, rewriter, cast);
+    // Currently there is no need to reset the pointer var because two
+    // deallocation points can never be reached without going through the
+    // alloca.
+    rewriter.setInsertionPointAfter(ifOp);
+  };
+  for (mlir::Operation *deallocPoint : deallocationPoints) {
+    rewriter.setInsertionPoint(deallocPoint);
+    genConditionalDealloc(deallocPoint->getLoc());
+  }
+}
+
+bool AllocaReplaceImpl::replace(mlir::RewriterBase &rewriter,
+                                fir::AllocaOp alloca) {
+  llvm::SmallVector<mlir::Operation *> deallocationPoints;
+  mlir::Region *owningRegion =
+      findDeallocationPointsAndOwner(alloca, deallocationPoints);
+  if (!owningRegion)
+    return false;
+  // return false;
+  rewriter.setInsertionPointAfter(alloca.getOperation());
+  bool deallocPointsDominateAlloc =
+      allocDominatesDealloc(alloca, deallocationPoints);
+  if (mlir::Value replacement =
+          allocaRewriter(rewriter, alloca, deallocPointsDominateAlloc)) {
+    mlir::Value castReplacement =
+        castIfNeeed(alloca.getLoc(), rewriter, alloca.getType(), replacement);
+    if (deallocPointsDominateAlloc)
+      for (mlir::Operation *deallocPoint : deallocationPoints) {
+        rewriter.setInsertionPoint(deallocPoint);
+        deallocGenerator(deallocPoint->getLoc(), rewriter, replacement);
+      }
+    else
+      genIndirectDeallocation(rewriter, alloca, deallocationPoints, replacement,
+                              *owningRegion);
+    rewriter.replaceOp(alloca, castReplacement);
+  }
+  return true;
+}
+
+bool fir::replaceAllocas(mlir::RewriterBase &rewriter,
+                         mlir::Operation *parentOp,
+                         MustRewriteCallBack mustReplace,
+                         AllocaRewriterCallBack allocaRewriter,
+                         DeallocCallBack deallocGenerator) {
+  // If the parent operation is not an alloca owner, the code below would risk
+  // modifying IR outside of parentOp.
+  if (!fir::AllocaOp::ownsNestedAlloca(parentOp))
+    return false;
+  auto insertPoint = rewriter.saveInsertionPoint();
+  bool replacedAllRequestedAlloca = true;
+  AllocaReplaceImpl impl(allocaRewriter, deallocGenerator);
+  parentOp->walk([&](fir::AllocaOp alloca) {
+    if (mustReplace(alloca))
+      replacedAllRequestedAlloca &= impl.replace(rewriter, alloca);
+  });
+  rewriter.restoreInsertionPoint(insertPoint);
+  return replacedAllRequestedAlloca;
+}
diff --git a/flang/test/Fir/memory-allocation-opt-2.fir b/flang/test/Fir/memory-allocation-opt-2.fir
new file mode 100644
index 0000000000000..469276858d0bf
--- /dev/null
+++ b/flang/test/Fir/memory-allocation-opt-2.fir
@@ -0,0 +1,105 @@
+// Test memory allocation pass for fir.alloca outside of function entry block
+// RUN: fir-opt --memory-allocation-opt="dynamic-array-on-heap=true" %s | FileCheck %s
+
+func.func @test_loop() {
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  fir.do_loop %arg0 = %c1 to %c100 step %c1 {
+    %1 = fir.alloca !fir.array<?xf32>, %arg0
+    fir.call @bar(%1) : (!fir.ref<!fir.array<?xf32>>) -> ()
+    fir.result
+  }
+  return
+}
+// CHECK-LABEL:   func.func @test_loop() {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 100 : index
+// CHECK:           fir.do_loop %[[VAL_2:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_0]] {
+// CHECK:             %[[VAL_3:.*]] = fir.allocmem !fir.array<?xf32>, %[[VAL_2]] {bindc_name = "", uniq_name = ""}
+// CHECK:             %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.heap<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+// CHECK:             fir.call @bar(%[[VAL_4]]) : (!fir.ref<!fir.array<?xf32>>) -> ()
+// CHECK:             fir.freemem %[[VAL_3]] : !fir.heap<!fir.array<?xf32>>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @test_unstructured(%n : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %0 = fir.alloca index
+  fir.store %c100 to %0 : !fir.ref<index>
+  cf.br ^bb1
+^bb1:  // 2 preds: ^bb0, ^bb4
+  %5 = fir.load %0 : !fir.ref<index>
+  %6 = arith.cmpi sgt, %5, %c0 : index
+  cf.cond_br %6, ^bb2, ^bb5
+^bb2:  // pred: ^bb1
+  %1 = fir.alloca !fir.array<?xf32>, %5
+  fir.call @bar(%1) : (!fir.ref<!fir.array<?xf32>>) -> ()
+  %25 = arith.cmpi slt, %5, %n : index
+  cf.cond_br %25, ^bb3, ^bb4
+^bb3:  // pred: ^bb2
+  fir.call @abort() : () -> ()
+  fir.unreachable
+^bb4:  // pred: ^bb2
+  %28 = arith.subi %5, %c1 : index
+  fir.store %28 to %0 : !fir.ref<index>
+  cf.br ^bb1
+^bb5:  // pred: ^bb1
+  return
+}
+// CHECK-LABEL:   func.func @test_unstructured(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: index) {
+// CHECK:           %[[VAL_1:.*]] = fir.alloca !fir.heap<!fir.array<?xf32>>
+// CHECK:           %[[VAL_2:.*]] = fir.zero_bits !fir.heap<!fir.array<?xf32>>
+// CHECK:           fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<?xf32>>>
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : i64
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_7:.*]] = fir.alloca index
+// CHECK:           fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<index>
+// CHECK:           cf.br ^bb1
+// CHECK:         ^bb1:
+// CHECK:           %[[VAL_8:.*]] = fir.load %[[VAL_7]] : !fir.ref<index>
+// CHECK:           %[[VAL_9:.*]] = arith.cmpi sgt, %[[VAL_8]], %[[VAL_4]] : index
+// CHECK:           cf.cond_br %[[VAL_9]], ^bb2, ^bb5
+// CHECK:         ^bb2:
+// CHECK:           %[[VAL_10:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<?xf32>>>
+// CHECK:           %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (!fir.heap<!fir.array<?xf32>>) -> i64
+// CHECK:           %[[VAL_12:.*]] = arith.cmpi ne, %[[VAL_11]], %[[VAL_3]] : i64
+// CHECK:           fir.if %[[VAL_12]] {
+// CHECK:             fir.freemem %[[VAL_10]] : !fir.heap<!fir.array<?xf32>>
+// CHECK:           }
+// CHECK:           %[[VAL_13:.*]] = fir.allocmem !fir.array<?xf32>, %[[VAL_8]] {bindc_name = "", uniq_name = ""}
+// CHECK:           %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (!fir.heap<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+// CHECK:           fir.store %[[VAL_13]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<?xf32>>>
+// CHECK:           fir.call @bar(%[[VAL_14]]) : (!fir.ref<!fir.array<?xf32>>) -> ()
+// CHECK:           %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_8]], %[[VAL_0]] : index
+// CHECK:           cf.cond_br %[[VAL_15]], ^bb3, ^bb4
+// CHECK:         ^bb3:
+// CHECK:           fir.call @abort() : () -> ()
+// CHECK:           %[[VAL_16:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<?xf32>>>
+// CHECK:           %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (!fir.heap<!fir.array<?xf32>>) -> i64
+// CHECK:           %[[VAL_18:.*]] = arith.cmpi ne, %[[VAL_17]], %[[VAL_3]] : i64
+// CHECK:           fir.if %[[VAL_18]] {
+// CHECK:             fir.freemem %[[VAL_16]] : !fir.heap<!fir.array<?xf32>>
+// CHECK:           }
+// CHECK:           fir.unreachable
+// CHECK:         ^bb4:
+// CHECK:           %[[VAL_19:.*]] = arith.subi %[[VAL_8]], %[[VAL_5]] : index
+// CHECK:           fir.store %[[VAL_19]] to %[[VAL_7]] : !fir.ref<index>
+// CHECK:           cf.br ^bb1
+// CHECK:         ^bb5:
+// CHECK:           %[[VAL_20:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<?xf32>>>
+// CHECK:           %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (!fir.heap<!fir.array<?xf32>>) -> i64
+// CHECK:           %[[VAL_22:.*]] = arith.cmpi ne, %[[VAL_21]], %[[VAL_3]] : i64
+// CHECK:           fir.if %[[VAL_22]] {
+// CHECK:             fir.freemem %[[VAL_20]] : !fir.heap<!fir.array<?xf32>>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func private @bar(!fir.ref<!fir.array<?xf32>>)
+func.func private @abort()



More information about the flang-commits mailing list