[libcxx-commits] [libcxx] linear_congruential_engine: Fixes for __lce_alg_picker (PR #81080)

via libcxx-commits libcxx-commits at lists.llvm.org
Thu Feb 8 02:01:02 PST 2024


https://github.com/LRFLEW updated https://github.com/llvm/llvm-project/pull/81080

>From 7edc51a15575c647289aa9938fb95c989e7d681c Mon Sep 17 00:00:00 2001
From: LRFLEW <LRFLEW at aol.com>
Date: Wed, 7 Feb 2024 15:04:27 -0600
Subject: [PATCH] Fixes for __lce_alg_picker

Fixes _OverflowOK and the static_assert to actually check for what they're intended to check.
---
 .../__random/linear_congruential_engine.h     |  4 +-
 .../rand/rand.eng/rand.eng.lcong/alg.pass.cpp | 59 ++++++++++++-------
 .../rand.eng/rand.eng.lcong/assign.pass.cpp   | 49 +++++++++++----
 .../rand.eng/rand.eng.lcong/copy.pass.cpp     | 48 +++++++++++----
 .../rand.eng/rand.eng.lcong/default.pass.cpp  | 48 +++++++++++----
 .../rand.eng/rand.eng.lcong/values.pass.cpp   | 48 +++++++++++----
 6 files changed, 180 insertions(+), 76 deletions(-)

diff --git a/libcxx/include/__random/linear_congruential_engine.h b/libcxx/include/__random/linear_congruential_engine.h
index 51f6b248d8f974..fe9cb909b74d21 100644
--- a/libcxx/include/__random/linear_congruential_engine.h
+++ b/libcxx/include/__random/linear_congruential_engine.h
@@ -31,10 +31,10 @@ template <unsigned long long __a,
           unsigned long long __m,
           unsigned long long _Mp,
           bool _MightOverflow = (__a != 0 && __m != 0 && __m - 1 > (_Mp - __c) / __a),
-          bool _OverflowOK    = ((__m | (__m - 1)) > __m),                    // m = 2^n
+          bool _OverflowOK    = ((__m & (__m - 1)) == 0ull),                  // m = 2^n
           bool _SchrageOK     = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q
 struct __lce_alg_picker {
-  static_assert(__a != 0 || __m != 0 || !_MightOverflow || _OverflowOK || _SchrageOK,
+  static_assert(!_MightOverflow || _OverflowOK || _SchrageOK,
                 "The current values of a, c, and m cannot generate a number "
                 "within bounds of linear_congruential_engine.");
 
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
index 77b7c570f85a1d..c2b61988e34cdb 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
@@ -23,47 +23,62 @@ int main(int, char**)
     typedef unsigned long long T;
 
     // m might overflow, but the overflow is OK so it shouldn't use schrage's algorithm
-    typedef std::linear_congruential_engine<T, 25214903917ull, 1, (1ull<<48)> E1;
+    typedef std::linear_congruential_engine<T, 25214903917ull, 1, (1ull << 48)> E1;
     E1 e1;
     // make sure the right algorithm was used
     assert(e1() == 25214903918);
     assert(e1() == 205774354444503);
     assert(e1() == 158051849450892);
     // make sure result is in bounds
-    assert(e1() < (1ull<<48));
-    assert(e1() < (1ull<<48));
-    assert(e1() < (1ull<<48));
-    assert(e1() < (1ull<<48));
-    assert(e1() < (1ull<<48));
+    assert(e1() < (1ull << 48));
+    assert(e1() < (1ull << 48));
+    assert(e1() < (1ull << 48));
+    assert(e1() < (1ull << 48));
+    assert(e1() < (1ull << 48));
 
     // m might overflow. The overflow is not OK and result will be in bounds
     // so we should use shrage's algorithm
-    typedef std::linear_congruential_engine<T, (1ull<<2), 0, (1ull<<63) + 1> E2;
+    typedef std::linear_congruential_engine<T, (1ull << 2), 0, (1ull << 63) + 1> E2;
     E2 e2;
     // make sure shrage's algorithm is used (it would be 0s otherwise)
     assert(e2() == 4);
     assert(e2() == 16);
     assert(e2() == 64);
     // make sure result is in bounds
-    assert(e2() < (1ull<<48) + 1);
-    assert(e2() < (1ull<<48) + 1);
-    assert(e2() < (1ull<<48) + 1);
-    assert(e2() < (1ull<<48) + 1);
-    assert(e2() < (1ull<<48) + 1);
+    assert(e2() < (1ull << 48) + 1);
+    assert(e2() < (1ull << 48) + 1);
+    assert(e2() < (1ull << 48) + 1);
+    assert(e2() < (1ull << 48) + 1);
+    assert(e2() < (1ull << 48) + 1);
 
-    // m will not overflow so we should not use shrage's algorithm
-    typedef std::linear_congruential_engine<T, 1ull, 1, (1ull<<48)> E3;
+    // m might overflow. The overflow is not OK and result will be in bounds
+    // so we should use shrage's algorithm. m is even
+    typedef std::linear_congruential_engine<T, 0x18000001ull, 0x12347ull, (3ull << 56)> E3;
     E3 e3;
+    // make sure shrage's algorithm is used (it would be 0s otherwise)
+    assert(e3() == 402727752);
+    assert(e3() == 162159612030764687);
+    assert(e3() == 108176466184989142);
+    // make sure result is in bounds
+    assert(e3() < (3ull << 56));
+    assert(e3() < (3ull << 56));
+    assert(e3() < (3ull << 56));
+    assert(e3() < (3ull << 56));
+    assert(e3() < (3ull << 56));
+
+    // m will not overflow so we should not use shrage's algorithm
+    typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E4;
+    E4 e4;
     // make sure the correct algorithm was used
-    assert(e3() == 2);
-    assert(e3() == 3);
-    assert(e3() == 4);
+    assert(e4() == 2);
+    assert(e4() == 3);
+    assert(e4() == 4);
     // make sure result is in bounds
-    assert(e3() < (1ull<<48));
-    assert(e3() < (1ull<<48));
-    assert(e3() < (1ull<<48));
-    assert(e3() < (1ull<<48));
-    assert(e2() < (1ull<<48));
+    assert(e4() < (1ull << 48));
+    assert(e4() < (1ull << 48));
+    assert(e4() < (1ull << 48));
+    assert(e4() < (1ull << 48));
+    assert(e4() < (1ull << 48));
 
     return 0;
 }
\ No newline at end of file
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
index 12620848626fc8..5317f171a98a79 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
@@ -15,6 +15,7 @@
 
 #include <random>
 #include <cassert>
+#include <climits>
 
 #include "test_macros.h"
 
@@ -35,19 +36,41 @@ template <class T>
 void
 test()
 {
-    test1<T, 0, 0, 0>();
-    test1<T, 0, 1, 2>();
-    test1<T, 1, 1, 2>();
-    const T M(static_cast<T>(-1));
-    test1<T, 0, 0, M>();
-    test1<T, 0, M-2, M>();
-    test1<T, 0, M-1, M>();
-    test1<T, M-2, 0, M>();
-    test1<T, M-2, M-2, M>();
-    test1<T, M-2, M-1, M>();
-    test1<T, M-1, 0, M>();
-    test1<T, M-1, M-2, M>();
-    test1<T, M-1, M-1, M>();
+  const int W = sizeof(T) * CHAR_BIT;
+  const T M(static_cast<T>(-1));
+  const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+  // Cases where m = 0
+  test1<T, 0, 0, 0>();
+  test1<T, A, 0, 0>();
+  test1<T, 0, 1, 0>();
+  test1<T, A, 1, 0>();
+
+  // Cases where m = 2^n for n < w
+  test1<T, 0, 0, 256>();
+  test1<T, 5, 0, 256>();
+  test1<T, 0, 1, 256>();
+  test1<T, 5, 1, 256>();
+
+  // Cases where m is odd and a = 0
+  test1<T, 0, 0, M>();
+  test1<T, 0, M - 2, M>();
+  test1<T, 0, M - 1, M>();
+
+  // Cases where m is odd and m % a <= m / a (Schrage)
+  test1<T, A, 0, M>();
+  test1<T, A, M - 2, M>();
+  test1<T, A, M - 1, M>();
+
+  /*
+  // Cases where m is odd and m % a > m / a (not implemented)
+  test1<T, M - 2, 0, M>();
+  test1<T, M - 2, M - 2, M>();
+  test1<T, M - 2, M - 1, M>();
+  test1<T, M - 1, 0, M>();
+  test1<T, M - 1, M - 2, M>();
+  test1<T, M - 1, M - 1, M>();
+  */
 }
 
 int main(int, char**)
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
index 5dac0772cb0e94..8e950043d594f9 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
@@ -35,19 +35,41 @@ template <class T>
 void
 test()
 {
-    test1<T, 0, 0, 0>();
-    test1<T, 0, 1, 2>();
-    test1<T, 1, 1, 2>();
-    const T M(static_cast<T>(-1));
-    test1<T, 0, 0, M>();
-    test1<T, 0, M-2, M>();
-    test1<T, 0, M-1, M>();
-    test1<T, M-2, 0, M>();
-    test1<T, M-2, M-2, M>();
-    test1<T, M-2, M-1, M>();
-    test1<T, M-1, 0, M>();
-    test1<T, M-1, M-2, M>();
-    test1<T, M-1, M-1, M>();
+  const int W = sizeof(T) * CHAR_BIT;
+  const T M(static_cast<T>(-1));
+  const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+  // Cases where m = 0
+  test1<T, 0, 0, 0>();
+  test1<T, A, 0, 0>();
+  test1<T, 0, 1, 0>();
+  test1<T, A, 1, 0>();
+
+  // Cases where m = 2^n for n < w
+  test1<T, 0, 0, 256>();
+  test1<T, 5, 0, 256>();
+  test1<T, 0, 1, 256>();
+  test1<T, 5, 1, 256>();
+
+  // Cases where m is odd and a = 0
+  test1<T, 0, 0, M>();
+  test1<T, 0, M - 2, M>();
+  test1<T, 0, M - 1, M>();
+
+  // Cases where m is odd and m % a <= m / a (Schrage)
+  test1<T, A, 0, M>();
+  test1<T, A, M - 2, M>();
+  test1<T, A, M - 1, M>();
+
+  /*
+  // Cases where m is odd and m % a > m / a (not implemented)
+  test1<T, M - 2, 0, M>();
+  test1<T, M - 2, M - 2, M>();
+  test1<T, M - 2, M - 1, M>();
+  test1<T, M - 1, 0, M>();
+  test1<T, M - 1, M - 2, M>();
+  test1<T, M - 1, M - 1, M>();
+  */
 }
 
 int main(int, char**)
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
index 10bc1d71d8e892..52126f7a200dbe 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
@@ -33,19 +33,41 @@ template <class T>
 void
 test()
 {
-    test1<T, 0, 0, 0>();
-    test1<T, 0, 1, 2>();
-    test1<T, 1, 1, 2>();
-    const T M(static_cast<T>(-1));
-    test1<T, 0, 0, M>();
-    test1<T, 0, M-2, M>();
-    test1<T, 0, M-1, M>();
-    test1<T, M-2, 0, M>();
-    test1<T, M-2, M-2, M>();
-    test1<T, M-2, M-1, M>();
-    test1<T, M-1, 0, M>();
-    test1<T, M-1, M-2, M>();
-    test1<T, M-1, M-1, M>();
+  const int W = sizeof(T) * CHAR_BIT;
+  const T M(static_cast<T>(-1));
+  const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+  // Cases where m = 0
+  test1<T, 0, 0, 0>();
+  test1<T, A, 0, 0>();
+  test1<T, 0, 1, 0>();
+  test1<T, A, 1, 0>();
+
+  // Cases where m = 2^n for n < w
+  test1<T, 0, 0, 256>();
+  test1<T, 5, 0, 256>();
+  test1<T, 0, 1, 256>();
+  test1<T, 5, 1, 256>();
+
+  // Cases where m is odd and a = 0
+  test1<T, 0, 0, M>();
+  test1<T, 0, M - 2, M>();
+  test1<T, 0, M - 1, M>();
+
+  // Cases where m is odd and m % a <= m / a (Schrage)
+  test1<T, A, 0, M>();
+  test1<T, A, M - 2, M>();
+  test1<T, A, M - 1, M>();
+
+  /*
+  // Cases where m is odd and m % a > m / a (not implemented)
+  test1<T, M - 2, 0, M>();
+  test1<T, M - 2, M - 2, M>();
+  test1<T, M - 2, M - 1, M>();
+  test1<T, M - 1, 0, M>();
+  test1<T, M - 1, M - 2, M>();
+  test1<T, M - 1, M - 1, M>();
+  */
 }
 
 int main(int, char**)
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
index d9d47c5d8db46c..28d8dfea01fab3 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
@@ -66,19 +66,41 @@ template <class T>
 void
 test()
 {
-    test1<T, 0, 0, 0>();
-    test1<T, 0, 1, 2>();
-    test1<T, 1, 1, 2>();
-    const T M(static_cast<T>(-1));
-    test1<T, 0, 0, M>();
-    test1<T, 0, M-2, M>();
-    test1<T, 0, M-1, M>();
-    test1<T, M-2, 0, M>();
-    test1<T, M-2, M-2, M>();
-    test1<T, M-2, M-1, M>();
-    test1<T, M-1, 0, M>();
-    test1<T, M-1, M-2, M>();
-    test1<T, M-1, M-1, M>();
+  const int W = sizeof(T) * CHAR_BIT;
+  const T M(static_cast<T>(-1));
+  const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+  // Cases where m = 0
+  test1<T, 0, 0, 0>();
+  test1<T, A, 0, 0>();
+  test1<T, 0, 1, 0>();
+  test1<T, A, 1, 0>();
+
+  // Cases where m = 2^n for n < w
+  test1<T, 0, 0, 256>();
+  test1<T, 5, 0, 256>();
+  test1<T, 0, 1, 256>();
+  test1<T, 5, 1, 256>();
+
+  // Cases where m is odd and a = 0
+  test1<T, 0, 0, M>();
+  test1<T, 0, M - 2, M>();
+  test1<T, 0, M - 1, M>();
+
+  // Cases where m is odd and m % a <= m / a (Schrage)
+  test1<T, A, 0, M>();
+  test1<T, A, M - 2, M>();
+  test1<T, A, M - 1, M>();
+
+  /*
+  // Cases where m is odd and m % a > m / a (not implemented)
+  test1<T, M - 2, 0, M>();
+  test1<T, M - 2, M - 2, M>();
+  test1<T, M - 2, M - 1, M>();
+  test1<T, M - 1, 0, M>();
+  test1<T, M - 1, M - 2, M>();
+  test1<T, M - 1, M - 1, M>();
+  */
 }
 
 int main(int, char**)



More information about the libcxx-commits mailing list