[llvm] [DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (PR #93182)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu May 30 11:49:57 PDT 2024


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/93182

>From dc6b14f5685c3c340980551f5f233513a76a64e8 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Thu, 23 May 2024 13:20:49 +0100
Subject: [PATCH 1/2] [DAG] Replace getValid*ShiftAmountConstant helpers with
 getValid*ShiftAmount helpers to support KnownBits analysis

The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as Thumb2 MVE).

This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available.

Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc.

However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element.

This addresses feedback on #92096
---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  56 +++---
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 160 ++++++++++--------
 .../CodeGen/SelectionDAG/TargetLowering.cpp   |  65 ++++---
 llvm/lib/Target/X86/X86ISelLowering.cpp       |   2 +-
 llvm/test/CodeGen/PowerPC/pr44183.ll          |   7 +-
 5 files changed, 155 insertions(+), 135 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 0dc237301abb4..b7b7443bdf6f3 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2159,36 +2159,38 @@ class SelectionDAG {
   /// splatted value it will return SDValue().
   SDValue getSplatValue(SDValue V, bool LegalTypes = false);
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has a uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V,
-                                           const APInt &DemandedElts) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              const APInt &DemandedElts,
+                                              unsigned Depth = 0) const;
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has a uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
 
   /// Match a binop + shuffle pyramid that represents a horizontal reduction
   /// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b05649c6ce955..e54fcc8373e6d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
   return SDValue();
 }
 
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
-                                          const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
+                                  unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
@@ -3020,91 +3020,113 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
     // Shifting more than the bitwidth is not valid.
     const APInt &ShAmt = SA->getAPIntValue();
     if (ShAmt.ult(BitWidth))
-      return &ShAmt;
+      return ShAmt.getZExtValue();
+  } else {
+    // Use computeKnownBits to find a hidden constant (usually type legalized).
+    // e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
+    KnownBits KnownAmt =
+        computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+    if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
+      return KnownAmt.getConstant().getZExtValue();
   }
-  return nullptr;
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidShiftAmountConstant(V, DemandedElts);
+  return getValidShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MinShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MinShAmt && MinShAmt->ule(ShAmt))
-      continue;
-    MinShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MinShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MinShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MinShAmt && MinShAmt->ule(ShAmt))
+        continue;
+      MinShAmt = &ShAmt;
+    }
+    if (MinShAmt)
+      return MinShAmt->getZExtValue();
   }
-  return MinShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMinValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMinimumShiftAmountConstant(V, DemandedElts);
+  return getValidMinimumShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MaxShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MaxShAmt && MaxShAmt->uge(ShAmt))
-      continue;
-    MaxShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MaxShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MaxShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MaxShAmt && MaxShAmt->uge(ShAmt))
+        continue;
+      MaxShAmt = &ShAmt;
+    }
+    if (MaxShAmt)
+      return MaxShAmt->getZExtValue();
   }
-  return MaxShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMaxValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMaximumShiftAmountConstant(V, DemandedElts);
+  return getValidMaximumShiftAmount(V, DemandedElts, Depth);
 }
 
 /// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3591,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
 
     // Minimum shift low bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setLowBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Known.Zero.setLowBits(*ShMinAmt);
     break;
   }
   case ISD::SRL:
@@ -3581,9 +3603,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
                             Op->getFlags().hasExact());
 
     // Minimum shift high bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setHighBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Known.Zero.setHighBits(*ShMinAmt);
     break;
   case ISD::SRA:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4609,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRA:
     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
     // SRA X, C -> adds C sign bits.
-    if (const APInt *ShAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
+    if (std::optional<uint64_t> ShAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
     return Tmp;
   case ISD::SHL:
-    if (const APInt *ShAmt =
-            getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> ShAmt =
+            getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
       // shl destroys sign bits, ensure it doesn't shift out all sign bits.
       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
-      if (ShAmt->ult(Tmp))
-        return Tmp - ShAmt->getZExtValue();
+      if (*ShAmt < Tmp)
+        return Tmp - *ShAmt;
     }
     break;
   case ISD::AND:
@@ -5270,7 +5292,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
   case ISD::SRL:
   case ISD::SRA:
     // If the max shift amount isn't in range, then the shift can create poison.
-    return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
+    return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
 
   case ISD::SCALAR_TO_VECTOR:
     // Check if we demand any upper (undef) elements.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 623b6343994a4..18ff4749dce91 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -796,10 +796,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> MaxSA =
+            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
       SDValue Op0 = Op.getOperand(0);
-      unsigned ShAmt = MaxSA->getZExtValue();
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1736,9 +1736,9 @@ bool TargetLowering::SimplifyDemandedBits(
     SDValue Op1 = Op.getOperand(1);
     EVT ShiftVT = Op1.getValueType();
 
-    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
-      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
+    if (std::optional<uint64_t> KnownSA =
+            TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *KnownSA;
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
@@ -1748,10 +1748,9 @@ bool TargetLowering::SimplifyDemandedBits(
       // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::SRL) {
         if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
-          KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
-                                                       DemandedElts, Depth + 1);
-          if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
-            unsigned C1 = InnerSA.getConstant().getZExtValue();
+          if (std::optional<uint64_t> InnerSA =
+                  TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 1)) {
+            unsigned C1 = *InnerSA;
             unsigned Opc = ISD::SHL;
             int Diff = ShAmt - C1;
             if (Diff < 0) {
@@ -1789,9 +1788,9 @@ bool TargetLowering::SimplifyDemandedBits(
         // TODO - support non-uniform vector amounts.
         if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
             InnerOp.hasOneUse()) {
-          if (const APInt *SA2 =
-                  TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
-            unsigned InnerShAmt = SA2->getZExtValue();
+          if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+                  InnerOp, DemandedElts, Depth + 1)) {
+            unsigned InnerShAmt = *SA2;
             if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
                 DemandedBits.getActiveBits() <=
                     (InnerBits - InnerShAmt + ShAmt) &&
@@ -1918,9 +1917,9 @@ bool TargetLowering::SimplifyDemandedBits(
 
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
-      unsigned ShAmt = MaxSA->getZExtValue();
+    if (std::optional<uint64_t> MaxSA =
+            TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1934,9 +1933,9 @@ bool TargetLowering::SimplifyDemandedBits(
     SDValue Op1 = Op.getOperand(1);
     EVT ShiftVT = Op1.getValueType();
 
-    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
-      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
+    if (std::optional<uint64_t> KnownSA =
+            TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *KnownSA;
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
@@ -1946,10 +1945,9 @@ bool TargetLowering::SimplifyDemandedBits(
       // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::SHL) {
         if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
-          KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
-                                                       DemandedElts, Depth + 1);
-          if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
-            unsigned C1 = InnerSA.getConstant().getZExtValue();
+          if (std::optional<uint64_t> InnerSA =
+                  TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 1)) {
+            unsigned C1 = *InnerSA;
             unsigned Opc = ISD::SRL;
             int Diff = ShAmt - C1;
             if (Diff < 0) {
@@ -2042,25 +2040,24 @@ bool TargetLowering::SimplifyDemandedBits(
     if (DemandedBits.isOne())
       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
 
-    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
-      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
+    if (std::optional<uint64_t> KnownSA =
+            TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *KnownSA;
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
       // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
       // supports sext_inreg.
       if (Op0.getOpcode() == ISD::SHL) {
-        KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
-                                                     DemandedElts, Depth + 1);
-        if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
+        if (std::optional<uint64_t> InnerSA =
+                TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 1)) {
           unsigned LowBits = BitWidth - ShAmt;
           EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
           if (VT.isVector())
             ExtVT = EVT::getVectorVT(*TLO.DAG.getContext(), ExtVT,
                                      VT.getVectorElementCount());
 
-          if (InnerSA.getConstant() == ShAmt) {
+          if (*InnerSA == ShAmt) {
             if (!TLO.LegalOperations() ||
                 getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) == Legal)
               return TLO.CombineTo(
@@ -2598,11 +2595,11 @@ bool TargetLowering::SimplifyDemandedBits(
         break;
 
       if (Src.getNode()->hasOneUse()) {
-        const APInt *ShAmtC =
-            TLO.DAG.getValidShiftAmountConstant(Src, DemandedElts);
-        if (!ShAmtC || ShAmtC->uge(BitWidth))
+        std::optional<uint64_t> ShAmtC =
+            TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 1);
+        if (!ShAmtC || *ShAmtC >= BitWidth)
           break;
-        uint64_t ShVal = ShAmtC->getZExtValue();
+        uint64_t ShVal = *ShAmtC;
 
         APInt HighBits =
             APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 839006cbaed4c..0a87a7d0799fa 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20503,7 +20503,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
   // the truncation then we can use PACKSS by converting the srl to a sra.
   // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
   if (In.getOpcode() == ISD::SRL && In->hasOneUse())
-    if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(In)) {
+    if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
       if (*ShAmt == MinSignBits) {
         PackOpcode = X86ISD::PACKSS;
         return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
diff --git a/llvm/test/CodeGen/PowerPC/pr44183.ll b/llvm/test/CodeGen/PowerPC/pr44183.ll
index 4d2c81c35b7fe..dc3e129922971 100644
--- a/llvm/test/CodeGen/PowerPC/pr44183.ll
+++ b/llvm/test/CodeGen/PowerPC/pr44183.ll
@@ -12,13 +12,12 @@ define void @_ZN1m1nEv(ptr %this) local_unnamed_addr nounwind align 2 {
 ; CHECK-NEXT:    mflr r0
 ; CHECK-NEXT:    std r30, -16(r1) # 8-byte Folded Spill
 ; CHECK-NEXT:    stdu r1, -48(r1)
-; CHECK-NEXT:    std r0, 64(r1)
 ; CHECK-NEXT:    mr r30, r3
-; CHECK-NEXT:    ld r3, 8(r3)
+; CHECK-NEXT:    std r0, 64(r1)
+; CHECK-NEXT:    lwz r3, 8(r3)
 ; CHECK-NEXT:    lwz r4, 36(r30)
-; CHECK-NEXT:    rldicl r3, r3, 60, 4
+; CHECK-NEXT:    rlwinm r3, r3, 27, 0, 0
 ; CHECK-NEXT:    clrlwi r4, r4, 31
-; CHECK-NEXT:    slwi r3, r3, 31
 ; CHECK-NEXT:    rlwimi r4, r3, 0, 0, 0
 ; CHECK-NEXT:    bl _ZN1llsE1d
 ; CHECK-NEXT:    nop

>From 3879271a9f54e26d50948dc16b1b2ed0d0210431 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Thu, 30 May 2024 19:48:10 +0100
Subject: [PATCH 2/2] [DAG] Add getValidShiftAmountRange to determine the range
 of valid shift amount values.

---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |   7 +
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 127 +++++++++---------
 2 files changed, 67 insertions(+), 67 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index b7b7443bdf6f3..6d28273029bd0 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -32,6 +32,7 @@
 #include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
+#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/DebugLoc.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/Allocator.h"
@@ -2159,6 +2160,12 @@ class SelectionDAG {
   /// splatted value it will return SDValue().
   SDValue getSplatValue(SDValue V, bool LegalTypes = false);
 
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the valid constant range.
+  std::optional<ConstantRange>
+  getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
+                           unsigned Depth) const;
+
   /// If a SHL/SRA/SRL node \p V has a uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
   std::optional<uint64_t> getValidShiftAmount(SDValue V,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index e54fcc8373e6d..c361bbdd2ea00 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -48,7 +48,6 @@
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
 #include "llvm/IR/Constant.h"
-#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfoMetadata.h"
@@ -3009,26 +3008,66 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
   return SDValue();
 }
 
+std::optional<ConstantRange>
+SelectionDAG::getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
+                                       unsigned Depth) const {
+  assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
+          V.getOpcode() == ISD::SRA) &&
+         "Unknown shift node");
+  // Shifting more than the bitwidth is not valid.
+  unsigned BitWidth = V.getScalarValueSizeInBits();
+
+  if (auto *Cst = dyn_cast<ConstantSDNode>(V.getOperand(1))) {
+    const APInt &ShAmt = Cst->getAPIntValue();
+    if (ShAmt.uge(BitWidth))
+      return std::nullopt;
+    return ConstantRange(ShAmt);
+  }
+
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MinAmt = nullptr, *MaxAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MinAmt = MaxAmt = nullptr;
+        break;
+      }
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (!MinAmt || MinAmt->ugt(ShAmt))
+        MinAmt = &ShAmt;
+      if (!MaxAmt || MaxAmt->ult(ShAmt))
+        MaxAmt = &ShAmt;
+    }
+    assert(((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) &&
+           "Failed to find matching min/max shift amounts");
+    if (MinAmt && MaxAmt)
+      return ConstantRange(*MinAmt, *MaxAmt);
+  }
+
+  // Use computeKnownBits to find a hidden constant/knownbits (usually type
+  // legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return ConstantRange::fromKnownBits(KnownAmt, /*IsSigned=*/false);
+
+  return std::nullopt;
+}
+
 std::optional<uint64_t>
 SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
                                   unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  unsigned BitWidth = V.getScalarValueSizeInBits();
-  if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1), DemandedElts)) {
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.ult(BitWidth))
-      return ShAmt.getZExtValue();
-  } else {
-    // Use computeKnownBits to find a hidden constant (usually type legalized).
-    // e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
-    KnownBits KnownAmt =
-        computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
-    if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
-      return KnownAmt.getConstant().getZExtValue();
-  }
+  if (std::optional<ConstantRange> AmtRange =
+          getValidShiftAmountRange(V, DemandedElts, Depth))
+    if (const APInt *ShAmt = AmtRange->getSingleElement())
+      return ShAmt->getZExtValue();
   return std::nullopt;
 }
 
@@ -3047,32 +3086,9 @@ SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  unsigned BitWidth = V.getScalarValueSizeInBits();
-  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
-    const APInt *MinShAmt = nullptr;
-    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-      if (!DemandedElts[i])
-        continue;
-      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-      if (!SA) {
-        MinShAmt = nullptr;
-        break;
-      }
-      // Shifting more than the bitwidth is not valid.
-      const APInt &ShAmt = SA->getAPIntValue();
-      if (ShAmt.uge(BitWidth))
-        return std::nullopt;
-      if (MinShAmt && MinShAmt->ule(ShAmt))
-        continue;
-      MinShAmt = &ShAmt;
-    }
-    if (MinShAmt)
-      return MinShAmt->getZExtValue();
-  }
-  KnownBits KnownAmt =
-      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
-  if (KnownAmt.getMaxValue().ult(BitWidth))
-    return KnownAmt.getMinValue().getZExtValue();
+  if (std::optional<ConstantRange> AmtRange =
+          getValidShiftAmountRange(V, DemandedElts, Depth))
+    return AmtRange->getUnsignedMin().getZExtValue();
   return std::nullopt;
 }
 
@@ -3091,32 +3107,9 @@ SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  unsigned BitWidth = V.getScalarValueSizeInBits();
-  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
-    const APInt *MaxShAmt = nullptr;
-    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-      if (!DemandedElts[i])
-        continue;
-      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-      if (!SA) {
-        MaxShAmt = nullptr;
-        break;
-      }
-      // Shifting more than the bitwidth is not valid.
-      const APInt &ShAmt = SA->getAPIntValue();
-      if (ShAmt.uge(BitWidth))
-        return std::nullopt;
-      if (MaxShAmt && MaxShAmt->uge(ShAmt))
-        continue;
-      MaxShAmt = &ShAmt;
-    }
-    if (MaxShAmt)
-      return MaxShAmt->getZExtValue();
-  }
-  KnownBits KnownAmt =
-      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
-  if (KnownAmt.getMaxValue().ult(BitWidth))
-    return KnownAmt.getMaxValue().getZExtValue();
+  if (std::optional<ConstantRange> AmtRange =
+          getValidShiftAmountRange(V, DemandedElts, Depth))
+    return AmtRange->getUnsignedMax().getZExtValue();
   return std::nullopt;
 }
 



More information about the llvm-commits mailing list