[flang-commits] [flang] [flang] Avoid recursion in runtime derived type initialization (PR #102394)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Sat Aug 17 13:59:48 PDT 2024


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/102394

>From 0aa4e123ed7f3ce4a9bbe3be32b88173b0b143cf Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 7 Aug 2024 15:01:29 -0700
Subject: [PATCH 1/3] [flang] Avoid recursion in runtime derived type
 initialization

Adds a recursion-avoiding work engine for things like this, and
adapts derived type instance initialization to use it.  If
successful, the engine can be reused to replace recursion in many
other runtime sites.
---
 flang/runtime/CMakeLists.txt |   1 +
 flang/runtime/derived.cpp    | 106 +++++++++++++-------------
 flang/runtime/engine.cpp     |  85 +++++++++++++++++++++
 flang/runtime/engine.h       | 144 +++++++++++++++++++++++++++++++++++
 4 files changed, 281 insertions(+), 55 deletions(-)
 create mode 100644 flang/runtime/engine.cpp
 create mode 100644 flang/runtime/engine.h

diff --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt
index 4537b2d059d65b..3d6278d831d433 100644
--- a/flang/runtime/CMakeLists.txt
+++ b/flang/runtime/CMakeLists.txt
@@ -125,6 +125,7 @@ set(sources
   dot-product.cpp
   edit-input.cpp
   edit-output.cpp
+  engine.cpp
   environment.cpp
   exceptions.cpp
   execute.cpp
diff --git a/flang/runtime/derived.cpp b/flang/runtime/derived.cpp
index 659f54fa344bb0..a55a51996afbaf 100644
--- a/flang/runtime/derived.cpp
+++ b/flang/runtime/derived.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "derived.h"
+#include "engine.h"
 #include "stat.h"
 #include "terminator.h"
 #include "tools.h"
@@ -29,97 +30,92 @@ static RT_API_ATTRS void GetComponentExtents(SubscriptValue (&extents)[maxRank],
   }
 }
 
-RT_API_ATTRS int Initialize(const Descriptor &instance,
-    const typeInfo::DerivedType &derived, Terminator &terminator, bool hasStat,
-    const Descriptor *errMsg) {
-  const Descriptor &componentDesc{derived.component()};
-  std::size_t elements{instance.Elements()};
-  int stat{StatOk};
-  // Initialize data components in each element; the per-element iterations
-  // constitute the inner loops, not the outer ones
-  std::size_t myComponents{componentDesc.Elements()};
-  for (std::size_t k{0}; k < myComponents; ++k) {
+RT_API_ATTRS auto engine::Initialization::Resume(Engine &engine) -> ResultType {
+  while (component_.Iterating(components_)) {
     const auto &comp{
-        *componentDesc.ZeroBasedIndexedElement<typeInfo::Component>(k)};
-    SubscriptValue at[maxRank];
-    instance.GetLowerBounds(at);
+        *componentDesc_->ZeroBasedIndexedElement<typeInfo::Component>(
+            component_.at)};
     if (comp.genre() == typeInfo::Component::Genre::Allocatable ||
         comp.genre() == typeInfo::Component::Genre::Automatic) {
-      for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
-        Descriptor &allocDesc{
-            *instance.ElementComponent<Descriptor>(at, comp.offset())};
-        comp.EstablishDescriptor(allocDesc, instance, terminator);
+      while (element_.Iterating(elements_, &instance_)) {
+        Descriptor &allocDesc{*instance_.ElementComponent<Descriptor>(
+            element_.subscripts, comp.offset())};
+        comp.EstablishDescriptor(allocDesc, instance_, engine.terminator());
         allocDesc.raw().attribute = CFI_attribute_allocatable;
         if (comp.genre() == typeInfo::Component::Genre::Automatic) {
-          stat = ReturnError(terminator, allocDesc.Allocate(), errMsg, hasStat);
-          if (stat == StatOk) {
-            if (const DescriptorAddendum * addendum{allocDesc.Addendum()}) {
-              if (const auto *derived{addendum->derivedType()}) {
-                if (!derived->noInitializationNeeded()) {
-                  stat = Initialize(
-                      allocDesc, *derived, terminator, hasStat, errMsg);
-                }
+          if (auto stat{ReturnError(engine.terminator(), allocDesc.Allocate(),
+                  engine.errMsg(), engine.hasStat())};
+              stat != StatOk) {
+            return engine.Fail(stat);
+          }
+          if (const DescriptorAddendum * addendum{allocDesc.Addendum()}) {
+            if (const auto *derived{addendum->derivedType()}) {
+              if (!derived->noInitializationNeeded()) {
+                component_.ResumeAtSameIteration();
+                return engine.Begin(Job::Initialization, allocDesc, derived);
               }
             }
           }
-          if (stat != StatOk) {
-            break;
-          }
         }
       }
     } else if (const void *init{comp.initialization()}) {
       // Explicit initialization of data pointers and
       // non-allocatable non-automatic components
-      std::size_t bytes{comp.SizeInBytes(instance)};
-      for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
-        char *ptr{instance.ElementComponent<char>(at, comp.offset())};
+      std::size_t bytes{comp.SizeInBytes(instance_)};
+      while (element_.Iterating(elements_, &instance_)) {
+        char *ptr{instance_.ElementComponent<char>(
+            element_.subscripts, comp.offset())};
         std::memcpy(ptr, init, bytes);
       }
     } else if (comp.genre() == typeInfo::Component::Genre::Pointer) {
       // Data pointers without explicit initialization are established
       // so that they are valid right-hand side targets of pointer
       // assignment statements.
-      for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
-        Descriptor &ptrDesc{
-            *instance.ElementComponent<Descriptor>(at, comp.offset())};
-        comp.EstablishDescriptor(ptrDesc, instance, terminator);
+      while (element_.Iterating(elements_, &instance_)) {
+        Descriptor &ptrDesc{*instance_.ElementComponent<Descriptor>(
+            element_.subscripts, comp.offset())};
+        comp.EstablishDescriptor(ptrDesc, instance_, engine.terminator());
         ptrDesc.raw().attribute = CFI_attribute_pointer;
       }
     } else if (comp.genre() == typeInfo::Component::Genre::Data &&
         comp.derivedType() && !comp.derivedType()->noInitializationNeeded()) {
       // Default initialization of non-pointer non-allocatable/automatic
-      // data component.  Handles parent component's elements.  Recursive.
-      SubscriptValue extents[maxRank];
-      GetComponentExtents(extents, comp, instance);
-      StaticDescriptor<maxRank, true, 0> staticDescriptor;
-      Descriptor &compDesc{staticDescriptor.descriptor()};
-      const typeInfo::DerivedType &compType{*comp.derivedType()};
-      for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
+      // data component.  Handles parent component's elements.
+      if (!element_.active) {
+        GetComponentExtents(extents_, comp, instance_);
+      }
+      while (element_.Iterating(elements_, &instance_)) {
+        Descriptor &compDesc{staticDescriptor_.descriptor()};
+        const typeInfo::DerivedType &compType{*comp.derivedType()};
         compDesc.Establish(compType,
-            instance.ElementComponent<char>(at, comp.offset()), comp.rank(),
-            extents);
-        stat = Initialize(compDesc, compType, terminator, hasStat, errMsg);
-        if (stat != StatOk) {
-          break;
-        }
+            instance_.ElementComponent<char>(
+                element_.subscripts, comp.offset()),
+            comp.rank(), extents_);
+        component_.ResumeAtSameIteration();
+        return engine.Begin(Job::Initialization, compDesc, &compType);
       }
     }
   }
   // Initialize procedure pointer components in each element
-  const Descriptor &procPtrDesc{derived.procPtr()};
+  const Descriptor &procPtrDesc{derived_->procPtr()};
   std::size_t myProcPtrs{procPtrDesc.Elements()};
   for (std::size_t k{0}; k < myProcPtrs; ++k) {
     const auto &comp{
         *procPtrDesc.ZeroBasedIndexedElement<typeInfo::ProcPtrComponent>(k)};
-    SubscriptValue at[maxRank];
-    instance.GetLowerBounds(at);
-    for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
-      auto &pptr{*instance.ElementComponent<typeInfo::ProcedurePointer>(
-          at, comp.offset)};
+    while (element_.Iterating(elements_, &instance_)) {
+      auto &pptr{*instance_.ElementComponent<typeInfo::ProcedurePointer>(
+          element_.subscripts, comp.offset)};
       pptr = comp.procInitialization;
     }
   }
-  return stat;
+  return engine.Done();
+}
+
+RT_API_ATTRS int Initialize(const Descriptor &instance,
+    const typeInfo::DerivedType &derived, Terminator &terminator, bool hasStat,
+    const Descriptor *errMsg) {
+  return engine::Engine{terminator, hasStat, errMsg}.Do(
+      engine::Job::Initialization, instance, &derived);
 }
 
 static RT_API_ATTRS const typeInfo::SpecialBinding *FindFinal(
diff --git a/flang/runtime/engine.cpp b/flang/runtime/engine.cpp
new file mode 100644
index 00000000000000..6ac683950728c6
--- /dev/null
+++ b/flang/runtime/engine.cpp
@@ -0,0 +1,85 @@
+//===-- runtime/engine.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "engine.h"
+#include "flang/Runtime/memory.h"
+
+namespace Fortran::runtime::engine {
+
+RT_API_ATTRS Engine::Work::Work(
+    Job job, const Descriptor &instance, const typeInfo::DerivedType *derived)
+    : job_{job}, u_{{instance}} {
+  auto &state{u_.commonState};
+  state.derived_ = derived;
+  state.elements_ = instance.Elements();
+  if (derived) {
+    state.componentDesc_ = &derived->component();
+    state.components_ = state.componentDesc_->Elements();
+  } else {
+    state.componentDesc_ = nullptr;
+    state.components_ = 0;
+  }
+}
+
+RT_API_ATTRS void Engine::Work::Resume(Engine &engine) {
+  switch (job_) {
+  case Job::Initialization:
+    u_.initialization.Resume(engine);
+    return;
+  }
+  engine.terminator().Crash(
+      "Work::Run: bad job_ code %d", static_cast<int>(job_));
+}
+
+RT_API_ATTRS Engine::~Engine() {
+  // deletes list owned by bottomWorkBlock_.next
+}
+
+RT_API_ATTRS int Engine::Do(
+    Job job, const Descriptor &instance, const typeInfo::DerivedType *derived) {
+  Begin(job, instance, derived);
+  while (topWorkBlock_ != &bottomWorkBlock_ && topWorkBlock_->depth > 0) {
+    if (status_ == StatOk) {
+      auto *w{reinterpret_cast<Work *>(
+          topWorkBlock_->workBuf[topWorkBlock_->depth - 1])};
+      w->Resume(*this);
+    } else {
+      Done();
+    }
+  }
+  return status_;
+}
+
+RT_API_ATTRS Task::ResultType Engine::Begin(
+    Job job, const Descriptor &instance, const typeInfo::DerivedType *derived) {
+  if (topWorkBlock_->depth == topWorkBlock_->maxDepth) {
+    if (!topWorkBlock_->next) {
+      topWorkBlock_->next = New<WorkBlock>{terminator_}(topWorkBlock_);
+    }
+    topWorkBlock_ = topWorkBlock_->next.get();
+  }
+  new (topWorkBlock_->workBuf[topWorkBlock_->depth++])
+      Work{job, instance, derived};
+  return Task::ResultType::ResultValue;
+}
+
+RT_API_ATTRS Task::ResultType Engine::Done() {
+  if (!--topWorkBlock_->depth) {
+    if (auto *previous{topWorkBlock_->previous}) {
+      topWorkBlock_ = previous;
+    }
+  }
+  return Task::ResultType::ResultValue;
+}
+
+RT_API_ATTRS Task::ResultType Engine::Fail(int status) {
+  status_ = status;
+  return Done();
+}
+
+} // namespace Fortran::runtime::engine
diff --git a/flang/runtime/engine.h b/flang/runtime/engine.h
new file mode 100644
index 00000000000000..40157adddb5f6f
--- /dev/null
+++ b/flang/runtime/engine.h
@@ -0,0 +1,144 @@
+//===-- runtime/engine.h --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// Implements a work engine for restartable tasks iterating over elements,
+// components, &c. of arrays and derived types.  Avoids recursion and
+// function pointers.
+
+#ifndef FORTRAN_RUNTIME_ENGINE_H_
+#define FORTRAN_RUNTIME_ENGINE_H_
+
+#include "derived.h"
+#include "stat.h"
+#include "terminator.h"
+#include "type-info.h"
+#include "flang/Runtime/descriptor.h"
+
+namespace Fortran::runtime::engine {
+
+class Engine;
+
+// Every task object derives from Task.
+struct Task {
+
+  enum class ResultType { ResultValue /*doesn't matter*/ };
+
+  struct Iteration {
+    RT_API_ATTRS bool Iterating(
+        std::size_t iters, const Descriptor *dtor = nullptr) {
+      if (!active) {
+        if (iters > 0) {
+          active = true;
+          at = 0;
+          n = iters;
+          descriptor = dtor;
+          if (descriptor) {
+            descriptor->GetLowerBounds(subscripts);
+          }
+        }
+      } else if (resuming) {
+        resuming = false;
+      } else if (++at < n) {
+        if (descriptor) {
+          descriptor->IncrementSubscripts(subscripts);
+        }
+      } else {
+        active = false;
+      }
+      return active;
+    }
+    // Call on all Iteration instances before calling Engine::Begin()
+    // when they should not advance when the job is resumed.
+    RT_API_ATTRS void ResumeAtSameIteration() { resuming = true; }
+
+    bool active{false}, resuming{false};
+    std::size_t at, n;
+    const Descriptor *descriptor;
+    SubscriptValue subscripts[maxRank];
+  };
+
+  // For looping over elements
+  const Descriptor &instance_;
+  std::size_t elements_;
+  Iteration element_;
+
+  // For looping over components
+  const typeInfo::DerivedType *derived_;
+  const Descriptor *componentDesc_;
+  std::size_t components_;
+  Iteration component_;
+};
+
+enum class Job { Initialization };
+
+class Initialization : protected Task {
+public:
+  RT_API_ATTRS ResultType Resume(Engine &);
+
+private:
+  SubscriptValue extents_[maxRank];
+  StaticDescriptor<maxRank, true, 8> staticDescriptor_;
+};
+
+class Engine {
+public:
+  RT_API_ATTRS Engine(
+      Terminator &terminator, bool hasStat, const Descriptor *errMsg)
+      : terminator_{terminator}, hasStat_{hasStat}, errMsg_{errMsg} {}
+  RT_API_ATTRS ~Engine();
+
+  RT_API_ATTRS Terminator &terminator() const { return terminator_; }
+  RT_API_ATTRS bool hasStat() const { return hasStat_; }
+  RT_API_ATTRS const Descriptor *errMsg() const { return errMsg_; }
+
+  // Start and run a job to completion; returns status code.
+  RT_API_ATTRS int Do(
+      Job, const Descriptor &instance, const typeInfo::DerivedType *);
+
+  // Callbacks from running tasks for use in their return statements.
+  // Suspends execution and start a nested job
+  RT_API_ATTRS Task::ResultType Begin(
+      Job, const Descriptor &instance, const typeInfo::DerivedType *);
+  // Terminates task successfully
+  RT_API_ATTRS Task::ResultType Done();
+  // Terminates task unsuccessfully
+  RT_API_ATTRS Task::ResultType Fail(int status);
+
+private:
+  class Work {
+  public:
+    RT_API_ATTRS Work(
+        Job job, const Descriptor &instance, const typeInfo::DerivedType *);
+    RT_API_ATTRS void Resume(Engine &);
+
+  private:
+    Job job_;
+    union {
+      Task commonState;
+      Initialization initialization;
+    } u_;
+  };
+
+  struct WorkBlock {
+    WorkBlock *previous{nullptr};
+    OwningPtr<WorkBlock> next;
+    int depth{0};
+    static constexpr int maxDepth{4};
+    alignas(Work) char workBuf[maxDepth][sizeof(Work)];
+  };
+
+  Terminator &terminator_;
+  bool hasStat_{false};
+  const Descriptor *errMsg_;
+  int status_{StatOk};
+  WorkBlock bottomWorkBlock_;
+  WorkBlock *topWorkBlock_{&bottomWorkBlock_};
+};
+
+} // namespace Fortran::runtime::engine
+#endif // FORTRAN_RUNTIME_ENGINE_H_

>From 24981b73ebfd74464984d4d09be4e290bfb99f70 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Fri, 9 Aug 2024 07:57:34 -0700
Subject: [PATCH 2/3] wip

---
 flang/runtime/derived.cpp | 188 +++++++++++++++++++++-----------------
 flang/runtime/engine.cpp  |  19 ++--
 flang/runtime/engine.h    |  39 ++++++--
 3 files changed, 145 insertions(+), 101 deletions(-)

diff --git a/flang/runtime/derived.cpp b/flang/runtime/derived.cpp
index a55a51996afbaf..e545c75c41fff7 100644
--- a/flang/runtime/derived.cpp
+++ b/flang/runtime/derived.cpp
@@ -81,12 +81,12 @@ RT_API_ATTRS auto engine::Initialization::Resume(Engine &engine) -> ResultType {
         comp.derivedType() && !comp.derivedType()->noInitializationNeeded()) {
       // Default initialization of non-pointer non-allocatable/automatic
       // data component.  Handles parent component's elements.
-      if (!element_.active) {
-        GetComponentExtents(extents_, comp, instance_);
-      }
       while (element_.Iterating(elements_, &instance_)) {
         Descriptor &compDesc{staticDescriptor_.descriptor()};
         const typeInfo::DerivedType &compType{*comp.derivedType()};
+        if (element_.at == 0) {
+          GetComponentExtents(extents_, comp, instance_);
+        }
         compDesc.Establish(compType,
             instance_.ElementComponent<char>(
                 element_.subscripts, comp.offset()),
@@ -133,7 +133,7 @@ static RT_API_ATTRS const typeInfo::SpecialBinding *FindFinal(
 }
 
 static RT_API_ATTRS void CallFinalSubroutine(const Descriptor &descriptor,
-    const typeInfo::DerivedType &derived, Terminator *terminator) {
+    const typeInfo::DerivedType &derived, Terminator &terminator) {
   if (const auto *special{FindFinal(derived, descriptor.rank())}) {
     if (special->which() == typeInfo::SpecialBinding::Which::ElementalFinal) {
       std::size_t elements{descriptor.Elements()};
@@ -170,9 +170,7 @@ static RT_API_ATTRS void CallFinalSubroutine(const Descriptor &descriptor,
         copy = descriptor;
         copy.set_base_addr(nullptr);
         copy.raw().attribute = CFI_attribute_allocatable;
-        Terminator stubTerminator{"CallFinalProcedure() in Fortran runtime", 0};
-        RUNTIME_CHECK(terminator ? *terminator : stubTerminator,
-            copy.Allocate() == CFI_SUCCESS);
+        RUNTIME_CHECK(terminator, copy.Allocate() == CFI_SUCCESS);
         ShallowCopyDiscontiguousToContiguous(copy, descriptor);
         argDescriptor = ©
       }
@@ -197,42 +195,42 @@ static RT_API_ATTRS void CallFinalSubroutine(const Descriptor &descriptor,
 }
 
 // Fortran 2018 subclause 7.5.6.2
-RT_API_ATTRS void Finalize(const Descriptor &descriptor,
-    const typeInfo::DerivedType &derived, Terminator *terminator) {
-  if (derived.noFinalizationNeeded() || !descriptor.IsAllocated()) {
-    return;
+RT_API_ATTRS auto engine::Finalization::Resume(Engine &engine) -> ResultType {
+  if (!derived_ || derived_->noFinalizationNeeded() ||
+      !instance_.IsAllocated()) {
+    return engine.Done();
+  }
+  if (phase_ == 0) {
+    ++phase_;
+    CallFinalSubroutine(instance_, *derived_, engine.terminator());
   }
-  CallFinalSubroutine(descriptor, derived, terminator);
-  const auto *parentType{derived.GetParentType()};
+  const auto *parentType{derived_->GetParentType()};
   bool recurse{parentType && !parentType->noFinalizationNeeded()};
   // If there's a finalizable parent component, handle it last, as required
   // by the Fortran standard (7.5.6.2), and do so recursively with the same
   // descriptor so that the rank is preserved.
-  const Descriptor &componentDesc{derived.component()};
-  std::size_t myComponents{componentDesc.Elements()};
-  std::size_t elements{descriptor.Elements()};
-  for (auto k{recurse ? std::size_t{1}
-                      /* skip first component, it's the parent */
-                      : 0};
-       k < myComponents; ++k) {
+  while (phase_ == 1 && component_.Iterating(components_)) {
+    if (recurse && component_.at == 0) {
+      continue; // skip first component, which is the parent
+    }
     const auto &comp{
-        *componentDesc.ZeroBasedIndexedElement<typeInfo::Component>(k)};
-    SubscriptValue at[maxRank];
-    descriptor.GetLowerBounds(at);
+        *componentDesc_->ZeroBasedIndexedElement<typeInfo::Component>(
+            component_.at)};
     if (comp.genre() == typeInfo::Component::Genre::Allocatable &&
         comp.category() == TypeCategory::Derived) {
       // Component may be polymorphic or unlimited polymorphic. Need to use the
       // dynamic type to check whether finalization is needed.
-      for (std::size_t j{0}; j++ < elements;
-           descriptor.IncrementSubscripts(at)) {
-        const Descriptor &compDesc{
-            *descriptor.ElementComponent<Descriptor>(at, comp.offset())};
+      while (element_.Iterating(elements_, &instance_)) {
+        const Descriptor &compDesc{*instance_.ElementComponent<Descriptor>(
+            element_.subscripts, comp.offset())};
         if (compDesc.IsAllocated()) {
           if (const DescriptorAddendum * addendum{compDesc.Addendum()}) {
             if (const typeInfo::DerivedType *
                 compDynamicType{addendum->derivedType()}) {
               if (!compDynamicType->noFinalizationNeeded()) {
-                Finalize(compDesc, *compDynamicType, terminator);
+                component_.ResumeAtSameIteration();
+                return engine.Begin(
+                    Job::Finalization, compDesc, compDynamicType);
               }
             }
           }
@@ -242,94 +240,120 @@ RT_API_ATTRS void Finalize(const Descriptor &descriptor,
         comp.genre() == typeInfo::Component::Genre::Automatic) {
       if (const typeInfo::DerivedType * compType{comp.derivedType()}) {
         if (!compType->noFinalizationNeeded()) {
-          for (std::size_t j{0}; j++ < elements;
-               descriptor.IncrementSubscripts(at)) {
-            const Descriptor &compDesc{
-                *descriptor.ElementComponent<Descriptor>(at, comp.offset())};
+          while (element_.Iterating(elements_, &instance_)) {
+            const Descriptor &compDesc{*instance_.ElementComponent<Descriptor>(
+                element_.subscripts, comp.offset())};
             if (compDesc.IsAllocated()) {
-              Finalize(compDesc, *compType, terminator);
+              component_.ResumeAtSameIteration();
+              return engine.Begin(Job::Finalization, compDesc, compType);
             }
           }
         }
       }
     } else if (comp.genre() == typeInfo::Component::Genre::Data &&
         comp.derivedType() && !comp.derivedType()->noFinalizationNeeded()) {
-      SubscriptValue extents[maxRank];
-      GetComponentExtents(extents, comp, descriptor);
-      StaticDescriptor<maxRank, true, 0> staticDescriptor;
-      Descriptor &compDesc{staticDescriptor.descriptor()};
+      Descriptor &compDesc{staticDescriptor_.descriptor()};
       const typeInfo::DerivedType &compType{*comp.derivedType()};
-      for (std::size_t j{0}; j++ < elements;
-           descriptor.IncrementSubscripts(at)) {
+      while (element_.Iterating(elements_, &instance_)) {
+        if (element_.at == 0) {
+          GetComponentExtents(extents_, comp, instance_);
+        }
         compDesc.Establish(compType,
-            descriptor.ElementComponent<char>(at, comp.offset()), comp.rank(),
-            extents);
-        Finalize(compDesc, compType, terminator);
+            instance_.ElementComponent<char>(
+                element_.subscripts, comp.offset()),
+            comp.rank(), extents_);
+        component_.ResumeAtSameIteration();
+        return engine.Begin(Job::Finalization, compDesc, &compType);
       }
     }
   }
-  if (recurse) {
-    StaticDescriptor<maxRank, true, 8 /*?*/> statDesc;
-    Descriptor &tmpDesc{statDesc.descriptor()};
-    tmpDesc = descriptor;
-    tmpDesc.raw().attribute = CFI_attribute_pointer;
-    tmpDesc.Addendum()->set_derivedType(parentType);
-    tmpDesc.raw().elem_len = parentType->sizeInBytes();
-    Finalize(tmpDesc, *parentType, terminator);
+  if (phase_ == 1) { // done with all non-parent components
+    ++phase_;
   }
+  if (phase_ == 2) {
+    ++phase_;
+    if (recurse) { // now finalize parent component
+      Descriptor &tmpDesc{staticDescriptor_.descriptor()};
+      tmpDesc = instance_;
+      tmpDesc.raw().attribute = CFI_attribute_pointer;
+      tmpDesc.Addendum()->set_derivedType(parentType);
+      tmpDesc.raw().elem_len = parentType->sizeInBytes();
+      return engine.Begin(Job::Finalization, tmpDesc, parentType);
+    }
+  }
+  return engine.Done();
 }
 
-// The order of finalization follows Fortran 2018 7.5.6.2, with
-// elementwise finalization of non-parent components taking place
-// before parent component finalization, and with all finalization
-// preceding any deallocation.
-RT_API_ATTRS void Destroy(const Descriptor &descriptor, bool finalize,
+// Fortran 2018 subclause 7.5.6.2
+RT_API_ATTRS void Finalize(const Descriptor &descriptor,
     const typeInfo::DerivedType &derived, Terminator *terminator) {
-  if (derived.noDestructionNeeded() || !descriptor.IsAllocated()) {
-    return;
-  }
-  if (finalize && !derived.noFinalizationNeeded()) {
-    Finalize(descriptor, derived, terminator);
+  if (!derived.noFinalizationNeeded() && descriptor.IsAllocated()) {
+    Terminator defaultTerminator{"Finalize() in Fortran runtime"};
+    if (!terminator) {
+      terminator = &defaultTerminator;
+    }
+    engine::Engine engine{*terminator, /*hasStat=*/false, /*errMsg=*/nullptr};
+    engine.Do(engine::Job::Finalization, descriptor, &derived);
   }
+}
+
+RT_API_ATTRS auto engine::Destruction::Resume(Engine &engine) -> ResultType {
   // Deallocate all direct and indirect allocatable and automatic components.
   // Contrary to finalization, the order of deallocation does not matter.
-  const Descriptor &componentDesc{derived.component()};
-  std::size_t myComponents{componentDesc.Elements()};
-  std::size_t elements{descriptor.Elements()};
-  SubscriptValue at[maxRank];
-  descriptor.GetLowerBounds(at);
-  for (std::size_t k{0}; k < myComponents; ++k) {
+  while (component_.Iterating(components_)) {
     const auto &comp{
-        *componentDesc.ZeroBasedIndexedElement<typeInfo::Component>(k)};
+        *componentDesc_->ZeroBasedIndexedElement<typeInfo::Component>(
+            component_.at)};
     const bool destroyComp{
         comp.derivedType() && !comp.derivedType()->noDestructionNeeded()};
     if (comp.genre() == typeInfo::Component::Genre::Allocatable ||
         comp.genre() == typeInfo::Component::Genre::Automatic) {
-      for (std::size_t j{0}; j < elements; ++j) {
-        Descriptor *d{
-            descriptor.ElementComponent<Descriptor>(at, comp.offset())};
+      while (element_.Iterating(elements_, &instance_)) {
+        Descriptor &d{*instance_.ElementComponent<Descriptor>(
+            element_.subscripts, comp.offset())};
         if (destroyComp) {
-          Destroy(*d, /*finalize=*/false, *comp.derivedType(), terminator);
+          component_.ResumeAtSameIteration();
+          return engine.Begin(Job::Destruction, d, comp.derivedType());
         }
-        d->Deallocate();
-        descriptor.IncrementSubscripts(at);
+        d.Deallocate();
       }
     } else if (destroyComp &&
         comp.genre() == typeInfo::Component::Genre::Data) {
-      SubscriptValue extents[maxRank];
-      GetComponentExtents(extents, comp, descriptor);
-      StaticDescriptor<maxRank, true, 0> staticDescriptor;
-      Descriptor &compDesc{staticDescriptor.descriptor()};
+      Descriptor &compDesc{staticDescriptor_.descriptor()};
       const typeInfo::DerivedType &compType{*comp.derivedType()};
-      for (std::size_t j{0}; j++ < elements;
-           descriptor.IncrementSubscripts(at)) {
+      while (element_.Iterating(elements_, &instance_)) {
+        if (element_.at == 0) {
+          GetComponentExtents(extents_, comp, instance_);
+        }
         compDesc.Establish(compType,
-            descriptor.ElementComponent<char>(at, comp.offset()), comp.rank(),
-            extents);
-        Destroy(compDesc, /*finalize=*/false, *comp.derivedType(), terminator);
+            instance_.ElementComponent<char>(
+                element_.subscripts, comp.offset()),
+            comp.rank(), extents_);
+        component_.ResumeAtSameIteration();
+        return engine.Begin(Job::Destruction, compDesc, &compType);
       }
     }
   }
+  return engine.Done();
+}
+
+// The order of finalization follows Fortran 2018 7.5.6.2, with
+// elementwise finalization of non-parent components taking place
+// before parent component finalization, and with all finalization
+// preceding any deallocation.
+RT_API_ATTRS void Destroy(const Descriptor &descriptor, bool finalize,
+    const typeInfo::DerivedType &derived, Terminator *terminator) {
+  if (!derived.noDestructionNeeded() && descriptor.IsAllocated()) {
+    Terminator defaultTerminator{"Destroy() in Fortran runtime"};
+    if (!terminator) {
+      terminator = &defaultTerminator;
+    }
+    engine::Engine engine{*terminator, /*hasStat=*/false, /*errMsg=*/nullptr};
+    if (finalize && !derived.noFinalizationNeeded()) {
+      engine.Do(engine::Job::Finalization, descriptor, &derived);
+    }
+    engine.Do(engine::Job::Destruction, descriptor, &derived);
+  }
 }
 
 RT_API_ATTRS bool HasDynamicComponent(const Descriptor &descriptor) {
diff --git a/flang/runtime/engine.cpp b/flang/runtime/engine.cpp
index 6ac683950728c6..816d3f9939da29 100644
--- a/flang/runtime/engine.cpp
+++ b/flang/runtime/engine.cpp
@@ -13,24 +13,19 @@ namespace Fortran::runtime::engine {
 
 RT_API_ATTRS Engine::Work::Work(
     Job job, const Descriptor &instance, const typeInfo::DerivedType *derived)
-    : job_{job}, u_{{instance}} {
-  auto &state{u_.commonState};
-  state.derived_ = derived;
-  state.elements_ = instance.Elements();
-  if (derived) {
-    state.componentDesc_ = &derived->component();
-    state.components_ = state.componentDesc_->Elements();
-  } else {
-    state.componentDesc_ = nullptr;
-    state.components_ = 0;
-  }
-}
+    : job_{job}, u_{{instance, derived}} {}
 
 RT_API_ATTRS void Engine::Work::Resume(Engine &engine) {
   switch (job_) {
   case Job::Initialization:
     u_.initialization.Resume(engine);
     return;
+  case Job::Finalization:
+    u_.finalization.Resume(engine);
+    return;
+  case Job::Destruction:
+    u_.destruction.Resume(engine);
+    return;
   }
   engine.terminator().Crash(
       "Work::Run: bad job_ code %d", static_cast<int>(job_));
diff --git a/flang/runtime/engine.h b/flang/runtime/engine.h
index 40157adddb5f6f..06ae7d8e04fe29 100644
--- a/flang/runtime/engine.h
+++ b/flang/runtime/engine.h
@@ -28,6 +28,9 @@ struct Task {
 
   enum class ResultType { ResultValue /*doesn't matter*/ };
 
+  Task(const Descriptor &instance, const typeInfo::DerivedType *derived)
+      : instance_{instance}, derived_{derived} {}
+
   struct Iteration {
     RT_API_ATTRS bool Iterating(
         std::size_t iters, const Descriptor *dtor = nullptr) {
@@ -62,19 +65,21 @@ struct Task {
     SubscriptValue subscripts[maxRank];
   };
 
-  // For looping over elements
   const Descriptor &instance_;
-  std::size_t elements_;
+  const typeInfo::DerivedType *derived_;
+  int phase_{0};
+
+  // For looping over elements
+  std::size_t elements_{instance_.Elements()};
   Iteration element_;
 
   // For looping over components
-  const typeInfo::DerivedType *derived_;
-  const Descriptor *componentDesc_;
-  std::size_t components_;
+  const Descriptor *componentDesc_{derived_ ? &derived_->component() : nullptr};
+  std::size_t components_{componentDesc_ ? componentDesc_->Elements() : 0};
   Iteration component_;
 };
 
-enum class Job { Initialization };
+enum class Job { Initialization, Finalization, Destruction };
 
 class Initialization : protected Task {
 public:
@@ -82,7 +87,25 @@ class Initialization : protected Task {
 
 private:
   SubscriptValue extents_[maxRank];
-  StaticDescriptor<maxRank, true, 8> staticDescriptor_;
+  StaticDescriptor<maxRank, true, 0> staticDescriptor_;
+};
+
+class Finalization : protected Task {
+public:
+  RT_API_ATTRS ResultType Resume(Engine &);
+
+private:
+  SubscriptValue extents_[maxRank];
+  StaticDescriptor<maxRank, true, 0> staticDescriptor_;
+};
+
+class Destruction : protected Task {
+public:
+  RT_API_ATTRS ResultType Resume(Engine &);
+
+private:
+  SubscriptValue extents_[maxRank];
+  StaticDescriptor<maxRank, true, 0> staticDescriptor_;
 };
 
 class Engine {
@@ -121,6 +144,8 @@ class Engine {
     union {
       Task commonState;
       Initialization initialization;
+      Finalization finalization;
+      Destruction destruction;
     } u_;
   };
 

>From 0bb843a8ea44cdcad146b2db785801aa8b05afb4 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Sat, 17 Aug 2024 13:59:34 -0700
Subject: [PATCH 3/3] fixes

---
 flang/runtime/engine.cpp | 2 +-
 flang/runtime/engine.h   | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/flang/runtime/engine.cpp b/flang/runtime/engine.cpp
index 816d3f9939da29..7c399e17823a04 100644
--- a/flang/runtime/engine.cpp
+++ b/flang/runtime/engine.cpp
@@ -38,7 +38,7 @@ RT_API_ATTRS Engine::~Engine() {
 RT_API_ATTRS int Engine::Do(
     Job job, const Descriptor &instance, const typeInfo::DerivedType *derived) {
   Begin(job, instance, derived);
-  while (topWorkBlock_ != &bottomWorkBlock_ && topWorkBlock_->depth > 0) {
+  while (topWorkBlock_ != &bottomWorkBlock_ || bottomWorkBlock_.depth > 0) {
     if (status_ == StatOk) {
       auto *w{reinterpret_cast<Work *>(
           topWorkBlock_->workBuf[topWorkBlock_->depth - 1])};
diff --git a/flang/runtime/engine.h b/flang/runtime/engine.h
index 06ae7d8e04fe29..2a2878d952315a 100644
--- a/flang/runtime/engine.h
+++ b/flang/runtime/engine.h
@@ -1,4 +1,4 @@
-//===-- runtime/engine.h --------------------------------------------------===//
+//===-- runtime/engine.h ---------------------------------------*- C++ -*- ===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.



More information about the flang-commits mailing list