[llvm] [KnownBits] Make abdu and abds optimal (PR #89081)
Jay Foad via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 18 02:07:03 PDT 2024
================
@@ -232,41 +232,34 @@ KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
}
KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
- // abdu(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
- KnownBits UMaxValue = umax(LHS, RHS);
- KnownBits UMinValue = umin(LHS, RHS);
- KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
- /*NUW=*/true, UMaxValue, UMinValue);
+ // If we know which argument is larger, return (sub LHS, RHS) or
+ // (sub RHS, LHS) directly.
+ if (LHS.getMinValue().uge(RHS.getMaxValue()))
+ return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
+ RHS);
+ if (RHS.getMinValue().uge(LHS.getMaxValue()))
+ return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
+ LHS);
- // find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
+ // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
KnownBits Diff0 =
- computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
+ computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
KnownBits Diff1 =
- computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
- KnownBits SubDiff = Diff0.intersectWith(Diff1);
-
- KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
- assert(!KnownAbsDiff.hasConflict() && "Bad Output");
- return KnownAbsDiff;
+ computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
+ return Diff0.intersectWith(Diff1);
}
KnownBits KnownBits::abds(const KnownBits &LHS, const KnownBits &RHS) {
- // abds(LHS,RHS) = sub(smax(LHS,RHS), smin(LHS,RHS)).
- KnownBits SMaxValue = smax(LHS, RHS);
- KnownBits SMinValue = smin(LHS, RHS);
- KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
- /*NUW=*/false, SMaxValue, SMinValue);
-
- // find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
- KnownBits Diff0 =
- computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
- KnownBits Diff1 =
- computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
- KnownBits SubDiff = Diff0.intersectWith(Diff1);
-
- KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
- assert(!KnownAbsDiff.hasConflict() && "Bad Output");
- return KnownAbsDiff;
+ // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
+ auto Flip = [](const KnownBits &Val) {
+ unsigned SignBitPosition = Val.getBitWidth() - 1;
+ APInt Zero = Val.Zero;
+ APInt One = Val.One;
+ Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
+ One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
+ return KnownBits(Zero, One);
+ };
+ return abdu(Flip(LHS), Flip(RHS));
----------------
jayfoad wrote:
> But either way, if you feel the need to explain it in the PR, it probably deserves a comment in the code :)
Fair enough. I've refactored the code and tried to explain the tricky part.
https://github.com/llvm/llvm-project/pull/89081
More information about the llvm-commits
mailing list