[libc-commits] [libc] [libc] Add max length argument to decimal to float (PR #84091)

Michael Jones via libc-commits libc-commits at lists.llvm.org
Wed Mar 6 10:51:48 PST 2024


https://github.com/michaelrj-google updated https://github.com/llvm/llvm-project/pull/84091

>From 788026bb4b19df63c3136c13fde2aea3afcc633e Mon Sep 17 00:00:00 2001
From: Michael Jones <michaelrj at google.com>
Date: Tue, 5 Mar 2024 15:01:15 -0800
Subject: [PATCH 1/2] [libc] Add max length argument to decimal to float

The implementation for from_chars in libcxx is possibly going to use our
decimal to float utilities, but to do that we need to support limiting
the length of the string to be parsed. This patch adds support for that
length limiting to decimal_exp_to_float, as well as the functions it
calls (high precision decimal, str to integer).
---
 libc/src/__support/high_precision_decimal.h   | 116 +++++----
 libc/src/__support/str_to_float.h             |  40 +--
 libc/src/__support/str_to_integer.h           |  71 ++++--
 libc/test/src/CMakeLists.txt                  |   4 +-
 libc/test/src/__support/CMakeLists.txt        |  13 +
 .../__support/high_precision_decimal_test.cpp |  28 ++
 .../src/__support/str_to_integer_test.cpp     | 240 ++++++++++++++++++
 7 files changed, 417 insertions(+), 95 deletions(-)
 create mode 100644 libc/test/src/__support/str_to_integer_test.cpp

diff --git a/libc/src/__support/high_precision_decimal.h b/libc/src/__support/high_precision_decimal.h
index d29f8c4cd932f4..2c5a349e4495eb 100644
--- a/libc/src/__support/high_precision_decimal.h
+++ b/libc/src/__support/high_precision_decimal.h
@@ -9,6 +9,7 @@
 #ifndef LLVM_LIBC_SRC___SUPPORT_HIGH_PRECISION_DECIMAL_H
 #define LLVM_LIBC_SRC___SUPPORT_HIGH_PRECISION_DECIMAL_H
 
+#include "src/__support/CPP/limits.h"
 #include "src/__support/ctype_utils.h"
 #include "src/__support/str_to_integer.h"
 #include <stdint.h>
@@ -115,9 +116,10 @@ class HighPrecisionDecimal {
   uint8_t digits[MAX_NUM_DIGITS];
 
 private:
-  bool should_round_up(int32_t roundToDigit, RoundDirection round) {
-    if (roundToDigit < 0 ||
-        static_cast<uint32_t>(roundToDigit) >= this->num_digits) {
+  LIBC_INLINE bool should_round_up(int32_t round_to_digit,
+                                   RoundDirection round) {
+    if (round_to_digit < 0 ||
+        static_cast<uint32_t>(round_to_digit) >= this->num_digits) {
       return false;
     }
 
@@ -133,8 +135,8 @@ class HighPrecisionDecimal {
     // Else round to nearest.
 
     // If we're right in the middle and there are no extra digits
-    if (this->digits[roundToDigit] == 5 &&
-        static_cast<uint32_t>(roundToDigit + 1) == this->num_digits) {
+    if (this->digits[round_to_digit] == 5 &&
+        static_cast<uint32_t>(round_to_digit + 1) == this->num_digits) {
 
       // Round up if we've truncated (since that means the result is slightly
       // higher than what's represented.)
@@ -143,22 +145,22 @@ class HighPrecisionDecimal {
       }
 
       // If this exactly halfway, round to even.
-      if (roundToDigit == 0)
+      if (round_to_digit == 0)
         // When the input is ".5".
         return false;
-      return this->digits[roundToDigit - 1] % 2 != 0;
+      return this->digits[round_to_digit - 1] % 2 != 0;
     }
-    // If there are digits after roundToDigit, they must be non-zero since we
+    // If there are digits after round_to_digit, they must be non-zero since we
     // trim trailing zeroes after all operations that change digits.
-    return this->digits[roundToDigit] >= 5;
+    return this->digits[round_to_digit] >= 5;
   }
 
   // Takes an amount to left shift and returns the number of new digits needed
   // to store the result based on LEFT_SHIFT_DIGIT_TABLE.
-  uint32_t get_num_new_digits(uint32_t lShiftAmount) {
+  LIBC_INLINE uint32_t get_num_new_digits(uint32_t lshift_amount) {
     const char *power_of_five =
-        LEFT_SHIFT_DIGIT_TABLE[lShiftAmount].power_of_five;
-    uint32_t new_digits = LEFT_SHIFT_DIGIT_TABLE[lShiftAmount].new_digits;
+        LEFT_SHIFT_DIGIT_TABLE[lshift_amount].power_of_five;
+    uint32_t new_digits = LEFT_SHIFT_DIGIT_TABLE[lshift_amount].new_digits;
     uint32_t digit_index = 0;
     while (power_of_five[digit_index] != 0) {
       if (digit_index >= this->num_digits) {
@@ -176,7 +178,7 @@ class HighPrecisionDecimal {
   }
 
   // Trim all trailing 0s
-  void trim_trailing_zeroes() {
+  LIBC_INLINE void trim_trailing_zeroes() {
     while (this->num_digits > 0 && this->digits[this->num_digits - 1] == 0) {
       --this->num_digits;
     }
@@ -186,19 +188,19 @@ class HighPrecisionDecimal {
   }
 
   // Perform a digitwise binary non-rounding right shift on this value by
-  // shiftAmount. The shiftAmount can't be more than MAX_SHIFT_AMOUNT to prevent
-  // overflow.
-  void right_shift(uint32_t shiftAmount) {
+  // shift_amount. The shift_amount can't be more than MAX_SHIFT_AMOUNT to
+  // prevent overflow.
+  LIBC_INLINE void right_shift(uint32_t shift_amount) {
     uint32_t read_index = 0;
     uint32_t write_index = 0;
 
     uint64_t accumulator = 0;
 
-    const uint64_t shift_mask = (uint64_t(1) << shiftAmount) - 1;
+    const uint64_t shift_mask = (uint64_t(1) << shift_amount) - 1;
 
     // Warm Up phase: we don't have enough digits to start writing, so just
     // read them into the accumulator.
-    while (accumulator >> shiftAmount == 0) {
+    while (accumulator >> shift_amount == 0) {
       uint64_t read_digit = 0;
       // If there are still digits to read, read the next one, else the digit is
       // assumed to be 0.
@@ -217,7 +219,7 @@ class HighPrecisionDecimal {
     // read. Keep reading until we run out of digits.
     while (read_index < this->num_digits) {
       uint64_t read_digit = this->digits[read_index];
-      uint64_t write_digit = accumulator >> shiftAmount;
+      uint64_t write_digit = accumulator >> shift_amount;
       accumulator &= shift_mask;
       this->digits[write_index] = static_cast<uint8_t>(write_digit);
       accumulator = accumulator * 10 + read_digit;
@@ -228,7 +230,7 @@ class HighPrecisionDecimal {
     // Cool Down phase: All of the readable digits have been read, so just write
     // the remainder, while treating any more digits as 0.
     while (accumulator > 0) {
-      uint64_t write_digit = accumulator >> shiftAmount;
+      uint64_t write_digit = accumulator >> shift_amount;
       accumulator &= shift_mask;
       if (write_index < MAX_NUM_DIGITS) {
         this->digits[write_index] = static_cast<uint8_t>(write_digit);
@@ -243,10 +245,10 @@ class HighPrecisionDecimal {
   }
 
   // Perform a digitwise binary non-rounding left shift on this value by
-  // shiftAmount. The shiftAmount can't be more than MAX_SHIFT_AMOUNT to prevent
-  // overflow.
-  void left_shift(uint32_t shiftAmount) {
-    uint32_t new_digits = this->get_num_new_digits(shiftAmount);
+  // shift_amount. The shift_amount can't be more than MAX_SHIFT_AMOUNT to
+  // prevent overflow.
+  LIBC_INLINE void left_shift(uint32_t shift_amount) {
+    uint32_t new_digits = this->get_num_new_digits(shift_amount);
 
     int32_t read_index = this->num_digits - 1;
     uint32_t write_index = this->num_digits + new_digits;
@@ -260,7 +262,7 @@ class HighPrecisionDecimal {
     // writing.
     while (read_index >= 0) {
       accumulator += static_cast<uint64_t>(this->digits[read_index])
-                     << shiftAmount;
+                     << shift_amount;
       uint64_t next_accumulator = accumulator / 10;
       uint64_t write_digit = accumulator - (10 * next_accumulator);
       --write_index;
@@ -296,45 +298,52 @@ class HighPrecisionDecimal {
   }
 
 public:
-  // numString is assumed to be a string of numeric characters. It doesn't
+  // num_string is assumed to be a string of numeric characters. It doesn't
   // handle leading spaces.
-  HighPrecisionDecimal(const char *__restrict numString) {
+  LIBC_INLINE
+  HighPrecisionDecimal(
+      const char *__restrict num_string,
+      const size_t num_len = cpp::numeric_limits<size_t>::max()) {
     bool saw_dot = false;
+    size_t num_cur = 0;
     // This counts the digits in the number, even if there isn't space to store
     // them all.
     uint32_t total_digits = 0;
-    while (isdigit(*numString) || *numString == '.') {
-      if (*numString == '.') {
+    while (num_cur < num_len &&
+           (isdigit(num_string[num_cur]) || num_string[num_cur] == '.')) {
+      if (num_string[num_cur] == '.') {
         if (saw_dot) {
           break;
         }
         this->decimal_point = total_digits;
         saw_dot = true;
       } else {
-        if (*numString == '0' && this->num_digits == 0) {
+        if (num_string[num_cur] == '0' && this->num_digits == 0) {
           --this->decimal_point;
-          ++numString;
+          ++num_cur;
           continue;
         }
         ++total_digits;
         if (this->num_digits < MAX_NUM_DIGITS) {
           this->digits[this->num_digits] =
-              static_cast<uint8_t>(*numString - '0');
+              static_cast<uint8_t>(num_string[num_cur] - '0');
           ++this->num_digits;
-        } else if (*numString != '0') {
+        } else if (num_string[num_cur] != '0') {
           this->truncated = true;
         }
       }
-      ++numString;
+      ++num_cur;
     }
 
     if (!saw_dot)
       this->decimal_point = total_digits;
 
-    if ((*numString | 32) == 'e') {
-      ++numString;
-      if (isdigit(*numString) || *numString == '+' || *numString == '-') {
-        auto result = strtointeger<int32_t>(numString, 10);
+    if (num_cur < num_len && ((num_string[num_cur] | 32) == 'e')) {
+      ++num_cur;
+      if (isdigit(num_string[num_cur]) || num_string[num_cur] == '+' ||
+          num_string[num_cur] == '-') {
+        auto result =
+            strtointeger<int32_t>(num_string + num_cur, 10, num_len - num_cur);
         if (result.has_error()) {
           // TODO: handle error
         }
@@ -358,33 +367,34 @@ class HighPrecisionDecimal {
     this->trim_trailing_zeroes();
   }
 
-  // Binary shift left (shiftAmount > 0) or right (shiftAmount < 0)
-  void shift(int shiftAmount) {
-    if (shiftAmount == 0) {
+  // Binary shift left (shift_amount > 0) or right (shift_amount < 0)
+  LIBC_INLINE void shift(int shift_amount) {
+    if (shift_amount == 0) {
       return;
     }
     // Left
-    else if (shiftAmount > 0) {
-      while (static_cast<uint32_t>(shiftAmount) > MAX_SHIFT_AMOUNT) {
+    else if (shift_amount > 0) {
+      while (static_cast<uint32_t>(shift_amount) > MAX_SHIFT_AMOUNT) {
         this->left_shift(MAX_SHIFT_AMOUNT);
-        shiftAmount -= MAX_SHIFT_AMOUNT;
+        shift_amount -= MAX_SHIFT_AMOUNT;
       }
-      this->left_shift(shiftAmount);
+      this->left_shift(shift_amount);
     }
     // Right
     else {
-      while (static_cast<uint32_t>(shiftAmount) < -MAX_SHIFT_AMOUNT) {
+      while (static_cast<uint32_t>(shift_amount) < -MAX_SHIFT_AMOUNT) {
         this->right_shift(MAX_SHIFT_AMOUNT);
-        shiftAmount += MAX_SHIFT_AMOUNT;
+        shift_amount += MAX_SHIFT_AMOUNT;
       }
-      this->right_shift(-shiftAmount);
+      this->right_shift(-shift_amount);
     }
   }
 
   // Round the number represented to the closest value of unsigned int type T.
   // This is done ignoring overflow.
   template <class T>
-  T round_to_integer_type(RoundDirection round = RoundDirection::Nearest) {
+  LIBC_INLINE T
+  round_to_integer_type(RoundDirection round = RoundDirection::Nearest) {
     T result = 0;
     uint32_t cur_digit = 0;
 
@@ -404,10 +414,10 @@ class HighPrecisionDecimal {
 
   // Extra functions for testing.
 
-  uint8_t *get_digits() { return this->digits; }
-  uint32_t get_num_digits() { return this->num_digits; }
-  int32_t get_decimal_point() { return this->decimal_point; }
-  void set_truncated(bool trunc) { this->truncated = trunc; }
+  LIBC_INLINE uint8_t *get_digits() { return this->digits; }
+  LIBC_INLINE uint32_t get_num_digits() { return this->num_digits; }
+  LIBC_INLINE int32_t get_decimal_point() { return this->decimal_point; }
+  LIBC_INLINE void set_truncated(bool trunc) { this->truncated = trunc; }
 };
 
 } // namespace internal
diff --git a/libc/src/__support/str_to_float.h b/libc/src/__support/str_to_float.h
index d2bf3f85b2709a..6caf8e62a454f2 100644
--- a/libc/src/__support/str_to_float.h
+++ b/libc/src/__support/str_to_float.h
@@ -313,14 +313,15 @@ constexpr int32_t NUM_POWERS_OF_TWO =
 // on the Simple Decimal Conversion algorithm by Nigel Tao, described at this
 // link: https://nigeltao.github.io/blog/2020/parse-number-f64-simple.html
 template <class T>
-LIBC_INLINE FloatConvertReturn<T>
-simple_decimal_conversion(const char *__restrict numStart,
-                          RoundDirection round = RoundDirection::Nearest) {
+LIBC_INLINE FloatConvertReturn<T> simple_decimal_conversion(
+    const char *__restrict numStart,
+    const size_t num_len = cpp::numeric_limits<size_t>::max(),
+    RoundDirection round = RoundDirection::Nearest) {
   using FPBits = typename fputil::FPBits<T>;
   using StorageType = typename FPBits::StorageType;
 
   int32_t exp2 = 0;
-  HighPrecisionDecimal hpd = HighPrecisionDecimal(numStart);
+  HighPrecisionDecimal hpd = HighPrecisionDecimal(numStart, num_len);
 
   FloatConvertReturn<T> output;
 
@@ -600,13 +601,17 @@ clinger_fast_path(ExpandedFloat<T> init_num,
 // non-inf result for this size of float. The value is
 // log10(2^(exponent bias)).
 // The generic approximation uses the fact that log10(2^x) ~= x/3
-template <typename T> constexpr int32_t get_upper_bound() {
+template <typename T> LIBC_INLINE constexpr int32_t get_upper_bound() {
   return fputil::FPBits<T>::EXP_BIAS / 3;
 }
 
-template <> constexpr int32_t get_upper_bound<float>() { return 39; }
+template <> LIBC_INLINE constexpr int32_t get_upper_bound<float>() {
+  return 39;
+}
 
-template <> constexpr int32_t get_upper_bound<double>() { return 309; }
+template <> LIBC_INLINE constexpr int32_t get_upper_bound<double>() {
+  return 309;
+}
 
 // The lower bound is the largest negative base-10 exponent that could possibly
 // give a non-zero result for this size of float. The value is
@@ -616,18 +621,18 @@ template <> constexpr int32_t get_upper_bound<double>() { return 309; }
 // low base 10 exponent with a very high intermediate mantissa can cancel each
 // other out, and subnormal numbers allow for the result to be at the very low
 // end of the final mantissa.
-template <typename T> constexpr int32_t get_lower_bound() {
+template <typename T> LIBC_INLINE constexpr int32_t get_lower_bound() {
   using FPBits = typename fputil::FPBits<T>;
   return -((FPBits::EXP_BIAS +
             static_cast<int32_t>(FPBits::FRACTION_LEN + FPBits::STORAGE_LEN)) /
            3);
 }
 
-template <> constexpr int32_t get_lower_bound<float>() {
+template <> LIBC_INLINE constexpr int32_t get_lower_bound<float>() {
   return -(39 + 6 + 10);
 }
 
-template <> constexpr int32_t get_lower_bound<double>() {
+template <> LIBC_INLINE constexpr int32_t get_lower_bound<double>() {
   return -(309 + 15 + 20);
 }
 
@@ -637,9 +642,10 @@ template <> constexpr int32_t get_lower_bound<double>() {
 // accuracy. The resulting mantissa and exponent are placed in outputMantissa
 // and outputExp2.
 template <class T>
-LIBC_INLINE FloatConvertReturn<T>
-decimal_exp_to_float(ExpandedFloat<T> init_num, const char *__restrict numStart,
-                     bool truncated, RoundDirection round) {
+LIBC_INLINE FloatConvertReturn<T> decimal_exp_to_float(
+    ExpandedFloat<T> init_num, bool truncated, RoundDirection round,
+    const char *__restrict numStart,
+    const size_t num_len = cpp::numeric_limits<size_t>::max()) {
   using FPBits = typename fputil::FPBits<T>;
   using StorageType = typename FPBits::StorageType;
 
@@ -701,7 +707,7 @@ decimal_exp_to_float(ExpandedFloat<T> init_num, const char *__restrict numStart,
 #endif // LIBC_COPT_STRTOFLOAT_DISABLE_EISEL_LEMIRE
 
 #ifndef LIBC_COPT_STRTOFLOAT_DISABLE_SIMPLE_DECIMAL_CONVERSION
-  output = simple_decimal_conversion<T>(numStart, round);
+  output = simple_decimal_conversion<T>(numStart, num_len, round);
 #else
 #warning "Simple decimal conversion is disabled, result may not be correct."
 #endif // LIBC_COPT_STRTOFLOAT_DISABLE_SIMPLE_DECIMAL_CONVERSION
@@ -894,6 +900,8 @@ decimal_string_to_float(const char *__restrict src, const char DECIMAL_POINT,
   if (!seen_digit)
     return output;
 
+  // TODO: When adding max length argument, handle the case of a trailing
+  // EXPONENT MARKER, see scanf for more details.
   if (tolower(src[index]) == EXPONENT_MARKER) {
     bool has_sign = false;
     if (src[index + 1] == '+' || src[index + 1] == '-') {
@@ -928,7 +936,7 @@ decimal_string_to_float(const char *__restrict src, const char DECIMAL_POINT,
     output.value = {0, 0};
   } else {
     auto temp =
-        decimal_exp_to_float<T>({mantissa, exponent}, src, truncated, round);
+        decimal_exp_to_float<T>({mantissa, exponent}, truncated, round, src);
     output.value = temp.num;
     output.error = temp.error;
   }
@@ -1071,6 +1079,8 @@ nan_mantissa_from_ncharseq(const cpp::string_view ncharseq) {
 
 // Takes a pointer to a string and a pointer to a string pointer. This function
 // is used as the backend for all of the string to float functions.
+// TODO: Add src_len member to match strtointeger.
+// TODO: Next, move from char* and length to string_view
 template <class T>
 LIBC_INLINE StrToNumResult<T> strtofloatingpoint(const char *__restrict src) {
   using FPBits = typename fputil::FPBits<T>;
diff --git a/libc/src/__support/str_to_integer.h b/libc/src/__support/str_to_integer.h
index e83a508e086b18..516b768b9cccd7 100644
--- a/libc/src/__support/str_to_integer.h
+++ b/libc/src/__support/str_to_integer.h
@@ -21,11 +21,15 @@ namespace internal {
 
 // Returns a pointer to the first character in src that is not a whitespace
 // character (as determined by isspace())
-LIBC_INLINE const char *first_non_whitespace(const char *__restrict src) {
-  while (internal::isspace(*src)) {
-    ++src;
+// TODO: Change from returning a pointer to returning a length.
+LIBC_INLINE const char *
+first_non_whitespace(const char *__restrict src,
+                     size_t src_len = cpp::numeric_limits<size_t>::max()) {
+  size_t src_cur = 0;
+  while (src_cur < src_len && internal::isspace(src[src_cur])) {
+    ++src_cur;
   }
-  return src;
+  return src + src_cur;
 }
 
 LIBC_INLINE int b36_char_to_int(char input) {
@@ -38,60 +42,75 @@ LIBC_INLINE int b36_char_to_int(char input) {
 
 // checks if the next 3 characters of the string pointer are the start of a
 // hexadecimal number. Does not advance the string pointer.
-LIBC_INLINE bool is_hex_start(const char *__restrict src) {
+LIBC_INLINE bool
+is_hex_start(const char *__restrict src,
+             size_t src_len = cpp::numeric_limits<size_t>::max()) {
+  if (src_len < 3)
+    return false;
   return *src == '0' && (*(src + 1) | 32) == 'x' && isalnum(*(src + 2)) &&
          b36_char_to_int(*(src + 2)) < 16;
 }
 
+struct BaseAndLen {
+  int base;
+  size_t len;
+};
+
 // Takes the address of the string pointer and parses the base from the start of
 // it. This function will advance |src| to the first valid digit in the inferred
 // base.
-LIBC_INLINE int infer_base(const char *__restrict *__restrict src) {
+LIBC_INLINE BaseAndLen infer_base(const char *__restrict src, size_t src_len) {
   // A hexadecimal number is defined as "the prefix 0x or 0X followed by a
   // sequence of the decimal digits and the letters a (or A) through f (or F)
   // with values 10 through 15 respectively." (C standard 6.4.4.1)
-  if (is_hex_start(*src)) {
-    (*src) += 2;
-    return 16;
+  if (is_hex_start(src, src_len)) {
+    return {16, 2};
   } // An octal number is defined as "the prefix 0 optionally followed by a
     // sequence of the digits 0 through 7 only" (C standard 6.4.4.1) and so any
     // number that starts with 0, including just 0, is an octal number.
-  else if (**src == '0') {
-    return 8;
+  else if (src_len > 0 && src[0] == '0') {
+    return {8, 0};
   } // A decimal number is defined as beginning "with a nonzero digit and
     // consist[ing] of a sequence of decimal digits." (C standard 6.4.4.1)
   else {
-    return 10;
+    return {10, 0};
   }
 }
 
 // Takes a pointer to a string and the base to convert to. This function is used
 // as the backend for all of the string to int functions.
 template <class T>
-LIBC_INLINE StrToNumResult<T> strtointeger(const char *__restrict src,
-                                           int base) {
+LIBC_INLINE StrToNumResult<T>
+strtointeger(const char *__restrict src, int base,
+             const size_t src_len = cpp::numeric_limits<size_t>::max()) {
+  // TODO: Rewrite to support numbers longer than long long
   unsigned long long result = 0;
   bool is_number = false;
-  const char *original_src = src;
+  size_t src_cur = 0;
   int error_val = 0;
 
+  if (src_len == 0)
+    return {0, 0, 0};
+
   if (base < 0 || base == 1 || base > 36) {
     error_val = EINVAL;
     return {0, 0, error_val};
   }
 
-  src = first_non_whitespace(src);
+  src_cur = first_non_whitespace(src, src_len) - src;
 
   char result_sign = '+';
-  if (*src == '+' || *src == '-') {
-    result_sign = *src;
-    ++src;
+  if (src[src_cur] == '+' || src[src_cur] == '-') {
+    result_sign = src[src_cur];
+    ++src_cur;
   }
 
   if (base == 0) {
-    base = infer_base(&src);
-  } else if (base == 16 && is_hex_start(src)) {
-    src = src + 2;
+    auto base_and_len = infer_base(src + src_cur, src_len - src_cur);
+    base = base_and_len.base;
+    src_cur += base_and_len.len;
+  } else if (base == 16 && is_hex_start(src + src_cur, src_len - src_cur)) {
+    src_cur = src_cur + 2;
   }
 
   constexpr bool IS_UNSIGNED = (cpp::numeric_limits<T>::min() == 0);
@@ -103,13 +122,13 @@ LIBC_INLINE StrToNumResult<T> strtointeger(const char *__restrict src,
   unsigned long long const abs_max =
       (is_positive ? cpp::numeric_limits<T>::max() : NEGATIVE_MAX);
   unsigned long long const abs_max_div_by_base = abs_max / base;
-  while (isalnum(*src)) {
-    int cur_digit = b36_char_to_int(*src);
+  while (src_cur < src_len && isalnum(src[src_cur])) {
+    int cur_digit = b36_char_to_int(src[src_cur]);
     if (cur_digit >= base)
       break;
 
     is_number = true;
-    ++src;
+    ++src_cur;
 
     // If the number has already hit the maximum value for the current type then
     // the result cannot change, but we still need to advance src to the end of
@@ -133,7 +152,7 @@ LIBC_INLINE StrToNumResult<T> strtointeger(const char *__restrict src,
     }
   }
 
-  ptrdiff_t str_len = is_number ? (src - original_src) : 0;
+  ptrdiff_t str_len = is_number ? (src_cur) : 0;
 
   if (error_val == ERANGE) {
     if (is_positive || IS_UNSIGNED)
diff --git a/libc/test/src/CMakeLists.txt b/libc/test/src/CMakeLists.txt
index 9ad868551f071c..f70ffda3f700e5 100644
--- a/libc/test/src/CMakeLists.txt
+++ b/libc/test/src/CMakeLists.txt
@@ -40,7 +40,6 @@ add_subdirectory(__support)
 add_subdirectory(ctype)
 add_subdirectory(errno)
 add_subdirectory(fenv)
-add_subdirectory(inttypes)
 add_subdirectory(math)
 add_subdirectory(search)
 add_subdirectory(stdbit)
@@ -50,6 +49,9 @@ add_subdirectory(stdlib)
 add_subdirectory(string)
 add_subdirectory(wchar)
 
+# Depends on utilities in stdlib
+add_subdirectory(inttypes)
+
 if(${LIBC_TARGET_OS} STREQUAL "linux")
   add_subdirectory(fcntl)
   add_subdirectory(sched)
diff --git a/libc/test/src/__support/CMakeLists.txt b/libc/test/src/__support/CMakeLists.txt
index 7200ac276fe502..8c861b576f9b1b 100644
--- a/libc/test/src/__support/CMakeLists.txt
+++ b/libc/test/src/__support/CMakeLists.txt
@@ -56,6 +56,19 @@ add_libc_test(
     libc.src.errno.errno
 )
 
+
+add_libc_test(
+  str_to_integer_test
+  SUITE
+    libc-support-tests
+  SRCS
+    str_to_integer_test.cpp
+  DEPENDS
+    libc.src.__support.integer_literals
+    libc.src.__support.str_to_integer
+    libc.src.errno.errno
+)
+
 add_libc_test(
   integer_to_string_test
   SUITE
diff --git a/libc/test/src/__support/high_precision_decimal_test.cpp b/libc/test/src/__support/high_precision_decimal_test.cpp
index a9c039e45774e9..2bb28bcdab0212 100644
--- a/libc/test/src/__support/high_precision_decimal_test.cpp
+++ b/libc/test/src/__support/high_precision_decimal_test.cpp
@@ -406,3 +406,31 @@ TEST(LlvmLibcHighPrecisionDecimalTest, BigExpTest) {
   // Same, but since the number is negative the net result is -123456788
   EXPECT_EQ(big_negative_hpd.get_decimal_point(), -123456789 + 1);
 }
+
+TEST(LlvmLibcHighPrecisionDecimalTest, NumLenExpTest) {
+  LIBC_NAMESPACE::internal::HighPrecisionDecimal hpd =
+      LIBC_NAMESPACE::internal::HighPrecisionDecimal("1e123456789", 5);
+
+  // The length of 5 includes things like the "e" so it only gets 3 digits of
+  // exponent.
+  EXPECT_EQ(hpd.get_decimal_point(), 123 + 1);
+
+  LIBC_NAMESPACE::internal::HighPrecisionDecimal negative_hpd =
+      LIBC_NAMESPACE::internal::HighPrecisionDecimal("1e-123456789", 5);
+
+  // The negative sign also counts as a character.
+  EXPECT_EQ(negative_hpd.get_decimal_point(), -12 + 1);
+}
+
+TEST(LlvmLibcHighPrecisionDecimalTest, NumLenDigitsTest) {
+  LIBC_NAMESPACE::internal::HighPrecisionDecimal hpd =
+      LIBC_NAMESPACE::internal::HighPrecisionDecimal("123456789e1", 5);
+
+  EXPECT_EQ(hpd.round_to_integer_type<uint64_t>(), uint64_t(12345));
+
+  LIBC_NAMESPACE::internal::HighPrecisionDecimal longer_hpd =
+      LIBC_NAMESPACE::internal::HighPrecisionDecimal("123456789e1", 10);
+
+  // With 10 characters it should see the e, but not actually act on it.
+  EXPECT_EQ(longer_hpd.round_to_integer_type<uint64_t>(), uint64_t(123456789));
+}
diff --git a/libc/test/src/__support/str_to_integer_test.cpp b/libc/test/src/__support/str_to_integer_test.cpp
new file mode 100644
index 00000000000000..34b645b4b38c83
--- /dev/null
+++ b/libc/test/src/__support/str_to_integer_test.cpp
@@ -0,0 +1,240 @@
+//===-- Unittests for str_to_integer --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/__support/str_to_integer.h"
+#include "src/errno/libc_errno.h"
+#include <stddef.h>
+
+#include "test/UnitTest/Test.h"
+
+// This file is for testing the src_len argument and other internal interface
+// features. Primary testing is done in stdlib/StrolTest.cpp through the public
+// interface.
+
+TEST(LlvmLibcStrToIntegerTest, SimpleLength) {
+  auto result = LIBC_NAMESPACE::internal::strtointeger<int>("12345", 10, 10);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(5));
+  ASSERT_EQ(result.value, 12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("12345", 10, 2);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(2));
+  ASSERT_EQ(result.value, 12);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("12345", 10, 0);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+}
+
+TEST(LlvmLibcStrToIntegerTest, LeadingSpaces) {
+  auto result =
+      LIBC_NAMESPACE::internal::strtointeger<int>("     12345", 10, 15);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(10));
+  ASSERT_EQ(result.value, 12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("     12345", 10, 10);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(10));
+  ASSERT_EQ(result.value, 12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("     12345", 10, 7);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(7));
+  ASSERT_EQ(result.value, 12);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("     12345", 10, 5);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("     12345", 10, 0);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+}
+
+TEST(LlvmLibcStrToIntegerTest, LeadingSign) {
+  auto result = LIBC_NAMESPACE::internal::strtointeger<int>("+12345", 10, 10);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(6));
+  ASSERT_EQ(result.value, 12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("-12345", 10, 10);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(6));
+  ASSERT_EQ(result.value, -12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("+12345", 10, 6);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(6));
+  ASSERT_EQ(result.value, 12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("-12345", 10, 6);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(6));
+  ASSERT_EQ(result.value, -12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("+12345", 10, 3);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(3));
+  ASSERT_EQ(result.value, 12);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("-12345", 10, 3);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(3));
+  ASSERT_EQ(result.value, -12);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("+12345", 10, 1);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("-12345", 10, 1);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("+12345", 10, 0);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("-12345", 10, 0);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+}
+
+TEST(LlvmLibcStrToIntegerTest, Base16PrefixAutoSelect) {
+  auto result = LIBC_NAMESPACE::internal::strtointeger<int>("0x12345", 0, 10);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(7));
+  ASSERT_EQ(result.value, 0x12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("0x12345", 0, 7);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(7));
+  ASSERT_EQ(result.value, 0x12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("0x12345", 0, 5);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(5));
+  ASSERT_EQ(result.value, 0x123);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("0x12345", 0, 2);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(1));
+  ASSERT_EQ(result.value, 0);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("0x12345", 0, 0);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+}
+
+TEST(LlvmLibcStrToIntegerTest, Base16PrefixManualSelect) {
+  auto result = LIBC_NAMESPACE::internal::strtointeger<int>("0x12345", 16, 10);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(7));
+  ASSERT_EQ(result.value, 0x12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("0x12345", 16, 7);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(7));
+  ASSERT_EQ(result.value, 0x12345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("0x12345", 16, 5);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(5));
+  ASSERT_EQ(result.value, 0x123);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("0x12345", 16, 2);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(1));
+  ASSERT_EQ(result.value, 0);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("0x12345", 16, 0);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+}
+
+TEST(LlvmLibcStrToIntegerTest, Base8PrefixAutoSelect) {
+  auto result = LIBC_NAMESPACE::internal::strtointeger<int>("012345", 0, 10);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(6));
+  ASSERT_EQ(result.value, 012345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("012345", 0, 6);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(6));
+  ASSERT_EQ(result.value, 012345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("012345", 0, 4);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(4));
+  ASSERT_EQ(result.value, 0123);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("012345", 0, 1);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(1));
+  ASSERT_EQ(result.value, 0);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("012345", 0, 0);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+}
+
+TEST(LlvmLibcStrToIntegerTest, Base8PrefixManualSelect) {
+  auto result = LIBC_NAMESPACE::internal::strtointeger<int>("012345", 8, 10);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(6));
+  ASSERT_EQ(result.value, 012345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("012345", 8, 6);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(6));
+  ASSERT_EQ(result.value, 012345);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("012345", 8, 4);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(4));
+  ASSERT_EQ(result.value, 0123);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("012345", 8, 1);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(1));
+  ASSERT_EQ(result.value, 0);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("012345", 8, 0);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(0));
+  ASSERT_EQ(result.value, 0);
+}
+
+TEST(LlvmLibcStrToIntegerTest, CombinedTests) {
+  auto result =
+      LIBC_NAMESPACE::internal::strtointeger<int>("    -0x123", 0, 10);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(10));
+  ASSERT_EQ(result.value, -0x123);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("    -0x123", 0, 8);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(8));
+  ASSERT_EQ(result.value, -0x1);
+
+  result = LIBC_NAMESPACE::internal::strtointeger<int>("    -0x123", 0, 7);
+  EXPECT_FALSE(result.has_error());
+  EXPECT_EQ(result.parsed_len, ptrdiff_t(6));
+  ASSERT_EQ(result.value, 0);
+}

>From 82e0868de2cfe6a8c5a47b4c0e1efb1e1867173a Mon Sep 17 00:00:00 2001
From: Michael Jones <michaelrj at google.com>
Date: Wed, 6 Mar 2024 10:51:30 -0800
Subject: [PATCH 2/2] str to integer cleanup

---
 libc/src/__support/str_to_integer.h | 48 +++++++++++------------------
 1 file changed, 18 insertions(+), 30 deletions(-)

diff --git a/libc/src/__support/str_to_integer.h b/libc/src/__support/str_to_integer.h
index 516b768b9cccd7..b87808993fee50 100644
--- a/libc/src/__support/str_to_integer.h
+++ b/libc/src/__support/str_to_integer.h
@@ -51,30 +51,22 @@ is_hex_start(const char *__restrict src,
          b36_char_to_int(*(src + 2)) < 16;
 }
 
-struct BaseAndLen {
-  int base;
-  size_t len;
-};
-
 // Takes the address of the string pointer and parses the base from the start of
-// it. This function will advance |src| to the first valid digit in the inferred
-// base.
-LIBC_INLINE BaseAndLen infer_base(const char *__restrict src, size_t src_len) {
+// it.
+LIBC_INLINE int infer_base(const char *__restrict src, size_t src_len) {
   // A hexadecimal number is defined as "the prefix 0x or 0X followed by a
   // sequence of the decimal digits and the letters a (or A) through f (or F)
   // with values 10 through 15 respectively." (C standard 6.4.4.1)
-  if (is_hex_start(src, src_len)) {
-    return {16, 2};
-  } // An octal number is defined as "the prefix 0 optionally followed by a
-    // sequence of the digits 0 through 7 only" (C standard 6.4.4.1) and so any
-    // number that starts with 0, including just 0, is an octal number.
-  else if (src_len > 0 && src[0] == '0') {
-    return {8, 0};
-  } // A decimal number is defined as beginning "with a nonzero digit and
-    // consist[ing] of a sequence of decimal digits." (C standard 6.4.4.1)
-  else {
-    return {10, 0};
-  }
+  if (is_hex_start(src, src_len))
+    return 16;
+  // An octal number is defined as "the prefix 0 optionally followed by a
+  // sequence of the digits 0 through 7 only" (C standard 6.4.4.1) and so any
+  // number that starts with 0, including just 0, is an octal number.
+  if (src_len > 0 && src[0] == '0')
+    return 8;
+  // A decimal number is defined as beginning "with a nonzero digit and
+  // consist[ing] of a sequence of decimal digits." (C standard 6.4.4.1)
+  return 10;
 }
 
 // Takes a pointer to a string and the base to convert to. This function is used
@@ -92,10 +84,8 @@ strtointeger(const char *__restrict src, int base,
   if (src_len == 0)
     return {0, 0, 0};
 
-  if (base < 0 || base == 1 || base > 36) {
-    error_val = EINVAL;
-    return {0, 0, error_val};
-  }
+  if (base < 0 || base == 1 || base > 36)
+    return {0, 0, EINVAL};
 
   src_cur = first_non_whitespace(src, src_len) - src;
 
@@ -105,13 +95,11 @@ strtointeger(const char *__restrict src, int base,
     ++src_cur;
   }
 
-  if (base == 0) {
-    auto base_and_len = infer_base(src + src_cur, src_len - src_cur);
-    base = base_and_len.base;
-    src_cur += base_and_len.len;
-  } else if (base == 16 && is_hex_start(src + src_cur, src_len - src_cur)) {
+  if (base == 0)
+    base = infer_base(src + src_cur, src_len - src_cur);
+
+  if (base == 16 && is_hex_start(src + src_cur, src_len - src_cur))
     src_cur = src_cur + 2;
-  }
 
   constexpr bool IS_UNSIGNED = (cpp::numeric_limits<T>::min() == 0);
   const bool is_positive = (result_sign == '+');



More information about the libc-commits mailing list