[flang-commits] [flang] [flang] Avoid recursion in runtime derived type initialization (PR #102394)
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Thu Aug 8 12:34:18 PDT 2024
https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/102394
>From b51ba303f213a682af6fea4c8d4eee1d1e9d387c Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 7 Aug 2024 15:01:29 -0700
Subject: [PATCH] [flang] Avoid recursion in runtime derived type
initialization
Adds a recursion-avoiding work engine for things like this, and
adapts derived type instance initialization to use it. If
successful, the engine can be reused to replace recursion in many
other runtime sites.
---
flang/runtime/CMakeLists.txt | 1 +
flang/runtime/derived.cpp | 106 ++++++++++++++--------------
flang/runtime/engine.cpp | 70 +++++++++++++++++++
flang/runtime/engine.h | 132 +++++++++++++++++++++++++++++++++++
4 files changed, 254 insertions(+), 55 deletions(-)
create mode 100644 flang/runtime/engine.cpp
create mode 100644 flang/runtime/engine.h
diff --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt
index 4537b2d059d65b..3d6278d831d433 100644
--- a/flang/runtime/CMakeLists.txt
+++ b/flang/runtime/CMakeLists.txt
@@ -125,6 +125,7 @@ set(sources
dot-product.cpp
edit-input.cpp
edit-output.cpp
+ engine.cpp
environment.cpp
exceptions.cpp
execute.cpp
diff --git a/flang/runtime/derived.cpp b/flang/runtime/derived.cpp
index 659f54fa344bb0..3fbfda076005cb 100644
--- a/flang/runtime/derived.cpp
+++ b/flang/runtime/derived.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "derived.h"
+#include "engine.h"
#include "stat.h"
#include "terminator.h"
#include "tools.h"
@@ -29,97 +30,92 @@ static RT_API_ATTRS void GetComponentExtents(SubscriptValue (&extents)[maxRank],
}
}
-RT_API_ATTRS int Initialize(const Descriptor &instance,
- const typeInfo::DerivedType &derived, Terminator &terminator, bool hasStat,
- const Descriptor *errMsg) {
- const Descriptor &componentDesc{derived.component()};
- std::size_t elements{instance.Elements()};
- int stat{StatOk};
- // Initialize data components in each element; the per-element iterations
- // constitute the inner loops, not the outer ones
- std::size_t myComponents{componentDesc.Elements()};
- for (std::size_t k{0}; k < myComponents; ++k) {
+auto engine::Initialization::Resume(Engine &engine) -> ResultType {
+ while (component_.Iterating(components_)) {
const auto &comp{
- *componentDesc.ZeroBasedIndexedElement<typeInfo::Component>(k)};
- SubscriptValue at[maxRank];
- instance.GetLowerBounds(at);
+ *componentDesc_->ZeroBasedIndexedElement<typeInfo::Component>(
+ component_.at)};
if (comp.genre() == typeInfo::Component::Genre::Allocatable ||
comp.genre() == typeInfo::Component::Genre::Automatic) {
- for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
- Descriptor &allocDesc{
- *instance.ElementComponent<Descriptor>(at, comp.offset())};
- comp.EstablishDescriptor(allocDesc, instance, terminator);
+ while (element_.Iterating(elements_, &instance_)) {
+ Descriptor &allocDesc{*instance_.ElementComponent<Descriptor>(
+ element_.subscripts, comp.offset())};
+ comp.EstablishDescriptor(allocDesc, instance_, engine.terminator());
allocDesc.raw().attribute = CFI_attribute_allocatable;
if (comp.genre() == typeInfo::Component::Genre::Automatic) {
- stat = ReturnError(terminator, allocDesc.Allocate(), errMsg, hasStat);
- if (stat == StatOk) {
- if (const DescriptorAddendum * addendum{allocDesc.Addendum()}) {
- if (const auto *derived{addendum->derivedType()}) {
- if (!derived->noInitializationNeeded()) {
- stat = Initialize(
- allocDesc, *derived, terminator, hasStat, errMsg);
- }
+ if (auto stat{ReturnError(engine.terminator(), allocDesc.Allocate(),
+ engine.errMsg(), engine.hasStat())};
+ stat != StatOk) {
+ return engine.Fail(stat);
+ }
+ if (const DescriptorAddendum * addendum{allocDesc.Addendum()}) {
+ if (const auto *derived{addendum->derivedType()}) {
+ if (!derived->noInitializationNeeded()) {
+ component_.ResumeAtSameIteration();
+ return engine.Begin(Job::Initialization, allocDesc, derived);
}
}
}
- if (stat != StatOk) {
- break;
- }
}
}
} else if (const void *init{comp.initialization()}) {
// Explicit initialization of data pointers and
// non-allocatable non-automatic components
- std::size_t bytes{comp.SizeInBytes(instance)};
- for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
- char *ptr{instance.ElementComponent<char>(at, comp.offset())};
+ std::size_t bytes{comp.SizeInBytes(instance_)};
+ while (element_.Iterating(elements_, &instance_)) {
+ char *ptr{instance_.ElementComponent<char>(
+ element_.subscripts, comp.offset())};
std::memcpy(ptr, init, bytes);
}
} else if (comp.genre() == typeInfo::Component::Genre::Pointer) {
// Data pointers without explicit initialization are established
// so that they are valid right-hand side targets of pointer
// assignment statements.
- for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
- Descriptor &ptrDesc{
- *instance.ElementComponent<Descriptor>(at, comp.offset())};
- comp.EstablishDescriptor(ptrDesc, instance, terminator);
+ while (element_.Iterating(elements_, &instance_)) {
+ Descriptor &ptrDesc{*instance_.ElementComponent<Descriptor>(
+ element_.subscripts, comp.offset())};
+ comp.EstablishDescriptor(ptrDesc, instance_, engine.terminator());
ptrDesc.raw().attribute = CFI_attribute_pointer;
}
} else if (comp.genre() == typeInfo::Component::Genre::Data &&
comp.derivedType() && !comp.derivedType()->noInitializationNeeded()) {
// Default initialization of non-pointer non-allocatable/automatic
- // data component. Handles parent component's elements. Recursive.
- SubscriptValue extents[maxRank];
- GetComponentExtents(extents, comp, instance);
- StaticDescriptor<maxRank, true, 0> staticDescriptor;
- Descriptor &compDesc{staticDescriptor.descriptor()};
- const typeInfo::DerivedType &compType{*comp.derivedType()};
- for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
+ // data component. Handles parent component's elements.
+ if (!element_.active) {
+ GetComponentExtents(extents_, comp, instance_);
+ }
+ while (element_.Iterating(elements_, &instance_)) {
+ Descriptor &compDesc{staticDescriptor_.descriptor()};
+ const typeInfo::DerivedType &compType{*comp.derivedType()};
compDesc.Establish(compType,
- instance.ElementComponent<char>(at, comp.offset()), comp.rank(),
- extents);
- stat = Initialize(compDesc, compType, terminator, hasStat, errMsg);
- if (stat != StatOk) {
- break;
- }
+ instance_.ElementComponent<char>(
+ element_.subscripts, comp.offset()),
+ comp.rank(), extents_);
+ component_.ResumeAtSameIteration();
+ return engine.Begin(Job::Initialization, compDesc, &compType);
}
}
}
// Initialize procedure pointer components in each element
- const Descriptor &procPtrDesc{derived.procPtr()};
+ const Descriptor &procPtrDesc{derived_->procPtr()};
std::size_t myProcPtrs{procPtrDesc.Elements()};
for (std::size_t k{0}; k < myProcPtrs; ++k) {
const auto &comp{
*procPtrDesc.ZeroBasedIndexedElement<typeInfo::ProcPtrComponent>(k)};
- SubscriptValue at[maxRank];
- instance.GetLowerBounds(at);
- for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
- auto &pptr{*instance.ElementComponent<typeInfo::ProcedurePointer>(
- at, comp.offset)};
+ while (element_.Iterating(elements_, &instance_)) {
+ auto &pptr{*instance_.ElementComponent<typeInfo::ProcedurePointer>(
+ element_.subscripts, comp.offset)};
pptr = comp.procInitialization;
}
}
- return stat;
+ return engine.Done();
+}
+
+RT_API_ATTRS int Initialize(const Descriptor &instance,
+ const typeInfo::DerivedType &derived, Terminator &terminator, bool hasStat,
+ const Descriptor *errMsg) {
+ return engine::Engine{terminator, hasStat, errMsg}.Do(
+ engine::Job::Initialization, instance, &derived);
}
static RT_API_ATTRS const typeInfo::SpecialBinding *FindFinal(
diff --git a/flang/runtime/engine.cpp b/flang/runtime/engine.cpp
new file mode 100644
index 00000000000000..525fb7ba56529e
--- /dev/null
+++ b/flang/runtime/engine.cpp
@@ -0,0 +1,70 @@
+//===-- runtime/engine.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "engine.h"
+
+namespace Fortran::runtime::engine {
+
+Engine::Work::Work(
+ Job job, const Descriptor &instance, const typeInfo::DerivedType *derived)
+ : job_{job}, u_{{instance}} {
+ auto &state{u_.commonState};
+ state.derived_ = derived;
+ state.elements_ = instance.Elements();
+ if (derived) {
+ state.componentDesc_ = &derived->component();
+ state.components_ = state.componentDesc_->Elements();
+ } else {
+ state.componentDesc_ = nullptr;
+ state.components_ = 0;
+ }
+}
+
+void Engine::Work::Resume(Engine &engine) {
+ switch (job_) {
+ case Job::Initialization:
+ u_.initialization.Resume(engine);
+ return;
+ }
+ engine.terminator().Crash(
+ "Work::Run: bad job_ code %d", static_cast<int>(job_));
+}
+
+int Engine::Do(
+ Job job, const Descriptor &instance, const typeInfo::DerivedType *derived) {
+ Begin(job, instance, derived);
+ while (depth_ > 0) {
+ if (status_ == StatOk) {
+ auto *w{reinterpret_cast<Work *>(workBuf_[depth_ - 1])};
+ w->Resume(*this);
+ } else {
+ Done();
+ }
+ }
+ return status_;
+}
+
+Task::ResultType Engine::Begin(
+ Job job, const Descriptor &instance, const typeInfo::DerivedType *derived) {
+ // TODO: heap allocation on overflow
+ RUNTIME_CHECK(terminator_, depth_ < maxDepth);
+ new (workBuf_[depth_++]) Work{job, instance, derived};
+ return Task::ResultType::ResultValue;
+}
+
+Task::ResultType Engine::Done() {
+ --depth_;
+ return Task::ResultType::ResultValue;
+}
+
+Task::ResultType Engine::Fail(int status) {
+ status_ = status;
+ return Done();
+}
+
+} // namespace Fortran::runtime::engine
diff --git a/flang/runtime/engine.h b/flang/runtime/engine.h
new file mode 100644
index 00000000000000..fcac989747ed4d
--- /dev/null
+++ b/flang/runtime/engine.h
@@ -0,0 +1,132 @@
+//===-- runtime/engine.h --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// Implements a work engine for restartable tasks iterating over elements,
+// components, &c. of arrays and derived types. Avoids recursion and
+// function pointers.
+
+#ifndef FORTRAN_RUNTIME_ENGINE_H_
+#define FORTRAN_RUNTIME_ENGINE_H_
+
+#include "derived.h"
+#include "stat.h"
+#include "terminator.h"
+#include "type-info.h"
+#include "flang/Runtime/descriptor.h"
+
+namespace Fortran::runtime::engine {
+
+class Engine;
+
+// Every task object derives from Task.
+struct Task {
+
+ enum class ResultType { ResultValue /*doesn't matter*/ };
+
+ struct Iteration {
+ bool Iterating(std::size_t iters, const Descriptor *dtor = nullptr) {
+ if (!active) {
+ if (iters > 0) {
+ active = true;
+ at = 0;
+ n = iters;
+ descriptor = dtor;
+ if (descriptor) {
+ descriptor->GetLowerBounds(subscripts);
+ }
+ }
+ } else if (resuming) {
+ resuming = false;
+ } else if (++at < n) {
+ if (descriptor) {
+ descriptor->IncrementSubscripts(subscripts);
+ }
+ } else {
+ active = false;
+ }
+ return active;
+ }
+ // Call on all Iteration instances before calling Engine::Begin()
+ // when they should not advance when the job is resumed.
+ void ResumeAtSameIteration() { resuming = true; }
+
+ bool active{false}, resuming{false};
+ std::size_t at, n;
+ const Descriptor *descriptor;
+ SubscriptValue subscripts[maxRank];
+ };
+
+ // For looping over elements
+ const Descriptor &instance_;
+ std::size_t elements_;
+ Iteration element_;
+
+ // For looping over components
+ const typeInfo::DerivedType *derived_;
+ const Descriptor *componentDesc_;
+ std::size_t components_;
+ Iteration component_;
+};
+
+enum class Job { Initialization };
+
+class Initialization : protected Task {
+public:
+ ResultType Resume(Engine &);
+
+private:
+ SubscriptValue extents_[maxRank];
+ StaticDescriptor<maxRank, true, 8> staticDescriptor_;
+};
+
+class Engine {
+public:
+ Engine(Terminator &terminator, bool hasStat, const Descriptor *errMsg)
+ : terminator_{terminator}, hasStat_{hasStat}, errMsg_{errMsg} {}
+
+ Terminator &terminator() const { return terminator_; }
+ bool hasStat() const { return hasStat_; }
+ const Descriptor *errMsg() const { return errMsg_; }
+
+ // Start and run a job to completion; returns status code.
+ int Do(Job, const Descriptor &instance, const typeInfo::DerivedType *);
+
+ // Callbacks from running tasks for use in their return statements.
+ // Suspends execution and start a nested job
+ Task::ResultType Begin(
+ Job, const Descriptor &instance, const typeInfo::DerivedType *);
+ // Terminates task successfully
+ Task::ResultType Done();
+ // Terminates task unsuccessfully
+ Task::ResultType Fail(int status);
+
+private:
+ class Work {
+ public:
+ Work(Job job, const Descriptor &instance, const typeInfo::DerivedType *);
+ void Resume(Engine &);
+
+ private:
+ Job job_;
+ union {
+ Task commonState;
+ Initialization initialization;
+ } u_;
+ };
+
+ Terminator &terminator_;
+ bool hasStat_{false};
+ const Descriptor *errMsg_;
+ int status_{StatOk};
+ int depth_{0};
+ static constexpr int maxDepth{4};
+ alignas(Work) char workBuf_[maxDepth][sizeof(Work)];
+};
+
+} // namespace Fortran::runtime::engine
+#endif // FORTRAN_RUNTIME_ENGINE_H_
More information about the flang-commits
mailing list