[llvm] 066b492 - [NFC] Add exhaustive test coverage for `{Un}signedDivisionByConstantInfo`
Roman Lebedev via llvm-commits
llvm-commits at lists.llvm.org
Sat Dec 24 17:10:55 PST 2022
Author: Roman Lebedev
Date: 2022-12-25T04:10:32+03:00
New Revision: 066b492b747a7e00f537eab9f0196575522ec285
URL: https://github.com/llvm/llvm-project/commit/066b492b747a7e00f537eab9f0196575522ec285
DIFF: https://github.com/llvm/llvm-project/commit/066b492b747a7e00f537eab9f0196575522ec285.diff
LOG: [NFC] Add exhaustive test coverage for `{Un}signedDivisionByConstantInfo`
Use this wrapper if you want to try brute-forcing wider bit widths:
https://godbolt.org/z/3xGzTM881
I've brute-forced i16 for both signed and unsigned, and we're all good.
As mentioned in https://reviews.llvm.org/D140636
Added:
llvm/unittests/Support/DivisionByConstantTest.cpp
Modified:
llvm/lib/Support/DivisionByConstantInfo.cpp
llvm/unittests/Support/CMakeLists.txt
Removed:
################################################################################
diff --git a/llvm/lib/Support/DivisionByConstantInfo.cpp b/llvm/lib/Support/DivisionByConstantInfo.cpp
index 989f1f2ed219..c84224069cbc 100644
--- a/llvm/lib/Support/DivisionByConstantInfo.cpp
+++ b/llvm/lib/Support/DivisionByConstantInfo.cpp
@@ -19,6 +19,11 @@ using namespace llvm;
/// the divisor not be 0, 1, or -1. Taken from "Hacker's Delight", Henry S.
/// Warren, Jr., Chapter 10.
SignedDivisionByConstantInfo SignedDivisionByConstantInfo::get(const APInt &D) {
+ assert(!D.isZero() && "Precondition violation.");
+
+ // We'd be endlessly stuck in the loop.
+ assert(D.getBitWidth() >= 3 && "Does not work at smaller bitwidths.");
+
APInt Delta;
APInt SignedMin = APInt::getSignedMinValue(D.getBitWidth());
struct SignedDivisionByConstantInfo Retval;
@@ -67,6 +72,9 @@ SignedDivisionByConstantInfo SignedDivisionByConstantInfo::get(const APInt &D) {
/// of the divided value are known zero.
UnsignedDivisionByConstantInfo
UnsignedDivisionByConstantInfo::get(const APInt &D, unsigned LeadingZeros) {
+ assert(!D.isZero() && "Precondition violation.");
+ assert(D.getBitWidth() > 1 && "Does not work at smaller bitwidths.");
+
APInt Delta;
struct UnsignedDivisionByConstantInfo Retval;
Retval.IsAdd = false; // initialize "add" indicator
diff --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt
index 052a524dd42d..91c6c9874870 100644
--- a/llvm/unittests/Support/CMakeLists.txt
+++ b/llvm/unittests/Support/CMakeLists.txt
@@ -27,8 +27,9 @@ add_llvm_unittest(SupportTests
CRCTest.cpp
CSKYAttributeParserTest.cpp
DataExtractorTest.cpp
- DebugTest.cpp
DebugCounterTest.cpp
+ DebugTest.cpp
+ DivisionByConstantTest.cpp
DJBTest.cpp
EndianStreamTest.cpp
EndianTest.cpp
diff --git a/llvm/unittests/Support/DivisionByConstantTest.cpp b/llvm/unittests/Support/DivisionByConstantTest.cpp
new file mode 100644
index 000000000000..fa8492bd9a66
--- /dev/null
+++ b/llvm/unittests/Support/DivisionByConstantTest.cpp
@@ -0,0 +1,187 @@
+//===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/Support/DivisionByConstantInfo.h"
+#include "gtest/gtest.h"
+#include <array>
+#include <optional>
+
+using namespace llvm;
+
+namespace {
+
+template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
+ APInt N(Bits, 0);
+ do {
+ TestFn(N);
+ } while (++N != 0);
+}
+
+APInt MULHS(APInt X, APInt Y) {
+ unsigned Bits = X.getBitWidth();
+ unsigned WideBits = 2 * Bits;
+ return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
+}
+
+APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
+ SignedDivisionByConstantInfo Magics) {
+ unsigned Bits = Numerator.getBitWidth();
+
+ APInt Factor(Bits, 0);
+ APInt ShiftMask(Bits, -1);
+ if (Divisor.isOne() || Divisor.isAllOnes()) {
+ // If d is +1/-1, we just multiply the numerator by +1/-1.
+ Factor = Divisor.getSExtValue();
+ Magics.Magic = 0;
+ Magics.ShiftAmount = 0;
+ ShiftMask = 0;
+ } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
+ // If d > 0 and m < 0, add the numerator.
+ Factor = 1;
+ } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
+ // If d < 0 and m > 0, subtract the numerator.
+ Factor = -1;
+ }
+
+ // Multiply the numerator by the magic value.
+ APInt Q = MULHS(Numerator, Magics.Magic);
+
+ // (Optionally) Add/subtract the numerator using Factor.
+ Factor = Numerator * Factor;
+ Q = Q + Factor;
+
+ // Shift right algebraic by shift value.
+ Q = Q.ashr(Magics.ShiftAmount);
+
+ // Extract the sign bit, mask it and add it to the quotient.
+ unsigned SignShift = Bits - 1;
+ APInt T = Q.lshr(SignShift);
+ T = T & ShiftMask;
+ return Q + T;
+}
+
+TEST(SignedDivisionByConstantTest, Test) {
+ for (unsigned Bits = 1; Bits <= 32; ++Bits) {
+ if (Bits < 3)
+ continue; // Not supported by `SignedDivisionByConstantInfo::get()`.
+ if (Bits > 12)
+ continue; // Unreasonably slow.
+ EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
+ if (Divisor.isZero())
+ return; // Division by zero is undefined behavior.
+ SignedDivisionByConstantInfo Magics;
+ if (!(Divisor.isOne() || Divisor.isAllOnes()))
+ Magics = SignedDivisionByConstantInfo::get(Divisor);
+ EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
+ if (Numerator.isMinSignedValue() && Divisor.isAllOnes())
+ return; // Overflow is undefined behavior.
+ APInt NativeResult = Numerator.sdiv(Divisor);
+ APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics);
+ ASSERT_EQ(MagicResult, NativeResult)
+ << " ... given the operation: srem i" << Bits << " " << Numerator
+ << ", " << Divisor;
+ });
+ });
+ }
+}
+
+APInt MULHU(APInt X, APInt Y) {
+ unsigned Bits = X.getBitWidth();
+ unsigned WideBits = 2 * Bits;
+ return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
+}
+
+APInt UnsignedDivideUsingMagic(APInt Numerator, APInt Divisor,
+ bool AllowEvenDivisorOptimization, bool ForceNPQ,
+ UnsignedDivisionByConstantInfo Magics) {
+ unsigned Bits = Numerator.getBitWidth();
+
+ bool UseNPQ = false;
+ unsigned PreShift = 0, PostShift = 0;
+
+ if (AllowEvenDivisorOptimization) {
+ // If the divisor is even, we can avoid using the expensive fixup by
+ // shifting the divided value upfront.
+ if (Magics.IsAdd && !Divisor[0]) {
+ PreShift = Divisor.countTrailingZeros();
+ // Get magic number for the shifted divisor.
+ Magics =
+ UnsignedDivisionByConstantInfo::get(Divisor.lshr(PreShift), PreShift);
+ assert(!Magics.IsAdd && "Should use cheap fixup now");
+ }
+ }
+
+ if (!Magics.IsAdd || Divisor.isOne()) {
+ assert(Magics.ShiftAmount < Divisor.getBitWidth() &&
+ "We shouldn't generate an undefined shift!");
+ PostShift = Magics.ShiftAmount;
+ UseNPQ = false;
+ } else {
+ PostShift = Magics.ShiftAmount - 1;
+ UseNPQ = true;
+ }
+
+ APInt NPQFactor =
+ UseNPQ ? APInt::getOneBitSet(Bits, Bits - 1) : APInt::getZero(Bits);
+
+ APInt Q = Numerator.lshr(PreShift);
+
+ // Multiply the numerator by the magic value.
+ Q = MULHU(Q, Magics.Magic);
+
+ if (UseNPQ || ForceNPQ) {
+ APInt NPQ = Numerator - Q;
+
+ // For vectors we might have a mix of non-NPQ/NPQ paths, so use
+ // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
+ APInt NPQ_Scalar = NPQ.lshr(1);
+ (void)NPQ_Scalar;
+ NPQ = MULHU(NPQ, NPQFactor);
+ assert(!UseNPQ || NPQ == NPQ_Scalar);
+
+ Q = NPQ + Q;
+ }
+
+ Q = Q.lshr(PostShift);
+
+ return Divisor.isOne() ? Numerator : Q;
+}
+
+TEST(UnsignedDivisionByConstantTest, Test) {
+ for (unsigned Bits = 1; Bits <= 32; ++Bits) {
+ if (Bits < 2)
+ continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`.
+ if (Bits > 11)
+ continue; // Unreasonably slow.
+ EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
+ if (Divisor.isZero())
+ return; // Division by zero is undefined behavior.
+ const UnsignedDivisionByConstantInfo Magics =
+ UnsignedDivisionByConstantInfo::get(Divisor);
+ EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
+ APInt NativeResult = Numerator.udiv(Divisor);
+ for (bool AllowEvenDivisorOptimization : {true, false}) {
+ for (bool ForceNPQ : {false, true}) {
+ APInt MagicResult = UnsignedDivideUsingMagic(
+ Numerator, Divisor, AllowEvenDivisorOptimization, ForceNPQ,
+ Magics);
+ ASSERT_EQ(MagicResult, NativeResult)
+ << " ... given the operation: urem i" << Bits << " "
+ << Numerator << ", " << Divisor
+ << " (allow even divisior optimization = "
+ << AllowEvenDivisorOptimization << ", force NPQ = " << ForceNPQ
+ << ")";
+ }
+ }
+ });
+ });
+ }
+}
+
+} // end anonymous namespace
More information about the llvm-commits
mailing list