[libcxx-commits] [libcxx] linear_congruential_engine: add using more precision to prevent overflow (PR #81583)

via libcxx-commits libcxx-commits at lists.llvm.org
Sat Apr 13 17:20:16 PDT 2024


https://github.com/LRFLEW updated https://github.com/llvm/llvm-project/pull/81583

>From 7dfdfb28f881673adfe5d14ca58d3f97fd222e84 Mon Sep 17 00:00:00 2001
From: LRFLEW <LRFLEW at aol.com>
Date: Tue, 13 Feb 2024 01:19:04 -0600
Subject: [PATCH] linear_congruential_engine: add using more precision to
 prevent overflow

---
 .../__random/linear_congruential_engine.h     | 93 +++++++++++++++----
 .../rand/rand.eng/rand.eng.lcong/alg.pass.cpp | 65 +++++++++----
 .../rand.eng/rand.eng.lcong/assign.pass.cpp   | 18 +++-
 .../rand.eng/rand.eng.lcong/copy.pass.cpp     | 18 +++-
 .../rand.eng/rand.eng.lcong/default.pass.cpp  | 18 +++-
 .../rand.eng/rand.eng.lcong/values.pass.cpp   | 18 +++-
 6 files changed, 180 insertions(+), 50 deletions(-)

diff --git a/libcxx/include/__random/linear_congruential_engine.h b/libcxx/include/__random/linear_congruential_engine.h
index fe9cb909b74d21..d1d811060c10d8 100644
--- a/libcxx/include/__random/linear_congruential_engine.h
+++ b/libcxx/include/__random/linear_congruential_engine.h
@@ -26,32 +26,56 @@ _LIBCPP_PUSH_MACROS
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
+enum __lce_alg_type {
+  _LCE_Full,
+  _LCE_Part,
+  _LCE_Schrage,
+  _LCE_Promote,
+};
+
 template <unsigned long long __a,
           unsigned long long __c,
           unsigned long long __m,
           unsigned long long _Mp,
-          bool _MightOverflow = (__a != 0 && __m != 0 && __m - 1 > (_Mp - __c) / __a),
-          bool _OverflowOK    = ((__m & (__m - 1)) == 0ull),                  // m = 2^n
-          bool _SchrageOK     = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q
+          bool _HasOverflow = (__a != 0ull && (__m & (__m - 1ull)) != 0ull),      // a != 0, m != 0, m != 2^n
+          bool _Full        = (!_HasOverflow || __m - 1ull <= (_Mp - __c) / __a), // (a * x + c) % m works
+          bool _Part        = (!_HasOverflow || __m - 1ull <= _Mp / __a),         // (a * x) % m works
+          bool _Schrage     = (_HasOverflow && __m % __a <= __m / __a)>               // r <= q
 struct __lce_alg_picker {
-  static_assert(!_MightOverflow || _OverflowOK || _SchrageOK,
-                "The current values of a, c, and m cannot generate a number "
-                "within bounds of linear_congruential_engine.");
+  static _LIBCPP_CONSTEXPR const __lce_alg_type __mode = _Full ? _LCE_Full : _Part ? _LCE_Part : _Schrage ? _LCE_Schrage : _LCE_Promote;
 
-  static _LIBCPP_CONSTEXPR const bool __use_schrage = _MightOverflow && !_OverflowOK && _SchrageOK;
+#ifdef _LIBCPP_HAS_NO_INT128
+  static_assert(_Mp != (unsigned long long)(~0) || _Full || _Part || _Schrage,
+                "The current values for a, c, and m are not currently supported on platforms without __int128");
+#endif
 };
 
 template <unsigned long long __a,
           unsigned long long __c,
           unsigned long long __m,
           unsigned long long _Mp,
-          bool _UseSchrage = __lce_alg_picker<__a, __c, __m, _Mp>::__use_schrage>
+          __lce_alg_type _Mode = __lce_alg_picker<__a, __c, __m, _Mp>::__mode>
 struct __lce_ta;
 
 // 64
 
+#ifndef _LIBCPP_HAS_NO_INT128
+template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
+struct __lce_ta<_Ap, _Cp, _Mp, (unsigned long long)(~0), _LCE_Promote> {
+  typedef unsigned long long result_type;
+  _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __xp) {
+    __extension__ using __calc_type = unsigned __int128;
+    const __calc_type __a = static_cast<__calc_type>(_Ap);
+    const __calc_type __c = static_cast<__calc_type>(_Cp);
+    const __calc_type __m = static_cast<__calc_type>(_Mp);
+    const __calc_type __x = static_cast<__calc_type>(__xp);
+    return static_cast<result_type>((__a * __x + __c) % __m);
+  }
+};
+#endif
+
 template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
-struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
+struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), _LCE_Schrage> {
   typedef unsigned long long result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     // Schrage's algorithm
@@ -66,7 +90,7 @@ struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
 };
 
 template <unsigned long long __a, unsigned long long __m>
-struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
+struct __lce_ta<__a, 0ull, __m, (unsigned long long)(~0), _LCE_Schrage> {
   typedef unsigned long long result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     // Schrage's algorithm
@@ -80,21 +104,40 @@ struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
 };
 
 template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
-struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), false> {
+struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), _LCE_Part> {
+  typedef unsigned long long result_type;
+  _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
+    // Use (((a*x) % m) + c) % m
+    __x = (__a * __x) % __m;
+    __x += __c - (__x >= __m - __c) * __m;
+    return __x;
+  }
+};
+
+template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
+struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), _LCE_Full> {
   typedef unsigned long long result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return (__a * __x + __c) % __m; }
 };
 
 template <unsigned long long __a, unsigned long long __c>
-struct __lce_ta<__a, __c, 0, (unsigned long long)(~0), false> {
+struct __lce_ta<__a, __c, 0ull, (unsigned long long)(~0), _LCE_Full> {
   typedef unsigned long long result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return __a * __x + __c; }
 };
 
 // 32
 
+template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
+struct __lce_ta<__a, __c, __m, unsigned(~0), _LCE_Promote> {
+  typedef unsigned result_type;
+  _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
+    return static_cast<result_type>(__lce_ta<__a, __c, __m, (unsigned long long)(~0)>::next(__x));
+  }
+};
+
 template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
-struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
+struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), _LCE_Schrage> {
   typedef unsigned result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     const result_type __a = static_cast<result_type>(_Ap);
@@ -112,7 +155,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
 };
 
 template <unsigned long long _Ap, unsigned long long _Mp>
-struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
+struct __lce_ta<_Ap, 0ull, _Mp, unsigned(~0), _LCE_Schrage> {
   typedef unsigned result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     const result_type __a = static_cast<result_type>(_Ap);
@@ -128,7 +171,21 @@ struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
 };
 
 template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
-struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
+struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), _LCE_Part> {
+  typedef unsigned result_type;
+  _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
+    const result_type __a = static_cast<result_type>(_Ap);
+    const result_type __c = static_cast<result_type>(_Cp);
+    const result_type __m = static_cast<result_type>(_Mp);
+    // Use (((a*x) % m) + c) % m
+    __x = (__a * __x) % __m;
+    __x += __c - (__x >= __m - __c) * __m;
+    return __x;
+  }
+};
+
+template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
+struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), _LCE_Full> {
   typedef unsigned result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     const result_type __a = static_cast<result_type>(_Ap);
@@ -139,7 +196,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
 };
 
 template <unsigned long long _Ap, unsigned long long _Cp>
-struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
+struct __lce_ta<_Ap, _Cp, 0ull, unsigned(~0), _LCE_Full> {
   typedef unsigned result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     const result_type __a = static_cast<result_type>(_Ap);
@@ -150,8 +207,8 @@ struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
 
 // 16
 
-template <unsigned long long __a, unsigned long long __c, unsigned long long __m, bool __b>
-struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __b> {
+template <unsigned long long __a, unsigned long long __c, unsigned long long __m, int __mode>
+struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __mode> {
   typedef unsigned short result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     return static_cast<result_type>(__lce_ta<__a, __c, __m, unsigned(~0)>::next(__x));
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
index 8a9cae0e610c35..159cb19f65468b 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
@@ -38,12 +38,12 @@ int main(int, char**)
 
     // m might overflow. The overflow is not OK and result will be in bounds
     // so we should use Schrage's algorithm
-    typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1> E2;
+    typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1ull> E2;
     E2 e2;
     // make sure Schrage's algorithm is used (it would be 0s after the first otherwise)
     assert(e2() == (1ull << 32));
     assert(e2() == (1ull << 63) - 1ull);
-    assert(e2() == (1ull << 63) - (1ull << 33) + 1ull);
+    assert(e2() == (1ull << 63) - 0x1ffffffffull);
     // make sure result is in bounds
     assert(e2() < (1ull << 63) + 1);
     assert(e2() < (1ull << 63) + 1);
@@ -56,9 +56,9 @@ int main(int, char**)
     typedef std::linear_congruential_engine<T, 0x18000001ull, 0x12347ull, (3ull << 56)> E3;
     E3 e3;
     // make sure Schrage's algorithm is used
-    assert(e3() == 402727752ull);
-    assert(e3() == 162159612030764687ull);
-    assert(e3() == 108176466184989142ull);
+    assert(e3() == 0x18012348ull);
+    assert(e3() == 0x2401b4ed802468full);
+    assert(e3() == 0x18051ec400369d6ull);
     // make sure result is in bounds
     assert(e3() < (3ull << 56));
     assert(e3() < (3ull << 56));
@@ -66,19 +66,52 @@ int main(int, char**)
     assert(e3() < (3ull << 56));
     assert(e3() < (3ull << 56));
 
-    // m will not overflow so we should not use Schrage's algorithm
-    typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E4;
+    // 32-bit case:
+    // m might overflow. The overflow is not OK, result will be in bounds,
+    // and Schrage's algorithm is incompatible here. Need to use 64 bit arithmetic.
+    typedef std::linear_congruential_engine<unsigned, 0x10009u, 0u, 0x7fffffffu> E4;
     E4 e4;
+    // make sure enough precision is used
+    assert(e4() == 0x10009u);
+    assert(e4() == 0x120053u);
+    assert(e4() == 0xf5030fu);
+    // make sure result is in bounds
+    assert(e4() < 0x7fffffffu);
+    assert(e4() < 0x7fffffffu);
+    assert(e4() < 0x7fffffffu);
+    assert(e4() < 0x7fffffffu);
+    assert(e4() < 0x7fffffffu);
+
+#ifndef _LIBCPP_HAS_NO_INT128
+    // m might overflow. The overflow is not OK, result will be in bounds,
+    // and Schrage's algorithm is incompatible here. Need to use 128 bit arithmetic.
+    typedef std::linear_congruential_engine<T, 0x100000001ull, 0ull, (1ull << 61) - 1ull> E5;
+    E5 e5;
+    // make sure enough precision is used
+    assert(e5() == 0x100000001ull);
+    assert(e5() == 0x200000009ull);
+    assert(e5() == 0xb00000019ull);
+    // make sure result is in bounds
+    assert(e5() < (1ull << 61) - 1ull);
+    assert(e5() < (1ull << 61) - 1ull);
+    assert(e5() < (1ull << 61) - 1ull);
+    assert(e5() < (1ull << 61) - 1ull);
+    assert(e5() < (1ull << 61) - 1ull);
+#endif
+
+    // m will not overflow so we should not use Schrage's algorithm
+    typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E6;
+    E6 e6;
     // make sure the correct algorithm was used
-    assert(e4() == 2ull);
-    assert(e4() == 3ull);
-    assert(e4() == 4ull);
+    assert(e6() == 2ull);
+    assert(e6() == 3ull);
+    assert(e6() == 4ull);
     // make sure result is in bounds
-    assert(e4() < (1ull << 48));
-    assert(e4() < (1ull << 48));
-    assert(e4() < (1ull << 48));
-    assert(e4() < (1ull << 48));
-    assert(e4() < (1ull << 48));
+    assert(e6() < (1ull << 48));
+    assert(e6() < (1ull << 48));
+    assert(e6() < (1ull << 48));
+    assert(e6() < (1ull << 48));
+    assert(e6() < (1ull << 48));
 
     return 0;
-}
\ No newline at end of file
+}
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
index 5317f171a98a79..73829071bd9580 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
@@ -61,24 +61,34 @@ test()
   test1<T, A, 0, M>();
   test1<T, A, M - 2, M>();
   test1<T, A, M - 1, M>();
+}
+
+template <class T>
+void test_ext() {
+  const T M(static_cast<T>(-1));
 
-  /*
-  // Cases where m is odd and m % a > m / a (not implemented)
+  // Cases where m is odd and m % a > m / a
   test1<T, M - 2, 0, M>();
   test1<T, M - 2, M - 2, M>();
   test1<T, M - 2, M - 1, M>();
   test1<T, M - 1, 0, M>();
   test1<T, M - 1, M - 2, M>();
   test1<T, M - 1, M - 1, M>();
-  */
 }
 
 int main(int, char**)
 {
     test<unsigned short>();
+    test_ext<unsigned short>();
     test<unsigned int>();
+    test_ext<unsigned int>();
     test<unsigned long>();
+    test_ext<unsigned long>();
     test<unsigned long long>();
+    // This isn't implemented on platforms without __int128
+#ifndef _LIBCPP_HAS_NO_INT128
+    test_ext<unsigned long long>();
+#endif
 
-  return 0;
+    return 0;
 }
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
index 8e950043d594f9..8387a1763714f0 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
@@ -60,24 +60,34 @@ test()
   test1<T, A, 0, M>();
   test1<T, A, M - 2, M>();
   test1<T, A, M - 1, M>();
+}
+
+template <class T>
+void test_ext() {
+  const T M(static_cast<T>(-1));
 
-  /*
-  // Cases where m is odd and m % a > m / a (not implemented)
+  // Cases where m is odd and m % a > m / a
   test1<T, M - 2, 0, M>();
   test1<T, M - 2, M - 2, M>();
   test1<T, M - 2, M - 1, M>();
   test1<T, M - 1, 0, M>();
   test1<T, M - 1, M - 2, M>();
   test1<T, M - 1, M - 1, M>();
-  */
 }
 
 int main(int, char**)
 {
     test<unsigned short>();
+    test_ext<unsigned short>();
     test<unsigned int>();
+    test_ext<unsigned int>();
     test<unsigned long>();
+    test_ext<unsigned long>();
     test<unsigned long long>();
+    // This isn't implemented on platforms without __int128
+#ifndef _LIBCPP_HAS_NO_INT128
+    test_ext<unsigned long long>();
+#endif
 
-  return 0;
+    return 0;
 }
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
index 52126f7a200dbe..c59afd7a3eb273 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
@@ -58,24 +58,34 @@ test()
   test1<T, A, 0, M>();
   test1<T, A, M - 2, M>();
   test1<T, A, M - 1, M>();
+}
+
+template <class T>
+void test_ext() {
+  const T M(static_cast<T>(-1));
 
-  /*
-  // Cases where m is odd and m % a > m / a (not implemented)
+  // Cases where m is odd and m % a > m / a
   test1<T, M - 2, 0, M>();
   test1<T, M - 2, M - 2, M>();
   test1<T, M - 2, M - 1, M>();
   test1<T, M - 1, 0, M>();
   test1<T, M - 1, M - 2, M>();
   test1<T, M - 1, M - 1, M>();
-  */
 }
 
 int main(int, char**)
 {
     test<unsigned short>();
+    test_ext<unsigned short>();
     test<unsigned int>();
+    test_ext<unsigned int>();
     test<unsigned long>();
+    test_ext<unsigned long>();
     test<unsigned long long>();
+    // This isn't implemented on platforms without __int128
+#ifndef _LIBCPP_HAS_NO_INT128
+    test_ext<unsigned long long>();
+#endif
 
-  return 0;
+    return 0;
 }
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
index 28d8dfea01fab3..98b07e70f247af 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
@@ -91,24 +91,34 @@ test()
   test1<T, A, 0, M>();
   test1<T, A, M - 2, M>();
   test1<T, A, M - 1, M>();
+}
 
-  /*
-  // Cases where m is odd and m % a > m / a (not implemented)
+template <class T>
+void test_ext() {
+  const T M(static_cast<T>(-1));
+
+  // Cases where m is odd and m % a > m / a
   test1<T, M - 2, 0, M>();
   test1<T, M - 2, M - 2, M>();
   test1<T, M - 2, M - 1, M>();
   test1<T, M - 1, 0, M>();
   test1<T, M - 1, M - 2, M>();
   test1<T, M - 1, M - 1, M>();
-  */
 }
 
 int main(int, char**)
 {
     test<unsigned short>();
+    test_ext<unsigned short>();
     test<unsigned int>();
+    test_ext<unsigned int>();
     test<unsigned long>();
+    test_ext<unsigned long>();
     test<unsigned long long>();
+    // This isn't implemented on platforms without __int128
+#ifndef _LIBCPP_HAS_NO_INT128
+    test_ext<unsigned long long>();
+#endif
 
-  return 0;
+    return 0;
 }



More information about the libcxx-commits mailing list