[flang-commits] [flang] 2b86fb2 - [flang][runtime] Avoid recursive calls in F18 runtime CUDA build. (#87428)

via flang-commits flang-commits at lists.llvm.org
Tue Apr 2 21:03:52 PDT 2024


Author: Slava Zakharin
Date: 2024-04-02T21:03:49-07:00
New Revision: 2b86fb21f8402f19da7e5887a9572b3d55052991

URL: https://github.com/llvm/llvm-project/commit/2b86fb21f8402f19da7e5887a9572b3d55052991
DIFF: https://github.com/llvm/llvm-project/commit/2b86fb21f8402f19da7e5887a9572b3d55052991.diff

LOG: [flang][runtime] Avoid recursive calls in F18 runtime CUDA build. (#87428)

Recurrencies in the call graph (even if they are not executed)
prevent computing the minimal stack size required for a kernel
execution. This change disables some functionality of F18 IO
to avoid recursive calls. A couple of functions are rewritten
to work without using recursion.

Added: 
    

Modified: 
    flang/include/flang/Common/api-attrs.h
    flang/runtime/descriptor-io.h
    flang/runtime/edit-output.cpp
    flang/runtime/emit-encoded.h
    flang/runtime/io-stmt.cpp
    flang/runtime/io-stmt.h
    flang/runtime/unit.cpp
    flang/runtime/unit.h

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Common/api-attrs.h b/flang/include/flang/Common/api-attrs.h
index 4d069c6097ddfe..04ee307326ac92 100644
--- a/flang/include/flang/Common/api-attrs.h
+++ b/flang/include/flang/Common/api-attrs.h
@@ -133,6 +133,18 @@
 #undef RT_DEVICE_COMPILATION
 #endif
 
+/*
+ * Recurrence in the call graph prevents computing minimal stack size
+ * required for a kernel execution. This macro can be used to disable
+ * some F18 runtime functionality that is implemented using recurrent
+ * function calls or to use alternative implementation.
+ */
+#if (defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__)
+#define RT_DEVICE_AVOID_RECURSION 1
+#else
+#undef RT_DEVICE_AVOID_RECURSION
+#endif
+
 #if defined(__CUDACC__)
 #define RT_DIAG_PUSH _Pragma("nv_diagnostic push")
 #define RT_DIAG_POP _Pragma("nv_diagnostic pop")

diff  --git a/flang/runtime/descriptor-io.h b/flang/runtime/descriptor-io.h
index 7063858d619196..0b188a12a0198e 100644
--- a/flang/runtime/descriptor-io.h
+++ b/flang/runtime/descriptor-io.h
@@ -250,6 +250,7 @@ static RT_API_ATTRS bool DefaultComponentIO(IoStatementState &io,
     const typeInfo::Component &component, const Descriptor &origDescriptor,
     const SubscriptValue origSubscripts[], Terminator &terminator,
     const NonTbpDefinedIoTable *table) {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
   if (component.genre() == typeInfo::Component::Genre::Data) {
     // Create a descriptor for the component
     StaticDescriptor<maxRank, true, 16 /*?*/> statDesc;
@@ -266,6 +267,9 @@ static RT_API_ATTRS bool DefaultComponentIO(IoStatementState &io,
     const Descriptor &compDesc{*reinterpret_cast<const Descriptor *>(pointer)};
     return DescriptorIO<DIR>(io, compDesc, table);
   }
+#else
+  terminator.Crash("not yet implemented: component IO");
+#endif
 }
 
 template <Direction DIR>

diff  --git a/flang/runtime/edit-output.cpp b/flang/runtime/edit-output.cpp
index b710c298babebf..a06ed258f0f1d2 100644
--- a/flang/runtime/edit-output.cpp
+++ b/flang/runtime/edit-output.cpp
@@ -751,43 +751,50 @@ RT_API_ATTRS bool RealOutputEditing<KIND>::EditEXOutput(const DataEdit &edit) {
 
 template <int KIND>
 RT_API_ATTRS bool RealOutputEditing<KIND>::Edit(const DataEdit &edit) {
-  switch (edit.descriptor) {
+  const DataEdit *editPtr{&edit};
+  DataEdit newEdit;
+  if (editPtr->descriptor == 'G') {
+    // Avoid recursive call as in Edit(EditForGOutput(edit)).
+    newEdit = EditForGOutput(*editPtr);
+    editPtr = &newEdit;
+    RUNTIME_CHECK(io_.GetIoErrorHandler(), editPtr->descriptor != 'G');
+  }
+  switch (editPtr->descriptor) {
   case 'D':
-    return EditEorDOutput(edit);
+    return EditEorDOutput(*editPtr);
   case 'E':
-    if (edit.variation == 'X') {
-      return EditEXOutput(edit);
+    if (editPtr->variation == 'X') {
+      return EditEXOutput(*editPtr);
     } else {
-      return EditEorDOutput(edit);
+      return EditEorDOutput(*editPtr);
     }
   case 'F':
-    return EditFOutput(edit);
+    return EditFOutput(*editPtr);
   case 'B':
-    return EditBOZOutput<1>(io_, edit,
+    return EditBOZOutput<1>(io_, *editPtr,
         reinterpret_cast<const unsigned char *>(&x_),
         common::BitsForBinaryPrecision(common::PrecisionOfRealKind(KIND)) >> 3);
   case 'O':
-    return EditBOZOutput<3>(io_, edit,
+    return EditBOZOutput<3>(io_, *editPtr,
         reinterpret_cast<const unsigned char *>(&x_),
         common::BitsForBinaryPrecision(common::PrecisionOfRealKind(KIND)) >> 3);
   case 'Z':
-    return EditBOZOutput<4>(io_, edit,
+    return EditBOZOutput<4>(io_, *editPtr,
         reinterpret_cast<const unsigned char *>(&x_),
         common::BitsForBinaryPrecision(common::PrecisionOfRealKind(KIND)) >> 3);
-  case 'G':
-    return Edit(EditForGOutput(edit));
   case 'L':
-    return EditLogicalOutput(io_, edit, *reinterpret_cast<const char *>(&x_));
+    return EditLogicalOutput(
+        io_, *editPtr, *reinterpret_cast<const char *>(&x_));
   case 'A': // legacy extension
     return EditCharacterOutput(
-        io_, edit, reinterpret_cast<char *>(&x_), sizeof x_);
+        io_, *editPtr, reinterpret_cast<char *>(&x_), sizeof x_);
   default:
-    if (edit.IsListDirected()) {
-      return EditListDirectedOutput(edit);
+    if (editPtr->IsListDirected()) {
+      return EditListDirectedOutput(*editPtr);
     }
     io_.GetIoErrorHandler().SignalError(IostatErrorInFormat,
         "Data edit descriptor '%c' may not be used with a REAL data item",
-        edit.descriptor);
+        editPtr->descriptor);
     return false;
   }
   return false;

diff  --git a/flang/runtime/emit-encoded.h b/flang/runtime/emit-encoded.h
index ac8c7d758a0d00..4b5e3900788357 100644
--- a/flang/runtime/emit-encoded.h
+++ b/flang/runtime/emit-encoded.h
@@ -18,22 +18,26 @@
 
 namespace Fortran::runtime::io {
 
-template <typename CONTEXT, typename CHAR>
+template <typename CONTEXT, typename CHAR, bool NL_ADVANCES_RECORD = true>
 RT_API_ATTRS bool EmitEncoded(
     CONTEXT &to, const CHAR *data, std::size_t chars) {
   ConnectionState &connection{to.GetConnectionState()};
-  if (connection.access == Access::Stream &&
-      connection.internalIoCharKind == 0) {
-    // Stream output: treat newlines as record advancements so that the left tab
-    // limit is correctly managed
-    while (const CHAR * nl{FindCharacter(data, CHAR{'\n'}, chars)}) {
-      auto pos{static_cast<std::size_t>(nl - data)};
-      if (!EmitEncoded(to, data, pos)) {
-        return false;
+  if constexpr (NL_ADVANCES_RECORD) {
+    if (connection.access == Access::Stream &&
+        connection.internalIoCharKind == 0) {
+      // Stream output: treat newlines as record advancements so that the left
+      // tab limit is correctly managed
+      while (const CHAR * nl{FindCharacter(data, CHAR{'\n'}, chars)}) {
+        auto pos{static_cast<std::size_t>(nl - data)};
+        // The [data, data + pos) does not contain the newline,
+        // so we can avoid the recursion by calling proper specialization.
+        if (!EmitEncoded<CONTEXT, CHAR, false>(to, data, pos)) {
+          return false;
+        }
+        data += pos + 1;
+        chars -= pos + 1;
+        to.AdvanceRecord();
       }
-      data += pos + 1;
-      chars -= pos + 1;
-      to.AdvanceRecord();
     }
   }
   if (connection.useUTF8<CHAR>()) {

diff  --git a/flang/runtime/io-stmt.cpp b/flang/runtime/io-stmt.cpp
index 022e4c806bf63b..1a5d32ecd8c5a1 100644
--- a/flang/runtime/io-stmt.cpp
+++ b/flang/runtime/io-stmt.cpp
@@ -220,7 +220,11 @@ ExternalIoStatementBase::ExternalIoStatementBase(
 
 MutableModes &ExternalIoStatementBase::mutableModes() {
   if (const ChildIo * child{unit_.GetChildIo()}) {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
     return child->parent().mutableModes();
+#else
+    ReportUnsupportedChildIo();
+#endif
   }
   return unit_.modes;
 }
@@ -891,17 +895,29 @@ ChildIoStatementState<DIR>::ChildIoStatementState(
 
 template <Direction DIR>
 MutableModes &ChildIoStatementState<DIR>::mutableModes() {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
   return child_.parent().mutableModes();
+#else
+  ReportUnsupportedChildIo();
+#endif
 }
 
 template <Direction DIR>
 ConnectionState &ChildIoStatementState<DIR>::GetConnectionState() {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
   return child_.parent().GetConnectionState();
+#else
+  ReportUnsupportedChildIo();
+#endif
 }
 
 template <Direction DIR>
 ExternalFileUnit *ChildIoStatementState<DIR>::GetExternalFileUnit() const {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
   return child_.parent().GetExternalFileUnit();
+#else
+  ReportUnsupportedChildIo();
+#endif
 }
 
 template <Direction DIR> int ChildIoStatementState<DIR>::EndIoStatement() {
@@ -914,22 +930,38 @@ template <Direction DIR> int ChildIoStatementState<DIR>::EndIoStatement() {
 template <Direction DIR>
 bool ChildIoStatementState<DIR>::Emit(
     const char *data, std::size_t bytes, std::size_t elementBytes) {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
   return child_.parent().Emit(data, bytes, elementBytes);
+#else
+  ReportUnsupportedChildIo();
+#endif
 }
 
 template <Direction DIR>
 std::size_t ChildIoStatementState<DIR>::GetNextInputBytes(const char *&p) {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
   return child_.parent().GetNextInputBytes(p);
+#else
+  ReportUnsupportedChildIo();
+#endif
 }
 
 template <Direction DIR>
 void ChildIoStatementState<DIR>::HandleAbsolutePosition(std::int64_t n) {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
   return child_.parent().HandleAbsolutePosition(n);
+#else
+  ReportUnsupportedChildIo();
+#endif
 }
 
 template <Direction DIR>
 void ChildIoStatementState<DIR>::HandleRelativePosition(std::int64_t n) {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
   return child_.parent().HandleRelativePosition(n);
+#else
+  ReportUnsupportedChildIo();
+#endif
 }
 
 template <Direction DIR, typename CHAR>
@@ -957,13 +989,21 @@ int ChildFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
 
 template <Direction DIR, typename CHAR>
 bool ChildFormattedIoStatementState<DIR, CHAR>::AdvanceRecord(int n) {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
   return this->child().parent().AdvanceRecord(n);
+#else
+  this->ReportUnsupportedChildIo();
+#endif
 }
 
 template <Direction DIR>
 bool ChildUnformattedIoStatementState<DIR>::Receive(
     char *data, std::size_t bytes, std::size_t elementBytes) {
+#if !defined(RT_DEVICE_AVOID_RECURSION)
   return this->child().parent().Receive(data, bytes, elementBytes);
+#else
+  this->ReportUnsupportedChildIo();
+#endif
 }
 
 template <Direction DIR> int ChildListIoStatementState<DIR>::EndIoStatement() {

diff  --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index 8b5752311de5c3..6053aeb777b7a5 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -296,6 +296,10 @@ class IoStatementBase : public IoErrorHandler {
 
   RT_API_ATTRS void BadInquiryKeywordHashCrash(InquiryKeywordHash);
 
+  RT_API_ATTRS void ReportUnsupportedChildIo() const {
+    Crash("not yet implemented: child IO");
+  }
+
 protected:
   bool completedOperation_{false};
 };

diff  --git a/flang/runtime/unit.cpp b/flang/runtime/unit.cpp
index 6c648d3bd83467..0e38cffdf907d7 100644
--- a/flang/runtime/unit.cpp
+++ b/flang/runtime/unit.cpp
@@ -206,7 +206,7 @@ bool ExternalFileUnit::BeginReadingRecord(IoErrorHandler &handler) {
       if (anyWriteSinceLastPositioning_ && access == Access::Sequential) {
         // Most Fortran implementations allow a READ after a WRITE;
         // the read then just hits an EOF.
-        DoEndfile(handler);
+        DoEndfile<false, Direction::Input>(handler);
       }
       recordLength.reset();
       RUNTIME_CHECK(handler, isUnformatted.has_value());
@@ -671,13 +671,23 @@ void ExternalFileUnit::DoImpliedEndfile(IoErrorHandler &handler) {
   impliedEndfile_ = false;
 }
 
+template <bool ANY_DIR, Direction DIR>
 void ExternalFileUnit::DoEndfile(IoErrorHandler &handler) {
   if (IsRecordFile() && access != Access::Direct) {
     furthestPositionInRecord =
         std::max(positionInRecord, furthestPositionInRecord);
     if (leftTabLimit) { // last I/O was non-advancing
       if (access == Access::Sequential && direction_ == Direction::Output) {
-        AdvanceRecord(handler);
+        if constexpr (ANY_DIR || DIR == Direction::Output) {
+          // When DoEndfile() is called from BeginReadingRecord(),
+          // this call to AdvanceRecord() may appear as a recursion
+          // though it may never happen. Expose the call only
+          // under the constexpr direction check.
+          AdvanceRecord(handler);
+        } else {
+          // This check always fails if we are here.
+          RUNTIME_CHECK(handler, direction_ != Direction::Output);
+        }
       } else { // Access::Stream or input
         leftTabLimit.reset();
         ++currentRecordNumber;
@@ -695,6 +705,12 @@ void ExternalFileUnit::DoEndfile(IoErrorHandler &handler) {
   anyWriteSinceLastPositioning_ = false;
 }
 
+template void ExternalFileUnit::DoEndfile(IoErrorHandler &handler);
+template void ExternalFileUnit::DoEndfile<false, Direction::Output>(
+    IoErrorHandler &handler);
+template void ExternalFileUnit::DoEndfile<false, Direction::Input>(
+    IoErrorHandler &handler);
+
 void ExternalFileUnit::CommitWrites() {
   frameOffsetInFile_ +=
       recordOffsetInFrame_ + recordLength.value_or(furthestPositionInRecord);

diff  --git a/flang/runtime/unit.h b/flang/runtime/unit.h
index a6ee5971a16524..e59fbbce2b5771 100644
--- a/flang/runtime/unit.h
+++ b/flang/runtime/unit.h
@@ -204,6 +204,7 @@ class ExternalFileUnit : public ConnectionState,
   RT_API_ATTRS void BackspaceVariableFormattedRecord(IoErrorHandler &);
   RT_API_ATTRS bool SetVariableFormattedRecordLength();
   RT_API_ATTRS void DoImpliedEndfile(IoErrorHandler &);
+  template <bool ANY_DIR = true, Direction DIR = Direction::Output>
   RT_API_ATTRS void DoEndfile(IoErrorHandler &);
   RT_API_ATTRS void CommitWrites();
   RT_API_ATTRS bool CheckDirectAccess(IoErrorHandler &);


        


More information about the flang-commits mailing list