[llvm] [DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (PR #93182)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 30 11:14:15 PDT 2024


================
@@ -3020,91 +3020,113 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
     // Shifting more than the bitwidth is not valid.
     const APInt &ShAmt = SA->getAPIntValue();
     if (ShAmt.ult(BitWidth))
-      return &ShAmt;
+      return ShAmt.getZExtValue();
+  } else {
+    // Use computeKnownBits to find a hidden constant (usually type legalized).
+    // e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
+    KnownBits KnownAmt =
+        computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+    if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
+      return KnownAmt.getConstant().getZExtValue();
   }
-  return nullptr;
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidShiftAmountConstant(V, DemandedElts);
+  return getValidShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MinShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MinShAmt && MinShAmt->ule(ShAmt))
-      continue;
-    MinShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MinShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MinShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MinShAmt && MinShAmt->ule(ShAmt))
+        continue;
+      MinShAmt = &ShAmt;
+    }
+    if (MinShAmt)
+      return MinShAmt->getZExtValue();
   }
-  return MinShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMinValue().getZExtValue();
+  return std::nullopt;
 }
 
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
   EVT VT = V.getValueType();
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return getValidMinimumShiftAmountConstant(V, DemandedElts);
+  return getValidMinimumShiftAmount(V, DemandedElts, Depth);
 }
 
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
-    SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+                                         unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
-  if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
-    return ValidAmt;
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
-  if (!BV)
-    return nullptr;
-  const APInt *MaxShAmt = nullptr;
-  for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
-    if (!SA)
-      return nullptr;
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.uge(BitWidth))
-      return nullptr;
-    if (MaxShAmt && MaxShAmt->uge(ShAmt))
-      continue;
-    MaxShAmt = &ShAmt;
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MaxShAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MaxShAmt = nullptr;
+        break;
+      }
+      // Shifting more than the bitwidth is not valid.
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (MaxShAmt && MaxShAmt->uge(ShAmt))
+        continue;
+      MaxShAmt = &ShAmt;
+    }
+    if (MaxShAmt)
+      return MaxShAmt->getZExtValue();
   }
-  return MaxShAmt;
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+  if (KnownAmt.getMaxValue().ult(BitWidth))
+    return KnownAmt.getMaxValue().getZExtValue();
+  return std::nullopt;
----------------
goldsteinn wrote:

Its a bit of a mixed bag. KnownBits impl is more complete but has to throw away range information that isn't tied to particular bits. ConstantRange has a less complete impl, but obv represents what we want here better.

https://github.com/llvm/llvm-project/pull/93182


More information about the llvm-commits mailing list