[llvm] 066b492 - [NFC] Add exhaustive test coverage for `{Un}signedDivisionByConstantInfo`

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 24 17:10:55 PST 2022


Author: Roman Lebedev
Date: 2022-12-25T04:10:32+03:00
New Revision: 066b492b747a7e00f537eab9f0196575522ec285

URL: https://github.com/llvm/llvm-project/commit/066b492b747a7e00f537eab9f0196575522ec285
DIFF: https://github.com/llvm/llvm-project/commit/066b492b747a7e00f537eab9f0196575522ec285.diff

LOG: [NFC] Add exhaustive test coverage for `{Un}signedDivisionByConstantInfo`

Use this wrapper if you want to try brute-forcing wider bit widths:
https://godbolt.org/z/3xGzTM881

I've brute-forced i16 for both signed and unsigned, and we're all good.
As mentioned in https://reviews.llvm.org/D140636

Added: 
    llvm/unittests/Support/DivisionByConstantTest.cpp

Modified: 
    llvm/lib/Support/DivisionByConstantInfo.cpp
    llvm/unittests/Support/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Support/DivisionByConstantInfo.cpp b/llvm/lib/Support/DivisionByConstantInfo.cpp
index 989f1f2ed219..c84224069cbc 100644
--- a/llvm/lib/Support/DivisionByConstantInfo.cpp
+++ b/llvm/lib/Support/DivisionByConstantInfo.cpp
@@ -19,6 +19,11 @@ using namespace llvm;
 /// the divisor not be 0, 1, or -1.  Taken from "Hacker's Delight", Henry S.
 /// Warren, Jr., Chapter 10.
 SignedDivisionByConstantInfo SignedDivisionByConstantInfo::get(const APInt &D) {
+  assert(!D.isZero() && "Precondition violation.");
+
+  // We'd be endlessly stuck in the loop.
+  assert(D.getBitWidth() >= 3 && "Does not work at smaller bitwidths.");
+
   APInt Delta;
   APInt SignedMin = APInt::getSignedMinValue(D.getBitWidth());
   struct SignedDivisionByConstantInfo Retval;
@@ -67,6 +72,9 @@ SignedDivisionByConstantInfo SignedDivisionByConstantInfo::get(const APInt &D) {
 /// of the divided value are known zero.
 UnsignedDivisionByConstantInfo
 UnsignedDivisionByConstantInfo::get(const APInt &D, unsigned LeadingZeros) {
+  assert(!D.isZero() && "Precondition violation.");
+  assert(D.getBitWidth() > 1 && "Does not work at smaller bitwidths.");
+
   APInt Delta;
   struct UnsignedDivisionByConstantInfo Retval;
   Retval.IsAdd = false; // initialize "add" indicator

diff  --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt
index 052a524dd42d..91c6c9874870 100644
--- a/llvm/unittests/Support/CMakeLists.txt
+++ b/llvm/unittests/Support/CMakeLists.txt
@@ -27,8 +27,9 @@ add_llvm_unittest(SupportTests
   CRCTest.cpp
   CSKYAttributeParserTest.cpp
   DataExtractorTest.cpp
-  DebugTest.cpp
   DebugCounterTest.cpp
+  DebugTest.cpp
+  DivisionByConstantTest.cpp
   DJBTest.cpp
   EndianStreamTest.cpp
   EndianTest.cpp

diff  --git a/llvm/unittests/Support/DivisionByConstantTest.cpp b/llvm/unittests/Support/DivisionByConstantTest.cpp
new file mode 100644
index 000000000000..fa8492bd9a66
--- /dev/null
+++ b/llvm/unittests/Support/DivisionByConstantTest.cpp
@@ -0,0 +1,187 @@
+//===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/Support/DivisionByConstantInfo.h"
+#include "gtest/gtest.h"
+#include <array>
+#include <optional>
+
+using namespace llvm;
+
+namespace {
+
+template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
+  APInt N(Bits, 0);
+  do {
+    TestFn(N);
+  } while (++N != 0);
+}
+
+APInt MULHS(APInt X, APInt Y) {
+  unsigned Bits = X.getBitWidth();
+  unsigned WideBits = 2 * Bits;
+  return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
+}
+
+APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
+                             SignedDivisionByConstantInfo Magics) {
+  unsigned Bits = Numerator.getBitWidth();
+
+  APInt Factor(Bits, 0);
+  APInt ShiftMask(Bits, -1);
+  if (Divisor.isOne() || Divisor.isAllOnes()) {
+    // If d is +1/-1, we just multiply the numerator by +1/-1.
+    Factor = Divisor.getSExtValue();
+    Magics.Magic = 0;
+    Magics.ShiftAmount = 0;
+    ShiftMask = 0;
+  } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
+    // If d > 0 and m < 0, add the numerator.
+    Factor = 1;
+  } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
+    // If d < 0 and m > 0, subtract the numerator.
+    Factor = -1;
+  }
+
+  // Multiply the numerator by the magic value.
+  APInt Q = MULHS(Numerator, Magics.Magic);
+
+  // (Optionally) Add/subtract the numerator using Factor.
+  Factor = Numerator * Factor;
+  Q = Q + Factor;
+
+  // Shift right algebraic by shift value.
+  Q = Q.ashr(Magics.ShiftAmount);
+
+  // Extract the sign bit, mask it and add it to the quotient.
+  unsigned SignShift = Bits - 1;
+  APInt T = Q.lshr(SignShift);
+  T = T & ShiftMask;
+  return Q + T;
+}
+
+TEST(SignedDivisionByConstantTest, Test) {
+  for (unsigned Bits = 1; Bits <= 32; ++Bits) {
+    if (Bits < 3)
+      continue; // Not supported by `SignedDivisionByConstantInfo::get()`.
+    if (Bits > 12)
+      continue; // Unreasonably slow.
+    EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
+      if (Divisor.isZero())
+        return; // Division by zero is undefined behavior.
+      SignedDivisionByConstantInfo Magics;
+      if (!(Divisor.isOne() || Divisor.isAllOnes()))
+        Magics = SignedDivisionByConstantInfo::get(Divisor);
+      EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
+        if (Numerator.isMinSignedValue() && Divisor.isAllOnes())
+          return; // Overflow is undefined behavior.
+        APInt NativeResult = Numerator.sdiv(Divisor);
+        APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics);
+        ASSERT_EQ(MagicResult, NativeResult)
+            << " ... given the operation:  srem i" << Bits << " " << Numerator
+            << ", " << Divisor;
+      });
+    });
+  }
+}
+
+APInt MULHU(APInt X, APInt Y) {
+  unsigned Bits = X.getBitWidth();
+  unsigned WideBits = 2 * Bits;
+  return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
+}
+
+APInt UnsignedDivideUsingMagic(APInt Numerator, APInt Divisor,
+                               bool AllowEvenDivisorOptimization, bool ForceNPQ,
+                               UnsignedDivisionByConstantInfo Magics) {
+  unsigned Bits = Numerator.getBitWidth();
+
+  bool UseNPQ = false;
+  unsigned PreShift = 0, PostShift = 0;
+
+  if (AllowEvenDivisorOptimization) {
+    // If the divisor is even, we can avoid using the expensive fixup by
+    // shifting the divided value upfront.
+    if (Magics.IsAdd && !Divisor[0]) {
+      PreShift = Divisor.countTrailingZeros();
+      // Get magic number for the shifted divisor.
+      Magics =
+          UnsignedDivisionByConstantInfo::get(Divisor.lshr(PreShift), PreShift);
+      assert(!Magics.IsAdd && "Should use cheap fixup now");
+    }
+  }
+
+  if (!Magics.IsAdd || Divisor.isOne()) {
+    assert(Magics.ShiftAmount < Divisor.getBitWidth() &&
+           "We shouldn't generate an undefined shift!");
+    PostShift = Magics.ShiftAmount;
+    UseNPQ = false;
+  } else {
+    PostShift = Magics.ShiftAmount - 1;
+    UseNPQ = true;
+  }
+
+  APInt NPQFactor =
+      UseNPQ ? APInt::getOneBitSet(Bits, Bits - 1) : APInt::getZero(Bits);
+
+  APInt Q = Numerator.lshr(PreShift);
+
+  // Multiply the numerator by the magic value.
+  Q = MULHU(Q, Magics.Magic);
+
+  if (UseNPQ || ForceNPQ) {
+    APInt NPQ = Numerator - Q;
+
+    // For vectors we might have a mix of non-NPQ/NPQ paths, so use
+    // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
+    APInt NPQ_Scalar = NPQ.lshr(1);
+    (void)NPQ_Scalar;
+    NPQ = MULHU(NPQ, NPQFactor);
+    assert(!UseNPQ || NPQ == NPQ_Scalar);
+
+    Q = NPQ + Q;
+  }
+
+  Q = Q.lshr(PostShift);
+
+  return Divisor.isOne() ? Numerator : Q;
+}
+
+TEST(UnsignedDivisionByConstantTest, Test) {
+  for (unsigned Bits = 1; Bits <= 32; ++Bits) {
+    if (Bits < 2)
+      continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`.
+    if (Bits > 11)
+      continue; // Unreasonably slow.
+    EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
+      if (Divisor.isZero())
+        return; // Division by zero is undefined behavior.
+      const UnsignedDivisionByConstantInfo Magics =
+          UnsignedDivisionByConstantInfo::get(Divisor);
+      EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
+        APInt NativeResult = Numerator.udiv(Divisor);
+        for (bool AllowEvenDivisorOptimization : {true, false}) {
+          for (bool ForceNPQ : {false, true}) {
+            APInt MagicResult = UnsignedDivideUsingMagic(
+                Numerator, Divisor, AllowEvenDivisorOptimization, ForceNPQ,
+                Magics);
+            ASSERT_EQ(MagicResult, NativeResult)
+                << " ... given the operation:  urem i" << Bits << " "
+                << Numerator << ", " << Divisor
+                << " (allow even divisior optimization = "
+                << AllowEvenDivisorOptimization << ", force NPQ = " << ForceNPQ
+                << ")";
+          }
+        }
+      });
+    });
+  }
+}
+
+} // end anonymous namespace


        


More information about the llvm-commits mailing list