[libc-commits] [libc] 8298424 - [libc] refactor atof string parsing

Michael Jones via libc-commits libc-commits at lists.llvm.org
Tue Nov 9 10:12:21 PST 2021


Author: Michael Jones
Date: 2021-11-09T10:12:18-08:00
New Revision: 8298424cae9b4d3d41dbe17857dc9cb247d90786

URL: https://github.com/llvm/llvm-project/commit/8298424cae9b4d3d41dbe17857dc9cb247d90786
DIFF: https://github.com/llvm/llvm-project/commit/8298424cae9b4d3d41dbe17857dc9cb247d90786.diff

LOG: [libc] refactor atof string parsing

Split the code for parsing hexadecimal floating point numbers from the
code for parsing the decimal floating point numbers so that the parsing
can be faster for both of them.

This decreases the time for the benchmark in release mode by about 15%,
which noticeably beats GLibc.

Old version: 2.299s
New version: 1.893s
GLibc: 2.133s

Tests run by running the following command 10 times for each version:
time ~/llvm-project/build/bin/libc_str_to_float_comparison_test ~/parse-number-fxx-test-data/data/*

the parse-number-fxx-test-data-repository is here:
https://github.com/nigeltao/parse-number-fxx-test-data/tree/fe94de252c691900982050c8e7c503d1efd1299a

It's important to build llvm-libc in Release mode for accurate
performance comparisons against glibc (set -DCMAKE_BUILD_TYPE=Release in
your cmake).
You also have to build the libc_str_to_float_comparison_test target.

Reviewed By: lntue

Differential Revision: https://reviews.llvm.org/D113036

Added: 
    

Modified: 
    libc/src/__support/str_to_float.h
    libc/test/src/stdlib/strtof_test.cpp

Removed: 
    


################################################################################
diff  --git a/libc/src/__support/str_to_float.h b/libc/src/__support/str_to_float.h
index 59bd1ec8e5da4..90408e77b557d 100644
--- a/libc/src/__support/str_to_float.h
+++ b/libc/src/__support/str_to_float.h
@@ -40,7 +40,7 @@ static inline T shiftRightAndRound(T numToShift, unsigned int amountToShift) {
   }
 }
 
-template <class T> uint32_t leadingZeroes(T inputNumber) {
+template <class T> uint32_t inline leadingZeroes(T inputNumber) {
   // TODO(michaelrj): investigate the portability of using something like
   // __builtin_clz for specific types.
   constexpr uint32_t bitsInT = sizeof(T) * 8;
@@ -71,6 +71,14 @@ template <class T> uint32_t leadingZeroes(T inputNumber) {
   return bitsInT - curGuess;
 }
 
+template <> uint32_t inline leadingZeroes<uint32_t>(uint32_t inputNumber) {
+  return inputNumber == 0 ? 32 : __builtin_clz(inputNumber);
+}
+
+template <> uint32_t inline leadingZeroes<uint64_t>(uint64_t inputNumber) {
+  return inputNumber == 0 ? 64 : __builtin_clzll(inputNumber);
+}
+
 static inline uint64_t low64(__uint128_t num) {
   return static_cast<uint64_t>(num & 0xffffffffffffffff);
 }
@@ -442,6 +450,81 @@ decimalExpToFloat(typename fputil::FPBits<T>::UIntType mantissa, int32_t exp10,
   return;
 }
 
+// Takes a mantissa and base 2 exponent and converts it into its closest
+// floating point type T equivalient. Since the exponent is already in the right
+// form, this is mostly just shifting and rounding. This is used for hexadecimal
+// numbers since a base 16 exponent multiplied by 4 is the base 2 exponent.
+template <class T>
+static inline void
+binaryExpToFloat(typename fputil::FPBits<T>::UIntType mantissa, int32_t exp2,
+                 typename fputil::FPBits<T>::UIntType *outputMantissa,
+                 uint32_t *outputExp2) {
+  using BitsType = typename fputil::FPBits<T>::UIntType;
+
+  // This is the number of leading zeroes a properly normalized float of type T
+  // should have.
+  constexpr uint32_t NORMALIZED_LEADING_ZEROES =
+      (sizeof(BitsType) * 8) - fputil::FloatProperties<T>::mantissaWidth - 1;
+  constexpr BitsType OVERFLOWED_MANTISSA =
+      BitsType(1) << (fputil::FloatProperties<T>::mantissaWidth + 1);
+
+  // Normalization
+  int32_t amountToShift =
+      NORMALIZED_LEADING_ZEROES - leadingZeroes<BitsType>(mantissa);
+  if (amountToShift < 0) {
+    mantissa <<= -amountToShift;
+  } else {
+    mantissa = shiftRightAndRound(mantissa, amountToShift);
+    if (mantissa == OVERFLOWED_MANTISSA) {
+      mantissa >>= 1;
+      exp2 += 1;
+    }
+  }
+  exp2 += amountToShift;
+
+  // Account for the fact that the mantissa represented an integer
+  // previously, but now represents the fractional part of a normalized
+  // number.
+  exp2 += fputil::FloatProperties<T>::mantissaWidth;
+
+  int32_t biasedExponent = exp2 + fputil::FPBits<T>::exponentBias;
+  // handle subnormals
+  if (biasedExponent <= 0) {
+
+    // the most mantissa is currently normalized, meaning that the msb is
+    // one bit left of where the decimal point should go.
+    amountToShift = 1;
+    BitsType mantissaCopy = mantissa >> 1;
+    while (biasedExponent < 0 && mantissaCopy > 0) {
+      mantissaCopy = mantissaCopy >> 1;
+      ++amountToShift;
+      ++biasedExponent;
+    }
+    // If we cut off any bits to fit this number into a subnormal, then it's
+    // out of range for this size of float.
+    if ((mantissa & ((1 << amountToShift) - 1)) > 0) {
+      errno = ERANGE; // NOLINT
+    }
+    mantissa = shiftRightAndRound(mantissa, amountToShift);
+    if (mantissa == OVERFLOWED_MANTISSA) {
+      mantissa >>= 1;
+      exp2 += 1;
+    } else if (mantissa == 0) {
+      biasedExponent = 0;
+    }
+  }
+  // handle numbers that're too large and get squashed to inf
+  else if (biasedExponent >
+           (1 << fputil::FloatProperties<T>::exponentWidth) - 1) {
+    // This indicates an overflow, so we make the result INF and set errno.
+    biasedExponent = (1 << fputil::FloatProperties<T>::exponentWidth) - 1;
+    mantissa = 0;
+    errno = ERANGE; // NOLINT
+  }
+  *outputMantissa = mantissa;
+  *outputExp2 = biasedExponent;
+}
+
 // checks if the next 4 characters of the string pointer are the start of a
 // hexadecimal floating point number. Does not advance the string pointer.
 static inline bool is_float_hex_start(const char *__restrict src,
@@ -456,190 +539,268 @@ static inline bool is_float_hex_start(const char *__restrict src,
   }
 }
 
-// Takes a pointer to a string and a pointer to a string pointer. This function
-// is used as the backend for all of the string to float functions.
+// Takes the start of a string representing a decimal float, as well as the
+// local decimalPoint. It returns if it suceeded in parsing any digits, and if
+// the return value is true then the outputs are pointer to the end of the
+// number, and the mantissa and exponent for the closest float T representation.
+// If the return value is false, then it is assumed that there is no number
+// here.
 template <class T>
-static inline T strtofloatingpoint(const char *__restrict src,
-                                   char **__restrict strEnd) {
+static inline bool
+decimalStringToFloat(const char *__restrict src, const char DECIMAL_POINT,
+                     char **__restrict strEnd,
+                     typename fputil::FPBits<T>::UIntType *outputMantissa,
+                     uint32_t *outputExponent) {
   using BitsType = typename fputil::FPBits<T>::UIntType;
-  fputil::FPBits<T> result = fputil::FPBits<T>();
-  const char *originalSrc = src;
+  constexpr uint32_t BASE = 10;
+  constexpr char EXPONENT_MARKER = 'e';
+
+  const char *__restrict numStart = src;
+  bool truncated = false;
   bool seenDigit = false;
-  src = first_non_whitespace(src);
+  bool afterDecimal = false;
+  BitsType mantissa = 0;
+  int32_t exponent = 0;
+
+  // The goal for the first step of parsing is to convert the number in src to
+  // the format mantissa * (base ^ exponent)
+
+  // The first loop fills the mantissa with as many digits as it can hold
+  const BitsType BITSTYPE_MAX_DIV_BY_BASE =
+      __llvm_libc::cpp::NumericLimits<BitsType>::max() / BASE;
+  while ((isdigit(*src) || *src == DECIMAL_POINT) &&
+         mantissa < BITSTYPE_MAX_DIV_BY_BASE) {
+    if (*src == DECIMAL_POINT) {
+      if (afterDecimal) {
+        break; // this means that *src points to a second decimal point, ending
+               // the number.
+      } else {
+        afterDecimal = true;
+        ++src;
+        continue;
+      }
+    }
+    uint32_t digit = *src - '0';
 
-  if (*src == '+' || *src == '-') {
-    if (*src == '-') {
-      result.setSign(true);
+    mantissa = (mantissa * BASE) + digit;
+    seenDigit = true;
+    if (afterDecimal) {
+      --exponent;
     }
+
     ++src;
   }
 
-  static constexpr char DECIMAL_POINT = '.';
-  static const char *INF_STRING = "infinity";
-  static const char *NAN_STRING = "nan";
-
-  bool truncated = false;
+  if (!seenDigit)
+    return false;
 
-  if (isdigit(*src) || *src == DECIMAL_POINT) { // regular number
-    int base = 10;
-    char exponentMarker = 'e';
-    if (is_float_hex_start(src, DECIMAL_POINT)) {
-      base = 16;
-      src += 2;
-      exponentMarker = 'p';
-      seenDigit = true;
-    }
-    const char *__restrict numStart = src;
-    bool afterDecimal = false;
-
-    BitsType mantissa = 0;
-    int32_t exponent = 0;
-
-    // The goal for the first step of parsing is to convert the number in src to
-    // the format mantissa * (base ^ exponent)
-
-    constexpr BitsType MANTISSA_MAX =
-        BitsType(1) << (fputil::FloatProperties<T>::mantissaWidth +
-                        1); // The extra bit is to give space for the implicit 1
-    const BitsType BITSTYPE_MAX_DIV_BY_BASE =
-        __llvm_libc::cpp::NumericLimits<BitsType>::max() / base;
-    while ((isalnum(*src) || *src == DECIMAL_POINT) &&
-           mantissa < BITSTYPE_MAX_DIV_BY_BASE) {
-      if (*src == DECIMAL_POINT && afterDecimal) {
+  // The second loop is to run through the remaining digits after we've filled
+  // the mantissa.
+  while (isdigit(*src) || *src == DECIMAL_POINT) {
+    if (*src == DECIMAL_POINT) {
+      if (afterDecimal) {
         break; // this means that *src points to a second decimal point, ending
                // the number.
-      } else if (*src == DECIMAL_POINT) {
+      } else {
         afterDecimal = true;
         ++src;
         continue;
       }
-      int digit = b36_char_to_int(*src);
-      if (digit >= base) {
-        break;
-      }
+    }
+    uint32_t digit = *src - '0';
 
-      mantissa = (mantissa * base) + digit;
-      seenDigit = true;
-      if (afterDecimal) {
-        --exponent;
-      }
+    if (digit > 0)
+      truncated = true;
+
+    if (!afterDecimal)
+      ++exponent;
+
+    ++src;
+  }
 
+  if ((*src | 32) == EXPONENT_MARKER) {
+    if (*(src + 1) == '+' || *(src + 1) == '-' || isdigit(*(src + 1))) {
       ++src;
+      char *tempStrEnd;
+      int32_t add_to_exponent = strtointeger<int32_t>(src, &tempStrEnd, 10);
+      if (add_to_exponent > 100000)
+        add_to_exponent = 100000;
+      else if (add_to_exponent < -100000)
+        add_to_exponent = -100000;
+
+      src = tempStrEnd;
+      exponent += add_to_exponent;
     }
+  }
+
+  *strEnd = const_cast<char *>(src);
+  if (mantissa == 0) { // if we have a 0, then also 0 the exponent.
+    *outputMantissa = 0;
+    *outputExponent = 0;
+  } else {
+    decimalExpToFloat<T>(mantissa, exponent, numStart, truncated,
+                         outputMantissa, outputExponent);
+  }
+  return true;
+}
 
-    // The second loop is to run through the remaining digits after we've filled
-    // the mantissa.
-    while (isalnum(*src) || *src == DECIMAL_POINT) {
-      if (*src == DECIMAL_POINT && afterDecimal) {
+// Takes the start of a string representing a hexadecimal float, as well as the
+// local decimal point. It returns if it suceeded in parsing any digits, and if
+// the return value is true then the outputs are pointer to the end of the
+// number, and the mantissa and exponent for the closest float T representation.
+// If the return value is false, then it is assumed that there is no number
+// here.
+template <class T>
+static inline bool
+hexadecimalStringToFloat(const char *__restrict src, const char DECIMAL_POINT,
+                         char **__restrict strEnd,
+                         typename fputil::FPBits<T>::UIntType *outputMantissa,
+                         uint32_t *outputExponent) {
+  using BitsType = typename fputil::FPBits<T>::UIntType;
+  constexpr uint32_t BASE = 16;
+  constexpr char EXPONENT_MARKER = 'p';
+
+  bool truncated = false;
+  bool seenDigit = false;
+  bool afterDecimal = false;
+  BitsType mantissa = 0;
+  int32_t exponent = 0;
+
+  // The goal for the first step of parsing is to convert the number in src to
+  // the format mantissa * (base ^ exponent)
+
+  // The first loop fills the mantissa with as many digits as it can hold
+  const BitsType BITSTYPE_MAX_DIV_BY_BASE =
+      __llvm_libc::cpp::NumericLimits<BitsType>::max() / BASE;
+  while ((isalnum(*src) || *src == DECIMAL_POINT) &&
+         mantissa < BITSTYPE_MAX_DIV_BY_BASE) {
+    if (*src == DECIMAL_POINT) {
+      if (afterDecimal) {
         break; // this means that *src points to a second decimal point, ending
                // the number.
-      } else if (*src == DECIMAL_POINT) {
+      } else {
         afterDecimal = true;
         ++src;
         continue;
       }
-      int digit = b36_char_to_int(*src);
-      if (digit >= base) {
-        break;
-      }
+    }
+    uint32_t digit = b36_char_to_int(*src);
+    if (digit >= BASE)
+      break;
 
-      if (digit > 0) {
-        truncated = true;
-      }
+    mantissa = (mantissa * BASE) + digit;
+    seenDigit = true;
+    if (afterDecimal)
+      --exponent;
+
+    ++src;
+  }
 
-      if (!afterDecimal) {
-        exponent++;
+  if (!seenDigit)
+    return false;
+
+  // The second loop is to run through the remaining digits after we've filled
+  // the mantissa.
+  while (isalnum(*src) || *src == DECIMAL_POINT) {
+    if (*src == DECIMAL_POINT) {
+      if (afterDecimal) {
+        break; // this means that *src points to a second decimal point, ending
+               // the number.
+      } else {
+        afterDecimal = true;
+        ++src;
+        continue;
       }
+    }
+    uint32_t digit = b36_char_to_int(*src);
+    if (digit >= BASE)
+      break;
+
+    if (digit > 0)
+      truncated = true;
+
+    if (!afterDecimal)
+      ++exponent;
 
+    ++src;
+  }
+
+  // Convert the exponent from having a base of 16 to having a base of 2.
+  exponent *= 4;
+
+  if ((*src | 32) == EXPONENT_MARKER) {
+    if (*(src + 1) == '+' || *(src + 1) == '-' || isdigit(*(src + 1))) {
       ++src;
+      char *tempStrEnd;
+      int32_t add_to_exponent = strtointeger<int32_t>(src, &tempStrEnd, 10);
+      if (add_to_exponent > 100000)
+        add_to_exponent = 100000;
+      else if (add_to_exponent < -100000)
+        add_to_exponent = -100000;
+      src = tempStrEnd;
+      exponent += add_to_exponent;
     }
+  }
+  *strEnd = const_cast<char *>(src);
+  if (mantissa == 0) { // if we have a 0, then also 0 the exponent.
+    *outputMantissa = 0;
+    *outputExponent = 0;
+  } else {
+    binaryExpToFloat<T>(mantissa, exponent, outputMantissa, outputExponent);
+  }
+  return true;
+}
 
-    // if our base is 16 then convert the exponent to base 2
-    if (base == 16) {
-      exponent *= 4;
-    }
+// Takes a pointer to a string and a pointer to a string pointer. This function
+// is used as the backend for all of the string to float functions.
+template <class T>
+static inline T strtofloatingpoint(const char *__restrict src,
+                                   char **__restrict strEnd) {
+  using BitsType = typename fputil::FPBits<T>::UIntType;
+  fputil::FPBits<T> result = fputil::FPBits<T>();
+  const char *originalSrc = src;
+  bool seenDigit = false;
+  src = first_non_whitespace(src);
 
-    if ((*src | 32) == exponentMarker) {
-      if (*(src + 1) == '+' || *(src + 1) == '-' || isdigit(*(src + 1))) {
-        ++src;
-        char *tempStrEnd;
-        int32_t add_to_exponent = strtointeger<int32_t>(src, &tempStrEnd, 10);
-        if (add_to_exponent > 100000) {
-          add_to_exponent = 100000;
-        } else if (add_to_exponent < -100000) {
-          add_to_exponent = -100000;
-        }
-        src += tempStrEnd - src;
-        exponent += add_to_exponent;
-      }
+  if (*src == '+' || *src == '-') {
+    if (*src == '-') {
+      result.setSign(true);
     }
+    ++src;
+  }
 
-    if (mantissa == 0) { // if we have a 0, then also 0 the exponent.
-      exponent = 0;
-    } else if (base == 16) {
-
-      // These two loops should normalize the number if we assume the decimal
-      // point is after the bit at mantissaWidth.
-      // For example if type T is a 32 bit float, this should result in a
-      // mantissa with its most significant 1 being at bit 23.
-      while (mantissa < (MANTISSA_MAX >> 1)) {
-        mantissa = mantissa << 1;
-        --exponent;
-      }
-      BitsType mantissaCopy = mantissa;
-      unsigned int amountToShift = 0;
-      while (mantissaCopy > MANTISSA_MAX) {
-        mantissaCopy = mantissaCopy >> 1;
-        ++amountToShift;
-      }
-      exponent += amountToShift;
-      mantissa = shiftRightAndRound(mantissa, amountToShift);
-
-      // Account for the fact that the mantissa represented an integer
-      // previously, but now represents the fractional part of a normalized
-      // number.
-      exponent += fputil::FloatProperties<T>::mantissaWidth;
-
-      int32_t biasedExponent = exponent + fputil::FPBits<T>::exponentBias;
-      if (biasedExponent <= 0) {
-        // handle subnormals here
-
-        // the most mantissa is currently normalized, meaning that the msb is
-        // one bit left of where the decimal point should go.
-        amountToShift = 1;
-        mantissaCopy = mantissa >> 1;
-        while (biasedExponent < 0 && mantissaCopy > 0) {
-          mantissaCopy = mantissaCopy >> 1;
-          ++amountToShift;
-          ++biasedExponent;
-        }
-        // If we cut off any bits to fit this number into a subnormal, then it's
-        // out of range for this size of float.
-        if ((mantissa & ((1 << amountToShift) - 1)) > 0) {
-          errno = ERANGE; // NOLINT
-        }
-        mantissa = shiftRightAndRound(mantissa, amountToShift);
-        if (mantissa == 0) {
-          biasedExponent = 0;
-        }
-      } else if (biasedExponent > result.maxExponent) {
-        // This indicates an overflow, so we make the result INF and set errno.
-        biasedExponent = result.maxExponent;
-        mantissa = 0;
-        errno = ERANGE; // NOLINT
-      }
+  static constexpr char DECIMAL_POINT = '.';
+  static const char *INF_STRING = "infinity";
+  static const char *NAN_STRING = "nan";
+
+  // bool truncated = false;
 
-      result.setUnbiasedExponent(biasedExponent);
-      result.setMantissa(mantissa);
+  if (isdigit(*src) || *src == DECIMAL_POINT) { // regular number
+    int base = 10;
+    char exponentMarker = 'e';
+    if (is_float_hex_start(src, DECIMAL_POINT)) {
+      base = 16;
+      src += 2;
+      exponentMarker = 'p';
+      seenDigit = true;
+    }
+    char *newStrEnd = nullptr;
+
+    BitsType outputMantissa = 0;
+    uint32_t outputExponent = 0;
+    if (base == 16) {
+      seenDigit = hexadecimalStringToFloat<T>(src, DECIMAL_POINT, &newStrEnd,
+                                              &outputMantissa, &outputExponent);
     } else { // base is 10
-      BitsType outputMantissa = 0;
-      uint32_t outputExponent = 0;
-      decimalExpToFloat<T>(mantissa, exponent, numStart, truncated,
-                           &outputMantissa, &outputExponent);
+      seenDigit = decimalStringToFloat<T>(src, DECIMAL_POINT, &newStrEnd,
+                                          &outputMantissa, &outputExponent);
+    }
+
+    if (seenDigit) {
+      src += newStrEnd - src;
       result.setMantissa(outputMantissa);
       result.setUnbiasedExponent(outputExponent);
     }
-
   } else if ((*src | 32) == 'n') { // NaN
     if ((src[1] | 32) == NAN_STRING[1] && (src[2] | 32) == NAN_STRING[2]) {
       seenDigit = true;

diff  --git a/libc/test/src/stdlib/strtof_test.cpp b/libc/test/src/stdlib/strtof_test.cpp
index f20cc0dfff5be..2109e7d19df69 100644
--- a/libc/test/src/stdlib/strtof_test.cpp
+++ b/libc/test/src/stdlib/strtof_test.cpp
@@ -132,6 +132,10 @@ TEST_F(LlvmLibcStrToFTest, HexadecimalNormalRoundingTests) {
   runTest("0x123456700", 11, 0x4f91a2b4);
 }
 
+TEST_F(LlvmLibcStrToFTest, HexadecimalsWithRoundingProblems) {
+  runTest("0xFFFFFFFF", 10, 0x4f800000);
+}
+
 TEST_F(LlvmLibcStrToFTest, HexadecimalOutOfRangeTests) {
   runTest("0x123456789123456789123456789123456789", 38, 0x7f800000, ERANGE);
   runTest("-0x123456789123456789123456789123456789", 39, 0xff800000, ERANGE);


        


More information about the libc-commits mailing list