[llvm] [SelectionDAG] implement computeKnownBits for add AVG* instructions (PR #86754)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 26 17:34:56 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Sizov Nikita (snikitav)

<details>
<summary>Changes</summary>

knownBits calculation for **AVGFLOORU** / **AVGFLOORS** / **AVGCEILU** / **AVGCEILS** instructions

Closes #<!-- -->53622 

---
Full diff: https://github.com/llvm/llvm-project/pull/86754.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+10-5) 
- (modified) llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp (+48) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index e8d1ac1d3a9167..e3b76b95eb86ad 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3419,13 +3419,18 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
       Known = KnownBits::mulhs(Known, Known2);
     break;
   }
-  case ISD::AVGCEILU: {
+  case ISD::AVGFLOORU:
+  case ISD::AVGCEILU:
+  case ISD::AVGFLOORS:
+  case ISD::AVGCEILS: {
+    bool IsCeil = Opcode == ISD::AVGCEILU || Opcode == ISD::AVGCEILS;
+    bool IsSigned = Opcode == ISD::AVGFLOORS || Opcode == ISD::AVGCEILS;
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known = Known.zext(BitWidth + 1);
-    Known2 = Known2.zext(BitWidth + 1);
-    KnownBits One = KnownBits::makeConstant(APInt(1, 1));
-    Known = KnownBits::computeForAddCarry(Known, Known2, One);
+    Known = IsSigned ? Known.sext(BitWidth + 1) : Known.zext(BitWidth + 1);
+    Known2 = IsSigned ? Known2.sext(BitWidth + 1) : Known2.zext(BitWidth + 1);
+    KnownBits Carry = KnownBits::makeConstant(APInt(1, IsCeil ? 1 : 0));
+    Known = KnownBits::computeForAddCarry(Known, Known2, Carry);
     Known = Known.extractBits(BitWidth, 1);
     break;
   }
diff --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
index e0772684e3a954..27bcad7c24c4db 100644
--- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
+++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
@@ -796,4 +796,52 @@ TEST_F(AArch64SelectionDAGTest, computeKnownBits_extload_knownnegative) {
   EXPECT_EQ(Known.One, APInt(32, 0xfffffff0));
 }
 
+TEST_F(AArch64SelectionDAGTest,
+       computeKnownBits_AVGFLOORU_AVGFLOORS_AVGCEILU_AVGCEILS) {
+  SDLoc Loc;
+  auto Int8VT = EVT::getIntegerVT(Context, 8);
+  auto Int16VT = EVT::getIntegerVT(Context, 16);
+  auto Int8Vec8VT = EVT::getVectorVT(Context, Int8VT, 8);
+  auto Int16Vec8VT = EVT::getVectorVT(Context, Int16VT, 8);
+
+  SDValue UnknownOp0 = DAG->getRegister(0, Int8Vec8VT);
+  SDValue UnknownOp1 = DAG->getRegister(1, Int8Vec8VT);
+
+  SDValue ZextOp0 =
+      DAG->getNode(ISD::ZERO_EXTEND, Loc, Int16Vec8VT, UnknownOp0);
+  SDValue ZextOp1 =
+      DAG->getNode(ISD::ZERO_EXTEND, Loc, Int16Vec8VT, UnknownOp1);
+  // ZextOp0 = 00000000????????
+  // ZextOp1 = 00000000????????
+  // => (for all AVG* instructions)
+  // Known.Zero = 1111111100000000 (0xFF00)
+  // Known.One  = 0000000000000000 (0x0000)
+  auto Zeroes = APInt(16, 0xFF00);
+  auto Ones = APInt(16, 0x0000);
+
+  SDValue AVGFLOORU =
+      DAG->getNode(ISD::AVGFLOORU, Loc, Int16Vec8VT, ZextOp0, ZextOp1);
+  KnownBits KnownAVGFLOORU = DAG->computeKnownBits(AVGFLOORU);
+  EXPECT_EQ(KnownAVGFLOORU.Zero, Zeroes);
+  EXPECT_EQ(KnownAVGFLOORU.One, Ones);
+
+  SDValue AVGFLOORS =
+      DAG->getNode(ISD::AVGFLOORU, Loc, Int16Vec8VT, ZextOp0, ZextOp1);
+  KnownBits KnownAVGFLOORS = DAG->computeKnownBits(AVGFLOORS);
+  EXPECT_EQ(KnownAVGFLOORS.Zero, Zeroes);
+  EXPECT_EQ(KnownAVGFLOORS.One, Ones);
+
+  SDValue AVGCEILU =
+      DAG->getNode(ISD::AVGCEILU, Loc, Int16Vec8VT, ZextOp0, ZextOp1);
+  KnownBits KnownAVGCEILU = DAG->computeKnownBits(AVGCEILU);
+  EXPECT_EQ(KnownAVGCEILU.Zero, Zeroes);
+  EXPECT_EQ(KnownAVGCEILU.One, Ones);
+
+  SDValue AVGCEILS =
+      DAG->getNode(ISD::AVGCEILS, Loc, Int16Vec8VT, ZextOp0, ZextOp1);
+  KnownBits KnownAVGCEILS = DAG->computeKnownBits(AVGCEILS);
+  EXPECT_EQ(KnownAVGCEILS.Zero, Zeroes);
+  EXPECT_EQ(KnownAVGCEILS.One, Ones);
+}
+
 } // end namespace llvm

``````````

</details>


https://github.com/llvm/llvm-project/pull/86754


More information about the llvm-commits mailing list