[flang-commits] [flang] [flang] Avoid recursion in runtime derived type initialization (PR #102394)
via flang-commits
flang-commits at lists.llvm.org
Wed Aug 7 15:04:15 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
Adds a recursion-avoiding work engine for things like this, and adapts derived type instance initialization to use it. If successful, the engine can be reused to replace recursion in many other runtime sites.
---
Full diff: https://github.com/llvm/llvm-project/pull/102394.diff
4 Files Affected:
- (modified) flang/runtime/CMakeLists.txt (+1)
- (modified) flang/runtime/derived.cpp (+51-55)
- (added) flang/runtime/engine.cpp (+67)
- (added) flang/runtime/engine.h (+112)
``````````diff
diff --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt
index 4537b2d059d65b..3d6278d831d433 100644
--- a/flang/runtime/CMakeLists.txt
+++ b/flang/runtime/CMakeLists.txt
@@ -125,6 +125,7 @@ set(sources
dot-product.cpp
edit-input.cpp
edit-output.cpp
+ engine.cpp
environment.cpp
exceptions.cpp
execute.cpp
diff --git a/flang/runtime/derived.cpp b/flang/runtime/derived.cpp
index 659f54fa344bb0..1e5c1c0dc05dd7 100644
--- a/flang/runtime/derived.cpp
+++ b/flang/runtime/derived.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "derived.h"
+#include "engine.h"
#include "stat.h"
#include "terminator.h"
#include "tools.h"
@@ -29,97 +30,92 @@ static RT_API_ATTRS void GetComponentExtents(SubscriptValue (&extents)[maxRank],
}
}
-RT_API_ATTRS int Initialize(const Descriptor &instance,
- const typeInfo::DerivedType &derived, Terminator &terminator, bool hasStat,
- const Descriptor *errMsg) {
- const Descriptor &componentDesc{derived.component()};
- std::size_t elements{instance.Elements()};
- int stat{StatOk};
- // Initialize data components in each element; the per-element iterations
- // constitute the inner loops, not the outer ones
- std::size_t myComponents{componentDesc.Elements()};
- for (std::size_t k{0}; k < myComponents; ++k) {
+int engine::Initialization::Run(Engine &engine) {
+ while (component_.Iterating(components_)) {
const auto &comp{
- *componentDesc.ZeroBasedIndexedElement<typeInfo::Component>(k)};
- SubscriptValue at[maxRank];
- instance.GetLowerBounds(at);
+ *componentDesc_->ZeroBasedIndexedElement<typeInfo::Component>(
+ component_.at)};
if (comp.genre() == typeInfo::Component::Genre::Allocatable ||
comp.genre() == typeInfo::Component::Genre::Automatic) {
- for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
- Descriptor &allocDesc{
- *instance.ElementComponent<Descriptor>(at, comp.offset())};
- comp.EstablishDescriptor(allocDesc, instance, terminator);
+ while (element_.Iterating(elements_, &instance_)) {
+ Descriptor &allocDesc{*instance_.ElementComponent<Descriptor>(
+ element_.subscripts, comp.offset())};
+ comp.EstablishDescriptor(allocDesc, instance_, engine.terminator());
allocDesc.raw().attribute = CFI_attribute_allocatable;
if (comp.genre() == typeInfo::Component::Genre::Automatic) {
- stat = ReturnError(terminator, allocDesc.Allocate(), errMsg, hasStat);
- if (stat == StatOk) {
- if (const DescriptorAddendum * addendum{allocDesc.Addendum()}) {
- if (const auto *derived{addendum->derivedType()}) {
- if (!derived->noInitializationNeeded()) {
- stat = Initialize(
- allocDesc, *derived, terminator, hasStat, errMsg);
- }
+ if (auto stat{ReturnError(engine.terminator(), allocDesc.Allocate(),
+ engine.errMsg(), engine.hasStat())};
+ stat != StatOk) {
+ return stat;
+ }
+ if (const DescriptorAddendum * addendum{allocDesc.Addendum()}) {
+ if (const auto *derived{addendum->derivedType()}) {
+ if (!derived->noInitializationNeeded()) {
+ component_.ResumeAtSameIteration();
+ return engine.Begin(Job::Initialize, allocDesc, derived);
}
}
}
- if (stat != StatOk) {
- break;
- }
}
}
} else if (const void *init{comp.initialization()}) {
// Explicit initialization of data pointers and
// non-allocatable non-automatic components
- std::size_t bytes{comp.SizeInBytes(instance)};
- for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
- char *ptr{instance.ElementComponent<char>(at, comp.offset())};
+ std::size_t bytes{comp.SizeInBytes(instance_)};
+ while (element_.Iterating(elements_, &instance_)) {
+ char *ptr{instance_.ElementComponent<char>(
+ element_.subscripts, comp.offset())};
std::memcpy(ptr, init, bytes);
}
} else if (comp.genre() == typeInfo::Component::Genre::Pointer) {
// Data pointers without explicit initialization are established
// so that they are valid right-hand side targets of pointer
// assignment statements.
- for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
- Descriptor &ptrDesc{
- *instance.ElementComponent<Descriptor>(at, comp.offset())};
- comp.EstablishDescriptor(ptrDesc, instance, terminator);
+ while (element_.Iterating(elements_, &instance_)) {
+ Descriptor &ptrDesc{*instance_.ElementComponent<Descriptor>(
+ element_.subscripts, comp.offset())};
+ comp.EstablishDescriptor(ptrDesc, instance_, engine.terminator());
ptrDesc.raw().attribute = CFI_attribute_pointer;
}
} else if (comp.genre() == typeInfo::Component::Genre::Data &&
comp.derivedType() && !comp.derivedType()->noInitializationNeeded()) {
// Default initialization of non-pointer non-allocatable/automatic
- // data component. Handles parent component's elements. Recursive.
- SubscriptValue extents[maxRank];
- GetComponentExtents(extents, comp, instance);
- StaticDescriptor<maxRank, true, 0> staticDescriptor;
- Descriptor &compDesc{staticDescriptor.descriptor()};
- const typeInfo::DerivedType &compType{*comp.derivedType()};
- for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
+ // data component. Handles parent component's elements.
+ if (!element_.active) {
+ GetComponentExtents(extents_, comp, instance_);
+ }
+ while (element_.Iterating(elements_, &instance_)) {
+ Descriptor &compDesc{staticDescriptor_.descriptor()};
+ const typeInfo::DerivedType &compType{*comp.derivedType()};
compDesc.Establish(compType,
- instance.ElementComponent<char>(at, comp.offset()), comp.rank(),
- extents);
- stat = Initialize(compDesc, compType, terminator, hasStat, errMsg);
- if (stat != StatOk) {
- break;
- }
+ instance_.ElementComponent<char>(
+ element_.subscripts, comp.offset()),
+ comp.rank(), extents_);
+ component_.ResumeAtSameIteration();
+ return engine.Begin(Job::Initialize, compDesc, &compType);
}
}
}
// Initialize procedure pointer components in each element
- const Descriptor &procPtrDesc{derived.procPtr()};
+ const Descriptor &procPtrDesc{derived_->procPtr()};
std::size_t myProcPtrs{procPtrDesc.Elements()};
for (std::size_t k{0}; k < myProcPtrs; ++k) {
const auto &comp{
*procPtrDesc.ZeroBasedIndexedElement<typeInfo::ProcPtrComponent>(k)};
- SubscriptValue at[maxRank];
- instance.GetLowerBounds(at);
- for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
- auto &pptr{*instance.ElementComponent<typeInfo::ProcedurePointer>(
- at, comp.offset)};
+ while (element_.Iterating(elements_, &instance_)) {
+ auto &pptr{*instance_.ElementComponent<typeInfo::ProcedurePointer>(
+ element_.subscripts, comp.offset)};
pptr = comp.procInitialization;
}
}
- return stat;
+ return engine.Done();
+}
+
+RT_API_ATTRS int Initialize(const Descriptor &instance,
+ const typeInfo::DerivedType &derived, Terminator &terminator, bool hasStat,
+ const Descriptor *errMsg) {
+ return engine::Engine{terminator, hasStat, errMsg}.Do(
+ engine::Job::Initialize, instance, &derived);
}
static RT_API_ATTRS const typeInfo::SpecialBinding *FindFinal(
diff --git a/flang/runtime/engine.cpp b/flang/runtime/engine.cpp
new file mode 100644
index 00000000000000..dc64202c72611e
--- /dev/null
+++ b/flang/runtime/engine.cpp
@@ -0,0 +1,67 @@
+//===-- runtime/engine.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "engine.h"
+
+namespace Fortran::runtime::engine {
+
+Work::Work(
+ Job job, const Descriptor &instance, const typeInfo::DerivedType *derived)
+ : job_{job}, u_{{instance}} {
+ auto &state{u_.commonState};
+ state.derived_ = derived;
+ state.elements_ = instance.Elements();
+ if (derived) {
+ state.componentDesc_ = &derived->component();
+ state.components_ = state.componentDesc_->Elements();
+ } else {
+ state.componentDesc_ = nullptr;
+ state.components_ = 0;
+ }
+}
+
+int Work::Run(Engine &engine) {
+ switch (job_) {
+ case Job::Initialize:
+ return u_.initialization.Run(engine);
+ }
+ engine.terminator().Crash(
+ "Work::Run: bad job_ code %d", static_cast<int>(job_));
+}
+
+int Engine::Do(
+ Job job, const Descriptor &instance, const typeInfo::DerivedType *derived) {
+ if (int status{Begin(job, instance, derived)}; status != StatOk) {
+ return status;
+ }
+ return Run();
+}
+
+int Engine::Begin(
+ Job job, const Descriptor &instance, const typeInfo::DerivedType *derived) {
+ // TODO: heap allocation on overflow
+ new (workBuf_[depth_++]) Work{job, instance, derived};
+ return StatOk;
+}
+
+int Engine::Run() {
+ while (depth_) {
+ auto *w{reinterpret_cast<Work *>(workBuf_[depth_ - 1])};
+ if (int status{w->Run(*this)}; status != StatOk) {
+ return status;
+ }
+ }
+ return StatOk;
+}
+
+int Engine::Done() {
+ --depth_;
+ return StatOk;
+}
+
+} // namespace Fortran::runtime::engine
diff --git a/flang/runtime/engine.h b/flang/runtime/engine.h
new file mode 100644
index 00000000000000..53b4a11f6b245a
--- /dev/null
+++ b/flang/runtime/engine.h
@@ -0,0 +1,112 @@
+//===-- runtime/engine.h --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_RUNTIME_ENGINE_H_
+#define FORTRAN_RUNTIME_ENGINE_H_
+
+#include "derived.h"
+#include "stat.h"
+#include "terminator.h"
+#include "type-info.h"
+#include "flang/Runtime/descriptor.h"
+
+namespace Fortran::runtime::engine {
+
+class Engine;
+
+enum class Job { Initialize };
+
+struct CommonState {
+ struct Iteration {
+ bool Iterating(std::size_t iters, const Descriptor *dtor = nullptr) {
+ if (!active) {
+ if (iters > 0) {
+ active = true;
+ at = 0;
+ n = iters;
+ descriptor = dtor;
+ if (descriptor) {
+ descriptor->GetLowerBounds(subscripts);
+ }
+ }
+ } else if (resuming) {
+ resuming = false;
+ } else if (++at < n) {
+ if (descriptor) {
+ descriptor->IncrementSubscripts(subscripts);
+ }
+ } else {
+ active = false;
+ }
+ return active;
+ }
+ void ResumeAtSameIteration() { resuming = true; }
+
+ bool active{false}, resuming{false};
+ std::size_t at, n;
+ const Descriptor *descriptor;
+ SubscriptValue subscripts[maxRank];
+ };
+
+ const Descriptor &instance_;
+ const typeInfo::DerivedType *derived_;
+ const Descriptor *componentDesc_;
+ std::size_t elements_, components_;
+ Iteration element_, component_;
+ StaticDescriptor<maxRank, true, 8> staticDescriptor_;
+};
+
+class Initialization : protected CommonState {
+public:
+ int Run(Engine &); // in derived.cpp
+private:
+ SubscriptValue extents_[maxRank];
+};
+
+class Engine {
+public:
+ Engine(Terminator &terminator, bool hasStat, const Descriptor *errMsg)
+ : terminator_{terminator}, hasStat_{hasStat}, errMsg_{errMsg} {}
+
+ // Start and run a job to completion.
+ int Do(Job, const Descriptor &instance, const typeInfo::DerivedType *);
+
+ Terminator &terminator() const { return terminator_; }
+ bool hasStat() const { return hasStat_; }
+ const Descriptor *errMsg() const { return errMsg_; }
+
+ // Call from running job to suspend execution and start a nested job
+ int Begin(Job, const Descriptor &instance, const typeInfo::DerivedType *);
+ // Call from a running job to terminate successfully
+ int Done();
+
+private:
+ class Work {
+ public:
+ Work(Job job, const Descriptor &instance, const typeInfo::DerivedType *);
+ int Run(Engine &); // nonzero on fatal error
+ private:
+ Job job_;
+ union {
+ CommonState commonState;
+ Initialization initialization;
+ } u_;
+ };
+
+ int Run();
+
+ Terminator &terminator_;
+ bool hasStat_{false};
+ const Descriptor *errMsg_;
+ int depth_{0};
+ static constexpr int maxDepth{4};
+ alignas(Work) char workBuf_[maxDepth][sizeof(Work)];
+};
+
+} // namespace Fortran::runtime::engine
+#endif // FORTRAN_RUNTIME_ENGINE_H_
``````````
</details>
https://github.com/llvm/llvm-project/pull/102394
More information about the flang-commits
mailing list