[llvm] [llvm] Add KnownBits implementations for avgFloor and avgCeil (PR #86445)

Nhat Nguyen via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 4 18:20:25 PDT 2024


https://github.com/changkhothuychung updated https://github.com/llvm/llvm-project/pull/86445

>From e447f4c15ced2880f520d636ffc239d1368032b7 Mon Sep 17 00:00:00 2001
From: changkhothuychung <nhat7203 at gmail.com>
Date: Sun, 24 Mar 2024 13:54:21 -0400
Subject: [PATCH 1/4] initial attempt

---
 llvm/lib/Support/KnownBits.cpp | 40 ++++++++++++++++++++++++++++++++++
 1 file changed, 40 insertions(+)

diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index d72355dab6f1d3..07c7ad0882387a 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -762,6 +762,46 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
   return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
 }
 
+KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
+  // (C1 & C2) + (C1 ^ C2).ashr(1)
+  KnownBits andResult = LHS & RHS;
+  KnownBits xorResult = LHS ^ RHS;
+  xorResult.Zero.ashrInPlace(1);
+  xorResult.One.ashrInPlace(1);
+  return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, andResult,
+                             xorResult);
+}
+
+KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
+  // (C1 & C2) + (C1 ^ C2).lshr(1)
+  KnownBits andResult = LHS & RHS;
+  KnownBits xorResult = LHS ^ RHS;
+  xorResult.Zero.lshrInPlace(1);
+  xorResult.One.lshrInPlace(1);
+  return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, andResult,
+                             xorResult);
+}
+
+KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
+  // (C1 | C2) - (C1 ^ C2).ashr(1)
+  KnownBits andResult = LHS & RHS;
+  KnownBits xorResult = LHS ^ RHS;
+  xorResult.Zero.ashrInPlace(1);
+  xorResult.One.ashrInPlace(1);
+  return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, andResult,
+                             xorResult);
+}
+
+KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
+  // (C1 | C2) - (C1 ^ C2).lshr(1)
+  KnownBits andResult = LHS & RHS;
+  KnownBits xorResult = LHS ^ RHS;
+  xorResult.Zero.lshrInPlace(1);
+  xorResult.One.lshrInPlace(1);
+  return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, andResult,
+                             xorResult);
+}
+
 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
                          bool NoUndefSelfMultiply) {
   unsigned BitWidth = LHS.getBitWidth();

>From 8b6e8bbec1e782ef1da06bf2dcf69319c9e96813 Mon Sep 17 00:00:00 2001
From: changkhothuychung <nhat7203 at gmail.com>
Date: Sun, 24 Mar 2024 14:02:31 -0400
Subject: [PATCH 2/4] add definitions to header file

---
 llvm/include/llvm/Support/KnownBits.h | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 73cb01e0644a8d..575362dc18b0cc 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -354,6 +354,18 @@ struct KnownBits {
   /// Compute knownbits resulting from llvm.usub.sat(LHS, RHS)
   static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS);
 
+  /// Compute knownbits resulting from (C1 & C2) + (C1 ^ C2).ashr(1)
+  static KnownBits avgFloorS(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Compute knownbits resulting from (C1 & C2) + (C1 ^ C2).lshr(1)
+  static KnownBits avgFloorU(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Compute knownbits resulting from (C1 & C2) - (C1 ^ C2).ashr(1)
+  static KnownBits avgCeilS(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Compute knownbits resulting from (C1 & C2) - (C1 ^ C2).lshr(1)
+  static KnownBits avgCeilU(const KnownBits &LHS, const KnownBits &RHS);
+
   /// Compute known bits resulting from multiplying LHS and RHS.
   static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS,
                        bool NoUndefSelfMultiply = false);

>From b0d81007840a354694e0238a0be96b13ce74b119 Mon Sep 17 00:00:00 2001
From: changkhothuychung <nhat7203 at gmail.com>
Date: Thu, 4 Apr 2024 21:10:27 -0400
Subject: [PATCH 3/4] fix comments

---
 llvm/include/llvm/Support/KnownBits.h         |  8 ++--
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 39 ++++++++++++++++---
 llvm/lib/Support/KnownBits.cpp                | 13 +++----
 3 files changed, 43 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 575362dc18b0cc..93049d05eb0b5c 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -354,16 +354,16 @@ struct KnownBits {
   /// Compute knownbits resulting from llvm.usub.sat(LHS, RHS)
   static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS);
 
-  /// Compute knownbits resulting from (C1 & C2) + (C1 ^ C2).ashr(1)
+  /// Compute knownbits resulting from APIntOps::avgFloorS
   static KnownBits avgFloorS(const KnownBits &LHS, const KnownBits &RHS);
 
-  /// Compute knownbits resulting from (C1 & C2) + (C1 ^ C2).lshr(1)
+  /// Compute knownbits resulting from APIntOps::avgFloorU
   static KnownBits avgFloorU(const KnownBits &LHS, const KnownBits &RHS);
 
-  /// Compute knownbits resulting from (C1 & C2) - (C1 ^ C2).ashr(1)
+  /// Compute knownbits resulting from APIntOps::avgCelS
   static KnownBits avgCeilS(const KnownBits &LHS, const KnownBits &RHS);
 
-  /// Compute knownbits resulting from (C1 & C2) - (C1 ^ C2).lshr(1)
+  /// Compute knownbits resulting from APIntOps::avgCelU
   static KnownBits avgCeilU(const KnownBits &LHS, const KnownBits &RHS);
 
   /// Compute known bits resulting from multiplying LHS and RHS.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 1dd0fa49a460f8..f6af02ded36a2e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3419,9 +3419,39 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
       Known = KnownBits::mulhs(Known, Known2);
     break;
   }
-  case ISD::AVGFLOORU:
-  case ISD::AVGCEILU:
-  case ISD::AVGFLOORS:
+  case ISD::AVGFLOORU: {
+    bool IsCeil = Opcode == ISD::AVGCEILU || Opcode == ISD::AVGCEILS;
+    bool IsSigned = Opcode == ISD::AVGFLOORS || Opcode == ISD::AVGCEILS;
+    Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
+    Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
+    Known = IsSigned ? Known.sext(BitWidth + 1) : Known.zext(BitWidth + 1);
+    Known2 = IsSigned ? Known2.sext(BitWidth + 1) : Known2.zext(BitWidth + 1);
+    Known = KnownBits::avgFloorU(Known, Known2);
+    Known = Known.extractBits(BitWidth, 1);
+    break;
+  }
+  case ISD::AVGCEILU: {
+    bool IsCeil = Opcode == ISD::AVGCEILU || Opcode == ISD::AVGCEILS;
+    bool IsSigned = Opcode == ISD::AVGFLOORS || Opcode == ISD::AVGCEILS;
+    Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
+    Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
+    Known = IsSigned ? Known.sext(BitWidth + 1) : Known.zext(BitWidth + 1);
+    Known2 = IsSigned ? Known2.sext(BitWidth + 1) : Known2.zext(BitWidth + 1);
+    Known = KnownBits::avgCeilU(Known, Known2);
+    Known = Known.extractBits(BitWidth, 1);
+    break;
+  }
+  case ISD::AVGFLOORS: {
+    bool IsCeil = Opcode == ISD::AVGCEILU || Opcode == ISD::AVGCEILS;
+    bool IsSigned = Opcode == ISD::AVGFLOORS || Opcode == ISD::AVGCEILS;
+    Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
+    Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
+    Known = IsSigned ? Known.sext(BitWidth + 1) : Known.zext(BitWidth + 1);
+    Known2 = IsSigned ? Known2.sext(BitWidth + 1) : Known2.zext(BitWidth + 1);
+    Known = KnownBits::avgFloorS(Known, Known2);
+    Known = Known.extractBits(BitWidth, 1);
+    break;
+  }
   case ISD::AVGCEILS: {
     bool IsCeil = Opcode == ISD::AVGCEILU || Opcode == ISD::AVGCEILS;
     bool IsSigned = Opcode == ISD::AVGFLOORS || Opcode == ISD::AVGCEILS;
@@ -3429,8 +3459,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
     Known = IsSigned ? Known.sext(BitWidth + 1) : Known.zext(BitWidth + 1);
     Known2 = IsSigned ? Known2.sext(BitWidth + 1) : Known2.zext(BitWidth + 1);
-    KnownBits Carry = KnownBits::makeConstant(APInt(1, IsCeil ? 1 : 0));
-    Known = KnownBits::computeForAddCarry(Known, Known2, Carry);
+    Known = KnownBits::avgCeilS(Known, Known2);
     Known = Known.extractBits(BitWidth, 1);
     break;
   }
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 07c7ad0882387a..1ad30a72166d9d 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -768,8 +768,7 @@ KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
   KnownBits xorResult = LHS ^ RHS;
   xorResult.Zero.ashrInPlace(1);
   xorResult.One.ashrInPlace(1);
-  return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, andResult,
-                             xorResult);
+  return computeForAddSub(/*Add*/ true, /*Signed*/ true, andResult, xorResult);
 }
 
 KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
@@ -778,8 +777,7 @@ KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
   KnownBits xorResult = LHS ^ RHS;
   xorResult.Zero.lshrInPlace(1);
   xorResult.One.lshrInPlace(1);
-  return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, andResult,
-                             xorResult);
+  return computeForAddSub(/*Add*/ true, /*Signed*/ false, andResult, xorResult);
 }
 
 KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
@@ -788,8 +786,7 @@ KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
   KnownBits xorResult = LHS ^ RHS;
   xorResult.Zero.ashrInPlace(1);
   xorResult.One.ashrInPlace(1);
-  return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, andResult,
-                             xorResult);
+  return computeForAddSub(/*Add*/ false, /*Signed*/ true, andResult, xorResult);
 }
 
 KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
@@ -798,8 +795,8 @@ KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
   KnownBits xorResult = LHS ^ RHS;
   xorResult.Zero.lshrInPlace(1);
   xorResult.One.lshrInPlace(1);
-  return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, andResult,
-                             xorResult);
+  return computeForAddSub(/*Add*/ false, /*Signed*/ false, andResult,
+                          xorResult);
 }
 
 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,

>From ba43d228509f12093608ea3f75819b9893ea65d2 Mon Sep 17 00:00:00 2001
From: changkhothuychung <nhat7203 at gmail.com>
Date: Thu, 4 Apr 2024 21:19:58 -0400
Subject: [PATCH 4/4] add tests

---
 llvm/unittests/Support/KnownBitsTest.cpp | 36 ++++++++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 027d6379af26b0..74432feea353ec 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -559,6 +559,42 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       KnownBits::mulhu,
       [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
       checkCorrectnessOnlyBinary);
+
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::avgFloorS(Known1, Known2);
+      },
+      [](const APInt &N1, const APInt &N2) {
+        return APIntOps::avgFloorS(N1, N2);
+      },
+      checkCorrectnessOnlyBinary);
+
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::avgFloorU(Known1, Known2);
+      },
+      [](const APInt &N1, const APInt &N2) {
+        return APIntOps::avgFloorU(N1, N2);
+      },
+      checkCorrectnessOnlyBinary);
+
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::avgCeilU(Known1, Known2);
+      },
+      [](const APInt &N1, const APInt &N2) {
+        return APIntOps::avgCeilU(N1, N2);
+      },
+      checkCorrectnessOnlyBinary);
+
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::avgCeilS(Known1, Known2);
+      },
+      [](const APInt &N1, const APInt &N2) {
+        return APIntOps::avgCeilS(N1, N2);
+      },
+      checkCorrectnessOnlyBinary);
 }
 
 TEST(KnownBitsTest, UnaryExhaustive) {



More information about the llvm-commits mailing list