[llvm] [DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (PR #93182)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 30 12:03:14 PDT 2024


================
@@ -3009,102 +3008,118 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
   return SDValue();
 }
 
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
-                                          const APInt &DemandedElts) const {
+std::optional<ConstantRange>
+SelectionDAG::getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
+                                       unsigned Depth) const {
   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
           V.getOpcode() == ISD::SRA) &&
          "Unknown shift node");
+  // Shifting more than the bitwidth is not valid.
   unsigned BitWidth = V.getScalarValueSizeInBits();
-  if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1), DemandedElts)) {
-    // Shifting more than the bitwidth is not valid.
-    const APInt &ShAmt = SA->getAPIntValue();
-    if (ShAmt.ult(BitWidth))
-      return &ShAmt;
+
+  if (auto *Cst = dyn_cast<ConstantSDNode>(V.getOperand(1))) {
+    const APInt &ShAmt = Cst->getAPIntValue();
+    if (ShAmt.uge(BitWidth))
+      return std::nullopt;
+    return ConstantRange(ShAmt);
   }
-  return nullptr;
+
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+    const APInt *MinAmt = nullptr, *MaxAmt = nullptr;
+    for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+      if (!DemandedElts[i])
+        continue;
+      auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+      if (!SA) {
+        MinAmt = MaxAmt = nullptr;
+        break;
+      }
+      const APInt &ShAmt = SA->getAPIntValue();
+      if (ShAmt.uge(BitWidth))
+        return std::nullopt;
+      if (!MinAmt || MinAmt->ugt(ShAmt))
+        MinAmt = &ShAmt;
+      if (!MaxAmt || MaxAmt->ult(ShAmt))
+        MaxAmt = &ShAmt;
+    }
+    assert(((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) &&
+           "Failed to find matching min/max shift amounts");
+    if (MinAmt && MaxAmt)
+      return ConstantRange(*MinAmt, *MaxAmt);
+  }
+
+  // Use computeKnownBits to find a hidden constant/knownbits (usually type
+  // legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
+  KnownBits KnownAmt =
+      computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
----------------
goldsteinn wrote:

Still think this should be `Depth` and the callers should be in charge of managing `Depth + 1` (the callers to `getValidShiftAmount`/`getValidMinimumShiftAmount`/`getValidMaximumShiftAmount` that is).

https://github.com/llvm/llvm-project/pull/93182


More information about the llvm-commits mailing list