[libcxx-commits] [libcxx] [libc++] Fix `money_get::do_get` with huge input (PR #126273)

via libcxx-commits libcxx-commits at lists.llvm.org
Fri Feb 7 09:36:25 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-libcxx

Author: A. Jiang (frederick-vs-ja)

<details>
<summary>Changes</summary>

`money_get::do_get` needs to be fixed to handle extremely huge input (e.g. more than 100 digits).
1. `__double_or_nothing` needs to copy the contents of the stack buffer on the initial allocation.
2. The `sscanf` call in `do_get` needs to scan the dynamic buffer if dynamic allocation happens.

I think the fix should be backported to frozen cxx03 headers as the previously wrong handling caused core language UB.

Fixes #<!-- -->121878.

---
Full diff: https://github.com/llvm/llvm-project/pull/126273.diff


3 Files Affected:

- (modified) libcxx/include/__cxx03/locale (+7-3) 
- (modified) libcxx/include/locale (+7-3) 
- (added) libcxx/test/std/localization/locale.categories/category.monetary/locale.money.get/locale.money.get.members/get_long_double_overlong.pass.cpp (+113) 


``````````diff
diff --git a/libcxx/include/__cxx03/locale b/libcxx/include/__cxx03/locale
index 6360bbc2f6b6082..a83ebacb5adbe1c 100644
--- a/libcxx/include/__cxx03/locale
+++ b/libcxx/include/__cxx03/locale
@@ -2460,6 +2460,8 @@ _LIBCPP_HIDE_FROM_ABI void __double_or_nothing(unique_ptr<_Tp, void (*)(void*)>&
     __throw_bad_alloc();
   if (__owns)
     __b.release();
+  else
+    std::memcpy(__t, __b.get(), __cur_cap);
   __b = unique_ptr<_Tp, void (*)(void*)>(__t, free);
   __new_cap /= sizeof(_Tp);
   __n = __b.get() + __n_off;
@@ -2655,20 +2657,22 @@ _InputIterator money_get<_CharT, _InputIterator>::do_get(
     char_type __atoms[sizeof(__src) - 1];
     __ct.widen(__src, __src + (sizeof(__src) - 1), __atoms);
     char __nbuf[__bz];
-    char* __nc = __nbuf;
+    char* __nc          = __nbuf;
+    const char* __nc_in = __nc;
     unique_ptr<char, void (*)(void*)> __h(nullptr, free);
     if (__wn - __wb.get() > __bz - 2) {
       __h.reset((char*)malloc(static_cast<size_t>(__wn - __wb.get() + 2)));
       if (__h.get() == nullptr)
         __throw_bad_alloc();
-      __nc = __h.get();
+      __nc    = __h.get();
+      __nc_in = __nc;
     }
     if (__neg)
       *__nc++ = '-';
     for (const char_type* __w = __wb.get(); __w < __wn; ++__w, ++__nc)
       *__nc = __src[std::find(__atoms, std::end(__atoms), *__w) - __atoms];
     *__nc = char();
-    if (sscanf(__nbuf, "%Lf", &__v) != 1)
+    if (sscanf(__nc_in, "%Lf", &__v) != 1)
       __throw_runtime_error("money_get error");
   }
   if (__b == __e)
diff --git a/libcxx/include/locale b/libcxx/include/locale
index be0f31cece671fb..919332a09bba1f0 100644
--- a/libcxx/include/locale
+++ b/libcxx/include/locale
@@ -2385,6 +2385,8 @@ _LIBCPP_HIDE_FROM_ABI void __double_or_nothing(unique_ptr<_Tp, void (*)(void*)>&
     __throw_bad_alloc();
   if (__owns)
     __b.release();
+  else
+    std::memcpy(__t, __b.get(), __cur_cap);
   __b = unique_ptr<_Tp, void (*)(void*)>(__t, free);
   __new_cap /= sizeof(_Tp);
   __n = __b.get() + __n_off;
@@ -2580,20 +2582,22 @@ _InputIterator money_get<_CharT, _InputIterator>::do_get(
     char_type __atoms[sizeof(__src) - 1];
     __ct.widen(__src, __src + (sizeof(__src) - 1), __atoms);
     char __nbuf[__bz];
-    char* __nc = __nbuf;
+    char* __nc          = __nbuf;
+    const char* __nc_in = __nc;
     unique_ptr<char, void (*)(void*)> __h(nullptr, free);
     if (__wn - __wb.get() > __bz - 2) {
       __h.reset((char*)malloc(static_cast<size_t>(__wn - __wb.get() + 2)));
       if (__h.get() == nullptr)
         __throw_bad_alloc();
-      __nc = __h.get();
+      __nc    = __h.get();
+      __nc_in = __nc;
     }
     if (__neg)
       *__nc++ = '-';
     for (const char_type* __w = __wb.get(); __w < __wn; ++__w, ++__nc)
       *__nc = __src[std::find(__atoms, std::end(__atoms), *__w) - __atoms];
     *__nc = char();
-    if (sscanf(__nbuf, "%Lf", &__v) != 1)
+    if (sscanf(__nc_in, "%Lf", &__v) != 1)
       __throw_runtime_error("money_get error");
   }
   if (__b == __e)
diff --git a/libcxx/test/std/localization/locale.categories/category.monetary/locale.money.get/locale.money.get.members/get_long_double_overlong.pass.cpp b/libcxx/test/std/localization/locale.categories/category.monetary/locale.money.get/locale.money.get.members/get_long_double_overlong.pass.cpp
new file mode 100644
index 000000000000000..5966f0312233862
--- /dev/null
+++ b/libcxx/test/std/localization/locale.categories/category.monetary/locale.money.get/locale.money.get.members/get_long_double_overlong.pass.cpp
@@ -0,0 +1,113 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// <locale>
+
+// class money_get<charT, InputIterator>
+
+// iter_type get(iter_type b, iter_type e, bool intl, ios_base& iob,
+//               ios_base::iostate& err, long double& v) const;
+
+#include <cassert>
+#include <cstddef>
+#include <ios>
+#include <locale>
+#include <streambuf>
+#include <string>
+
+#include "test_macros.h"
+#include "test_iterators.h"
+
+typedef std::money_get<char, cpp17_input_iterator<const char*> > Fn;
+
+class my_facet : public Fn {
+public:
+  explicit my_facet(std::size_t refs = 0) : Fn(refs) {}
+};
+
+#ifndef TEST_HAS_NO_WIDE_CHARACTERS
+typedef std::money_get<wchar_t, cpp17_input_iterator<const wchar_t*> > Fw;
+
+class my_facetw : public Fw {
+public:
+  explicit my_facetw(std::size_t refs = 0) : Fw(refs) {}
+};
+#endif
+
+int main(int, char**) {
+  struct digit_result_case {
+    std::size_t digit;
+    long double result;
+  };
+  const digit_result_case digit_result_cases[] = {
+      {60, 2.0E60L}, {120, 2.0E120L}, {180, 2.0E180L}, {240, 2.0E240L}, {300, 2.0E300L}};
+
+  std::ios ios(0);
+  {
+    const my_facet f(1);
+    for (std::size_t i = 0; i != sizeof(digit_result_cases) / sizeof(digit_result_cases[0]); ++i) {
+      {
+        std::string v = "2";
+        v.append(digit_result_cases[i].digit, '0');
+
+        typedef cpp17_input_iterator<const char*> I;
+        long double ex;
+        std::ios_base::iostate err = std::ios_base::goodbit;
+        I iter                     = f.get(I(v.data()), I(v.data() + v.size()), false, ios, err, ex);
+        assert(base(iter) == v.data() + v.size());
+        assert(err == std::ios_base::eofbit);
+        assert(ex == digit_result_cases[i].result);
+      }
+      {
+        std::string v = "-2";
+        v.append(digit_result_cases[i].digit, '0');
+
+        typedef cpp17_input_iterator<const char*> I;
+        long double ex;
+        std::ios_base::iostate err = std::ios_base::goodbit;
+        I iter                     = f.get(I(v.data()), I(v.data() + v.size()), false, ios, err, ex);
+        assert(base(iter) == v.data() + v.size());
+        assert(err == std::ios_base::eofbit);
+        assert(ex == -digit_result_cases[i].result);
+      }
+    }
+  }
+#ifndef TEST_HAS_NO_WIDE_CHARACTERS
+  {
+    const my_facetw f(1);
+    for (std::size_t i = 0; i != sizeof(digit_result_cases) / sizeof(digit_result_cases[0]); ++i) {
+      {
+        std::wstring v = L"2";
+        v.append(digit_result_cases[i].digit, L'0');
+
+        typedef cpp17_input_iterator<const wchar_t*> I;
+        long double ex;
+        std::ios_base::iostate err = std::ios_base::goodbit;
+        I iter                     = f.get(I(v.data()), I(v.data() + v.size()), false, ios, err, ex);
+        assert(base(iter) == v.data() + v.size());
+        assert(err == std::ios_base::eofbit);
+        assert(ex == digit_result_cases[i].result);
+      }
+      {
+        std::wstring v = L"-2";
+        v.append(digit_result_cases[i].digit, L'0');
+
+        typedef cpp17_input_iterator<const wchar_t*> I;
+        long double ex;
+        std::ios_base::iostate err = std::ios_base::goodbit;
+        I iter                     = f.get(I(v.data()), I(v.data() + v.size()), false, ios, err, ex);
+        assert(base(iter) == v.data() + v.size());
+        assert(err == std::ios_base::eofbit);
+        assert(ex == -digit_result_cases[i].result);
+      }
+    }
+  }
+#endif
+
+  return 0;
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/126273


More information about the libcxx-commits mailing list