[llvm] 2b1dfd2 - [DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (#93182)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jun 1 08:48:30 PDT 2024


Author: Simon Pilgrim
Date: 2024-06-01T16:48:26+01:00
New Revision: 2b1dfd2b35b5684c8af85206e199152bd6ac3a8d

URL: https://github.com/llvm/llvm-project/commit/2b1dfd2b35b5684c8af85206e199152bd6ac3a8d
DIFF: https://github.com/llvm/llvm-project/commit/2b1dfd2b35b5684c8af85206e199152bd6ac3a8d.diff

LOG: [DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (#93182)

The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be partially scalarized or split into v4i32 on some targets such as 32-bit x86, Thumb2 MVE).

This patch proposes we generalize these helpers to work with ConstantRange+KnownBits if a scalar/buildvector constant isn't available.

Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc.

However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element.

This addresses feedback on #92096

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAG.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/PowerPC/pr44183.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 0dc237301abb4..6d28273029bd0 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -32,6 +32,7 @@
 #include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
+#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/DebugLoc.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/Allocator.h"
@@ -2159,36 +2160,44 @@ class SelectionDAG {
   /// splatted value it will return SDValue().
   SDValue getSplatValue(SDValue V, bool LegalTypes = false);
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the valid constant range.
+  std::optional<ConstantRange>
+  getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
+                           unsigned Depth) const;
+
+  /// If a SHL/SRA/SRL node \p V has a uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V,
-                                           const APInt &DemandedElts) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              const APInt &DemandedElts,
+                                              unsigned Depth = 0) const;
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has a uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
 
   /// Match a binop + shuffle pyramid that represents a horizontal reduction
   /// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b05649c6ce955..0ea33c1c699bf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -48,7 +48,6 @@
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
 #include "llvm/IR/Constant.h"
-#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfoMetadata.h"
@@ -3009,102 +3008,117 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
   return SDValue();
 }
 
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
-                                          const APInt &DemandedElts) const {
+std::optional<ConstantRange>
+SelectionDAG::getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
+                                       unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
+  // Shifting more than the bitwidth is not valid.
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1), DemandedElts)) {
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.ult(BitWidth))
-      return &ShAmt;
+
+  if (auto *Cst = dyn_cast<ConstantSDNode>(V.getOperand(1))) {
+    const APInt &ShAmt = Cst->getAPIntValue();
+    if (ShAmt.uge(BitWidth))
+      return std::nullopt;
+    return ConstantRange(ShAmt);
   }
-  return nullptr;
+
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MinAmt = nullptr, *MaxAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MinAmt = MaxAmt = nullptr;
+        break;
+      }
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (!MinAmt || MinAmt->ugt(ShAmt))
+        MinAmt = &ShAmt;
+      if (!MaxAmt || MaxAmt->ult(ShAmt))
+        MaxAmt = &ShAmt;
+    }
+    assert(((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) &&
+           "Failed to find matching min/max shift amounts");
+    if (MinAmt && MaxAmt)
+      return ConstantRange(*MinAmt, *MaxAmt + 1);
+  }
+
+  // Use computeKnownBits to find a hidden constant/knownbits (usually type
+  // legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
+  KnownBits KnownAmt = computeKnownBits(V.getOperand(1), DemandedElts, Depth);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return ConstantRange::fromKnownBits(KnownAmt, /*IsSigned=*/false);
+
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
+                                  unsigned Depth) const {
+  assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
+          V.getOpcode() == ISD::SRA) &&
+         "Unknown shift node");
+  if (std::optional<ConstantRange> AmtRange =
+          getValidShiftAmountRange(V, DemandedElts, Depth))
+    if (const APInt *ShAmt = AmtRange->getSingleElement())
+      return ShAmt->getZExtValue();
+  return std::nullopt;
+}
+
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidShiftAmountConstant(V, DemandedElts);
+  return getValidShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
-  unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MinShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MinShAmt && MinShAmt->ule(ShAmt))
-      continue;
-    MinShAmt = &ShAmt;
-  }
-  return MinShAmt;
+  if (std::optional<ConstantRange> AmtRange =
+          getValidShiftAmountRange(V, DemandedElts, Depth))
+    return AmtRange->getUnsignedMin().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMinimumShiftAmountConstant(V, DemandedElts);
+  return getValidMinimumShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
-  unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MaxShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MaxShAmt && MaxShAmt->uge(ShAmt))
-      continue;
-    MaxShAmt = &ShAmt;
-  }
-  return MaxShAmt;
+  if (std::optional<ConstantRange> AmtRange =
+          getValidShiftAmountRange(V, DemandedElts, Depth))
+    return AmtRange->getUnsignedMax().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMaximumShiftAmountConstant(V, DemandedElts);
+  return getValidMaximumShiftAmount(V, DemandedElts, Depth);
 }
 
 /// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3583,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
 
     // Minimum shift low bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setLowBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
+      Known.Zero.setLowBits(*ShMinAmt);
     break;
   }
   case ISD::SRL:
@@ -3581,9 +3595,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
                             Op->getFlags().hasExact());
 
     // Minimum shift high bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setHighBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
+      Known.Zero.setHighBits(*ShMinAmt);
     break;
   case ISD::SRA:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4601,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRA:
     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
     // SRA X, C -> adds C sign bits.
-    if (const APInt *ShAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
+    if (std::optional<uint64_t> ShAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
+      Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
     return Tmp;
   case ISD::SHL:
-    if (const APInt *ShAmt =
-            getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> ShAmt =
+            getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
       // shl destroys sign bits, ensure it doesn't shift out all sign bits.
       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
-      if (ShAmt->ult(Tmp))
-        return Tmp - ShAmt->getZExtValue();
+      if (*ShAmt < Tmp)
+        return Tmp - *ShAmt;
     }
     break;
   case ISD::AND:
@@ -5270,7 +5284,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
   case ISD::SRL:
   case ISD::SRA:
     // If the max shift amount isn't in range, then the shift can create poison.
-    return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
+    return !getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1);
 
   case ISD::SCALAR_TO_VECTOR:
     // Check if we demand any upper (undef) elements.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 623b6343994a4..f856c8a51984e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -796,10 +796,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> MaxSA =
+            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
-      unsigned ShAmt = MaxSA->getZExtValue();
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1736,9 +1736,9 @@ bool TargetLowering::SimplifyDemandedBits(
     SDValue Op1 = Op.getOperand(1);
     EVT ShiftVT = Op1.getValueType();
 
-    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
-      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
+    if (std::optional<uint64_t> KnownSA =
+            TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
+      unsigned ShAmt = *KnownSA;
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
@@ -1748,10 +1748,9 @@ bool TargetLowering::SimplifyDemandedBits(
       // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::SRL) {
         if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
-          KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
-                                                       DemandedElts, Depth + 1);
-          if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
-            unsigned C1 = InnerSA.getConstant().getZExtValue();
+          if (std::optional<uint64_t> InnerSA =
+                  TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
+            unsigned C1 = *InnerSA;
             unsigned Opc = ISD::SHL;
             int Diff = ShAmt - C1;
             if (Diff < 0) {
@@ -1789,9 +1788,9 @@ bool TargetLowering::SimplifyDemandedBits(
         // TODO - support non-uniform vector amounts.
         if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
             InnerOp.hasOneUse()) {
-          if (const APInt *SA2 =
-                  TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
-            unsigned InnerShAmt = SA2->getZExtValue();
+          if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+                  InnerOp, DemandedElts, Depth + 2)) {
+            unsigned InnerShAmt = *SA2;
             if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
                 DemandedBits.getActiveBits() <=
                     (InnerBits - InnerShAmt + ShAmt) &&
@@ -1918,9 +1917,9 @@ bool TargetLowering::SimplifyDemandedBits(
 
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
-      unsigned ShAmt = MaxSA->getZExtValue();
+    if (std::optional<uint64_t> MaxSA =
+            TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1934,9 +1933,9 @@ bool TargetLowering::SimplifyDemandedBits(
     SDValue Op1 = Op.getOperand(1);
     EVT ShiftVT = Op1.getValueType();
 
-    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
-      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
+    if (std::optional<uint64_t> KnownSA =
+            TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
+      unsigned ShAmt = *KnownSA;
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
@@ -1946,10 +1945,9 @@ bool TargetLowering::SimplifyDemandedBits(
       // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::SHL) {
         if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
-          KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
-                                                       DemandedElts, Depth + 1);
-          if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
-            unsigned C1 = InnerSA.getConstant().getZExtValue();
+          if (std::optional<uint64_t> InnerSA =
+                  TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
+            unsigned C1 = *InnerSA;
             unsigned Opc = ISD::SRL;
             int Diff = ShAmt - C1;
             if (Diff < 0) {
@@ -2042,25 +2040,24 @@ bool TargetLowering::SimplifyDemandedBits(
     if (DemandedBits.isOne())
       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
 
-    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
-      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
+    if (std::optional<uint64_t> KnownSA =
+            TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
+      unsigned ShAmt = *KnownSA;
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
       // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
       // supports sext_inreg.
       if (Op0.getOpcode() == ISD::SHL) {
-        KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
-                                                     DemandedElts, Depth + 1);
-        if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
+        if (std::optional<uint64_t> InnerSA =
+                TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
           unsigned LowBits = BitWidth - ShAmt;
           EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
           if (VT.isVector())
             ExtVT = EVT::getVectorVT(*TLO.DAG.getContext(), ExtVT,
                                      VT.getVectorElementCount());
 
-          if (InnerSA.getConstant() == ShAmt) {
+          if (*InnerSA == ShAmt) {
             if (!TLO.LegalOperations() ||
                 getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) == Legal)
               return TLO.CombineTo(
@@ -2598,11 +2595,11 @@ bool TargetLowering::SimplifyDemandedBits(
         break;
 
       if (Src.getNode()->hasOneUse()) {
-        const APInt *ShAmtC =
-            TLO.DAG.getValidShiftAmountConstant(Src, DemandedElts);
-        if (!ShAmtC || ShAmtC->uge(BitWidth))
+        std::optional<uint64_t> ShAmtC =
+            TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
+        if (!ShAmtC || *ShAmtC >= BitWidth)
           break;
-        uint64_t ShVal = ShAmtC->getZExtValue();
+        uint64_t ShVal = *ShAmtC;
 
         APInt HighBits =
             APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 7b9e6c0a00273..0e377dd53b742 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20519,7 +20519,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
   // the truncation then we can use PACKSS by converting the srl to a sra.
   // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
   if (In.getOpcode() == ISD::SRL && In->hasOneUse())
-    if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(In)) {
+    if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
       if (*ShAmt == MinSignBits) {
         PackOpcode = X86ISD::PACKSS;
         return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());

diff  --git a/llvm/test/CodeGen/PowerPC/pr44183.ll b/llvm/test/CodeGen/PowerPC/pr44183.ll
index 4d2c81c35b7fe..dc3e129922971 100644
--- a/llvm/test/CodeGen/PowerPC/pr44183.ll
+++ b/llvm/test/CodeGen/PowerPC/pr44183.ll
@@ -12,13 +12,12 @@ define void @_ZN1m1nEv(ptr %this) local_unnamed_addr nounwind align 2 {
 ; CHECK-NEXT:    mflr r0
 ; CHECK-NEXT:    std r30, -16(r1) # 8-byte Folded Spill
 ; CHECK-NEXT:    stdu r1, -48(r1)
-; CHECK-NEXT:    std r0, 64(r1)
 ; CHECK-NEXT:    mr r30, r3
-; CHECK-NEXT:    ld r3, 8(r3)
+; CHECK-NEXT:    std r0, 64(r1)
+; CHECK-NEXT:    lwz r3, 8(r3)
 ; CHECK-NEXT:    lwz r4, 36(r30)
-; CHECK-NEXT:    rldicl r3, r3, 60, 4
+; CHECK-NEXT:    rlwinm r3, r3, 27, 0, 0
 ; CHECK-NEXT:    clrlwi r4, r4, 31
-; CHECK-NEXT:    slwi r3, r3, 31
 ; CHECK-NEXT:    rlwimi r4, r3, 0, 0, 0
 ; CHECK-NEXT:    bl _ZN1llsE1d
 ; CHECK-NEXT:    nop


        


More information about the llvm-commits mailing list