[flang-commits] [flang] 8db4dc8 - [flang] Error recovery improvement in runtime (IOMSG=)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Fri Mar 18 17:28:18 PDT 2022


Author: Peter Klausler
Date: 2022-03-18T17:24:32-07:00
New Revision: 8db4dc86861211dfd72848cb9f35e78a1fe5c1d8

URL: https://github.com/llvm/llvm-project/commit/8db4dc86861211dfd72848cb9f35e78a1fe5c1d8
DIFF: https://github.com/llvm/llvm-project/commit/8db4dc86861211dfd72848cb9f35e78a1fe5c1d8.diff

LOG: [flang] Error recovery improvement in runtime (IOMSG=)

Some refactoring and related fixes for more accurate
user program error recovery in the I/O runtime, especially
for error recovery with IOMSG= character values.

1) Move any work in an EndIoStatement() implementation
that may raise an error into a new CompleteOperation()
member function.  This allows error handling APIs like
GetIoMsg() to complete a pending I/O statement and harvest
any errors that may result.

2) Move the pending error code from ErroneousIoStatementState
to a new pendingError_ data member in IoErrorHandler.
This allows IoErrorHandler::InError() to return a correct
result when there is a pending error that will be recovered
from so that I/O list data transfers don't crash in the meantime.

3) Don't create and leak a unit for a failed OPEN(NEWUNIT=n)
with error recovery, and don't modify 'n'.  (Depends on
changes to API call ordering in lowering, in a separate patch;
code was added to ensure that OPEN statement control list
specifiers, e.g. SetFile(), must be passed before GetNewUnit().)

4) Fix the code that calls a form of strerror to fill an
IOMSG= variable so that it actually works for Fortran's
character type: blank fill with no null or newline termination.

Differential Revision: https://reviews.llvm.org/D122036

Added: 
    

Modified: 
    flang/runtime/descriptor-io.h
    flang/runtime/io-api.cpp
    flang/runtime/io-error.cpp
    flang/runtime/io-error.h
    flang/runtime/io-stmt.cpp
    flang/runtime/io-stmt.h

Removed: 
    


################################################################################
diff  --git a/flang/runtime/descriptor-io.h b/flang/runtime/descriptor-io.h
index f6528762795a2..7e098d8cfca99 100644
--- a/flang/runtime/descriptor-io.h
+++ b/flang/runtime/descriptor-io.h
@@ -372,6 +372,10 @@ static bool UnformattedDescriptorIO(
 
 template <Direction DIR>
 static bool DescriptorIO(IoStatementState &io, const Descriptor &descriptor) {
+  IoErrorHandler &handler{io.GetIoErrorHandler()};
+  if (handler.InError()) {
+    return false;
+  }
   if (!io.get_if<IoDirectionState<DIR>>()) {
     io.GetIoErrorHandler().Crash(
         "DescriptorIO() called for wrong I/O direction");
@@ -385,7 +389,6 @@ static bool DescriptorIO(IoStatementState &io, const Descriptor &descriptor) {
   if (!io.get_if<FormattedIoStatementState<DIR>>()) {
     return UnformattedDescriptorIO<DIR>(io, descriptor);
   }
-  IoErrorHandler &handler{io.GetIoErrorHandler()};
   if (auto catAndKind{descriptor.type().GetCategoryAndKind()}) {
     TypeCategory cat{catAndKind->first};
     int kind{catAndKind->second};

diff  --git a/flang/runtime/io-api.cpp b/flang/runtime/io-api.cpp
index 8733463e63083..70bd1480a2662 100644
--- a/flang/runtime/io-api.cpp
+++ b/flang/runtime/io-api.cpp
@@ -646,6 +646,9 @@ bool IONAME(SetAccess)(Cookie cookie, const char *keyword, std::size_t length) {
   if (!open) {
     io.GetIoErrorHandler().Crash(
         "SetAccess() called when not in an OPEN statement");
+  } else if (open->completedOperation()) {
+    io.GetIoErrorHandler().Crash(
+        "SetAccess() called after GetNewUnit() for an OPEN statement");
   }
   static const char *keywords[]{
       "SEQUENTIAL", "DIRECT", "STREAM", "APPEND", nullptr};
@@ -675,6 +678,9 @@ bool IONAME(SetAction)(Cookie cookie, const char *keyword, std::size_t length) {
   if (!open) {
     io.GetIoErrorHandler().Crash(
         "SetAction() called when not in an OPEN statement");
+  } else if (open->completedOperation()) {
+    io.GetIoErrorHandler().Crash(
+        "SetAction() called after GetNewUnit() for an OPEN statement");
   }
   std::optional<Action> action;
   static const char *keywords[]{"READ", "WRITE", "READWRITE", nullptr};
@@ -711,6 +717,9 @@ bool IONAME(SetAsynchronous)(
   if (!open) {
     io.GetIoErrorHandler().Crash(
         "SetAsynchronous() called when not in an OPEN statement");
+  } else if (open->completedOperation()) {
+    io.GetIoErrorHandler().Crash(
+        "SetAsynchronous() called after GetNewUnit() for an OPEN statement");
   }
   static const char *keywords[]{"YES", "NO", nullptr};
   switch (IdentifyValue(keyword, length, keywords)) {
@@ -734,6 +743,9 @@ bool IONAME(SetCarriagecontrol)(
   if (!open) {
     io.GetIoErrorHandler().Crash(
         "SetCarriageControl() called when not in an OPEN statement");
+  } else if (open->completedOperation()) {
+    io.GetIoErrorHandler().Crash(
+        "SetCarriageControl() called after GetNewUnit() for an OPEN statement");
   }
   static const char *keywords[]{"LIST", "FORTRAN", "NONE", nullptr};
   switch (IdentifyValue(keyword, length, keywords)) {
@@ -759,6 +771,9 @@ bool IONAME(SetConvert)(
   if (!open) {
     io.GetIoErrorHandler().Crash(
         "SetConvert() called when not in an OPEN statement");
+  } else if (open->completedOperation()) {
+    io.GetIoErrorHandler().Crash(
+        "SetConvert() called after GetNewUnit() for an OPEN statement");
   }
   if (auto convert{GetConvertFromString(keyword, length)}) {
     open->set_convert(*convert);
@@ -777,6 +792,9 @@ bool IONAME(SetEncoding)(
   if (!open) {
     io.GetIoErrorHandler().Crash(
         "SetEncoding() called when not in an OPEN statement");
+  } else if (open->completedOperation()) {
+    io.GetIoErrorHandler().Crash(
+        "SetEncoding() called after GetNewUnit() for an OPEN statement");
   }
   bool isUTF8{false};
   static const char *keywords[]{"UTF-8", "DEFAULT", nullptr};
@@ -806,6 +824,9 @@ bool IONAME(SetForm)(Cookie cookie, const char *keyword, std::size_t length) {
   if (!open) {
     io.GetIoErrorHandler().Crash(
         "SetForm() called when not in an OPEN statement");
+  } else if (open->completedOperation()) {
+    io.GetIoErrorHandler().Crash(
+        "SetForm() called after GetNewUnit() for an OPEN statement");
   }
   static const char *keywords[]{"FORMATTED", "UNFORMATTED", nullptr};
   switch (IdentifyValue(keyword, length, keywords)) {
@@ -829,6 +850,9 @@ bool IONAME(SetPosition)(
   if (!open) {
     io.GetIoErrorHandler().Crash(
         "SetPosition() called when not in an OPEN statement");
+  } else if (open->completedOperation()) {
+    io.GetIoErrorHandler().Crash(
+        "SetPosition() called after GetNewUnit() for an OPEN statement");
   }
   static const char *positions[]{"ASIS", "REWIND", "APPEND", nullptr};
   switch (IdentifyValue(keyword, length, positions)) {
@@ -854,6 +878,9 @@ bool IONAME(SetRecl)(Cookie cookie, std::size_t n) {
   if (!open) {
     io.GetIoErrorHandler().Crash(
         "SetRecl() called when not in an OPEN statement");
+  } else if (open->completedOperation()) {
+    io.GetIoErrorHandler().Crash(
+        "SetRecl() called after GetNewUnit() for an OPEN statement");
   }
   if (n <= 0) {
     io.GetIoErrorHandler().SignalError("RECL= must be greater than zero");
@@ -871,6 +898,10 @@ bool IONAME(SetRecl)(Cookie cookie, std::size_t n) {
 bool IONAME(SetStatus)(Cookie cookie, const char *keyword, std::size_t length) {
   IoStatementState &io{*cookie};
   if (auto *open{io.get_if<OpenStatementState>()}) {
+    if (open->completedOperation()) {
+      io.GetIoErrorHandler().Crash(
+          "SetStatus() called after GetNewUnit() for an OPEN statement");
+    }
     static const char *statuses[]{
         "OLD", "NEW", "SCRATCH", "REPLACE", "UNKNOWN", nullptr};
     switch (IdentifyValue(keyword, length, statuses)) {
@@ -920,6 +951,10 @@ bool IONAME(SetStatus)(Cookie cookie, const char *keyword, std::size_t length) {
 bool IONAME(SetFile)(Cookie cookie, const char *path, std::size_t chars) {
   IoStatementState &io{*cookie};
   if (auto *open{io.get_if<OpenStatementState>()}) {
+    if (open->completedOperation()) {
+      io.GetIoErrorHandler().Crash(
+          "SetFile() called after GetNewUnit() for an OPEN statement");
+    }
     open->set_path(path, chars);
     return true;
   }
@@ -934,6 +969,12 @@ bool IONAME(GetNewUnit)(Cookie cookie, int &unit, int kind) {
   if (!open) {
     io.GetIoErrorHandler().Crash(
         "GetNewUnit() called when not in an OPEN statement");
+  } else if (!open->InError()) {
+    open->CompleteOperation();
+  }
+  if (open->InError()) {
+    // A failed OPEN(NEWUNIT=n) does not modify 'n'
+    return false;
   }
   std::int64_t result{open->unit().unitNumber()};
   if (!SetInteger(unit, kind, result)) {
@@ -971,16 +1012,17 @@ bool IONAME(OutputUnformattedBlock)(Cookie cookie, const char *x,
 bool IONAME(InputUnformattedBlock)(
     Cookie cookie, char *x, std::size_t length, std::size_t elementBytes) {
   IoStatementState &io{*cookie};
+  IoErrorHandler &handler{io.GetIoErrorHandler()};
   io.BeginReadingRecord();
-  if (io.GetIoErrorHandler().InError()) {
+  if (handler.InError()) {
     return false;
   }
   if (auto *unf{
           io.get_if<ExternalUnformattedIoStatementState<Direction::Input>>()}) {
     return unf->Receive(x, length, elementBytes);
   }
-  io.GetIoErrorHandler().Crash("InputUnformattedBlock() called for an I/O "
-                               "statement that is not unformatted output");
+  handler.Crash("InputUnformattedBlock() called for an I/O statement that is "
+                "not unformatted input");
   return false;
 }
 
@@ -1157,27 +1199,39 @@ bool IONAME(InputLogical)(Cookie cookie, bool &truth) {
 
 std::size_t IONAME(GetSize)(Cookie cookie) {
   IoStatementState &io{*cookie};
+  IoErrorHandler &handler{io.GetIoErrorHandler()};
+  if (!handler.InError()) {
+    io.CompleteOperation();
+  }
   if (const auto *formatted{
           io.get_if<FormattedIoStatementState<Direction::Input>>()}) {
     return formatted->GetEditDescriptorChars();
   }
-  io.GetIoErrorHandler().Crash(
+  handler.Crash(
       "GetIoSize() called for an I/O statement that is not a formatted READ()");
   return 0;
 }
 
 std::size_t IONAME(GetIoLength)(Cookie cookie) {
   IoStatementState &io{*cookie};
+  IoErrorHandler &handler{io.GetIoErrorHandler()};
+  if (!handler.InError()) {
+    io.CompleteOperation();
+  }
   if (const auto *inq{io.get_if<InquireIOLengthState>()}) {
     return inq->bytes();
   }
-  io.GetIoErrorHandler().Crash("GetIoLength() called for an I/O statement that "
-                               "is not INQUIRE(IOLENGTH=)");
+  handler.Crash("GetIoLength() called for an I/O statement that is not "
+                "INQUIRE(IOLENGTH=)");
   return 0;
 }
 
 void IONAME(GetIoMsg)(Cookie cookie, char *msg, std::size_t length) {
-  IoErrorHandler &handler{cookie->GetIoErrorHandler()};
+  IoStatementState &io{*cookie};
+  IoErrorHandler &handler{io.GetIoErrorHandler()};
+  if (!handler.InError()) {
+    io.CompleteOperation();
+  }
   if (handler.InError()) { // leave "msg" alone when no error
     handler.GetIoMsg(msg, length);
   }

diff  --git a/flang/runtime/io-error.cpp b/flang/runtime/io-error.cpp
index e139e0649e503..790c579e1c43c 100644
--- a/flang/runtime/io-error.cpp
+++ b/flang/runtime/io-error.cpp
@@ -27,7 +27,7 @@ void IoErrorHandler::SignalError(int iostatOrErrno, const char *msg, ...) {
       ioStat_ = IostatEor; // least priority
     }
   } else if (iostatOrErrno != IostatOk) {
-    if (flags_ & (hasIoStat | hasErr)) {
+    if (flags_ & (hasIoStat | hasIoMsg | hasErr)) {
       if (ioStat_ <= 0) {
         ioStat_ = iostatOrErrno; // priority over END=/EOR=
         if (msg && (flags_ & hasIoMsg)) {
@@ -75,41 +75,57 @@ void IoErrorHandler::SignalEnd() { SignalError(IostatEnd); }
 
 void IoErrorHandler::SignalEor() { SignalError(IostatEor); }
 
+void IoErrorHandler::SignalPendingError() {
+  int error{pendingError_};
+  pendingError_ = IostatOk;
+  SignalError(error);
+}
+
 bool IoErrorHandler::GetIoMsg(char *buffer, std::size_t bufferLength) {
   const char *msg{ioMsg_.get()};
   if (!msg) {
-    msg = IostatErrorString(ioStat_);
+    msg = IostatErrorString(ioStat_ == IostatOk ? pendingError_ : ioStat_);
   }
   if (msg) {
     ToFortranDefaultCharacter(buffer, bufferLength, msg);
     return true;
   }
 
-  char *newBuf;
   // Following code is taken from llvm/lib/Support/Errno.cpp
-  // in LLVM v9.0.1
+  // in LLVM v9.0.1 with inadequate modification for Fortran,
+  // since rectified.
+  bool ok{false};
 #if HAVE_STRERROR_R
   // strerror_r is thread-safe.
 #if defined(__GLIBC__) && defined(_GNU_SOURCE)
   // glibc defines its own incompatible version of strerror_r
   // which may not use the buffer supplied.
-  newBuf = ::strerror_r(ioStat_, buffer, bufferLength);
+  msg = ::strerror_r(ioStat_, buffer, bufferLength);
 #else
-  return ::strerror_r(ioStat_, buffer, bufferLength) == 0;
+  ok = ::strerror_r(ioStat_, buffer, bufferLength) == 0;
 #endif
 #elif HAVE_DECL_STRERROR_S // "Windows Secure API"
-  return ::strerror_s(buffer, bufferLength, ioStat_) == 0;
+  ok = ::strerror_s(buffer, bufferLength, ioStat_) == 0;
 #elif HAVE_STRERROR
   // Copy the thread un-safe result of strerror into
   // the buffer as fast as possible to minimize impact
   // of collision of strerror in multiple threads.
-  newBuf = strerror(ioStat_);
+  msg = strerror(ioStat_);
 #else
   // Strange that this system doesn't even have strerror
   return false;
 #endif
-  ::strncpy(buffer, newBuf, bufferLength - 1);
-  buffer[bufferLength - 1] = '\n';
-  return true;
+  if (msg) {
+    ToFortranDefaultCharacter(buffer, bufferLength, msg);
+    return true;
+  } else if (ok) {
+    std::size_t copied{std::strlen(buffer)};
+    if (copied < bufferLength) {
+      std::memset(buffer + copied, ' ', bufferLength - copied);
+    }
+    return true;
+  } else {
+    return false;
+  }
 }
 } // namespace Fortran::runtime::io

diff  --git a/flang/runtime/io-error.h b/flang/runtime/io-error.h
index c530a120c381e..4186ee68af084 100644
--- a/flang/runtime/io-error.h
+++ b/flang/runtime/io-error.h
@@ -36,7 +36,15 @@ class IoErrorHandler : public Terminator {
     flags_ = hasIoStat | hasErr | hasEnd | hasEor | hasIoMsg;
   }
 
-  bool InError() const { return ioStat_ != IostatOk; }
+  bool InError() const {
+    return ioStat_ != IostatOk || pendingError_ != IostatOk;
+  }
+
+  // For I/O statements that detect fatal errors in their
+  // Begin...() API routines before it is known whether they
+  // have error handling control list items.  Such statements
+  // have an ErroneousIoStatementState with a pending error.
+  void SetPendingError(int iostat) { pendingError_ = iostat; }
 
   void SignalError(int iostatOrErrno, const char *msg, ...);
   void SignalError(int iostatOrErrno);
@@ -49,6 +57,7 @@ class IoErrorHandler : public Terminator {
   void SignalErrno(); // SignalError(errno)
   void SignalEnd(); // input only; EOF on internal write is an error
   void SignalEor(); // non-advancing input only; EOR on write is an error
+  void SignalPendingError();
 
   int GetIoStat() const { return ioStat_; }
   bool GetIoMsg(char *, std::size_t);
@@ -64,6 +73,7 @@ class IoErrorHandler : public Terminator {
   std::uint8_t flags_{0};
   int ioStat_{IostatOk};
   OwningPtr<char> ioMsg_;
+  int pendingError_{IostatOk};
 };
 
 } // namespace Fortran::runtime::io

diff  --git a/flang/runtime/io-stmt.cpp b/flang/runtime/io-stmt.cpp
index 9415029201eae..1a8b06068802d 100644
--- a/flang/runtime/io-stmt.cpp
+++ b/flang/runtime/io-stmt.cpp
@@ -19,8 +19,6 @@
 
 namespace Fortran::runtime::io {
 
-int IoStatementBase::EndIoStatement() { return GetIoStat(); }
-
 bool IoStatementBase::Emit(const char *, std::size_t, std::size_t) {
   return false;
 }
@@ -163,10 +161,18 @@ InternalFormattedIoStatementState<DIR, CHAR>::InternalFormattedIoStatementState(
       ioStatementState_{*this}, format_{*this, format, formatLength} {}
 
 template <Direction DIR, typename CHAR>
-int InternalFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
-  if constexpr (DIR == Direction::Output) {
-    format_.Finish(*this); // ignore any remaining input positioning actions
+void InternalFormattedIoStatementState<DIR, CHAR>::CompleteOperation() {
+  if (!this->completedOperation()) {
+    if constexpr (DIR == Direction::Output) {
+      format_.Finish(*this); // ignore any remaining input positioning actions
+    }
+    IoStatementBase::CompleteOperation();
   }
+}
+
+template <Direction DIR, typename CHAR>
+int InternalFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
+  CompleteOperation();
   return InternalIoStatementState<DIR, CHAR>::EndIoStatement();
 }
 
@@ -191,12 +197,19 @@ MutableModes &ExternalIoStatementBase::mutableModes() { return unit_.modes; }
 
 ConnectionState &ExternalIoStatementBase::GetConnectionState() { return unit_; }
 
-int ExternalIoStatementBase::EndIoStatement() {
-  if (mutableModes().nonAdvancing) {
-    unit_.leftTabLimit = unit_.furthestPositionInRecord;
-  } else {
-    unit_.leftTabLimit.reset();
+void ExternalIoStatementBase::CompleteOperation() {
+  if (!completedOperation()) {
+    if (mutableModes().nonAdvancing) {
+      unit_.leftTabLimit = unit_.furthestPositionInRecord;
+    } else {
+      unit_.leftTabLimit.reset();
+    }
+    IoStatementBase::CompleteOperation();
   }
+}
+
+int ExternalIoStatementBase::EndIoStatement() {
+  CompleteOperation();
   auto result{IoStatementBase::EndIoStatement()};
   unit_.EndIoStatement(); // annihilates *this in unit_.u_
   return result;
@@ -207,7 +220,10 @@ void OpenStatementState::set_path(const char *path, std::size_t length) {
   path_ = SaveDefaultCharacter(path, pathLength_, *this);
 }
 
-int OpenStatementState::EndIoStatement() {
+void OpenStatementState::CompleteOperation() {
+  if (completedOperation()) {
+    return;
+  }
   if (position_) {
     if (access_ && *access_ == Access::Direct) {
       SignalError("POSITION= may not be set with ACCESS='DIRECT'");
@@ -246,17 +262,33 @@ int OpenStatementState::EndIoStatement() {
     // Set default format (C.7.4 point 2).
     unit().isUnformatted = unit().access != Access::Sequential;
   }
+  if (!wasExtant_ && InError()) {
+    // Release the new unit on failure
+    unit().CloseUnit(CloseStatus::Delete, *this);
+    unit().DestroyClosed();
+  }
+  IoStatementBase::CompleteOperation();
+}
+
+int OpenStatementState::EndIoStatement() {
+  CompleteOperation();
   return ExternalIoStatementBase::EndIoStatement();
 }
 
 int CloseStatementState::EndIoStatement() {
+  CompleteOperation();
   int result{ExternalIoStatementBase::EndIoStatement()};
   unit().CloseUnit(status_, *this);
   unit().DestroyClosed();
   return result;
 }
 
+void NoUnitIoStatementState::CompleteOperation() {
+  IoStatementBase::CompleteOperation();
+}
+
 int NoUnitIoStatementState::EndIoStatement() {
+  CompleteOperation();
   auto result{IoStatementBase::EndIoStatement()};
   FreeMemory(this);
   return result;
@@ -277,7 +309,11 @@ ExternalIoStatementState<DIR>::ExternalIoStatementState(
   }
 }
 
-template <Direction DIR> int ExternalIoStatementState<DIR>::EndIoStatement() {
+template <Direction DIR>
+void ExternalIoStatementState<DIR>::CompleteOperation() {
+  if (completedOperation()) {
+    return;
+  }
   if constexpr (DIR == Direction::Input) {
     BeginReadingRecord(); // in case there were no I/O items
     if (!mutableModes().nonAdvancing || GetIoStat() == IostatEor) {
@@ -289,6 +325,11 @@ template <Direction DIR> int ExternalIoStatementState<DIR>::EndIoStatement() {
     }
     unit().FlushIfTerminal(*this);
   }
+  return ExternalIoStatementBase::CompleteOperation();
+}
+
+template <Direction DIR> int ExternalIoStatementState<DIR>::EndIoStatement() {
+  CompleteOperation();
   return ExternalIoStatementBase::EndIoStatement();
 }
 
@@ -391,11 +432,20 @@ ExternalFormattedIoStatementState<DIR, CHAR>::ExternalFormattedIoStatementState(
       format_{*this, format, formatLength} {}
 
 template <Direction DIR, typename CHAR>
-int ExternalFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
+void ExternalFormattedIoStatementState<DIR, CHAR>::CompleteOperation() {
+  if (this->completedOperation()) {
+    return;
+  }
   if constexpr (DIR == Direction::Input) {
     this->BeginReadingRecord(); // in case there were no I/O items
   }
   format_.Finish(*this);
+  return ExternalIoStatementState<DIR>::CompleteOperation();
+}
+
+template <Direction DIR, typename CHAR>
+int ExternalFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
+  CompleteOperation();
   return ExternalIoStatementState<DIR>::EndIoStatement();
 }
 
@@ -448,6 +498,10 @@ void IoStatementState::HandleAbsolutePosition(std::int64_t n) {
   std::visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
 }
 
+void IoStatementState::CompleteOperation() {
+  std::visit([](auto &x) { x.get().CompleteOperation(); }, u_);
+}
+
 int IoStatementState::EndIoStatement() {
   return std::visit([](auto &x) { return x.get().EndIoStatement(); }, u_);
 }
@@ -775,7 +829,12 @@ ExternalFileUnit *ChildIoStatementState<DIR>::GetExternalFileUnit() const {
   return child_.parent().GetExternalFileUnit();
 }
 
+template <Direction DIR> void ChildIoStatementState<DIR>::CompleteOperation() {
+  IoStatementBase::CompleteOperation();
+}
+
 template <Direction DIR> int ChildIoStatementState<DIR>::EndIoStatement() {
+  CompleteOperation();
   auto result{IoStatementBase::EndIoStatement()};
   child_.EndIoStatement(); // annihilates *this in child_.u_
   return result;
@@ -825,9 +884,17 @@ ChildFormattedIoStatementState<DIR, CHAR>::ChildFormattedIoStatementState(
       mutableModes_{child.parent().mutableModes()}, format_{*this, format,
                                                         formatLength} {}
 
+template <Direction DIR, typename CHAR>
+void ChildFormattedIoStatementState<DIR, CHAR>::CompleteOperation() {
+  if (!this->completedOperation()) {
+    format_.Finish(*this);
+    ChildIoStatementState<DIR>::CompleteOperation();
+  }
+}
+
 template <Direction DIR, typename CHAR>
 int ChildFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
-  format_.Finish(*this);
+  CompleteOperation();
   return ChildIoStatementState<DIR>::EndIoStatement();
 }
 
@@ -865,7 +932,10 @@ template class ChildListIoStatementState<Direction::Input>;
 template class ChildUnformattedIoStatementState<Direction::Output>;
 template class ChildUnformattedIoStatementState<Direction::Input>;
 
-int ExternalMiscIoStatementState::EndIoStatement() {
+void ExternalMiscIoStatementState::CompleteOperation() {
+  if (completedOperation()) {
+    return;
+  }
   ExternalFileUnit &ext{unit()};
   switch (which_) {
   case Flush:
@@ -882,6 +952,11 @@ int ExternalMiscIoStatementState::EndIoStatement() {
     ext.Rewind(*this);
     break;
   }
+  return ExternalIoStatementBase::CompleteOperation();
+}
+
+int ExternalMiscIoStatementState::EndIoStatement() {
+  CompleteOperation();
   return ExternalIoStatementBase::EndIoStatement();
 }
 
@@ -1366,7 +1441,7 @@ bool InquireIOLengthState::Emit(const char32_t *p, std::size_t n) {
 }
 
 int ErroneousIoStatementState::EndIoStatement() {
-  SignalError(iostat_);
+  SignalPendingError();
   return IoStatementBase::EndIoStatement();
 }
 

diff  --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index fef1e261cde29..2c43151296b8a 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -77,7 +77,15 @@ class IoStatementState {
   // to interact with the state of the I/O statement in progress.
   // This design avoids virtual member functions and function pointers,
   // which may not have good support in some runtime environments.
+
+  // CompleteOperation() is the last opportunity to raise an I/O error.
+  // It is called by EndIoStatement(), but it can be invoked earlier to
+  // catch errors for (e.g.) GetIoMsg() and GetNewUnit().  If called
+  // more than once, it is a no-op.
+  void CompleteOperation();
+  // Completes an I/O statement and reclaims storage.
   int EndIoStatement();
+
   bool Emit(const char *, std::size_t, std::size_t elementBytes);
   bool Emit(const char *, std::size_t);
   bool Emit(const char16_t *, std::size_t chars);
@@ -234,11 +242,16 @@ class IoStatementState {
 };
 
 // Base class for all per-I/O statement state classes.
-struct IoStatementBase : public IoErrorHandler {
+class IoStatementBase : public IoErrorHandler {
+public:
   using IoErrorHandler::IoErrorHandler;
 
+  bool completedOperation() const { return completedOperation_; }
+
+  void CompleteOperation() { completedOperation_ = true; }
+  int EndIoStatement() { return GetIoStat(); }
+
   // These are default no-op backstops that can be overridden by descendants.
-  int EndIoStatement();
   bool Emit(const char *, std::size_t, std::size_t elementBytes);
   bool Emit(const char *, std::size_t);
   bool Emit(const char16_t *, std::size_t chars);
@@ -260,6 +273,9 @@ struct IoStatementBase : public IoErrorHandler {
   bool Inquire(InquiryKeywordHash, std::int64_t &);
 
   void BadInquiryKeywordHashCrash(InquiryKeywordHash);
+
+protected:
+  bool completedOperation_{false};
 };
 
 // Common state for list-directed & NAMELIST I/O, both internal & external
@@ -354,6 +370,7 @@ class InternalFormattedIoStatementState
       std::size_t formatLength, const char *sourceFile = nullptr,
       int sourceLine = 0);
   IoStatementState &ioStatementState() { return ioStatementState_; }
+  void CompleteOperation();
   int EndIoStatement();
   std::optional<DataEdit> GetNextDataEdit(
       IoStatementState &, int maxRepeat = 1) {
@@ -392,6 +409,7 @@ class ExternalIoStatementBase : public IoStatementBase {
   ExternalFileUnit &unit() { return unit_; }
   MutableModes &mutableModes();
   ConnectionState &GetConnectionState();
+  void CompleteOperation();
   int EndIoStatement();
   ExternalFileUnit *GetExternalFileUnit() const { return &unit_; }
 
@@ -406,6 +424,7 @@ class ExternalIoStatementState : public ExternalIoStatementBase,
   ExternalIoStatementState(
       ExternalFileUnit &, const char *sourceFile = nullptr, int sourceLine = 0);
   MutableModes &mutableModes() { return mutableModes_; }
+  void CompleteOperation();
   int EndIoStatement();
   bool Emit(const char *, std::size_t, std::size_t elementBytes);
   bool Emit(const char *, std::size_t);
@@ -435,6 +454,7 @@ class ExternalFormattedIoStatementState
   ExternalFormattedIoStatementState(ExternalFileUnit &, const CharType *format,
       std::size_t formatLength, const char *sourceFile = nullptr,
       int sourceLine = 0);
+  void CompleteOperation();
   int EndIoStatement();
   std::optional<DataEdit> GetNextDataEdit(
       IoStatementState &, int maxRepeat = 1) {
@@ -471,6 +491,7 @@ class ChildIoStatementState : public IoStatementBase,
   MutableModes &mutableModes();
   ConnectionState &GetConnectionState();
   ExternalFileUnit *GetExternalFileUnit() const;
+  void CompleteOperation();
   int EndIoStatement();
   bool Emit(const char *, std::size_t, std::size_t elementBytes);
   bool Emit(const char *, std::size_t);
@@ -493,6 +514,7 @@ class ChildFormattedIoStatementState : public ChildIoStatementState<DIR>,
       std::size_t formatLength, const char *sourceFile = nullptr,
       int sourceLine = 0);
   MutableModes &mutableModes() { return mutableModes_; }
+  void CompleteOperation();
   int EndIoStatement();
   bool AdvanceRecord(int = 1);
   std::optional<DataEdit> GetNextDataEdit(
@@ -535,6 +557,8 @@ class OpenStatementState : public ExternalIoStatementBase {
   void set_convert(Convert convert) { convert_ = convert; } // CONVERT=
   void set_access(Access access) { access_ = access; } // ACCESS=
   void set_isUnformatted(bool yes = true) { isUnformatted_ = yes; } // FORM=
+
+  void CompleteOperation();
   int EndIoStatement();
 
 private:
@@ -567,6 +591,7 @@ class NoUnitIoStatementState : public IoStatementBase {
   IoStatementState &ioStatementState() { return ioStatementState_; }
   MutableModes &mutableModes() { return connection_.modes; }
   ConnectionState &GetConnectionState() { return connection_; }
+  void CompleteOperation();
   int EndIoStatement();
 
 protected:
@@ -674,6 +699,7 @@ class ExternalMiscIoStatementState : public ExternalIoStatementBase {
   ExternalMiscIoStatementState(ExternalFileUnit &unit, Which which,
       const char *sourceFile = nullptr, int sourceLine = 0)
       : ExternalIoStatementBase{unit, sourceFile, sourceLine}, which_{which} {}
+  void CompleteOperation();
   int EndIoStatement();
 
 private:
@@ -684,13 +710,14 @@ class ErroneousIoStatementState : public IoStatementBase {
 public:
   explicit ErroneousIoStatementState(
       Iostat iostat, const char *sourceFile = nullptr, int sourceLine = 0)
-      : IoStatementBase{sourceFile, sourceLine}, iostat_{iostat} {}
+      : IoStatementBase{sourceFile, sourceLine} {
+    SetPendingError(iostat);
+  }
   int EndIoStatement();
   ConnectionState &GetConnectionState() { return connection_; }
   MutableModes &mutableModes() { return connection_.modes; }
 
 private:
-  Iostat iostat_;
   ConnectionState connection_;
 };
 


        


More information about the flang-commits mailing list