[llvm] [DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (PR #93182)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 30 12:03:14 PDT 2024
================
@@ -3009,102 +3008,118 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
return SDValue();
}
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const {
+std::optional<ConstantRange>
+SelectionDAG::getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
+ // Shifting more than the bitwidth is not valid.
unsigned BitWidth = V.getScalarValueSizeInBits();
- if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1), DemandedElts)) {
- // Shifting more than the bitwidth is not valid.
- const APInt &ShAmt = SA->getAPIntValue();
- if (ShAmt.ult(BitWidth))
- return &ShAmt;
+
+ if (auto *Cst = dyn_cast<ConstantSDNode>(V.getOperand(1))) {
+ const APInt &ShAmt = Cst->getAPIntValue();
+ if (ShAmt.uge(BitWidth))
+ return std::nullopt;
+ return ConstantRange(ShAmt);
}
- return nullptr;
+
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+ const APInt *MinAmt = nullptr, *MaxAmt = nullptr;
+ for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+ if (!SA) {
+ MinAmt = MaxAmt = nullptr;
+ break;
+ }
+ const APInt &ShAmt = SA->getAPIntValue();
+ if (ShAmt.uge(BitWidth))
+ return std::nullopt;
+ if (!MinAmt || MinAmt->ugt(ShAmt))
+ MinAmt = &ShAmt;
+ if (!MaxAmt || MaxAmt->ult(ShAmt))
+ MaxAmt = &ShAmt;
+ }
+ assert(((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) &&
+ "Failed to find matching min/max shift amounts");
+ if (MinAmt && MaxAmt)
+ return ConstantRange(*MinAmt, *MaxAmt);
+ }
+
+ // Use computeKnownBits to find a hidden constant/knownbits (usually type
+ // legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
----------------
goldsteinn wrote:
Still think this should be `Depth` and the callers should be in charge of managing `Depth + 1` (the callers to `getValidShiftAmount`/`getValidMinimumShiftAmount`/`getValidMaximumShiftAmount` that is).
https://github.com/llvm/llvm-project/pull/93182
More information about the llvm-commits
mailing list