[flang-commits] [flang] [flang][runtime] Avoid call recursion in CopyElement runtime. (PR #101421)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Wed Jul 31 15:12:29 PDT 2024


https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/101421

Device compilers may fail to identify maximum stack size required
by a kernel that calls CopyElement due to potential recursive calls.
To avoid this, we can use dynamically allocated Stack. To avoid
dynamic allocations on the host for simple cases, the Stack implementation
has a reserved space (that ends up being allocated on the program stack).
I tested both pre-allocated and 0-reserve implementations on the host,
and all passed. The actual reserve values might be tuned as needed.


>From ca6890b2271dd2a2616060930a9e9354c7f7d8b7 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 31 Jul 2024 15:04:17 -0700
Subject: [PATCH] [flang][runtime] Avoid call recursion in CopyElement runtime.

Device compilers may fail to identify maximum stack size required
by a kernel that calls CopyElement due to potential recursive calls.
To avoid this, we can use dynamically allocated Stack. To avoid
dynamic allocations on the host for simple cases, the Stack implementation
has a reserved space (that ends up being allocated on the program stack).
I tested both pre-allocated and 0-reserve implementations on the host,
and all passed. The actual reserve values might be tuned as needed.
---
 flang/runtime/copy.cpp | 215 +++++++++++++++++++++++++++++------------
 flang/runtime/copy.h   |   4 -
 flang/runtime/stack.h  | 136 ++++++++++++++++++++++++++
 3 files changed, 291 insertions(+), 64 deletions(-)
 create mode 100644 flang/runtime/stack.h

diff --git a/flang/runtime/copy.cpp b/flang/runtime/copy.cpp
index 7cf9483654141..c2dbbc4a11c06 100644
--- a/flang/runtime/copy.cpp
+++ b/flang/runtime/copy.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "copy.h"
+#include "stack.h"
 #include "terminator.h"
 #include "type-info.h"
 #include "flang/Runtime/allocatable.h"
@@ -14,76 +15,170 @@
 #include <cstring>
 
 namespace Fortran::runtime {
+namespace {
+using StaticDescTy = StaticDescriptor<maxRank, true, 0>;
+
+// A structure describing the data copy that needs to be done
+// from one descriptor to another. It is a helper structure
+// for CopyElement.
+struct CopyDescriptor {
+  // A constructor specifying all members explicitly.
+  RT_API_ATTRS CopyDescriptor(const Descriptor &to, const SubscriptValue toAt[],
+      const Descriptor &from, const SubscriptValue fromAt[],
+      std::size_t elements, bool usesStaticDescriptors = false)
+      : to_(to), from_(from), elements_(elements),
+        usesStaticDescriptors_(usesStaticDescriptors) {
+    for (int dim{0}; dim < to.rank(); ++dim) {
+      toAt_[dim] = toAt[dim];
+    }
+    for (int dim{0}; dim < from.rank(); ++dim) {
+      fromAt_[dim] = fromAt[dim];
+    }
+  }
+  // The number of elements to copy is initialized from the to descriptor.
+  // The current element subscripts are initialized from the lower bounds
+  // of the to and from descriptors.
+  RT_API_ATTRS CopyDescriptor(const Descriptor &to, const Descriptor &from,
+      bool usesStaticDescriptors = false)
+      : to_(to), from_(from), elements_(to.Elements()),
+        usesStaticDescriptors_(usesStaticDescriptors) {
+    to.GetLowerBounds(toAt_);
+    from.GetLowerBounds(fromAt_);
+  }
+
+  // Descriptor of the destination.
+  const Descriptor &to_;
+  // A subscript specifying the current element position to copy to.
+  SubscriptValue toAt_[maxRank];
+  // Descriptor of the source.
+  const Descriptor &from_;
+  // A subscript specifying the current element position to copy from.
+  SubscriptValue fromAt_[maxRank];
+  // Number of elements left to copy.
+  std::size_t elements_;
+  // Must be true, if the to and from descriptors are allocated
+  // by the CopyElement runtime. The allocated memory belongs
+  // to a separate stack that needs to be popped in correspondence
+  // with popping such a CopyDescriptor node.
+  bool usesStaticDescriptors_;
+};
+
+// A pair of StaticDescTy elements.
+struct StaticDescriptorsPair {
+  StaticDescTy to;
+  StaticDescTy from;
+};
+} // namespace
+
 RT_OFFLOAD_API_GROUP_BEGIN
 
 RT_API_ATTRS void CopyElement(const Descriptor &to, const SubscriptValue toAt[],
     const Descriptor &from, const SubscriptValue fromAt[],
     Terminator &terminator) {
-  char *toPtr{to.Element<char>(toAt)};
-  char *fromPtr{from.Element<char>(fromAt)};
-  RUNTIME_CHECK(terminator, to.ElementBytes() == from.ElementBytes());
-  std::memcpy(toPtr, fromPtr, to.ElementBytes());
-  // Deep copy allocatable and automatic components if any.
-  if (const auto *addendum{to.Addendum()}) {
-    if (const auto *derived{addendum->derivedType()};
-        derived && !derived->noDestructionNeeded()) {
-      RUNTIME_CHECK(terminator,
-          from.Addendum() && derived == from.Addendum()->derivedType());
-      const Descriptor &componentDesc{derived->component()};
-      const typeInfo::Component *component{
-          componentDesc.OffsetElement<typeInfo::Component>()};
-      std::size_t nComponents{componentDesc.Elements()};
-      for (std::size_t j{0}; j < nComponents; ++j, ++component) {
-        if (component->genre() == typeInfo::Component::Genre::Allocatable ||
-            component->genre() == typeInfo::Component::Genre::Automatic) {
-          Descriptor &toDesc{
-              *reinterpret_cast<Descriptor *>(toPtr + component->offset())};
-          if (toDesc.raw().base_addr != nullptr) {
-            toDesc.set_base_addr(nullptr);
-            RUNTIME_CHECK(terminator, toDesc.Allocate() == CFI_SUCCESS);
-            const Descriptor &fromDesc{*reinterpret_cast<const Descriptor *>(
-                fromPtr + component->offset())};
-            CopyArray(toDesc, fromDesc, terminator);
-          }
-        } else if (component->genre() == typeInfo::Component::Genre::Data &&
-            component->derivedType() &&
-            !component->derivedType()->noDestructionNeeded()) {
-          SubscriptValue extents[maxRank];
-          const typeInfo::Value *bounds{component->bounds()};
-          for (int dim{0}; dim < component->rank(); ++dim) {
-            SubscriptValue lb{bounds[2 * dim].GetValue(&to).value_or(0)};
-            SubscriptValue ub{bounds[2 * dim + 1].GetValue(&to).value_or(0)};
-            extents[dim] = ub >= lb ? ub - lb + 1 : 0;
+#if !defined(RT_DEVICE_COMPILATION)
+  constexpr unsigned copyStackReserve{16};
+  constexpr unsigned descriptorStackReserve{6};
+#else
+  // Always use dynamic allocation on the device to avoid
+  // big stack sizes. This may be tuned as needed.
+  constexpr unsigned copyStackReserve{0};
+  constexpr unsigned descriptorStackReserve{0};
+#endif
+  // Keep a stack of CopyDescriptor's to avoid recursive calls.
+  Stack<CopyDescriptor, copyStackReserve> copyStack{terminator};
+  // Keep a separate stack of StaticDescTy pairs. These descriptors
+  // may be used for representing copies of Component::Genre::Data
+  // components (since they do not have their descriptors allocated
+  // in memory).
+  Stack<StaticDescriptorsPair, descriptorStackReserve> descriptorsStack{
+      terminator};
+  copyStack.emplace(to, toAt, from, fromAt, /*elements=*/std::size_t{1});
+
+  while (!copyStack.empty()) {
+    CopyDescriptor &currentCopy{copyStack.top()};
+    std::size_t &elements{currentCopy.elements_};
+    if (elements == 0) {
+      // This copy has been exhausted.
+      if (currentCopy.usesStaticDescriptors_) {
+        // Pop the static descriptors, if they were used
+        // for the current copy.
+        descriptorsStack.pop();
+      }
+      copyStack.pop();
+      continue;
+    }
+    const Descriptor &curTo{currentCopy.to_};
+    SubscriptValue *curToAt{currentCopy.toAt_};
+    const Descriptor &curFrom{currentCopy.from_};
+    SubscriptValue *curFromAt{currentCopy.fromAt_};
+    char *toPtr{curTo.Element<char>(curToAt)};
+    char *fromPtr{curFrom.Element<char>(curFromAt)};
+    RUNTIME_CHECK(terminator, curTo.ElementBytes() == curFrom.ElementBytes());
+    // TODO: the memcpy can be optimized when both to and from are contiguous.
+    // Moreover, if we came here from an Component::Genre::Data component,
+    // all the per-element copies are redundant, because the parent
+    // has already been copied as a whole.
+    std::memcpy(toPtr, fromPtr, curTo.ElementBytes());
+    --elements;
+    if (elements != 0) {
+      curTo.IncrementSubscripts(curToAt);
+      curFrom.IncrementSubscripts(curFromAt);
+    }
+
+    // Deep copy allocatable and automatic components if any.
+    if (const auto *addendum{curTo.Addendum()}) {
+      if (const auto *derived{addendum->derivedType()};
+          derived && !derived->noDestructionNeeded()) {
+        RUNTIME_CHECK(terminator,
+            curFrom.Addendum() && derived == curFrom.Addendum()->derivedType());
+        const Descriptor &componentDesc{derived->component()};
+        const typeInfo::Component *component{
+            componentDesc.OffsetElement<typeInfo::Component>()};
+        std::size_t nComponents{componentDesc.Elements()};
+        for (std::size_t j{0}; j < nComponents; ++j, ++component) {
+          if (component->genre() == typeInfo::Component::Genre::Allocatable ||
+              component->genre() == typeInfo::Component::Genre::Automatic) {
+            Descriptor &toDesc{
+                *reinterpret_cast<Descriptor *>(toPtr + component->offset())};
+            if (toDesc.raw().base_addr != nullptr) {
+              toDesc.set_base_addr(nullptr);
+              RUNTIME_CHECK(terminator, toDesc.Allocate() == CFI_SUCCESS);
+              const Descriptor &fromDesc{*reinterpret_cast<const Descriptor *>(
+                  fromPtr + component->offset())};
+              copyStack.emplace(toDesc, fromDesc);
+            }
+          } else if (component->genre() == typeInfo::Component::Genre::Data &&
+              component->derivedType() &&
+              !component->derivedType()->noDestructionNeeded()) {
+            SubscriptValue extents[maxRank];
+            const typeInfo::Value *bounds{component->bounds()};
+            std::size_t elements{1};
+            for (int dim{0}; dim < component->rank(); ++dim) {
+              SubscriptValue lb{bounds[2 * dim].GetValue(&curTo).value_or(0)};
+              SubscriptValue ub{
+                  bounds[2 * dim + 1].GetValue(&curTo).value_or(0)};
+              extents[dim] = ub >= lb ? ub - lb + 1 : 0;
+              elements *= extents[dim];
+            }
+            if (elements != 0) {
+              const typeInfo::DerivedType &compType{*component->derivedType()};
+              // Place a pair of static descriptors onto the descriptors stack.
+              descriptorsStack.emplace();
+              StaticDescriptorsPair &descs{descriptorsStack.top()};
+              Descriptor &toCompDesc{descs.to.descriptor()};
+              toCompDesc.Establish(compType, toPtr + component->offset(),
+                  component->rank(), extents);
+              Descriptor &fromCompDesc{descs.from.descriptor()};
+              fromCompDesc.Establish(compType, fromPtr + component->offset(),
+                  component->rank(), extents);
+              copyStack.emplace(toCompDesc, fromCompDesc,
+                  /*usesStaticDescriptors=*/true);
+            }
           }
-          const typeInfo::DerivedType &compType{*component->derivedType()};
-          StaticDescriptor<maxRank, true, 0> toStaticDescriptor;
-          Descriptor &toCompDesc{toStaticDescriptor.descriptor()};
-          toCompDesc.Establish(compType, toPtr + component->offset(),
-              component->rank(), extents);
-          StaticDescriptor<maxRank, true, 0> fromStaticDescriptor;
-          Descriptor &fromCompDesc{fromStaticDescriptor.descriptor()};
-          fromCompDesc.Establish(compType, fromPtr + component->offset(),
-              component->rank(), extents);
-          CopyArray(toCompDesc, fromCompDesc, terminator);
         }
       }
     }
   }
 }
-
-RT_API_ATTRS void CopyArray(
-    const Descriptor &to, const Descriptor &from, Terminator &terminator) {
-  std::size_t elements{to.Elements()};
-  RUNTIME_CHECK(terminator, elements == from.Elements());
-  SubscriptValue toAt[maxRank], fromAt[maxRank];
-  to.GetLowerBounds(toAt);
-  from.GetLowerBounds(fromAt);
-  while (elements-- > 0) {
-    CopyElement(to, toAt, from, fromAt, terminator);
-    to.IncrementSubscripts(toAt);
-    from.IncrementSubscripts(fromAt);
-  }
-}
-
 RT_OFFLOAD_API_GROUP_END
 } // namespace Fortran::runtime
diff --git a/flang/runtime/copy.h b/flang/runtime/copy.h
index 5d725de725735..542660530bfb6 100644
--- a/flang/runtime/copy.h
+++ b/flang/runtime/copy.h
@@ -21,9 +21,5 @@ namespace Fortran::runtime {
 RT_API_ATTRS void CopyElement(const Descriptor &to, const SubscriptValue toAt[],
     const Descriptor &from, const SubscriptValue fromAt[], Terminator &);
 
-// Copies data from one allocated descriptor's array to another.
-RT_API_ATTRS void CopyArray(
-    const Descriptor &to, const Descriptor &from, Terminator &);
-
 } // namespace Fortran::runtime
 #endif // FORTRAN_RUNTIME_COPY_H_
diff --git a/flang/runtime/stack.h b/flang/runtime/stack.h
new file mode 100644
index 0000000000000..fa438b35ec4b8
--- /dev/null
+++ b/flang/runtime/stack.h
@@ -0,0 +1,136 @@
+//===-- runtime/stack.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// Trivial implementation of stack that can be used on all targets.
+// It is a list based stack with dynamic allocation/deallocation
+// of the list nodes.
+
+#ifndef FORTRAN_RUNTIME_STACK_H
+#define FORTRAN_RUNTIME_STACK_H
+
+#include "terminator.h"
+#include "flang/Runtime/memory.h"
+
+namespace Fortran::runtime {
+// Storage for the Stack elements of type T.
+template <typename T, unsigned N> struct StackStorage {
+  void *getElement(unsigned i) {
+    if (i < N) {
+      return storage[i];
+    } else {
+      return nullptr;
+    }
+  }
+  const void *getElement(unsigned i) const {
+    if (i < N) {
+      return storage[i];
+    } else {
+      return nullptr;
+    }
+  }
+
+private:
+  // Storage to hold N elements of type T.
+  // It is declared as an array of bytes to avoid
+  // default construction (if any is implied by type T).
+  alignas(T) char storage[N][sizeof(T)];
+};
+
+// 0-size specialization that provides no storage.
+template <typename T> struct alignas(T) StackStorage<T, 0> {
+  void *getElement(unsigned) { return nullptr; }
+  const void *getElement(unsigned) const { return nullptr; }
+};
+
+template <typename T, unsigned N = 0> class Stack : public StackStorage<T, N> {
+public:
+  Stack() = delete;
+  Stack(const Stack &) = delete;
+  Stack(Stack &&) = delete;
+  RT_API_ATTRS Stack(Terminator &terminator) : terminator_{terminator} {}
+  RT_API_ATTRS ~Stack() {
+    while (!empty()) {
+      pop();
+    }
+  }
+  RT_API_ATTRS void push(const T &object) {
+    if (void *ptr{this->getElement(size_)}) {
+      new (ptr) T{object};
+    } else {
+      top_ = New<List>{terminator_}(top_, object).release();
+    }
+    ++size_;
+  }
+  RT_API_ATTRS void push(T &&object) {
+    if (void *ptr{this->getElement(size_)}) {
+      new (ptr) T{std::move(object)};
+    } else {
+      top_ = New<List>{terminator_}(top_, std::move(object)).release();
+    }
+    ++size_;
+  }
+  template <typename... Args> RT_API_ATTRS void emplace(Args &&...args) {
+    if (void *ptr{this->getElement(size_)}) {
+      new (ptr) T{std::forward<Args>(args)...};
+    } else {
+      top_ =
+          New<List>{terminator_}(top_, std::forward<Args>(args)...).release();
+    }
+    ++size_;
+  }
+  RT_API_ATTRS T &top() {
+    RUNTIME_CHECK(terminator_, size_ > 0);
+    if (void *ptr{this->getElement(size_ - 1)}) {
+      return *reinterpret_cast<T *>(ptr);
+    } else {
+      RUNTIME_CHECK(terminator_, top_);
+      return top_->object_;
+    }
+  }
+  RT_API_ATTRS const T &top() const {
+    RUNTIME_CHECK(terminator_, size_ > 0);
+    if (void *ptr{this->getElement(size_ - 1)}) {
+      return *reinterpret_cast<const T *>(ptr);
+    } else {
+      RUNTIME_CHECK(terminator_, top_);
+      return top_->object_;
+    }
+  }
+  RT_API_ATTRS void pop() {
+    RUNTIME_CHECK(terminator_, size_ > 0);
+    if (void *ptr{this->getElement(size_ - 1)}) {
+      reinterpret_cast<T *>(ptr)->~T();
+    } else {
+      RUNTIME_CHECK(terminator_, top_);
+      List *next{top_->next_};
+      top_->~List();
+      FreeMemory(top_);
+      top_ = next;
+    }
+    --size_;
+  }
+  RT_API_ATTRS bool empty() const { return size_ == 0; }
+
+private:
+  struct List {
+    template <typename... Args>
+    RT_API_ATTRS List(List *next, Args &&...args)
+        : next_(next), object_(std::forward<Args>(args)...) {}
+    RT_API_ATTRS List(List *next, const T &object)
+        : next_(next), object_(object) {}
+    RT_API_ATTRS List(List *next, T &&object)
+        : next_(next), object_(std::move(object)) {}
+    List *next_{nullptr};
+    T object_;
+  };
+  List *top_{nullptr};
+  std::size_t size_{0};
+  Terminator &terminator_;
+};
+} // namespace Fortran::runtime
+#endif // FORTRAN_RUNTIME_STACK_H



More information about the flang-commits mailing list