[llvm] [DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (PR #93182)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 23 05:23:50 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-powerpc

Author: Simon Pilgrim (RKSimon)

<details>
<summary>Changes</summary>

The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as 32-bit x86, Thumb2 MVE).

This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available.

Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc.

However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element.

This addresses feedback on #<!-- -->92096

---
Full diff: https://github.com/llvm/llvm-project/pull/93182.diff


5 Files Affected:

- (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+29-27) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+89-69) 
- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+13-13) 
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1) 
- (modified) llvm/test/CodeGen/PowerPC/pr44183.ll (+3-4) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 96a6270690468..95afbeb5dd6ec 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2159,36 +2159,38 @@ class SelectionDAG {
   /// splatted value it will return SDValue().
   SDValue getSplatValue(SDValue V, bool LegalTypes = false);
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has an uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V,
-                                           const APInt &DemandedElts) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              const APInt &DemandedElts,
+                                              unsigned Depth = 0) const;
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has an uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
 
   /// Match a binop + shuffle pyramid that represents a horizontal reduction
   /// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b05649c6ce955..b71b496c0aa84 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
   return SDValue();
 }
 
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
-                                          const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
+                                  unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
@@ -3020,91 +3020,111 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
     // Shifting more than the bitwidth is not valid.
     const APInt &ShAmt = SA->getAPIntValue();
     if (ShAmt.ult(BitWidth))
-      return &ShAmt;
+      return ShAmt.getZExtValue();
+  } else {
+    KnownBits KnownAmt =
+        computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+    if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
+      return KnownAmt.getConstant().getZExtValue();
   }
-  return nullptr;
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidShiftAmountConstant(V, DemandedElts);
+  return getValidShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MinShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MinShAmt && MinShAmt->ule(ShAmt))
-      continue;
-    MinShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MinShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MinShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MinShAmt && MinShAmt->ule(ShAmt))
+        continue;
+      MinShAmt = &ShAmt;
+    }
+    if (MinShAmt)
+      return MinShAmt->getZExtValue();
   }
-  return MinShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMinValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMinimumShiftAmountConstant(V, DemandedElts);
+  return getValidMinimumShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MaxShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MaxShAmt && MaxShAmt->uge(ShAmt))
-      continue;
-    MaxShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MaxShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MaxShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MaxShAmt && MaxShAmt->uge(ShAmt))
+        continue;
+      MaxShAmt = &ShAmt;
+    }
+    if (MaxShAmt)
+      return MaxShAmt->getZExtValue();
   }
-  return MaxShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMaxValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMaximumShiftAmountConstant(V, DemandedElts);
+  return getValidMaximumShiftAmount(V, DemandedElts, Depth);
 }
 
 /// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3589,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
 
     // Minimum shift low bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setLowBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Known.Zero.setLowBits(*ShMinAmt);
     break;
   }
   case ISD::SRL:
@@ -3581,9 +3601,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
                             Op->getFlags().hasExact());
 
     // Minimum shift high bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setHighBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Known.Zero.setHighBits(*ShMinAmt);
     break;
   case ISD::SRA:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4607,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRA:
     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
     // SRA X, C -> adds C sign bits.
-    if (const APInt *ShAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
+    if (std::optional<uint64_t> ShAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
     return Tmp;
   case ISD::SHL:
-    if (const APInt *ShAmt =
-            getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> ShAmt =
+            getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
       // shl destroys sign bits, ensure it doesn't shift out all sign bits.
       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
-      if (ShAmt->ult(Tmp))
-        return Tmp - ShAmt->getZExtValue();
+      if (*ShAmt < Tmp)
+        return Tmp - *ShAmt;
     }
     break;
   case ISD::AND:
@@ -5270,7 +5290,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
   case ISD::SRL:
   case ISD::SRA:
     // If the max shift amount isn't in range, then the shift can create poison.
-    return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
+    return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
 
   case ISD::SCALAR_TO_VECTOR:
     // Check if we demand any upper (undef) elements.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 37c72339fe295..dfcd5439b8a9d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -796,10 +796,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> MaxSA =
+            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
       SDValue Op0 = Op.getOperand(0);
-      unsigned ShAmt = MaxSA->getZExtValue();
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1789,9 +1789,9 @@ bool TargetLowering::SimplifyDemandedBits(
         // TODO - support non-uniform vector amounts.
         if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
             InnerOp.hasOneUse()) {
-          if (const APInt *SA2 =
-                  TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
-            unsigned InnerShAmt = SA2->getZExtValue();
+          if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+                  InnerOp, DemandedElts, Depth + 1)) {
+            unsigned InnerShAmt = *SA2;
             if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
                 DemandedBits.getActiveBits() <=
                     (InnerBits - InnerShAmt + ShAmt) &&
@@ -1918,9 +1918,9 @@ bool TargetLowering::SimplifyDemandedBits(
 
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
-      unsigned ShAmt = MaxSA->getZExtValue();
+    if (std::optional<uint64_t> MaxSA =
+            TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -2598,11 +2598,11 @@ bool TargetLowering::SimplifyDemandedBits(
         break;
 
       if (Src.getNode()->hasOneUse()) {
-        const APInt *ShAmtC =
-            TLO.DAG.getValidShiftAmountConstant(Src, DemandedElts);
-        if (!ShAmtC || ShAmtC->uge(BitWidth))
+        std::optional<uint64_t> ShAmtC =
+            TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 1);
+        if (!ShAmtC || *ShAmtC >= BitWidth)
           break;
-        uint64_t ShVal = ShAmtC->getZExtValue();
+        uint64_t ShVal = *ShAmtC;
 
         APInt HighBits =
             APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 215cbc308e43d..fd99f0e345d14 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20490,7 +20490,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
   // the truncation then we can use PACKSS by converting the srl to a sra.
   // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
   if (In.getOpcode() == ISD::SRL && In->hasOneUse())
-    if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(In)) {
+    if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
       if (*ShAmt == MinSignBits) {
         PackOpcode = X86ISD::PACKSS;
         return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
diff --git a/llvm/test/CodeGen/PowerPC/pr44183.ll b/llvm/test/CodeGen/PowerPC/pr44183.ll
index 4d2c81c35b7fe..dc3e129922971 100644
--- a/llvm/test/CodeGen/PowerPC/pr44183.ll
+++ b/llvm/test/CodeGen/PowerPC/pr44183.ll
@@ -12,13 +12,12 @@ define void @_ZN1m1nEv(ptr %this) local_unnamed_addr nounwind align 2 {
 ; CHECK-NEXT:    mflr r0
 ; CHECK-NEXT:    std r30, -16(r1) # 8-byte Folded Spill
 ; CHECK-NEXT:    stdu r1, -48(r1)
-; CHECK-NEXT:    std r0, 64(r1)
 ; CHECK-NEXT:    mr r30, r3
-; CHECK-NEXT:    ld r3, 8(r3)
+; CHECK-NEXT:    std r0, 64(r1)
+; CHECK-NEXT:    lwz r3, 8(r3)
 ; CHECK-NEXT:    lwz r4, 36(r30)
-; CHECK-NEXT:    rldicl r3, r3, 60, 4
+; CHECK-NEXT:    rlwinm r3, r3, 27, 0, 0
 ; CHECK-NEXT:    clrlwi r4, r4, 31
-; CHECK-NEXT:    slwi r3, r3, 31
 ; CHECK-NEXT:    rlwimi r4, r3, 0, 0, 0
 ; CHECK-NEXT:    bl _ZN1llsE1d
 ; CHECK-NEXT:    nop

``````````

</details>


https://github.com/llvm/llvm-project/pull/93182


More information about the llvm-commits mailing list