[libcxx-commits] [libcxx] linear_congruential_engine: add using more precision to prevent overflow (PR #81583)

via libcxx-commits libcxx-commits at lists.llvm.org
Tue Feb 13 00:54:06 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-libcxx

Author: None (LRFLEW)

<details>
<summary>Changes</summary>

This PR is a followup to #<!-- -->81080, and as such includes the commit from that PR. This should either be merged after that one is merged (and this one is rebased), or this one can take the place of that PR. Since the other PR has gotten pretty far in terms of reviews, and this PR is more involved, the other PR should probably be merged before this one.

This PR makes two major changes to how the LCG operation is computed:

The first is that I added an additional case where `ax + c` might overflow the intermediate variable, but `ax` by itself won't. In this case, it's much better to use `(ax mod m) + c mod m` than the previous behavior of falling back to Schrage's algorithm. The addition modulo is done in the same way as when using Schrage's algorithm (i.e. `x += c - (x >= m - c)*m`), but the multiplication modulo is calculated directly, which is faster.

The second is that I added handling for the case where the `ax` intermediate might overflow, but Schrage's algorithm doesn't apply (i.e. r > q). In this case, the only real option is to increase the precision of the intermediate values. The good news is that - for `x`, `a`, and `c` being n-bit values - `ax + c` will never overflow a 2n-bit intermediary, meaning this promotion can only happen once, and will always be able to use the simplest implementation. This is already the case for 16-bit LCGs, as libcxx chooses to compute them with 32-bit intermediate values. For 32-bit LCGs, I simply added code similar to the 16-bit case to use the existing 64-bit implementations. Lastly, for 64-bit LCGs, I wrote a case that calculates it using `unsigned __int128` if it is available to use.

While this implementation covers a *lot* of the missing cases from #<!-- -->81080, this still won't compile **every** possible `linear_congruential_engine`. Specifically, if `a`, `c`, and `m` are chosen such that it needs 128-bit integers, but the platform doesn't support `__int128` (eg. 32-bit x86), then it will fail to compile. However, this is a fairly rare case to see actually used, and libcxx would be in good company with this, as [libstdc++ also fails to compile under these circumstances](https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87744). Fixing **this** gap would require even **more** work of further complexity, so that would probably be best handled by a different PR (I'll put more details on what that PR would entail in a comment).

I currently consider this PR to be a WIP because a) I haven't gotten the test cases written properly to avoid failing when __int128 isn't available, and b) I'm not 100% sure about how I've structured / formatted the changes, and may still want to tweak it before merging. I'm making the PR now so I can start getting feedback if anybody has any.

---
Full diff: https://github.com/llvm/llvm-project/pull/81583.diff


6 Files Affected:

- (modified) libcxx/include/__random/linear_congruential_engine.h (+68-18) 
- (modified) libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp (+79-31) 
- (modified) libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp (+34-13) 
- (modified) libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp (+33-13) 
- (modified) libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp (+33-13) 
- (modified) libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp (+33-13) 


``````````diff
diff --git a/libcxx/include/__random/linear_congruential_engine.h b/libcxx/include/__random/linear_congruential_engine.h
index 51f6b248d8f974..e32f0a9d05395a 100644
--- a/libcxx/include/__random/linear_congruential_engine.h
+++ b/libcxx/include/__random/linear_congruential_engine.h
@@ -30,28 +30,45 @@ template <unsigned long long __a,
           unsigned long long __c,
           unsigned long long __m,
           unsigned long long _Mp,
-          bool _MightOverflow = (__a != 0 && __m != 0 && __m - 1 > (_Mp - __c) / __a),
-          bool _OverflowOK    = ((__m | (__m - 1)) > __m),                    // m = 2^n
-          bool _SchrageOK     = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q
+          bool _HasOverflow = (__a != 0ull && (__m & (__m - 1ull)) != 0ull), // a != 0, m != 0, m != 2^n
+          bool _Full = (!_HasOverflow || __m - 1ull <= (_Mp - __c) / __a), // (a * x + c) % m works
+          bool _Part = (!_HasOverflow || __m - 1ull <= _Mp / __a), // (a * x) % m works
+          bool _Schrage = (_HasOverflow && __m % __a <= __m / __a)> // r <= q
 struct __lce_alg_picker {
-  static_assert(__a != 0 || __m != 0 || !_MightOverflow || _OverflowOK || _SchrageOK,
-                "The current values of a, c, and m cannot generate a number "
-                "within bounds of linear_congruential_engine.");
+  static _LIBCPP_CONSTEXPR const int __mode = _Full ? 0 : _Part ? 1 : _Schrage ? 2 : 3;
 
-  static _LIBCPP_CONSTEXPR const bool __use_schrage = _MightOverflow && !_OverflowOK && _SchrageOK;
+#ifndef __SIZEOF_INT128__
+  static_assert(_Mp != (unsigned long long)(~0) || _Full || _Part || _Schrage,
+                "The current values for a, c, and m are not currently supported on platforms without __int128");
+#endif
 };
 
 template <unsigned long long __a,
           unsigned long long __c,
           unsigned long long __m,
           unsigned long long _Mp,
-          bool _UseSchrage = __lce_alg_picker<__a, __c, __m, _Mp>::__use_schrage>
+          int _Mode = __lce_alg_picker<__a, __c, __m, _Mp>::__mode>
 struct __lce_ta;
 
 // 64
 
+#ifdef __SIZEOF_INT128__
+template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
+struct __lce_ta<_Ap, _Cp, _Mp, (unsigned long long)(~0), 3> {
+  typedef unsigned long long result_type;
+  _LIBCPP_HIDE_FROM_ABI static result_type next(result_type _Xp) {
+    __extension__ typedef unsigned __int128 calc_type;
+    const calc_type __a = static_cast<calc_type>(_Ap);
+    const calc_type __c = static_cast<calc_type>(_Cp);
+    const calc_type __m = static_cast<calc_type>(_Mp);
+    const calc_type __x = static_cast<calc_type>(_Xp);
+    return static_cast<result_type>((__a * __x + __c) % __m);
+  }
+};
+#endif
+
 template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
-struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
+struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), 2> {
   typedef unsigned long long result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     // Schrage's algorithm
@@ -66,7 +83,7 @@ struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
 };
 
 template <unsigned long long __a, unsigned long long __m>
-struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
+struct __lce_ta<__a, 0ull, __m, (unsigned long long)(~0), 2> {
   typedef unsigned long long result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     // Schrage's algorithm
@@ -80,21 +97,40 @@ struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
 };
 
 template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
-struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), false> {
+struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), 1> {
+  typedef unsigned long long result_type;
+  _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
+    // Use (((a*x) % m) + c) % m
+    __x = (__a * __x) % __m;
+    __x += __c - (__x >= __m - __c) * __m;
+    return __x;
+  }
+};
+
+template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
+struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), 0> {
   typedef unsigned long long result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return (__a * __x + __c) % __m; }
 };
 
 template <unsigned long long __a, unsigned long long __c>
-struct __lce_ta<__a, __c, 0, (unsigned long long)(~0), false> {
+struct __lce_ta<__a, __c, 0ull, (unsigned long long)(~0), 0> {
   typedef unsigned long long result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return __a * __x + __c; }
 };
 
 // 32
 
+template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
+struct __lce_ta<__a, __c, __m, unsigned(~0), 3> {
+  typedef unsigned result_type;
+  _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
+    return static_cast<result_type>(__lce_ta<__a, __c, __m, (unsigned long long)(~0)>::next(__x));
+  }
+};
+
 template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
-struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
+struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), 2> {
   typedef unsigned result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     const result_type __a = static_cast<result_type>(_Ap);
@@ -112,7 +148,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
 };
 
 template <unsigned long long _Ap, unsigned long long _Mp>
-struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
+struct __lce_ta<_Ap, 0ull, _Mp, unsigned(~0), 2> {
   typedef unsigned result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     const result_type __a = static_cast<result_type>(_Ap);
@@ -128,7 +164,21 @@ struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
 };
 
 template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
-struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
+struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), 1> {
+  typedef unsigned result_type;
+  _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
+    const result_type __a = static_cast<result_type>(_Ap);
+    const result_type __c = static_cast<result_type>(_Cp);
+    const result_type __m = static_cast<result_type>(_Mp);
+    // Use (((a*x) % m) + c) % m
+    __x = (__a * __x) % __m;
+    __x += __c - (__x >= __m - __c) * __m;
+    return __x;
+  }
+};
+
+template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
+struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), 0> {
   typedef unsigned result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     const result_type __a = static_cast<result_type>(_Ap);
@@ -139,7 +189,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
 };
 
 template <unsigned long long _Ap, unsigned long long _Cp>
-struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
+struct __lce_ta<_Ap, _Cp, 0ull, unsigned(~0), 0> {
   typedef unsigned result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     const result_type __a = static_cast<result_type>(_Ap);
@@ -150,8 +200,8 @@ struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
 
 // 16
 
-template <unsigned long long __a, unsigned long long __c, unsigned long long __m, bool __b>
-struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __b> {
+template <unsigned long long __a, unsigned long long __c, unsigned long long __m, int __mode>
+struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __mode> {
   typedef unsigned short result_type;
   _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
     return static_cast<result_type>(__lce_ta<__a, __c, __m, unsigned(~0)>::next(__x));
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
index 77b7c570f85a1d..fff93a895f8955 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
@@ -22,48 +22,96 @@ int main(int, char**)
 {
     typedef unsigned long long T;
 
-    // m might overflow, but the overflow is OK so it shouldn't use schrage's algorithm
-    typedef std::linear_congruential_engine<T, 25214903917ull, 1, (1ull<<48)> E1;
+    // m might overflow, but the overflow is OK so it shouldn't use Schrage's algorithm
+    typedef std::linear_congruential_engine<T, 25214903917ull, 1, (1ull << 48)> E1;
     E1 e1;
     // make sure the right algorithm was used
-    assert(e1() == 25214903918);
-    assert(e1() == 205774354444503);
-    assert(e1() == 158051849450892);
+    assert(e1() == 25214903918ull);
+    assert(e1() == 205774354444503ull);
+    assert(e1() == 158051849450892ull);
     // make sure result is in bounds
-    assert(e1() < (1ull<<48));
-    assert(e1() < (1ull<<48));
-    assert(e1() < (1ull<<48));
-    assert(e1() < (1ull<<48));
-    assert(e1() < (1ull<<48));
+    assert(e1() < (1ull << 48));
+    assert(e1() < (1ull << 48));
+    assert(e1() < (1ull << 48));
+    assert(e1() < (1ull << 48));
+    assert(e1() < (1ull << 48));
 
     // m might overflow. The overflow is not OK and result will be in bounds
-    // so we should use shrage's algorithm
-    typedef std::linear_congruential_engine<T, (1ull<<2), 0, (1ull<<63) + 1> E2;
+    // so we should use Schrage's algorithm
+    typedef std::linear_congruential_engine<T, 0x100000000ull, 0, (1ull << 63) + 1ull> E2;
     E2 e2;
-    // make sure shrage's algorithm is used (it would be 0s otherwise)
-    assert(e2() == 4);
-    assert(e2() == 16);
-    assert(e2() == 64);
+    // make sure Schrage's algorithm is used (it would be 0s after the first otherwise)
+    assert(e2() == 0x100000000ull);
+    assert(e2() == (1ull << 63) - 1ull);
+    assert(e2() == (1ull << 63) - 0x1ffffffffull);
     // make sure result is in bounds
-    assert(e2() < (1ull<<48) + 1);
-    assert(e2() < (1ull<<48) + 1);
-    assert(e2() < (1ull<<48) + 1);
-    assert(e2() < (1ull<<48) + 1);
-    assert(e2() < (1ull<<48) + 1);
+    assert(e2() < (1ull << 63) + 1);
+    assert(e2() < (1ull << 63) + 1);
+    assert(e2() < (1ull << 63) + 1);
+    assert(e2() < (1ull << 63) + 1);
+    assert(e2() < (1ull << 63) + 1);
 
-    // m will not overflow so we should not use shrage's algorithm
-    typedef std::linear_congruential_engine<T, 1ull, 1, (1ull<<48)> E3;
+    // m might overflow. The overflow is not OK and result will be in bounds
+    // so we should use Schrage's algorithm. m is even
+    typedef std::linear_congruential_engine<T, 0x18000001ull, 0x12347ull, (3ull << 56)> E3;
     E3 e3;
+    // make sure Schrage's algorithm is used
+    assert(e3() == 0x18012348ull);
+    assert(e3() == 0x2401b4ed802468full);
+    assert(e3() == 0x18051ec400369d6ull);
+    // make sure result is in bounds
+    assert(e3() < (3ull << 56));
+    assert(e3() < (3ull << 56));
+    assert(e3() < (3ull << 56));
+    assert(e3() < (3ull << 56));
+    assert(e3() < (3ull << 56));
+
+    // 32-bit case:
+    // m might overflow. The overflow is not OK, result will be in bounds,
+    // and Schrage's algorithm is incompatible here. Need to use 64 bit arithmetic.
+    typedef std::linear_congruential_engine<unsigned, 0x10009u, 0u, 0x7fffffffu> E4;
+    E4 e4;
+    // make sure enough precision is used
+    assert(e4() == 0x10009u);
+    assert(e4() == 0x120053u);
+    assert(e4() == 0xf5030fu);
+    // make sure result is in bounds
+    assert(e4() < 0x7fffffffu);
+    assert(e4() < 0x7fffffffu);
+    assert(e4() < 0x7fffffffu);
+    assert(e4() < 0x7fffffffu);
+    assert(e4() < 0x7fffffffu);
+
+#ifdef __SIZEOF_INT128__
+    // m might overflow. The overflow is not OK, result will be in bounds,
+    // and Schrage's algorithm is incompatible here. Need to use 128 bit arithmetic.
+    typedef std::linear_congruential_engine<T, 0x100000001ull, 0ull, (1ull << 61) - 1ull> E5;
+    E5 e5;
+    // make sure enough precision is used
+    assert(e5() == 0x100000001ull);
+    assert(e5() == 0x200000009ull);
+    assert(e5() == 0xb00000019ull);
+    // make sure result is in bounds
+    assert(e5() < (1ull << 61) - 1ull);
+    assert(e5() < (1ull << 61) - 1ull);
+    assert(e5() < (1ull << 61) - 1ull);
+    assert(e5() < (1ull << 61) - 1ull);
+    assert(e5() < (1ull << 61) - 1ull);
+#endif
+
+    // m will not overflow so we should not use Schrage's algorithm
+    typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E6;
+    E6 e6;
     // make sure the correct algorithm was used
-    assert(e3() == 2);
-    assert(e3() == 3);
-    assert(e3() == 4);
+    assert(e6() == 2ull);
+    assert(e6() == 3ull);
+    assert(e6() == 4ull);
     // make sure result is in bounds
-    assert(e3() < (1ull<<48));
-    assert(e3() < (1ull<<48));
-    assert(e3() < (1ull<<48));
-    assert(e3() < (1ull<<48));
-    assert(e2() < (1ull<<48));
+    assert(e6() < (1ull << 48));
+    assert(e6() < (1ull << 48));
+    assert(e6() < (1ull << 48));
+    assert(e6() < (1ull << 48));
+    assert(e6() < (1ull << 48));
 
     return 0;
 }
\ No newline at end of file
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
index 12620848626fc8..8f5a861cbff563 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
@@ -15,6 +15,7 @@
 
 #include <random>
 #include <cassert>
+#include <climits>
 
 #include "test_macros.h"
 
@@ -35,19 +36,39 @@ template <class T>
 void
 test()
 {
-    test1<T, 0, 0, 0>();
-    test1<T, 0, 1, 2>();
-    test1<T, 1, 1, 2>();
-    const T M(static_cast<T>(-1));
-    test1<T, 0, 0, M>();
-    test1<T, 0, M-2, M>();
-    test1<T, 0, M-1, M>();
-    test1<T, M-2, 0, M>();
-    test1<T, M-2, M-2, M>();
-    test1<T, M-2, M-1, M>();
-    test1<T, M-1, 0, M>();
-    test1<T, M-1, M-2, M>();
-    test1<T, M-1, M-1, M>();
+  const int W = sizeof(T) * CHAR_BIT;
+  const T M(static_cast<T>(-1));
+  const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+  // Cases where m = 0
+  test1<T, 0, 0, 0>();
+  test1<T, A, 0, 0>();
+  test1<T, 0, 1, 0>();
+  test1<T, A, 1, 0>();
+
+  // Cases where m = 2^n for n < w
+  test1<T, 0, 0, 256>();
+  test1<T, 5, 0, 256>();
+  test1<T, 0, 1, 256>();
+  test1<T, 5, 1, 256>();
+
+  // Cases where m is odd and a = 0
+  test1<T, 0, 0, M>();
+  test1<T, 0, M - 2, M>();
+  test1<T, 0, M - 1, M>();
+
+  // Cases where m is odd and m % a <= m / a (Schrage)
+  test1<T, A, 0, M>();
+  test1<T, A, M - 2, M>();
+  test1<T, A, M - 1, M>();
+
+  // Cases where m is odd and m % a > m / a (not implemented)
+  test1<T, M - 2, 0, M>();
+  test1<T, M - 2, M - 2, M>();
+  test1<T, M - 2, M - 1, M>();
+  test1<T, M - 1, 0, M>();
+  test1<T, M - 1, M - 2, M>();
+  test1<T, M - 1, M - 1, M>();
 }
 
 int main(int, char**)
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
index 5dac0772cb0e94..654352cd13fa8e 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
@@ -35,19 +35,39 @@ template <class T>
 void
 test()
 {
-    test1<T, 0, 0, 0>();
-    test1<T, 0, 1, 2>();
-    test1<T, 1, 1, 2>();
-    const T M(static_cast<T>(-1));
-    test1<T, 0, 0, M>();
-    test1<T, 0, M-2, M>();
-    test1<T, 0, M-1, M>();
-    test1<T, M-2, 0, M>();
-    test1<T, M-2, M-2, M>();
-    test1<T, M-2, M-1, M>();
-    test1<T, M-1, 0, M>();
-    test1<T, M-1, M-2, M>();
-    test1<T, M-1, M-1, M>();
+  const int W = sizeof(T) * CHAR_BIT;
+  const T M(static_cast<T>(-1));
+  const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+  // Cases where m = 0
+  test1<T, 0, 0, 0>();
+  test1<T, A, 0, 0>();
+  test1<T, 0, 1, 0>();
+  test1<T, A, 1, 0>();
+
+  // Cases where m = 2^n for n < w
+  test1<T, 0, 0, 256>();
+  test1<T, 5, 0, 256>();
+  test1<T, 0, 1, 256>();
+  test1<T, 5, 1, 256>();
+
+  // Cases where m is odd and a = 0
+  test1<T, 0, 0, M>();
+  test1<T, 0, M - 2, M>();
+  test1<T, 0, M - 1, M>();
+
+  // Cases where m is odd and m % a <= m / a (Schrage)
+  test1<T, A, 0, M>();
+  test1<T, A, M - 2, M>();
+  test1<T, A, M - 1, M>();
+
+  // Cases where m is odd and m % a > m / a (not implemented)
+  test1<T, M - 2, 0, M>();
+  test1<T, M - 2, M - 2, M>();
+  test1<T, M - 2, M - 1, M>();
+  test1<T, M - 1, 0, M>();
+  test1<T, M - 1, M - 2, M>();
+  test1<T, M - 1, M - 1, M>();
 }
 
 int main(int, char**)
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
index 10bc1d71d8e892..caee6b89571d79 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
@@ -33,19 +33,39 @@ template <class T>
 void
 test()
 {
-    test1<T, 0, 0, 0>();
-    test1<T, 0, 1, 2>();
-    test1<T, 1, 1, 2>();
-    const T M(static_cast<T>(-1));
-    test1<T, 0, 0, M>();
-    test1<T, 0, M-2, M>();
-    test1<T, 0, M-1, M>();
-    test1<T, M-2, 0, M>();
-    test1<T, M-2, M-2, M>();
-    test1<T, M-2, M-1, M>();
-    test1<T, M-1, 0, M>();
-    test1<T, M-1, M-2, M>();
-    test1<T, M-1, M-1, M>();
+  const int W = sizeof(T) * CHAR_BIT;
+  const T M(static_cast<T>(-1));
+  const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+  // Cases where m = 0
+  test1<T, 0, 0, 0>();
+  test1<T, A, 0, 0>();
+  test1<T, 0, 1, 0>();
+  test1<T, A, 1, 0>();
+
+  // Cases where m = 2^n for n < w
+  test1<T, 0, 0, 256>();
+  test1<T, 5, 0, 256>();
+  test1<T, 0, 1, 256>();
+  test1<T, 5, 1, 256>();
+
+  // Cases where m is odd and a = 0
+  test1<T, 0, 0, M>();
+  test1<T, 0, M - 2, M>();
+  test1<T, 0, M - 1, M>();
+
+  // Cases where m is odd and m % a <= m / a (Schrage)
+  test1<T, A, 0, M>();
+  test1<T, A, M - 2, M>();
+  test1<T, A, M - 1, M>();
+
+  // Cases where m is odd and m % a > m / a (not implemented)
+  test1<T, M - 2, 0, M>();
+  test1<T, M - 2, M - 2, M>();
+  test1<T, M - 2, M - 1, M>();
+  test1<T, M - 1, 0, M>();
+  test1<T, M - 1, M - 2, M>();
+  test1<T, M - 1, M - 1, M>();
 }
 
 int main(int, char**)
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
index d9d47c5d8db46c..1af116e529156f 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
@@ -66,19 +66,39 @@ template <class T>
 void
 test()
 {
-    test1<T, 0, 0, 0>();
-    test1<T, 0, 1, 2>();
-    test1<T, 1, 1, 2>();
-    const T M(static_cast<T>(-1));
-    test1<T, 0, 0, M>();
-    test1<T, 0, M-2, M>();
-    test1<T, 0, M-1, M>();
-    test1<T, M-2, 0, M>();
-    test1<T, M-2, M-2, M>();
-    test1<T, M-2, M-1, M>();
-    test1<T, M-1, 0, M>();
-    test1<T, M-1, M-2, M>();
-    test1<T, M-1, M-1, M>();
+  const int W = sizeof(T) * CHAR_BIT;
+  const T M(static_cast<T>(-1));
+  const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+  // Cases where m = 0
+  test1<T, 0, 0, 0>();
+  test1<T, A, 0, 0>();
+  test1<T, 0, 1, 0>();
+  test1<T, A, 1, 0>();
+
+  // Cases where m = 2^n for n < w
+  test1<T, 0, 0, 256>();
+  test1<T, 5, 0, 256>();
+  test1<T, 0, 1, 256>();
+  test1<T, 5, 1, 256>();
+
+  // Cases where m is odd and a = 0
+  test1<T, 0, 0, M>();
+  test1<T, 0, M - 2, M>();
+  test1<T, 0, M - 1, M>();
+
+  // Cases where m is odd and m % a <= m / a (Schrage)
+  test1<T, A, 0, M>();
+  test1<T, A, M - 2, M>();
+  test1<T, A, M - 1, M>();
+
+  // Cases where m is odd and m % a > m / a (not implemented)
+  test1<T, M - 2, 0, M>();
+  test1<T, M - 2, M - 2, M>();
+  test1<T, M - 2, M - 1, M>();
+  test1<T, M - 1, 0, M>();
+  test1<T, M - 1, M - 2, M>();
+  test1<T, M - 1, M - 1, M>();
 }
 
 int main(int, char**)

``````````

</details>


https://github.com/llvm/llvm-project/pull/81583


More information about the libcxx-commits mailing list