[flang-commits] [flang] [llvm] [flang][runtime] Replace recursion with iterative work queue (WORK IN PROGRESS) (PR #137727)
via flang-commits
flang-commits at lists.llvm.org
Tue Apr 29 14:03:55 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
Recursion, both direct and indirect, prevents accurate stack size calculation at link time for GPU device code. Restructure these recursive (often mutually so) routines in the Fortran runtime with new implementations based on an iterative work queue with suspendable/resumable work tickets: Assign, Initialize, initializeClone, Finalize, and Destroy.
Default derived type I/O is also recursive, but already disabled. It can be added to this new framework later if the overall approach succeeds.
Note that derived type FINAL subroutine calls, defined assignments, and defined I/O procedures all perform callbacks into user code, which may well reenter the runtime library. This kind of recursion is not handled by this change, although it may be possible to do so in the future using thread-local work queues.
The effects of this restructuring on CPU performance are yet to be measured.
---
Patch is 73.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/137727.diff
8 Files Affected:
- (added) flang-rt/include/flang-rt/runtime/work-queue.h (+278)
- (modified) flang-rt/lib/runtime/CMakeLists.txt (+10)
- (modified) flang-rt/lib/runtime/assign.cpp (+303-239)
- (modified) flang-rt/lib/runtime/derived.cpp (+244-243)
- (modified) flang-rt/lib/runtime/type-info.cpp (+3-3)
- (added) flang-rt/lib/runtime/work-queue.cpp (+172)
- (modified) flang/include/flang/Runtime/assign.h (+1-1)
- (modified) flang/runtime/CMakeLists.txt (+2)
``````````diff
diff --git a/flang-rt/include/flang-rt/runtime/work-queue.h b/flang-rt/include/flang-rt/runtime/work-queue.h
new file mode 100644
index 0000000000000..00c4562d08105
--- /dev/null
+++ b/flang-rt/include/flang-rt/runtime/work-queue.h
@@ -0,0 +1,278 @@
+//===-- include/flang-rt/runtime/work-queue.h -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// Internal runtime utilities for work queues that replace the use of recursion
+// for better GPU device support.
+//
+// A work queue is a list of tickets. Each ticket class has a Begin()
+// member function that is called once, and a Continue() member function
+// that can be called zero or more times. A ticket's execution terminates
+// when either of these member functions returns a status other than
+// StatOkContinue, and if that status is not StatOk, then the whole queue
+// is shut down.
+//
+// By returning StatOkContinue from its Continue() member function,
+// a ticket suspends its execution so that any nested tickets that it
+// may have created can be run to completion. It is the reponsibility
+// of each ticket class to maintain resumption information in its state
+// and manage its own progress. Most ticket classes inherit from
+// class ComponentTicketBase, which implements an outer loop over all
+// components of a derived type, and an inner loop over all elements
+// of a descriptor, possibly with multiple phases of execution per element.
+//
+// Tickets are created by WorkQueue::Begin...() member functions.
+// There is one of these for each "top level" recursive function in the
+// Fortran runtime support library that has been restructured into this
+// ticket framework.
+//
+// When the work queue is running tickets, it always selects the last ticket
+// on the list for execution -- "work stack" might have been a more accurate
+// name for this framework. This ticket may, while doing its job, create
+// new tickets, and since those are pushed after the active one, the first
+// such nested ticket will be the next one executed to completion -- i.e.,
+// the order of nested WorkQueue::Begin...() calls is respected.
+// Note that a ticket's Continue() member function won't be called again
+// until all nested tickets have run to completion and it is once again
+// the last ticket on the queue.
+//
+// Example for an assignment to a derived type:
+// 1. Assign() is called, and its work queue is created. It calls
+// WorkQueue::BeginAssign() and then WorkQueue::Run().
+// 2. Run calls AssignTicket::Begin(), which pushes a tickets via
+// BeginFinalize() and returns StatOkContinue.
+// 3. FinalizeTicket::Begin() and FinalizeTicket::Continue() are called
+// until one of them returns StatOk, which ends the finalization ticket.
+// 4. AssignTicket::Continue() is then called; it creates a DerivedAssignTicket
+// and then returns StatOk, which ends the ticket.
+// 5. At this point, only one ticket remains. DerivedAssignTicket::Begin()
+// and ::Continue() are called until they are done (not StatOkContinue).
+// Along the way, it may create nested AssignTickets for components,
+// and suspend itself so that they may each run to completion.
+
+#ifndef FLANG_RT_RUNTIME_WORK_QUEUE_H_
+#define FLANG_RT_RUNTIME_WORK_QUEUE_H_
+
+#include "flang-rt/runtime/descriptor.h"
+#include "flang-rt/runtime/stat.h"
+#include "flang/Common/api-attrs.h"
+#include "flang/Runtime/freestanding-tools.h"
+#include <flang/Common/variant.h>
+
+namespace Fortran::runtime {
+class Terminator;
+class WorkQueue;
+namespace typeInfo {
+class DerivedType;
+class Component;
+} // namespace typeInfo
+
+// Ticket workers
+
+// Ticket workers return status codes. Returning StatOkContinue means
+// that the ticket is incomplete and must be resumed; any other value
+// means that the ticket is complete, and if not StatOk, the whole
+// queue can be shut down due to an error.
+static constexpr int StatOkContinue{1234};
+
+struct NullTicket {
+ RT_API_ATTRS int Begin(WorkQueue &) const { return StatOk; }
+ RT_API_ATTRS int Continue(WorkQueue &) const { return StatOk; }
+};
+
+// Base class for ticket workers that operate elementwise over descriptors
+// TODO: if ComponentTicketBase remains this class' only client,
+// merge them for better comprehensibility.
+class ElementalTicketBase {
+protected:
+ RT_API_ATTRS ElementalTicketBase(const Descriptor &instance)
+ : instance_{instance} {
+ instance_.GetLowerBounds(subscripts_);
+ }
+ RT_API_ATTRS bool CueUpNextItem() const { return elementAt_ < elements_; }
+ RT_API_ATTRS void AdvanceToNextElement() {
+ phase_ = 0;
+ ++elementAt_;
+ instance_.IncrementSubscripts(subscripts_);
+ }
+ RT_API_ATTRS void Reset() {
+ phase_ = 0;
+ elementAt_ = 0;
+ instance_.GetLowerBounds(subscripts_);
+ }
+
+ const Descriptor &instance_;
+ std::size_t elements_{instance_.Elements()};
+ std::size_t elementAt_{0};
+ int phase_{0};
+ SubscriptValue subscripts_[common::maxRank];
+};
+
+// Base class for ticket workers that operate over derived type components
+// in an outer loop, and elements in an inner loop.
+class ComponentTicketBase : protected ElementalTicketBase {
+protected:
+ RT_API_ATTRS ComponentTicketBase(
+ const Descriptor &instance, const typeInfo::DerivedType &derived);
+ RT_API_ATTRS bool CueUpNextItem();
+ RT_API_ATTRS void AdvanceToNextComponent() { elementAt_ = elements_; }
+
+ const typeInfo::DerivedType &derived_;
+ const typeInfo::Component *component_{nullptr};
+ std::size_t components_{0}, componentAt_{0};
+ StaticDescriptor<common::maxRank, true, 0> componentDescriptor_;
+};
+
+// Implements derived type instance initialization
+class InitializeTicket : private ComponentTicketBase {
+public:
+ RT_API_ATTRS InitializeTicket(
+ const Descriptor &instance, const typeInfo::DerivedType &derived)
+ : ComponentTicketBase{instance, derived} {}
+ RT_API_ATTRS int Begin(WorkQueue &);
+ RT_API_ATTRS int Continue(WorkQueue &);
+};
+
+// Initializes one derived type instance from the value of another
+class InitializeCloneTicket : private ComponentTicketBase {
+public:
+ RT_API_ATTRS InitializeCloneTicket(const Descriptor &clone,
+ const Descriptor &original, const typeInfo::DerivedType &derived,
+ bool hasStat, const Descriptor *errMsg)
+ : ComponentTicketBase{original, derived}, clone_{clone},
+ hasStat_{hasStat}, errMsg_{errMsg} {}
+ RT_API_ATTRS int Begin(WorkQueue &) { return StatOkContinue; }
+ RT_API_ATTRS int Continue(WorkQueue &);
+
+private:
+ const Descriptor &clone_;
+ bool hasStat_{false};
+ const Descriptor *errMsg_{nullptr};
+ StaticDescriptor<common::maxRank, true, 0> cloneComponentDescriptor_;
+};
+
+// Implements derived type instance finalization
+class FinalizeTicket : private ComponentTicketBase {
+public:
+ RT_API_ATTRS FinalizeTicket(
+ const Descriptor &instance, const typeInfo::DerivedType &derived)
+ : ComponentTicketBase{instance, derived} {}
+ RT_API_ATTRS int Begin(WorkQueue &);
+ RT_API_ATTRS int Continue(WorkQueue &);
+
+private:
+ const typeInfo::DerivedType *finalizableParentType_{nullptr};
+};
+
+// Implements derived type instance destruction
+class DestroyTicket : private ComponentTicketBase {
+public:
+ RT_API_ATTRS DestroyTicket(const Descriptor &instance,
+ const typeInfo::DerivedType &derived, bool finalize)
+ : ComponentTicketBase{instance, derived}, finalize_{finalize} {}
+ RT_API_ATTRS int Begin(WorkQueue &);
+ RT_API_ATTRS int Continue(WorkQueue &);
+
+private:
+ bool finalize_{false};
+};
+
+// Implements general intrinsic assignment
+class AssignTicket {
+public:
+ RT_API_ATTRS AssignTicket(
+ Descriptor &to, const Descriptor &from, int flags, MemmoveFct memmoveFct)
+ : to_{to}, from_{&from}, flags_{flags}, memmoveFct_{memmoveFct} {}
+ RT_API_ATTRS int Begin(WorkQueue &);
+ RT_API_ATTRS int Continue(WorkQueue &);
+
+private:
+ RT_API_ATTRS bool IsSimpleMemmove() const {
+ return !toDerived_ && to_.rank() == from_->rank() && to_.IsContiguous() &&
+ from_->IsContiguous() && to_.ElementBytes() == from_->ElementBytes();
+ }
+ RT_API_ATTRS Descriptor &GetTempDescriptor();
+
+ Descriptor &to_;
+ const Descriptor *from_{nullptr};
+ int flags_{0}; // enum AssignFlags
+ MemmoveFct memmoveFct_{nullptr};
+ StaticDescriptor<common::maxRank, true, 0> tempDescriptor_;
+ const typeInfo::DerivedType *toDerived_{nullptr};
+ Descriptor *toDeallocate_{nullptr};
+ bool persist_{false};
+ bool done_{false};
+};
+
+// Implements derived type intrinsic assignment
+class DerivedAssignTicket : private ComponentTicketBase {
+public:
+ RT_API_ATTRS DerivedAssignTicket(const Descriptor &to, const Descriptor &from,
+ const typeInfo::DerivedType &derived, int flags, MemmoveFct memmoveFct,
+ Descriptor *deallocateAfter)
+ : ComponentTicketBase{to, derived}, from_{from}, flags_{flags},
+ memmoveFct_{memmoveFct}, deallocateAfter_{deallocateAfter} {}
+ RT_API_ATTRS int Begin(WorkQueue &);
+ RT_API_ATTRS int Continue(WorkQueue &);
+ RT_API_ATTRS void AdvanceToNextElement();
+
+private:
+ const Descriptor &from_;
+ int flags_{0};
+ MemmoveFct memmoveFct_{nullptr};
+ Descriptor *deallocateAfter_{nullptr};
+ SubscriptValue fromSubscripts_[common::maxRank];
+ StaticDescriptor<common::maxRank, true, 0> fromComponentDescriptor_;
+};
+
+struct Ticket {
+ RT_API_ATTRS int Continue(WorkQueue &);
+ bool begun{false};
+ std::variant<NullTicket, InitializeTicket, InitializeCloneTicket,
+ FinalizeTicket, DestroyTicket, AssignTicket, DerivedAssignTicket>
+ u{NullTicket{}};
+};
+
+class WorkQueue {
+public:
+ RT_API_ATTRS explicit WorkQueue(Terminator &terminator)
+ : terminator_{terminator} {}
+ RT_API_ATTRS ~WorkQueue();
+ RT_API_ATTRS Terminator &terminator() { return terminator_; };
+
+ RT_API_ATTRS void BeginInitialize(
+ const Descriptor &descriptor, const typeInfo::DerivedType &derived);
+ RT_API_ATTRS void BeginInitializeClone(const Descriptor &clone,
+ const Descriptor &original, const typeInfo::DerivedType &derived,
+ bool hasStat, const Descriptor *errMsg);
+ RT_API_ATTRS void BeginFinalize(
+ const Descriptor &descriptor, const typeInfo::DerivedType &derived);
+ RT_API_ATTRS void BeginDestroy(const Descriptor &descriptor,
+ const typeInfo::DerivedType &derived, bool finalize);
+ RT_API_ATTRS void BeginAssign(
+ Descriptor &to, const Descriptor &from, int flags, MemmoveFct memmoveFct);
+ RT_API_ATTRS void BeginDerivedAssign(Descriptor &to, const Descriptor &from,
+ const typeInfo::DerivedType &derived, int flags, MemmoveFct memmoveFct,
+ Descriptor *deallocateAfter);
+
+ RT_API_ATTRS int Run();
+
+private:
+ struct TicketList {
+ Ticket ticket;
+ TicketList *previous{nullptr}, *next{nullptr};
+ };
+
+ RT_API_ATTRS Ticket &StartTicket();
+ RT_API_ATTRS void Stop();
+ Terminator &terminator_;
+ TicketList *first_{nullptr}, *last_{nullptr}, *insertAfter_{nullptr};
+ TicketList *firstFree_{nullptr};
+};
+
+} // namespace Fortran::runtime
+#endif // FLANG_RT_RUNTIME_WORK_QUEUE_H_
diff --git a/flang-rt/lib/runtime/CMakeLists.txt b/flang-rt/lib/runtime/CMakeLists.txt
index c5e7bdce5b2fd..a0ebcf4157ff5 100644
--- a/flang-rt/lib/runtime/CMakeLists.txt
+++ b/flang-rt/lib/runtime/CMakeLists.txt
@@ -12,6 +12,14 @@ find_package(Backtrace)
set(HAVE_BACKTRACE ${Backtrace_FOUND})
set(BACKTRACE_HEADER ${Backtrace_HEADER})
+# BE ADVISED: If you are about to add a new source file to one or more
+# of "supported_sources", "host_sources", or "gpu_sources" lists, you
+# probably need to also add that file to "flang/runtime/CMakeLists.txt",
+# which still exists and is still used for some purposes. If you do not,
+# you will get confusing unsatisfied external references when unit tests
+# are linked. I don't know why things are this way or whether anybody
+# is going to fix it. Hope this helps!
+
# List of files that are buildable for all devices.
set(supported_sources
${FLANG_SOURCE_DIR}/lib/Decimal/binary-to-decimal.cpp
@@ -67,6 +75,7 @@ set(supported_sources
type-info.cpp
unit.cpp
utf.cpp
+ work-queue.cpp
)
# List of source not used for GPU offloading.
@@ -130,6 +139,7 @@ set(gpu_sources
type-code.cpp
type-info.cpp
utf.cpp
+ work-queue.cpp
complex-powi.cpp
reduce.cpp
reduction.cpp
diff --git a/flang-rt/lib/runtime/assign.cpp b/flang-rt/lib/runtime/assign.cpp
index 4a813cd489022..054d9672bcf11 100644
--- a/flang-rt/lib/runtime/assign.cpp
+++ b/flang-rt/lib/runtime/assign.cpp
@@ -14,6 +14,7 @@
#include "flang-rt/runtime/terminator.h"
#include "flang-rt/runtime/tools.h"
#include "flang-rt/runtime/type-info.h"
+#include "flang-rt/runtime/work-queue.h"
namespace Fortran::runtime {
@@ -99,11 +100,7 @@ static RT_API_ATTRS int AllocateAssignmentLHS(
toDim.SetByteStride(stride);
stride *= toDim.Extent();
}
- int result{ReturnError(terminator, to.Allocate(kNoAsyncId))};
- if (result == StatOk && derived && !derived->noInitializationNeeded()) {
- result = ReturnError(terminator, Initialize(to, *derived, terminator));
- }
- return result;
+ return ReturnError(terminator, to.Allocate(kNoAsyncId));
}
// least <= 0, most >= 0
@@ -228,6 +225,8 @@ static RT_API_ATTRS void BlankPadCharacterAssignment(Descriptor &to,
}
}
+RT_OFFLOAD_API_GROUP_BEGIN
+
// Common implementation of assignments, both intrinsic assignments and
// those cases of polymorphic user-defined ASSIGNMENT(=) TBPs that could not
// be resolved in semantics. Most assignment statements do not need any
@@ -241,274 +240,339 @@ static RT_API_ATTRS void BlankPadCharacterAssignment(Descriptor &to,
// dealing with array constructors.
RT_API_ATTRS void Assign(Descriptor &to, const Descriptor &from,
Terminator &terminator, int flags, MemmoveFct memmoveFct) {
- bool mustDeallocateLHS{(flags & DeallocateLHS) ||
- MustDeallocateLHS(to, from, terminator, flags)};
- DescriptorAddendum *toAddendum{to.Addendum()};
- const typeInfo::DerivedType *toDerived{
- toAddendum ? toAddendum->derivedType() : nullptr};
- if (toDerived && (flags & NeedFinalization) &&
- toDerived->noFinalizationNeeded()) {
- flags &= ~NeedFinalization;
- }
- std::size_t toElementBytes{to.ElementBytes()};
- std::size_t fromElementBytes{from.ElementBytes()};
- // The following lambda definition violates the conding style,
- // but cuda-11.8 nvcc hits an internal error with the brace initialization.
- auto isSimpleMemmove = [&]() {
- return !toDerived && to.rank() == from.rank() && to.IsContiguous() &&
- from.IsContiguous() && toElementBytes == fromElementBytes;
- };
- StaticDescriptor<maxRank, true, 10 /*?*/> deferredDeallocStatDesc;
- Descriptor *deferDeallocation{nullptr};
- if (MayAlias(to, from)) {
+ WorkQueue workQueue{terminator};
+ workQueue.BeginAssign(to, from, flags, memmoveFct);
+ workQueue.Run();
+}
+
+RT_API_ATTRS int AssignTicket::Begin(WorkQueue &workQueue) {
+ bool mustDeallocateLHS{(flags_ & DeallocateLHS) ||
+ MustDeallocateLHS(to_, *from_, workQueue.terminator(), flags_)};
+ DescriptorAddendum *toAddendum{to_.Addendum()};
+ toDerived_ = toAddendum ? toAddendum->derivedType() : nullptr;
+ if (toDerived_ && (flags_ & NeedFinalization) &&
+ toDerived_->noFinalizationNeeded()) {
+ flags_ &= ~NeedFinalization;
+ }
+ const typeInfo::SpecialBinding *scalarDefinedAssignment{nullptr};
+ const typeInfo::SpecialBinding *elementalDefinedAssignment{nullptr};
+ if (toDerived_ && (flags_ & CanBeDefinedAssignment)) {
+ // Check for a user-defined assignment type-bound procedure;
+ // see 10.2.1.4-5. A user-defined assignment TBP defines all of
+ // the semantics, including allocatable (re)allocation and any
+ // finalization.
+ //
+ // Note that the aliasing and LHS (re)allocation handling below
+ // needs to run even with CanBeDefinedAssignment flag, since
+ // Assign() can be invoked recursively for component-wise assignments.
+ if (to_.rank() == 0) {
+ scalarDefinedAssignment = toDerived_->FindSpecialBinding(
+ typeInfo::SpecialBinding::Which::ScalarAssignment);
+ }
+ if (!scalarDefinedAssignment) {
+ elementalDefinedAssignment = toDerived_->FindSpecialBinding(
+ typeInfo::SpecialBinding::Which::ElementalAssignment);
+ }
+ }
+ if (MayAlias(to_, *from_)) {
if (mustDeallocateLHS) {
- deferDeallocation = &deferredDeallocStatDesc.descriptor();
+ // Convert the LHS into a temporary, then make it look deallocated.
+ toDeallocate_ = &tempDescriptor_.descriptor();
+ persist_ = true; // tempDescriptor_ state must outlive child tickets
std::memcpy(
- reinterpret_cast<void *>(deferDeallocation), &to, to.SizeInBytes());
- to.set_base_addr(nullptr);
- } else if (!isSimpleMemmove()) {
+ reinterpret_cast<void *>(toDeallocate_), &to_, to_.SizeInBytes());
+ to_.set_base_addr(nullptr);
+ } else if (!IsSimpleMemmove() || scalarDefinedAssignment ||
+ elementalDefinedAssignment) {
// Handle LHS/RHS aliasing by copying RHS into a temp, then
// recursively assigning from that temp.
- auto descBytes{from.SizeInBytes()};
- StaticDescriptor<maxRank, true, 16> staticDesc;
- Descriptor &newFrom{staticDesc.descriptor()};
- std::memcpy(reinterpret_cast<void *>(&newFrom), &from, descBytes);
+ auto descBytes{from_->SizeInBytes()};
+ Descriptor &newFrom{tempDescriptor_.descriptor()};
+ persist_ = true; // tempDescriptor_ state must outlive child tickets
+ std::memcpy(reinterpret_cast<void *>(&newFrom), from_, descBytes);
// Pretend the temporary descriptor is for an ALLOCATABLE
// entity, otherwise, the Deallocate() below will not
// free the descriptor memory.
newFrom.raw().attribute = CFI_attribute_allocatable;
- auto stat{ReturnError(terminator, newFrom.Allocate(kNoAsyncId))};
- if (stat == StatOk) {
- if (HasDynamicComponent(from)) {
- // If 'from' has allocatable/automatic component, we cannot
- // just make a shallow copy of the descriptor member.
- // This will still leave data overlap in 'to' and 'newFrom'.
- // For example:
- // type t
- // character, allocatable :: c(:)
- // end type t
- // type(t) :: x(3)
- // x(2:3) = x(1:2)
- // We have to make a deep copy into 'newFrom' in this case.
- RTNAME(AssignTemporary)
- (newFrom, from, terminator.sourceFileName(), terminator.sourceLine());
- } else {
- ShallowCopy(newFrom, from, true, from.IsContiguous());
+ if (int stat{ReturnError(
+ workQueue.terminator(), newFrom.Allocate(kNoAsyncId))};
+ stat != StatOk) {
+ return stat;
+ }
+ if (HasDynamicComponent(*from_)) {
+ // If 'from' has allocatable/automatic component, we cannot
+ // just make a shallow copy of the descriptor member.
+ // This will still leave data overlap in 'to' and 'newFrom'.
+ // For example:
+ // type t
+ // character, allocatable :: c(:)
+ // end type t
+ // type(t) :: x(3)
+ // x(2:3) = x(1:2)
+ // We have to make a deep copy into 'newFrom' in this case.
+ if (const DescriptorAddendum * addendum{newFrom.Addendum()}) {
+ if (const auto *derived{addendum->derivedType()}) {
+ if (!derived->noInitializationNeeded()) {
+ workQueue.BeginInitialize(newFrom, *derived);
+ }
+ }
}
- Assign(to, newFrom, terminator,
- flags &
- (NeedFinalization | ComponentCanBeDefinedAssignment |
- ExplicitLengthCharacterLHS | CanBeDefinedAssignment));
- newFrom.Deallocate();
+ workQueue.BeginAssign(
+ newFrom, *from_, MaybeReallocate | PolymorphicLHS, memmoveFct_);
+ } else {
+ ShallowCopy(newFrom, *from_, true, from_->IsContiguous());
}
- return;
+ from_ = &newFrom;
+ flags_ &= ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/137727
More information about the flang-commits
mailing list