[flang-commits] [flang] [flang] Avoid recursion in runtime derived type initialization (PR #102394)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Wed Aug 7 15:03:45 PDT 2024


https://github.com/klausler created https://github.com/llvm/llvm-project/pull/102394

Adds a recursion-avoiding work engine for things like this, and adapts derived type instance initialization to use it.  If successful, the engine can be reused to replace recursion in many other runtime sites.

>From 5619fc94bc79c78257c00b3f0c0d3c010e3c3259 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 7 Aug 2024 15:01:29 -0700
Subject: [PATCH] [flang] Avoid recursion in runtime derived type
 initialization

Adds a recursion-avoiding work engine for things like this, and
adapts derived type instance initialization to use it.  If
successful, the engine can be reused to replace recursion in many
other runtime sites.
---
 flang/runtime/CMakeLists.txt |   1 +
 flang/runtime/derived.cpp    | 106 ++++++++++++++++-----------------
 flang/runtime/engine.cpp     |  67 +++++++++++++++++++++
 flang/runtime/engine.h       | 112 +++++++++++++++++++++++++++++++++++
 4 files changed, 231 insertions(+), 55 deletions(-)
 create mode 100644 flang/runtime/engine.cpp
 create mode 100644 flang/runtime/engine.h

diff --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt
index 4537b2d059d65..3d6278d831d43 100644
--- a/flang/runtime/CMakeLists.txt
+++ b/flang/runtime/CMakeLists.txt
@@ -125,6 +125,7 @@ set(sources
   dot-product.cpp
   edit-input.cpp
   edit-output.cpp
+  engine.cpp
   environment.cpp
   exceptions.cpp
   execute.cpp
diff --git a/flang/runtime/derived.cpp b/flang/runtime/derived.cpp
index 659f54fa344bb..1e5c1c0dc05dd 100644
--- a/flang/runtime/derived.cpp
+++ b/flang/runtime/derived.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "derived.h"
+#include "engine.h"
 #include "stat.h"
 #include "terminator.h"
 #include "tools.h"
@@ -29,97 +30,92 @@ static RT_API_ATTRS void GetComponentExtents(SubscriptValue (&extents)[maxRank],
   }
 }
 
-RT_API_ATTRS int Initialize(const Descriptor &instance,
-    const typeInfo::DerivedType &derived, Terminator &terminator, bool hasStat,
-    const Descriptor *errMsg) {
-  const Descriptor &componentDesc{derived.component()};
-  std::size_t elements{instance.Elements()};
-  int stat{StatOk};
-  // Initialize data components in each element; the per-element iterations
-  // constitute the inner loops, not the outer ones
-  std::size_t myComponents{componentDesc.Elements()};
-  for (std::size_t k{0}; k < myComponents; ++k) {
+int engine::Initialization::Run(Engine &engine) {
+  while (component_.Iterating(components_)) {
     const auto &comp{
-        *componentDesc.ZeroBasedIndexedElement<typeInfo::Component>(k)};
-    SubscriptValue at[maxRank];
-    instance.GetLowerBounds(at);
+        *componentDesc_->ZeroBasedIndexedElement<typeInfo::Component>(
+            component_.at)};
     if (comp.genre() == typeInfo::Component::Genre::Allocatable ||
         comp.genre() == typeInfo::Component::Genre::Automatic) {
-      for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
-        Descriptor &allocDesc{
-            *instance.ElementComponent<Descriptor>(at, comp.offset())};
-        comp.EstablishDescriptor(allocDesc, instance, terminator);
+      while (element_.Iterating(elements_, &instance_)) {
+        Descriptor &allocDesc{*instance_.ElementComponent<Descriptor>(
+            element_.subscripts, comp.offset())};
+        comp.EstablishDescriptor(allocDesc, instance_, engine.terminator());
         allocDesc.raw().attribute = CFI_attribute_allocatable;
         if (comp.genre() == typeInfo::Component::Genre::Automatic) {
-          stat = ReturnError(terminator, allocDesc.Allocate(), errMsg, hasStat);
-          if (stat == StatOk) {
-            if (const DescriptorAddendum * addendum{allocDesc.Addendum()}) {
-              if (const auto *derived{addendum->derivedType()}) {
-                if (!derived->noInitializationNeeded()) {
-                  stat = Initialize(
-                      allocDesc, *derived, terminator, hasStat, errMsg);
-                }
+          if (auto stat{ReturnError(engine.terminator(), allocDesc.Allocate(),
+                  engine.errMsg(), engine.hasStat())};
+              stat != StatOk) {
+            return stat;
+          }
+          if (const DescriptorAddendum * addendum{allocDesc.Addendum()}) {
+            if (const auto *derived{addendum->derivedType()}) {
+              if (!derived->noInitializationNeeded()) {
+                component_.ResumeAtSameIteration();
+                return engine.Begin(Job::Initialize, allocDesc, derived);
               }
             }
           }
-          if (stat != StatOk) {
-            break;
-          }
         }
       }
     } else if (const void *init{comp.initialization()}) {
       // Explicit initialization of data pointers and
       // non-allocatable non-automatic components
-      std::size_t bytes{comp.SizeInBytes(instance)};
-      for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
-        char *ptr{instance.ElementComponent<char>(at, comp.offset())};
+      std::size_t bytes{comp.SizeInBytes(instance_)};
+      while (element_.Iterating(elements_, &instance_)) {
+        char *ptr{instance_.ElementComponent<char>(
+            element_.subscripts, comp.offset())};
         std::memcpy(ptr, init, bytes);
       }
     } else if (comp.genre() == typeInfo::Component::Genre::Pointer) {
       // Data pointers without explicit initialization are established
       // so that they are valid right-hand side targets of pointer
       // assignment statements.
-      for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
-        Descriptor &ptrDesc{
-            *instance.ElementComponent<Descriptor>(at, comp.offset())};
-        comp.EstablishDescriptor(ptrDesc, instance, terminator);
+      while (element_.Iterating(elements_, &instance_)) {
+        Descriptor &ptrDesc{*instance_.ElementComponent<Descriptor>(
+            element_.subscripts, comp.offset())};
+        comp.EstablishDescriptor(ptrDesc, instance_, engine.terminator());
         ptrDesc.raw().attribute = CFI_attribute_pointer;
       }
     } else if (comp.genre() == typeInfo::Component::Genre::Data &&
         comp.derivedType() && !comp.derivedType()->noInitializationNeeded()) {
       // Default initialization of non-pointer non-allocatable/automatic
-      // data component.  Handles parent component's elements.  Recursive.
-      SubscriptValue extents[maxRank];
-      GetComponentExtents(extents, comp, instance);
-      StaticDescriptor<maxRank, true, 0> staticDescriptor;
-      Descriptor &compDesc{staticDescriptor.descriptor()};
-      const typeInfo::DerivedType &compType{*comp.derivedType()};
-      for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
+      // data component.  Handles parent component's elements.
+      if (!element_.active) {
+        GetComponentExtents(extents_, comp, instance_);
+      }
+      while (element_.Iterating(elements_, &instance_)) {
+        Descriptor &compDesc{staticDescriptor_.descriptor()};
+        const typeInfo::DerivedType &compType{*comp.derivedType()};
         compDesc.Establish(compType,
-            instance.ElementComponent<char>(at, comp.offset()), comp.rank(),
-            extents);
-        stat = Initialize(compDesc, compType, terminator, hasStat, errMsg);
-        if (stat != StatOk) {
-          break;
-        }
+            instance_.ElementComponent<char>(
+                element_.subscripts, comp.offset()),
+            comp.rank(), extents_);
+        component_.ResumeAtSameIteration();
+        return engine.Begin(Job::Initialize, compDesc, &compType);
       }
     }
   }
   // Initialize procedure pointer components in each element
-  const Descriptor &procPtrDesc{derived.procPtr()};
+  const Descriptor &procPtrDesc{derived_->procPtr()};
   std::size_t myProcPtrs{procPtrDesc.Elements()};
   for (std::size_t k{0}; k < myProcPtrs; ++k) {
     const auto &comp{
         *procPtrDesc.ZeroBasedIndexedElement<typeInfo::ProcPtrComponent>(k)};
-    SubscriptValue at[maxRank];
-    instance.GetLowerBounds(at);
-    for (std::size_t j{0}; j++ < elements; instance.IncrementSubscripts(at)) {
-      auto &pptr{*instance.ElementComponent<typeInfo::ProcedurePointer>(
-          at, comp.offset)};
+    while (element_.Iterating(elements_, &instance_)) {
+      auto &pptr{*instance_.ElementComponent<typeInfo::ProcedurePointer>(
+          element_.subscripts, comp.offset)};
       pptr = comp.procInitialization;
     }
   }
-  return stat;
+  return engine.Done();
+}
+
+RT_API_ATTRS int Initialize(const Descriptor &instance,
+    const typeInfo::DerivedType &derived, Terminator &terminator, bool hasStat,
+    const Descriptor *errMsg) {
+  return engine::Engine{terminator, hasStat, errMsg}.Do(
+      engine::Job::Initialize, instance, &derived);
 }
 
 static RT_API_ATTRS const typeInfo::SpecialBinding *FindFinal(
diff --git a/flang/runtime/engine.cpp b/flang/runtime/engine.cpp
new file mode 100644
index 0000000000000..dc64202c72611
--- /dev/null
+++ b/flang/runtime/engine.cpp
@@ -0,0 +1,67 @@
+//===-- runtime/engine.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "engine.h"
+
+namespace Fortran::runtime::engine {
+
+Work::Work(
+    Job job, const Descriptor &instance, const typeInfo::DerivedType *derived)
+    : job_{job}, u_{{instance}} {
+  auto &state{u_.commonState};
+  state.derived_ = derived;
+  state.elements_ = instance.Elements();
+  if (derived) {
+    state.componentDesc_ = &derived->component();
+    state.components_ = state.componentDesc_->Elements();
+  } else {
+    state.componentDesc_ = nullptr;
+    state.components_ = 0;
+  }
+}
+
+int Work::Run(Engine &engine) {
+  switch (job_) {
+  case Job::Initialize:
+    return u_.initialization.Run(engine);
+  }
+  engine.terminator().Crash(
+      "Work::Run: bad job_ code %d", static_cast<int>(job_));
+}
+
+int Engine::Do(
+    Job job, const Descriptor &instance, const typeInfo::DerivedType *derived) {
+  if (int status{Begin(job, instance, derived)}; status != StatOk) {
+    return status;
+  }
+  return Run();
+}
+
+int Engine::Begin(
+    Job job, const Descriptor &instance, const typeInfo::DerivedType *derived) {
+  // TODO: heap allocation on overflow
+  new (workBuf_[depth_++]) Work{job, instance, derived};
+  return StatOk;
+}
+
+int Engine::Run() {
+  while (depth_) {
+    auto *w{reinterpret_cast<Work *>(workBuf_[depth_ - 1])};
+    if (int status{w->Run(*this)}; status != StatOk) {
+      return status;
+    }
+  }
+  return StatOk;
+}
+
+int Engine::Done() {
+  --depth_;
+  return StatOk;
+}
+
+} // namespace Fortran::runtime::engine
diff --git a/flang/runtime/engine.h b/flang/runtime/engine.h
new file mode 100644
index 0000000000000..53b4a11f6b245
--- /dev/null
+++ b/flang/runtime/engine.h
@@ -0,0 +1,112 @@
+//===-- runtime/engine.h --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_RUNTIME_ENGINE_H_
+#define FORTRAN_RUNTIME_ENGINE_H_
+
+#include "derived.h"
+#include "stat.h"
+#include "terminator.h"
+#include "type-info.h"
+#include "flang/Runtime/descriptor.h"
+
+namespace Fortran::runtime::engine {
+
+class Engine;
+
+enum class Job { Initialize };
+
+struct CommonState {
+  struct Iteration {
+    bool Iterating(std::size_t iters, const Descriptor *dtor = nullptr) {
+      if (!active) {
+        if (iters > 0) {
+          active = true;
+          at = 0;
+          n = iters;
+          descriptor = dtor;
+          if (descriptor) {
+            descriptor->GetLowerBounds(subscripts);
+          }
+        }
+      } else if (resuming) {
+        resuming = false;
+      } else if (++at < n) {
+        if (descriptor) {
+          descriptor->IncrementSubscripts(subscripts);
+        }
+      } else {
+        active = false;
+      }
+      return active;
+    }
+    void ResumeAtSameIteration() { resuming = true; }
+
+    bool active{false}, resuming{false};
+    std::size_t at, n;
+    const Descriptor *descriptor;
+    SubscriptValue subscripts[maxRank];
+  };
+
+  const Descriptor &instance_;
+  const typeInfo::DerivedType *derived_;
+  const Descriptor *componentDesc_;
+  std::size_t elements_, components_;
+  Iteration element_, component_;
+  StaticDescriptor<maxRank, true, 8> staticDescriptor_;
+};
+
+class Initialization : protected CommonState {
+public:
+  int Run(Engine &); // in derived.cpp
+private:
+  SubscriptValue extents_[maxRank];
+};
+
+class Engine {
+public:
+  Engine(Terminator &terminator, bool hasStat, const Descriptor *errMsg)
+      : terminator_{terminator}, hasStat_{hasStat}, errMsg_{errMsg} {}
+
+  // Start and run a job to completion.
+  int Do(Job, const Descriptor &instance, const typeInfo::DerivedType *);
+
+  Terminator &terminator() const { return terminator_; }
+  bool hasStat() const { return hasStat_; }
+  const Descriptor *errMsg() const { return errMsg_; }
+
+  // Call from running job to suspend execution and start a nested job
+  int Begin(Job, const Descriptor &instance, const typeInfo::DerivedType *);
+  // Call from a running job to terminate successfully
+  int Done();
+
+private:
+  class Work {
+  public:
+    Work(Job job, const Descriptor &instance, const typeInfo::DerivedType *);
+    int Run(Engine &); // nonzero on fatal error
+  private:
+    Job job_;
+    union {
+      CommonState commonState;
+      Initialization initialization;
+    } u_;
+  };
+
+  int Run();
+
+  Terminator &terminator_;
+  bool hasStat_{false};
+  const Descriptor *errMsg_;
+  int depth_{0};
+  static constexpr int maxDepth{4};
+  alignas(Work) char workBuf_[maxDepth][sizeof(Work)];
+};
+
+} // namespace Fortran::runtime::engine
+#endif // FORTRAN_RUNTIME_ENGINE_H_



More information about the flang-commits mailing list