[llvm] [DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (PR #93182)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri May 24 03:45:26 PDT 2024


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/93182

>From 9ced5e9548273789dc132ec033a31385ef429a21 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Thu, 23 May 2024 13:20:49 +0100
Subject: [PATCH] [DAG] Replace getValid*ShiftAmountConstant helpers with
 getValid*ShiftAmount helpers to support KnownBits analysis

The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as Thumb2 MVE).

This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available.

Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc.

However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element.

This addresses feedback on #92096
---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  56 ++++---
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 158 ++++++++++--------
 .../CodeGen/SelectionDAG/TargetLowering.cpp   |  65 ++++---
 llvm/lib/Target/X86/X86ISelLowering.cpp       |   2 +-
 llvm/test/CodeGen/PowerPC/pr44183.ll          |   7 +-
 5 files changed, 153 insertions(+), 135 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 96a6270690468..95afbeb5dd6ec 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2159,36 +2159,38 @@ class SelectionDAG {
   /// splatted value it will return SDValue().
   SDValue getSplatValue(SDValue V, bool LegalTypes = false);
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has an uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V,
-                                           const APInt &DemandedElts) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              const APInt &DemandedElts,
+                                              unsigned Depth = 0) const;
 
-  /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+  /// If a SHL/SRA/SRL node \p V has an uniform shift amount
   /// that is less than the element bit-width of the shift node, return it.
-  const APInt *getValidShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the minimum value.
-  const APInt *
-  getValidMinimumShiftAmountConstant(SDValue V) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V,
-                                     const APInt &DemandedElts) const;
-
-  /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
-  /// than the element bit-width of the shift node, return the maximum value.
-  const APInt *
-  getValidMaximumShiftAmountConstant(SDValue V) const;
+  std::optional<uint64_t> getValidShiftAmount(SDValue V,
+                                              unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the minimum possible value.
+  std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     const APInt &DemandedElts,
+                                                     unsigned Depth = 0) const;
+
+  /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+  /// element bit-width of the shift node, return the maximum possible value.
+  std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+                                                     unsigned Depth = 0) const;
 
   /// Match a binop + shuffle pyramid that represents a horizontal reduction
   /// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b05649c6ce955..b71b496c0aa84 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
   return SDValue();
 }
 
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
-                                          const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
+                                  unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
@@ -3020,91 +3020,111 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
     // Shifting more than the bitwidth is not valid.
     const APInt &ShAmt = SA->getAPIntValue();
     if (ShAmt.ult(BitWidth))
-      return &ShAmt;
+      return ShAmt.getZExtValue();
+  } else {
+    KnownBits KnownAmt =
+        computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+    if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
+      return KnownAmt.getConstant().getZExtValue();
   }
-  return nullptr;
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidShiftAmountConstant(V, DemandedElts);
+  return getValidShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MinShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MinShAmt && MinShAmt->ule(ShAmt))
-      continue;
-    MinShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MinShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MinShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MinShAmt && MinShAmt->ule(ShAmt))
+        continue;
+      MinShAmt = &ShAmt;
+    }
+    if (MinShAmt)
+      return MinShAmt->getZExtValue();
   }
-  return MinShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMinValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMinimumShiftAmountConstant(V, DemandedElts);
+  return getValidMinimumShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MaxShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MaxShAmt && MaxShAmt->uge(ShAmt))
-      continue;
-    MaxShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MaxShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MaxShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MaxShAmt && MaxShAmt->uge(ShAmt))
+        continue;
+      MaxShAmt = &ShAmt;
+    }
+    if (MaxShAmt)
+      return MaxShAmt->getZExtValue();
   }
-  return MaxShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMaxValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMaximumShiftAmountConstant(V, DemandedElts);
+  return getValidMaximumShiftAmount(V, DemandedElts, Depth);
 }
 
 /// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3589,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
 
     // Minimum shift low bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setLowBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Known.Zero.setLowBits(*ShMinAmt);
     break;
   }
   case ISD::SRL:
@@ -3581,9 +3601,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
                             Op->getFlags().hasExact());
 
     // Minimum shift high bits are known zero.
-    if (const APInt *ShMinAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Known.Zero.setHighBits(ShMinAmt->getZExtValue());
+    if (std::optional<uint64_t> ShMinAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Known.Zero.setHighBits(*ShMinAmt);
     break;
   case ISD::SRA:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4607,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRA:
     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
     // SRA X, C -> adds C sign bits.
-    if (const APInt *ShAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
-      Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
+    if (std::optional<uint64_t> ShAmt =
+            getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+      Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
     return Tmp;
   case ISD::SHL:
-    if (const APInt *ShAmt =
-            getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> ShAmt =
+            getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
       // shl destroys sign bits, ensure it doesn't shift out all sign bits.
       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
-      if (ShAmt->ult(Tmp))
-        return Tmp - ShAmt->getZExtValue();
+      if (*ShAmt < Tmp)
+        return Tmp - *ShAmt;
     }
     break;
   case ISD::AND:
@@ -5270,7 +5290,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
   case ISD::SRL:
   case ISD::SRA:
     // If the max shift amount isn't in range, then the shift can create poison.
-    return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
+    return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
 
   case ISD::SCALAR_TO_VECTOR:
     // Check if we demand any upper (undef) elements.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 37c72339fe295..63d5821bedd95 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -796,10 +796,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+    if (std::optional<uint64_t> MaxSA =
+            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
       SDValue Op0 = Op.getOperand(0);
-      unsigned ShAmt = MaxSA->getZExtValue();
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1736,9 +1736,9 @@ bool TargetLowering::SimplifyDemandedBits(
     SDValue Op1 = Op.getOperand(1);
     EVT ShiftVT = Op1.getValueType();
 
-    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
-      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
+    if (std::optional<uint64_t> KnownSA =
+            TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *KnownSA;
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
@@ -1748,10 +1748,9 @@ bool TargetLowering::SimplifyDemandedBits(
       // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::SRL) {
         if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
-          KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
-                                                       DemandedElts, Depth + 1);
-          if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
-            unsigned C1 = InnerSA.getConstant().getZExtValue();
+          if (std::optional<uint64_t> InnerSA =
+                  TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 1)) {
+            unsigned C1 = *InnerSA;
             unsigned Opc = ISD::SHL;
             int Diff = ShAmt - C1;
             if (Diff < 0) {
@@ -1789,9 +1788,9 @@ bool TargetLowering::SimplifyDemandedBits(
         // TODO - support non-uniform vector amounts.
         if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
             InnerOp.hasOneUse()) {
-          if (const APInt *SA2 =
-                  TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
-            unsigned InnerShAmt = SA2->getZExtValue();
+          if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+                  InnerOp, DemandedElts, Depth + 1)) {
+            unsigned InnerShAmt = *SA2;
             if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
                 DemandedBits.getActiveBits() <=
                     (InnerBits - InnerShAmt + ShAmt) &&
@@ -1918,9 +1917,9 @@ bool TargetLowering::SimplifyDemandedBits(
 
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (const APInt *MaxSA =
-            TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
-      unsigned ShAmt = MaxSA->getZExtValue();
+    if (std::optional<uint64_t> MaxSA =
+            TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *MaxSA;
       unsigned NumSignBits =
           TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1934,9 +1933,9 @@ bool TargetLowering::SimplifyDemandedBits(
     SDValue Op1 = Op.getOperand(1);
     EVT ShiftVT = Op1.getValueType();
 
-    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
-      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
+    if (std::optional<uint64_t> KnownSA =
+            TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *KnownSA;
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
@@ -1946,10 +1945,9 @@ bool TargetLowering::SimplifyDemandedBits(
       // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::SHL) {
         if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
-          KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
-                                                       DemandedElts, Depth + 1);
-          if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
-            unsigned C1 = InnerSA.getConstant().getZExtValue();
+          if (std::optional<uint64_t> InnerSA =
+                  TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 1)) {
+            unsigned C1 = *InnerSA;
             unsigned Opc = ISD::SRL;
             int Diff = ShAmt - C1;
             if (Diff < 0) {
@@ -2042,25 +2040,24 @@ bool TargetLowering::SimplifyDemandedBits(
     if (DemandedBits.isOne())
       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
 
-    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
-    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
-      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
+    if (std::optional<uint64_t> KnownSA =
+            TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth)) {
+      unsigned ShAmt = *KnownSA;
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
       // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
       // supports sext_inreg.
       if (Op0.getOpcode() == ISD::SHL) {
-        KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
-                                                     DemandedElts, Depth + 1);
-        if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
+        if (std::optional<uint64_t> InnerSA =
+                TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 1)) {
           unsigned LowBits = BitWidth - ShAmt;
           EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
           if (VT.isVector())
             ExtVT = EVT::getVectorVT(*TLO.DAG.getContext(), ExtVT,
                                      VT.getVectorElementCount());
 
-          if (InnerSA.getConstant() == ShAmt) {
+          if (*InnerSA == ShAmt) {
             if (!TLO.LegalOperations() ||
                 getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) == Legal)
               return TLO.CombineTo(
@@ -2598,11 +2595,11 @@ bool TargetLowering::SimplifyDemandedBits(
         break;
 
       if (Src.getNode()->hasOneUse()) {
-        const APInt *ShAmtC =
-            TLO.DAG.getValidShiftAmountConstant(Src, DemandedElts);
-        if (!ShAmtC || ShAmtC->uge(BitWidth))
+        std::optional<uint64_t> ShAmtC =
+            TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 1);
+        if (!ShAmtC || *ShAmtC >= BitWidth)
           break;
-        uint64_t ShVal = ShAmtC->getZExtValue();
+        uint64_t ShVal = *ShAmtC;
 
         APInt HighBits =
             APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 95282c7c1455b..41e3ee4b845bc 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20490,7 +20490,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
   // the truncation then we can use PACKSS by converting the srl to a sra.
   // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
   if (In.getOpcode() == ISD::SRL && In->hasOneUse())
-    if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(In)) {
+    if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
       if (*ShAmt == MinSignBits) {
         PackOpcode = X86ISD::PACKSS;
         return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
diff --git a/llvm/test/CodeGen/PowerPC/pr44183.ll b/llvm/test/CodeGen/PowerPC/pr44183.ll
index 4d2c81c35b7fe..dc3e129922971 100644
--- a/llvm/test/CodeGen/PowerPC/pr44183.ll
+++ b/llvm/test/CodeGen/PowerPC/pr44183.ll
@@ -12,13 +12,12 @@ define void @_ZN1m1nEv(ptr %this) local_unnamed_addr nounwind align 2 {
 ; CHECK-NEXT:    mflr r0
 ; CHECK-NEXT:    std r30, -16(r1) # 8-byte Folded Spill
 ; CHECK-NEXT:    stdu r1, -48(r1)
-; CHECK-NEXT:    std r0, 64(r1)
 ; CHECK-NEXT:    mr r30, r3
-; CHECK-NEXT:    ld r3, 8(r3)
+; CHECK-NEXT:    std r0, 64(r1)
+; CHECK-NEXT:    lwz r3, 8(r3)
 ; CHECK-NEXT:    lwz r4, 36(r30)
-; CHECK-NEXT:    rldicl r3, r3, 60, 4
+; CHECK-NEXT:    rlwinm r3, r3, 27, 0, 0
 ; CHECK-NEXT:    clrlwi r4, r4, 31
-; CHECK-NEXT:    slwi r3, r3, 31
 ; CHECK-NEXT:    rlwimi r4, r3, 0, 0, 0
 ; CHECK-NEXT:    bl _ZN1llsE1d
 ; CHECK-NEXT:    nop



More information about the llvm-commits mailing list