[llvm] [SelectionDAG] Let ComputeKnownSignBits handle (shl (ext X), C) (PR #97695)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 4 01:48:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Björn Pettersson (bjope)

<details>
<summary>Changes</summary>

Add simple support for looking through ZEXT/ANYEXT/SEXT when doing ComputeKnownSignBits for SHL. This is valid for the case when all extended bits are shifted out, because then the number of sign bits can be found by analysing the EXT operand.

A future improvement could be to pass along the "shifted left by" information in the recursive calls to ComputeKnownSignBits. Allowing us to handle this more generically.

---
Full diff: https://github.com/llvm/llvm-project/pull/97695.diff


1 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+15) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 96242305e9eab..991df60a5e650 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -4617,6 +4617,21 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SHL:
     if (std::optional<uint64_t> ShAmt =
             getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+      if (Op.getOperand(0).getOpcode() == ISD::ANY_EXTEND ||
+          Op.getOperand(0).getOpcode() == ISD::ZERO_EXTEND ||
+          Op.getOperand(0).getOpcode() == ISD::SIGN_EXTEND) {
+        SDValue Src = Op.getOperand(0);
+        EVT SrcVT = Src.getValueType();
+        SDValue ExtendedOp = Op.getOperand(0).getOperand(0);
+        EVT ExtendedOpVT = ExtendedOp.getValueType();
+        uint64_t ExtendedWidth =
+            SrcVT.getScalarSizeInBits() - ExtendedOpVT.getScalarSizeInBits();
+        if (ExtendedWidth <= *ShAmt) {
+          Tmp = ComputeNumSignBits(ExtendedOp, DemandedElts, Depth + 1);
+          if (*ShAmt - ExtendedWidth < Tmp)
+            return Tmp - (*ShAmt - ExtendedWidth);
+        }
+      }
       // shl destroys sign bits, ensure it doesn't shift out all sign bits.
       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
       if (*ShAmt < Tmp)

``````````

</details>


https://github.com/llvm/llvm-project/pull/97695


More information about the llvm-commits mailing list