[flang-commits] [flang] da25f96 - [flang] Runtime performance improvements to real formatted input

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Fri Nov 12 11:40:10 PST 2021


Author: Peter Klausler
Date: 2021-11-12T11:40:02-08:00
New Revision: da25f968a90ad4560fc920a6d18fc2a0221d2750

URL: https://github.com/llvm/llvm-project/commit/da25f968a90ad4560fc920a6d18fc2a0221d2750
DIFF: https://github.com/llvm/llvm-project/commit/da25f968a90ad4560fc920a6d18fc2a0221d2750.diff

LOG: [flang] Runtime performance improvements to real formatted input

Profiling a basic internal real input read benchmark shows some
hot spots in the code used to prepare input for decimal-to-binary
conversion, which is of course where the time should be spent.
The library that implements decimal to/from binary conversions has
been optimized, but not the code in the Fortran runtime that calls it,
and there are some obvious light changes worth making here.

Move some member functions from *.cpp files into the class definitions
of Descriptor and IoStatementState to enable inlining and specialization.

Make GetNextInputBytes() the new basic input API within the
runtime, replacing GetCurrentChar() -- which is rewritten in terms of
GetNextInputBytes -- so that input routines can have the
ability to acquire more than one input character at a time
and amortize overhead.

These changes speed up the time to read 1M random reals
using internal I/O from a character array from 1.29s to 0.54s
on my machine, which on par with Intel Fortran and much faster than
GNU Fortran.

Differential Revision: https://reviews.llvm.org/D113697

Added: 
    

Modified: 
    flang/include/flang/Decimal/decimal.h
    flang/include/flang/Runtime/descriptor.h
    flang/lib/Decimal/big-radix-floating-point.h
    flang/lib/Decimal/decimal-to-binary.cpp
    flang/runtime/descriptor.cpp
    flang/runtime/edit-input.cpp
    flang/runtime/internal-unit.cpp
    flang/runtime/internal-unit.h
    flang/runtime/io-stmt.cpp
    flang/runtime/io-stmt.h
    flang/runtime/unit.cpp
    flang/runtime/unit.h
    flang/unittests/Runtime/NumericalFormatTest.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Decimal/decimal.h b/flang/include/flang/Decimal/decimal.h
index 6891d39159f44..45d60d8eedb87 100644
--- a/flang/include/flang/Decimal/decimal.h
+++ b/flang/include/flang/Decimal/decimal.h
@@ -101,21 +101,21 @@ template <int PREC> struct ConversionToBinaryResult {
 };
 
 template <int PREC>
-ConversionToBinaryResult<PREC> ConvertToBinary(
-    const char *&, enum FortranRounding = RoundNearest);
+ConversionToBinaryResult<PREC> ConvertToBinary(const char *&,
+    enum FortranRounding = RoundNearest, const char *end = nullptr);
 
 extern template ConversionToBinaryResult<8> ConvertToBinary<8>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end = nullptr);
 extern template ConversionToBinaryResult<11> ConvertToBinary<11>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end = nullptr);
 extern template ConversionToBinaryResult<24> ConvertToBinary<24>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end = nullptr);
 extern template ConversionToBinaryResult<53> ConvertToBinary<53>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end = nullptr);
 extern template ConversionToBinaryResult<64> ConvertToBinary<64>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end = nullptr);
 extern template ConversionToBinaryResult<113> ConvertToBinary<113>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end = nullptr);
 } // namespace Fortran::decimal
 extern "C" {
 #define NS(x) Fortran::decimal::x

diff  --git a/flang/include/flang/Runtime/descriptor.h b/flang/include/flang/Runtime/descriptor.h
index 75c5e2176d929..dc2f2b77356d5 100644
--- a/flang/include/flang/Runtime/descriptor.h
+++ b/flang/include/flang/Runtime/descriptor.h
@@ -247,12 +247,51 @@ class Descriptor {
   // subscripts of the array, these wrap the subscripts around to
   // their first (or last) values and return false.
   bool IncrementSubscripts(
-      SubscriptValue[], const int *permutation = nullptr) const;
+      SubscriptValue subscript[], const int *permutation = nullptr) const {
+    for (int j{0}; j < raw_.rank; ++j) {
+      int k{permutation ? permutation[j] : j};
+      const Dimension &dim{GetDimension(k)};
+      if (subscript[k]++ < dim.UpperBound()) {
+        return true;
+      }
+      subscript[k] = dim.LowerBound();
+    }
+    return false;
+  }
+
   bool DecrementSubscripts(
       SubscriptValue[], const int *permutation = nullptr) const;
+
   // False when out of range.
-  bool SubscriptsForZeroBasedElementNumber(SubscriptValue *,
-      std::size_t elementNumber, const int *permutation = nullptr) const;
+  bool SubscriptsForZeroBasedElementNumber(SubscriptValue subscript[],
+      std::size_t elementNumber, const int *permutation = nullptr) const {
+    if (raw_.rank == 0) {
+      return elementNumber == 0;
+    }
+    std::size_t dimCoefficient[maxRank];
+    int k0{permutation ? permutation[0] : 0};
+    dimCoefficient[0] = 1;
+    auto coefficient{static_cast<std::size_t>(GetDimension(k0).Extent())};
+    for (int j{1}; j < raw_.rank; ++j) {
+      int k{permutation ? permutation[j] : j};
+      const Dimension &dim{GetDimension(k)};
+      dimCoefficient[j] = coefficient;
+      coefficient *= dim.Extent();
+    }
+    if (elementNumber >= coefficient) {
+      return false; // out of range
+    }
+    for (int j{raw_.rank - 1}; j > 0; --j) {
+      int k{permutation ? permutation[j] : j};
+      const Dimension &dim{GetDimension(k)};
+      std::size_t quotient{elementNumber / dimCoefficient[j]};
+      subscript[k] = quotient + dim.LowerBound();
+      elementNumber -= quotient * dimCoefficient[j];
+    }
+    subscript[k0] = elementNumber + GetDimension(k0).LowerBound();
+    return true;
+  }
+
   std::size_t ZeroBasedElementNumber(
       const SubscriptValue *, const int *permutation = nullptr) const;
 

diff  --git a/flang/lib/Decimal/big-radix-floating-point.h b/flang/lib/Decimal/big-radix-floating-point.h
index 4ae417cd9263e..32563235a76cb 100644
--- a/flang/lib/Decimal/big-radix-floating-point.h
+++ b/flang/lib/Decimal/big-radix-floating-point.h
@@ -87,7 +87,8 @@ template <int PREC, int LOG10RADIX = 16> class BigRadixFloatingPointNumber {
   // spaces.
   // The argument is a reference to a pointer that is left
   // pointing to the first character that wasn't parsed.
-  ConversionToBinaryResult<PREC> ConvertToBinary(const char *&);
+  ConversionToBinaryResult<PREC> ConvertToBinary(
+      const char *&, const char *end = nullptr);
 
   // Formats a decimal floating-point number to a user buffer.
   // May emit "NaN" or "Inf", or an possibly-signed integer.
@@ -337,7 +338,12 @@ template <int PREC, int LOG10RADIX = 16> class BigRadixFloatingPointNumber {
   // Returns true when the the result has effectively been rounded down.
   bool Mean(const BigRadixFloatingPointNumber &);
 
-  bool ParseNumber(const char *&, bool &inexact);
+  // Parses a floating-point number; leaves the pointer reference
+  // argument pointing at the next character after what was recognized.
+  // The "end" argument can be left null if the caller is sure that the
+  // string is properly terminated with an addressable character that
+  // can't be in a valid floating-point character.
+  bool ParseNumber(const char *&, bool &inexact, const char *end);
 
   using Raw = typename Real::RawType;
   constexpr Raw SignBit() const { return Raw{isNegative_} << (Real::bits - 1); }

diff  --git a/flang/lib/Decimal/decimal-to-binary.cpp b/flang/lib/Decimal/decimal-to-binary.cpp
index d6e30bed84c74..9e3fc5f882f00 100644
--- a/flang/lib/Decimal/decimal-to-binary.cpp
+++ b/flang/lib/Decimal/decimal-to-binary.cpp
@@ -19,10 +19,16 @@ namespace Fortran::decimal {
 
 template <int PREC, int LOG10RADIX>
 bool BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ParseNumber(
-    const char *&p, bool &inexact) {
+    const char *&p, bool &inexact, const char *end) {
   SetToZero();
-  while (*p == ' ') {
-    ++p;
+  if (end && p >= end) {
+    return false;
+  }
+  // Skip leading spaces
+  for (; p != end && *p == ' '; ++p) {
+  }
+  if (p == end) {
+    return false;
   }
   const char *q{p};
   isNegative_ = *q == '-';
@@ -30,23 +36,22 @@ bool BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ParseNumber(
     ++q;
   }
   const char *start{q};
-  while (*q == '0') {
-    ++q;
+  for (; q != end && *q == '0'; ++q) {
   }
-  const char *first{q};
-  for (; *q >= '0' && *q <= '9'; ++q) {
+  const char *firstDigit{q};
+  for (; q != end && *q >= '0' && *q <= '9'; ++q) {
   }
   const char *point{nullptr};
-  if (*q == '.') {
+  if (q != end && *q == '.') {
     point = q;
-    for (++q; *q >= '0' && *q <= '9'; ++q) {
+    for (++q; q != end && *q >= '0' && *q <= '9'; ++q) {
     }
   }
-  if (q == start || (q == start + 1 && *start == '.')) {
+  if (q == start || (q == start + 1 && start == point)) {
     return false; // require at least one digit
   }
   // There's a valid number here; set the reference argument to point to
-  // the first character afterward.
+  // the first character afterward, which might be an exponent part.
   p = q;
   // Strip off trailing zeroes
   if (point) {
@@ -59,13 +64,13 @@ bool BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ParseNumber(
     }
   }
   if (!point) {
-    while (q > first && q[-1] == '0') {
+    while (q > firstDigit && q[-1] == '0') {
       --q;
       ++exponent_;
     }
   }
   // Trim any excess digits
-  const char *limit{first + maxDigits * log10Radix + (point != nullptr)};
+  const char *limit{firstDigit + maxDigits * log10Radix + (point != nullptr)};
   if (q > limit) {
     inexact = true;
     if (point >= limit) {
@@ -80,11 +85,11 @@ bool BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ParseNumber(
   if (point) {
     exponent_ -= static_cast<int>(q - point - 1);
   }
-  if (q == first) {
+  if (q == firstDigit) {
     exponent_ = 0; // all zeros
   }
   // Rack the decimal digits up into big Digits.
-  for (auto times{radix}; q-- > first;) {
+  for (auto times{radix}; q-- > firstDigit;) {
     if (*q != '.') {
       if (times == radix) {
         digit_[digits_++] = *q - '0';
@@ -96,6 +101,9 @@ bool BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ParseNumber(
     }
   }
   // Look for an optional exponent field.
+  if (p == end) {
+    return true;
+  }
   q = p;
   switch (*q) {
   case 'e':
@@ -104,18 +112,20 @@ bool BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ParseNumber(
   case 'D':
   case 'q':
   case 'Q': {
-    bool negExpo{*++q == '-'};
+    if (++q == end) {
+      break;
+    }
+    bool negExpo{*q == '-'};
     if (*q == '-' || *q == '+') {
       ++q;
     }
-    if (*q >= '0' && *q <= '9') {
+    if (q != end && *q >= '0' && *q <= '9') {
       int expo{0};
-      while (*q == '0') {
-        ++q;
+      for (; q != end && *q == '0'; ++q) {
       }
       const char *expDig{q};
-      while (*q >= '0' && *q <= '9') {
-        expo = 10 * expo + *q++ - '0';
+      for (; q != end && *q >= '0' && *q <= '9'; ++q) {
+        expo = 10 * expo + *q - '0';
       }
       if (q >= expDig + 8) {
         // There's a ridiculous number of nonzero exponent digits.
@@ -125,7 +135,7 @@ bool BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ParseNumber(
         expo = 10 * Real::decimalRange;
         exponent_ = 0;
       }
-      p = q; // exponent was valid
+      p = q; // exponent is valid; advance the termination pointer
       if (negExpo) {
         exponent_ -= expo;
       } else {
@@ -385,9 +395,10 @@ BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ConvertToBinary() {
 
 template <int PREC, int LOG10RADIX>
 ConversionToBinaryResult<PREC>
-BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ConvertToBinary(const char *&p) {
+BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ConvertToBinary(
+    const char *&p, const char *limit) {
   bool inexact{false};
-  if (ParseNumber(p, inexact)) {
+  if (ParseNumber(p, inexact, limit)) {
     auto result{ConvertToBinary()};
     if (inexact) {
       result.flags =
@@ -422,22 +433,22 @@ BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ConvertToBinary(const char *&p) {
 
 template <int PREC>
 ConversionToBinaryResult<PREC> ConvertToBinary(
-    const char *&p, enum FortranRounding rounding) {
-  return BigRadixFloatingPointNumber<PREC>{rounding}.ConvertToBinary(p);
+    const char *&p, enum FortranRounding rounding, const char *end) {
+  return BigRadixFloatingPointNumber<PREC>{rounding}.ConvertToBinary(p, end);
 }
 
 template ConversionToBinaryResult<8> ConvertToBinary<8>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end);
 template ConversionToBinaryResult<11> ConvertToBinary<11>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end);
 template ConversionToBinaryResult<24> ConvertToBinary<24>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end);
 template ConversionToBinaryResult<53> ConvertToBinary<53>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end);
 template ConversionToBinaryResult<64> ConvertToBinary<64>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end);
 template ConversionToBinaryResult<113> ConvertToBinary<113>(
-    const char *&, enum FortranRounding);
+    const char *&, enum FortranRounding, const char *end);
 
 extern "C" {
 enum ConversionResultFlags ConvertDecimalToFloat(

diff  --git a/flang/runtime/descriptor.cpp b/flang/runtime/descriptor.cpp
index a5524bd55b228..d4f6de65965af 100644
--- a/flang/runtime/descriptor.cpp
+++ b/flang/runtime/descriptor.cpp
@@ -163,19 +163,6 @@ int Descriptor::Destroy(bool finalize) {
 
 int Descriptor::Deallocate() { return ISO::CFI_deallocate(&raw_); }
 
-bool Descriptor::IncrementSubscripts(
-    SubscriptValue *subscript, const int *permutation) const {
-  for (int j{0}; j < raw_.rank; ++j) {
-    int k{permutation ? permutation[j] : j};
-    const Dimension &dim{GetDimension(k)};
-    if (subscript[k]++ < dim.UpperBound()) {
-      return true;
-    }
-    subscript[k] = dim.LowerBound();
-  }
-  return false;
-}
-
 bool Descriptor::DecrementSubscripts(
     SubscriptValue *subscript, const int *permutation) const {
   for (int j{raw_.rank - 1}; j >= 0; --j) {
@@ -202,29 +189,6 @@ std::size_t Descriptor::ZeroBasedElementNumber(
   return result;
 }
 
-bool Descriptor::SubscriptsForZeroBasedElementNumber(SubscriptValue *subscript,
-    std::size_t elementNumber, const int *permutation) const {
-  std::size_t coefficient{1};
-  std::size_t dimCoefficient[maxRank];
-  for (int j{0}; j < raw_.rank; ++j) {
-    int k{permutation ? permutation[j] : j};
-    const Dimension &dim{GetDimension(k)};
-    dimCoefficient[j] = coefficient;
-    coefficient *= dim.Extent();
-  }
-  if (elementNumber >= coefficient) {
-    return false; // out of range
-  }
-  for (int j{raw_.rank - 1}; j >= 0; --j) {
-    int k{permutation ? permutation[j] : j};
-    const Dimension &dim{GetDimension(k)};
-    std::size_t quotient{elementNumber / dimCoefficient[j]};
-    subscript[k] = quotient + dim.LowerBound();
-    elementNumber -= quotient * dimCoefficient[j];
-  }
-  return true;
-}
-
 bool Descriptor::EstablishPointerSection(const Descriptor &source,
     const SubscriptValue *lower, const SubscriptValue *upper,
     const SubscriptValue *stride) {

diff  --git a/flang/runtime/edit-input.cpp b/flang/runtime/edit-input.cpp
index 19aa56a84932e..a14216d51deab 100644
--- a/flang/runtime/edit-input.cpp
+++ b/flang/runtime/edit-input.cpp
@@ -266,9 +266,68 @@ static int ScanRealInput(char *buffer, int bufferSize, IoStatementState &io,
   return got;
 }
 
+// If no special modes are in effect and the form of the input value
+// that's present in the input stream is acceptable to the decimal->binary
+// converter without modification, this fast path for real input
+// saves time by avoiding memory copies and reformatting of the exponent.
+template <int PRECISION>
+static bool TryFastPathRealInput(
+    IoStatementState &io, const DataEdit &edit, void *n) {
+  if (edit.modes.editingFlags & (blankZero | decimalComma)) {
+    return false;
+  }
+  if (edit.modes.scale != 0) {
+    return false;
+  }
+  const char *str{nullptr};
+  std::size_t got{io.GetNextInputBytes(str)};
+  if (got == 0 || str == nullptr ||
+      !io.GetConnectionState().recordLength.has_value()) {
+    return false; // could not access reliably-terminated input stream
+  }
+  const char *p{str};
+  std::int64_t maxConsume{
+      std::min<std::int64_t>(got, edit.width.value_or(got))};
+  const char *limit{str + maxConsume};
+  decimal::ConversionToBinaryResult<PRECISION> converted{
+      decimal::ConvertToBinary<PRECISION>(p, edit.modes.round, limit)};
+  if (converted.flags & decimal::Invalid) {
+    return false;
+  }
+  if (edit.digits.value_or(0) != 0 &&
+      std::memchr(str, '.', p - str) == nullptr) {
+    // No explicit decimal point, and edit descriptor is Fw.d (or other)
+    // with d != 0, which implies scaling.
+    return false;
+  }
+  for (; p < limit && (*p == ' ' || *p == '\t'); ++p) {
+  }
+  if (edit.descriptor == DataEdit::ListDirectedImaginaryPart) {
+    // Need a trailing ')'
+    if (p >= limit || *p != ')') {
+      return false;
+    }
+    for (++ ++p; p < limit && (*p == ' ' || *p == '\t'); ++p) {
+    }
+  }
+  if (p < limit) {
+    return false; // unconverted characters remain in field
+  }
+  // Success on the fast path!
+  // TODO: raise converted.flags as exceptions?
+  *reinterpret_cast<decimal::BinaryFloatingPointNumber<PRECISION> *>(n) =
+      converted.binary;
+  io.HandleRelativePosition(p - str);
+  return true;
+}
+
 template <int KIND>
 bool EditCommonRealInput(IoStatementState &io, const DataEdit &edit, void *n) {
   constexpr int binaryPrecision{common::PrecisionOfRealKind(KIND)};
+  if (TryFastPathRealInput<binaryPrecision>(io, edit, n)) {
+    return true;
+  }
+  // Fast path wasn't available or didn't work; go the more general route
   static constexpr int maxDigits{
       common::MaxDecimalConversionDigits(binaryPrecision)};
   static constexpr int bufferSize{maxDigits + 18};
@@ -285,7 +344,38 @@ bool EditCommonRealInput(IoStatementState &io, const DataEdit &edit, void *n) {
   }
   bool hadExtra{got > maxDigits};
   if (exponent != 0) {
-    got += std::snprintf(&buffer[got], bufferSize - got, "e%d", exponent);
+    buffer[got++] = 'e';
+    if (exponent < 0) {
+      buffer[got++] = '-';
+      exponent = -exponent;
+    }
+    if (exponent > 9999) {
+      exponent = 9999; // will convert to +/-Inf
+    }
+    if (exponent > 999) {
+      int dig{exponent / 1000};
+      buffer[got++] = '0' + dig;
+      int rest{exponent - 1000 * dig};
+      dig = rest / 100;
+      buffer[got++] = '0' + dig;
+      rest -= 100 * dig;
+      dig = rest / 10;
+      buffer[got++] = '0' + dig;
+      buffer[got++] = '0' + (rest - 10 * dig);
+    } else if (exponent > 99) {
+      int dig{exponent / 100};
+      buffer[got++] = '0' + dig;
+      int rest{exponent - 100 * dig};
+      dig = rest / 10;
+      buffer[got++] = '0' + dig;
+      buffer[got++] = '0' + (rest - 10 * dig);
+    } else if (exponent > 9) {
+      int dig{exponent / 10};
+      buffer[got++] = '0' + dig;
+      buffer[got++] = '0' + (exponent - 10 * dig);
+    } else {
+      buffer[got++] = '0' + exponent;
+    }
   }
   buffer[got] = '\0';
   const char *p{buffer};

diff  --git a/flang/runtime/internal-unit.cpp b/flang/runtime/internal-unit.cpp
index 1cd0909331a9c..ba274b6b5ace5 100644
--- a/flang/runtime/internal-unit.cpp
+++ b/flang/runtime/internal-unit.cpp
@@ -88,25 +88,39 @@ bool InternalDescriptorUnit<DIR>::Emit(
 }
 
 template <Direction DIR>
-std::optional<char32_t> InternalDescriptorUnit<DIR>::GetCurrentChar(
-    IoErrorHandler &handler) {
+std::size_t InternalDescriptorUnit<DIR>::GetNextInputBytes(
+    const char *&p, IoErrorHandler &handler) {
   if constexpr (DIR == Direction::Output) {
-    handler.Crash(
-        "InternalDescriptorUnit<Direction::Output>::GetCurrentChar() called");
-    return std::nullopt;
-  }
-  const char *record{CurrentRecord()};
-  if (!record) {
-    handler.SignalEnd();
-    return std::nullopt;
+    handler.Crash("InternalDescriptorUnit<Direction::Output>::"
+                  "GetNextInputBytes() called");
+    return 0;
+  } else {
+    const char *record{CurrentRecord()};
+    if (!record) {
+      handler.SignalEnd();
+      return 0;
+    } else if (positionInRecord >= recordLength.value_or(positionInRecord)) {
+      return 0;
+    } else {
+      p = &record[positionInRecord];
+      return *recordLength - positionInRecord;
+    }
   }
-  if (positionInRecord >= recordLength.value_or(positionInRecord)) {
+}
+
+template <Direction DIR>
+std::optional<char32_t> InternalDescriptorUnit<DIR>::GetCurrentChar(
+    IoErrorHandler &handler) {
+  const char *p{nullptr};
+  std::size_t bytes{GetNextInputBytes(p, handler)};
+  if (bytes == 0) {
     return std::nullopt;
+  } else {
+    if (isUTF8) {
+      // TODO: UTF-8 decoding
+    }
+    return *p;
   }
-  if (isUTF8) {
-    // TODO: UTF-8 decoding
-  }
-  return record[positionInRecord];
 }
 
 template <Direction DIR>

diff  --git a/flang/runtime/internal-unit.h b/flang/runtime/internal-unit.h
index a26bcf64d97d4..9fc9c86bd2ccf 100644
--- a/flang/runtime/internal-unit.h
+++ b/flang/runtime/internal-unit.h
@@ -31,6 +31,7 @@ template <Direction DIR> class InternalDescriptorUnit : public ConnectionState {
   void EndIoStatement();
 
   bool Emit(const char *, std::size_t, IoErrorHandler &);
+  std::size_t GetNextInputBytes(const char *&, IoErrorHandler &);
   std::optional<char32_t> GetCurrentChar(IoErrorHandler &);
   bool AdvanceRecord(IoErrorHandler &);
   void BackspaceRecord(IoErrorHandler &);

diff  --git a/flang/runtime/io-stmt.cpp b/flang/runtime/io-stmt.cpp
index 784d14496aac0..afb0935ac40ad 100644
--- a/flang/runtime/io-stmt.cpp
+++ b/flang/runtime/io-stmt.cpp
@@ -37,8 +37,9 @@ bool IoStatementBase::Emit(const char32_t *, std::size_t) {
   return false;
 }
 
-std::optional<char32_t> IoStatementBase::GetCurrentChar() {
-  return std::nullopt;
+std::size_t IoStatementBase::GetNextInputBytes(const char *&p) {
+  p = nullptr;
+  return 0;
 }
 
 bool IoStatementBase::AdvanceRecord(int) { return false; }
@@ -110,13 +111,9 @@ bool InternalIoStatementState<DIR, CHAR>::Emit(
 }
 
 template <Direction DIR, typename CHAR>
-std::optional<char32_t> InternalIoStatementState<DIR, CHAR>::GetCurrentChar() {
-  if constexpr (DIR == Direction::Output) {
-    Crash(
-        "InternalIoStatementState<Direction::Output>::GetCurrentChar() called");
-    return std::nullopt;
-  }
-  return unit_.GetCurrentChar(*this);
+std::size_t InternalIoStatementState<DIR, CHAR>::GetNextInputBytes(
+    const char *&p) {
+  return unit_.GetNextInputBytes(p, *this);
 }
 
 template <Direction DIR, typename CHAR>
@@ -326,12 +323,8 @@ bool ExternalIoStatementState<DIR>::Emit(
 }
 
 template <Direction DIR>
-std::optional<char32_t> ExternalIoStatementState<DIR>::GetCurrentChar() {
-  if constexpr (DIR == Direction::Output) {
-    Crash(
-        "ExternalIoStatementState<Direction::Output>::GetCurrentChar() called");
-  }
-  return unit().GetCurrentChar(*this);
+std::size_t ExternalIoStatementState<DIR>::GetNextInputBytes(const char *&p) {
+  return unit().GetNextInputBytes(p, *this);
 }
 
 template <Direction DIR>
@@ -424,8 +417,8 @@ bool IoStatementState::Receive(
       [=](auto &x) { return x.get().Receive(data, n, elementBytes); }, u_);
 }
 
-std::optional<char32_t> IoStatementState::GetCurrentChar() {
-  return std::visit([&](auto &x) { return x.get().GetCurrentChar(); }, u_);
+std::size_t IoStatementState::GetNextInputBytes(const char *&p) {
+  return std::visit([&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
 }
 
 bool IoStatementState::AdvanceRecord(int n) {
@@ -501,100 +494,6 @@ bool IoStatementState::EmitField(
   }
 }
 
-std::optional<char32_t> IoStatementState::PrepareInput(
-    const DataEdit &edit, std::optional<int> &remaining) {
-  remaining.reset();
-  if (edit.descriptor == DataEdit::ListDirected) {
-    GetNextNonBlank();
-  } else {
-    if (edit.width.value_or(0) > 0) {
-      remaining = *edit.width;
-    }
-    SkipSpaces(remaining);
-  }
-  return NextInField(remaining);
-}
-
-std::optional<char32_t> IoStatementState::SkipSpaces(
-    std::optional<int> &remaining) {
-  while (!remaining || *remaining > 0) {
-    if (auto ch{GetCurrentChar()}) {
-      if (*ch != ' ' && *ch != '\t') {
-        return ch;
-      }
-      HandleRelativePosition(1);
-      if (remaining) {
-        GotChar();
-        --*remaining;
-      }
-    } else {
-      break;
-    }
-  }
-  return std::nullopt;
-}
-
-std::optional<char32_t> IoStatementState::NextInField(
-    std::optional<int> &remaining) {
-  if (!remaining) { // list-directed or NAMELIST: check for separators
-    if (auto next{GetCurrentChar()}) {
-      switch (*next) {
-      case ' ':
-      case '\t':
-      case ',':
-      case ';':
-      case '/':
-      case '(':
-      case ')':
-      case '\'':
-      case '"':
-      case '*':
-      case '\n': // for stream access
-        break;
-      default:
-        HandleRelativePosition(1);
-        return next;
-      }
-    }
-  } else if (*remaining > 0) {
-    if (auto next{GetCurrentChar()}) {
-      --*remaining;
-      HandleRelativePosition(1);
-      GotChar();
-      return next;
-    }
-    const ConnectionState &connection{GetConnectionState()};
-    if (!connection.IsAtEOF() && connection.recordLength &&
-        connection.positionInRecord >= *connection.recordLength) {
-      IoErrorHandler &handler{GetIoErrorHandler()};
-      if (mutableModes().nonAdvancing) {
-        handler.SignalEor();
-      } else if (connection.isFixedRecordLength && !connection.modes.pad) {
-        handler.SignalError(IostatRecordReadOverrun);
-      }
-      if (connection.modes.pad) { // PAD='YES'
-        --*remaining;
-        return std::optional<char32_t>{' '};
-      }
-    }
-  }
-  return std::nullopt;
-}
-
-std::optional<char32_t> IoStatementState::GetNextNonBlank() {
-  auto ch{GetCurrentChar()};
-  bool inNamelist{GetConnectionState().modes.inNamelist};
-  while (!ch || *ch == ' ' || *ch == '\t' || (inNamelist && *ch == '!')) {
-    if (ch && (*ch == ' ' || *ch == '\t')) {
-      HandleRelativePosition(1);
-    } else if (!AdvanceRecord()) {
-      return std::nullopt;
-    }
-    ch = GetCurrentChar();
-  }
-  return ch;
-}
-
 bool IoStatementState::Inquire(
     InquiryKeywordHash inquiry, char *out, std::size_t chars) {
   return std::visit(
@@ -827,8 +726,8 @@ bool ChildIoStatementState<DIR>::Emit(const char32_t *data, std::size_t chars) {
 }
 
 template <Direction DIR>
-std::optional<char32_t> ChildIoStatementState<DIR>::GetCurrentChar() {
-  return child_.parent().GetCurrentChar();
+std::size_t ChildIoStatementState<DIR>::GetNextInputBytes(const char *&p) {
+  return child_.parent().GetNextInputBytes(p);
 }
 
 template <Direction DIR>

diff  --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index 0006cab2b2ae1..ca3a6db264570 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -82,7 +82,7 @@ class IoStatementState {
   bool Emit(const char16_t *, std::size_t chars);
   bool Emit(const char32_t *, std::size_t chars);
   bool Receive(char *, std::size_t, std::size_t elementBytes = 0);
-  std::optional<char32_t> GetCurrentChar(); // vacant after end of record
+  std::size_t GetNextInputBytes(const char *&);
   bool AdvanceRecord(int = 1);
   void BackspaceRecord();
   void HandleRelativePosition(std::int64_t);
@@ -113,6 +113,18 @@ class IoStatementState {
         u_);
   }
 
+  // Vacant after the end of the current record
+  std::optional<char32_t> GetCurrentChar() {
+    const char *p{nullptr};
+    std::size_t bytes{GetNextInputBytes(p)};
+    if (bytes == 0) {
+      return std::nullopt;
+    } else {
+      // TODO: UTF-8 decoding; may have to get more bytes in a loop
+      return *p;
+    }
+  }
+
   bool EmitRepeated(char, std::size_t);
   bool EmitField(const char *, std::size_t length, std::size_t width);
 
@@ -120,12 +132,97 @@ class IoStatementState {
   // Skip over leading blanks, then return the first non-blank character (if
   // any).
   std::optional<char32_t> PrepareInput(
-      const DataEdit &edit, std::optional<int> &remaining);
+      const DataEdit &edit, std::optional<int> &remaining) {
+    remaining.reset();
+    if (edit.descriptor == DataEdit::ListDirected) {
+      GetNextNonBlank();
+    } else {
+      if (edit.width.value_or(0) > 0) {
+        remaining = *edit.width;
+      }
+      SkipSpaces(remaining);
+    }
+    return NextInField(remaining);
+  }
+
+  std::optional<char32_t> SkipSpaces(std::optional<int> &remaining) {
+    while (!remaining || *remaining > 0) {
+      if (auto ch{GetCurrentChar()}) {
+        if (*ch != ' ' && *ch != '\t') {
+          return ch;
+        }
+        HandleRelativePosition(1);
+        if (remaining) {
+          GotChar();
+          --*remaining;
+        }
+      } else {
+        break;
+      }
+    }
+    return std::nullopt;
+  }
+
+  std::optional<char32_t> NextInField(std::optional<int> &remaining) {
+    if (!remaining) { // list-directed or NAMELIST: check for separators
+      if (auto next{GetCurrentChar()}) {
+        switch (*next) {
+        case ' ':
+        case '\t':
+        case ',':
+        case ';':
+        case '/':
+        case '(':
+        case ')':
+        case '\'':
+        case '"':
+        case '*':
+        case '\n': // for stream access
+          break;
+        default:
+          HandleRelativePosition(1);
+          return next;
+        }
+      }
+    } else if (*remaining > 0) {
+      if (auto next{GetCurrentChar()}) {
+        --*remaining;
+        HandleRelativePosition(1);
+        GotChar();
+        return next;
+      }
+      const ConnectionState &connection{GetConnectionState()};
+      if (!connection.IsAtEOF() && connection.recordLength &&
+          connection.positionInRecord >= *connection.recordLength) {
+        IoErrorHandler &handler{GetIoErrorHandler()};
+        if (mutableModes().nonAdvancing) {
+          handler.SignalEor();
+        } else if (connection.isFixedRecordLength && !connection.modes.pad) {
+          handler.SignalError(IostatRecordReadOverrun);
+        }
+        if (connection.modes.pad) { // PAD='YES'
+          --*remaining;
+          return std::optional<char32_t>{' '};
+        }
+      }
+    }
+    return std::nullopt;
+  }
 
-  std::optional<char32_t> SkipSpaces(std::optional<int> &remaining);
-  std::optional<char32_t> NextInField(std::optional<int> &remaining);
   // Skips spaces, advances records, and ignores NAMELIST comments
-  std::optional<char32_t> GetNextNonBlank();
+  std::optional<char32_t> GetNextNonBlank() {
+    auto ch{GetCurrentChar()};
+    bool inNamelist{GetConnectionState().modes.inNamelist};
+    while (!ch || *ch == ' ' || *ch == '\t' || (inNamelist && *ch == '!')) {
+      if (ch && (*ch == ' ' || *ch == '\t')) {
+        HandleRelativePosition(1);
+      } else if (!AdvanceRecord()) {
+        return std::nullopt;
+      }
+      ch = GetCurrentChar();
+    }
+    return ch;
+  }
 
   template <Direction D> void CheckFormattedStmtType(const char *name) {
     if (!get_if<FormattedIoStatementState<D>>()) {
@@ -182,7 +279,7 @@ struct IoStatementBase : public IoErrorHandler {
   bool Emit(const char16_t *, std::size_t chars);
   bool Emit(const char32_t *, std::size_t chars);
   bool Receive(char *, std::size_t, std::size_t elementBytes = 0);
-  std::optional<char32_t> GetCurrentChar();
+  std::size_t GetNextInputBytes(const char *&);
   bool AdvanceRecord(int);
   void BackspaceRecord();
   void HandleRelativePosition(std::int64_t);
@@ -264,8 +361,7 @@ class InternalIoStatementState : public IoStatementBase,
   using IoStatementBase::Emit;
   bool Emit(
       const CharType *data, std::size_t chars /* not necessarily bytes */);
-
-  std::optional<char32_t> GetCurrentChar();
+  std::size_t GetNextInputBytes(const char *&);
   bool AdvanceRecord(int = 1);
   void BackspaceRecord();
   ConnectionState &GetConnectionState() { return unit_; }
@@ -349,7 +445,7 @@ class ExternalIoStatementState : public ExternalIoStatementBase,
   bool Emit(const char *, std::size_t);
   bool Emit(const char16_t *, std::size_t chars /* not bytes */);
   bool Emit(const char32_t *, std::size_t chars /* not bytes */);
-  std::optional<char32_t> GetCurrentChar();
+  std::size_t GetNextInputBytes(const char *&);
   bool AdvanceRecord(int = 1);
   void BackspaceRecord();
   void HandleRelativePosition(std::int64_t);
@@ -414,7 +510,7 @@ class ChildIoStatementState : public IoStatementBase,
   bool Emit(const char *, std::size_t);
   bool Emit(const char16_t *, std::size_t chars /* not bytes */);
   bool Emit(const char32_t *, std::size_t chars /* not bytes */);
-  std::optional<char32_t> GetCurrentChar();
+  std::size_t GetNextInputBytes(const char *&);
   void HandleRelativePosition(std::int64_t);
   void HandleAbsolutePosition(std::int64_t);
 

diff  --git a/flang/runtime/unit.cpp b/flang/runtime/unit.cpp
index 829b471f3424a..50b70c001db30 100644
--- a/flang/runtime/unit.cpp
+++ b/flang/runtime/unit.cpp
@@ -317,14 +317,23 @@ bool ExternalFileUnit::Receive(char *data, std::size_t bytes,
   }
 }
 
+std::size_t ExternalFileUnit::GetNextInputBytes(
+    const char *&p, IoErrorHandler &handler) {
+  RUNTIME_CHECK(handler, direction_ == Direction::Input);
+  p = FrameNextInput(handler, 1);
+  return p ? recordLength.value_or(positionInRecord + 1) - positionInRecord : 0;
+}
+
 std::optional<char32_t> ExternalFileUnit::GetCurrentChar(
     IoErrorHandler &handler) {
-  RUNTIME_CHECK(handler, direction_ == Direction::Input);
-  if (const char *p{FrameNextInput(handler, 1)}) {
+  const char *p{nullptr};
+  std::size_t bytes{GetNextInputBytes(p, handler)};
+  if (bytes == 0) {
+    return std::nullopt;
+  } else {
     // TODO: UTF-8 decoding; may have to get more bytes in a loop
     return *p;
   }
-  return std::nullopt;
 }
 
 const char *ExternalFileUnit::FrameNextInput(

diff  --git a/flang/runtime/unit.h b/flang/runtime/unit.h
index 797c9fda73165..4f7a18ebc1650 100644
--- a/flang/runtime/unit.h
+++ b/flang/runtime/unit.h
@@ -76,6 +76,7 @@ class ExternalFileUnit : public ConnectionState,
   bool Emit(
       const char *, std::size_t, std::size_t elementBytes, IoErrorHandler &);
   bool Receive(char *, std::size_t, std::size_t elementBytes, IoErrorHandler &);
+  std::size_t GetNextInputBytes(const char *&, IoErrorHandler &);
   std::optional<char32_t> GetCurrentChar(IoErrorHandler &);
   void SetLeftTabLimit();
   bool BeginReadingRecord(IoErrorHandler &);

diff  --git a/flang/unittests/Runtime/NumericalFormatTest.cpp b/flang/unittests/Runtime/NumericalFormatTest.cpp
index ee60956e8efda..a1f0a976b38ee 100644
--- a/flang/unittests/Runtime/NumericalFormatTest.cpp
+++ b/flang/unittests/Runtime/NumericalFormatTest.cpp
@@ -710,7 +710,6 @@ TEST(IOApiTests, FormatDoubleInputValues) {
 
     // Ensure raw uint64 value matches expected conversion from double
     ASSERT_EQ(u.raw, want) << '\'' << format << "' failed reading '" << data
-                           << "', want 0x" << std::hex << want << ", got 0x"
-                           << u.raw;
+                           << "', want " << want << ", got " << u.raw;
   }
 }


        


More information about the flang-commits mailing list