[llvm] [KnownBits] Make abdu and abds optimal (PR #89081)

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 17 12:01:06 PDT 2024


================
@@ -232,41 +232,34 @@ KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
 }
 
 KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
-  // abdu(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
-  KnownBits UMaxValue = umax(LHS, RHS);
-  KnownBits UMinValue = umin(LHS, RHS);
-  KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
-                                          /*NUW=*/true, UMaxValue, UMinValue);
+  // If we know which argument is larger, return (sub LHS, RHS) or
+  // (sub RHS, LHS) directly.
+  if (LHS.getMinValue().uge(RHS.getMaxValue()))
+    return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
+                            RHS);
+  if (RHS.getMinValue().uge(LHS.getMaxValue()))
+    return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
+                            LHS);
 
-  // find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
+  // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
   KnownBits Diff0 =
-      computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
+      computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
   KnownBits Diff1 =
-      computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
-  KnownBits SubDiff = Diff0.intersectWith(Diff1);
-
-  KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
-  assert(!KnownAbsDiff.hasConflict() && "Bad Output");
-  return KnownAbsDiff;
+      computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
+  return Diff0.intersectWith(Diff1);
 }
 
 KnownBits KnownBits::abds(const KnownBits &LHS, const KnownBits &RHS) {
-  // abds(LHS,RHS) = sub(smax(LHS,RHS), smin(LHS,RHS)).
-  KnownBits SMaxValue = smax(LHS, RHS);
-  KnownBits SMinValue = smin(LHS, RHS);
-  KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
-                                          /*NUW=*/false, SMaxValue, SMinValue);
-
-  // find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
-  KnownBits Diff0 =
-      computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
-  KnownBits Diff1 =
-      computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
-  KnownBits SubDiff = Diff0.intersectWith(Diff1);
-
-  KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
-  assert(!KnownAbsDiff.hasConflict() && "Bad Output");
-  return KnownAbsDiff;
+  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
+  auto Flip = [](const KnownBits &Val) {
+    unsigned SignBitPosition = Val.getBitWidth() - 1;
+    APInt Zero = Val.Zero;
+    APInt One = Val.One;
+    Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
+    One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
+    return KnownBits(Zero, One);
+  };
+  return abdu(Flip(LHS), Flip(RHS));
----------------
jayfoad wrote:

> What do you mean the input is signed?

Consider these definitions:

An operation has **signed overflow** if `(sext X) op (sext Y) != sext (X op Y)`, where the extensions extend to some suitably large (maybe infinite) bit width.

An operation has **unsigned overflow** if `(zext X) op (zext Y) != zext (X op Y)`.

Fact: `abdu` never has unsigned overflow.

An operation has **signed-input-unsigned-output overflow** if `(sext X) op (sext Y) != zext (X op Y)`.

Fact: `abds` never has signed-input-unsigned-output overflow. (But it can have signed overflow, and it can have unsigned overflow.)

https://github.com/llvm/llvm-project/pull/89081


More information about the llvm-commits mailing list