[llvm] [DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (PR #93182)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 30 11:14:15 PDT 2024
================
@@ -3020,91 +3020,113 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
// Shifting more than the bitwidth is not valid.
const APInt &ShAmt = SA->getAPIntValue();
if (ShAmt.ult(BitWidth))
- return &ShAmt;
+ return ShAmt.getZExtValue();
+ } else {
+ // Use computeKnownBits to find a hidden constant (usually type legalized).
+ // e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
+ return KnownAmt.getConstant().getZExtValue();
}
- return nullptr;
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidShiftAmountConstant(V, DemandedElts);
+ return getValidShiftAmount(V, DemandedElts, Depth);
}
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
- SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
- if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
- return ValidAmt;
unsigned BitWidth = V.getScalarValueSizeInBits();
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
- if (!BV)
- return nullptr;
- const APInt *MinShAmt = nullptr;
- for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
- if (!DemandedElts[i])
- continue;
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
- if (!SA)
- return nullptr;
- // Shifting more than the bitwidth is not valid.
- const APInt &ShAmt = SA->getAPIntValue();
- if (ShAmt.uge(BitWidth))
- return nullptr;
- if (MinShAmt && MinShAmt->ule(ShAmt))
- continue;
- MinShAmt = &ShAmt;
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+ const APInt *MinShAmt = nullptr;
+ for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+ if (!SA) {
+ MinShAmt = nullptr;
+ break;
+ }
+ // Shifting more than the bitwidth is not valid.
+ const APInt &ShAmt = SA->getAPIntValue();
+ if (ShAmt.uge(BitWidth))
+ return std::nullopt;
+ if (MinShAmt && MinShAmt->ule(ShAmt))
+ continue;
+ MinShAmt = &ShAmt;
+ }
+ if (MinShAmt)
+ return MinShAmt->getZExtValue();
}
- return MinShAmt;
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.getMaxValue().ult(BitWidth))
+ return KnownAmt.getMinValue().getZExtValue();
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidMinimumShiftAmountConstant(V, DemandedElts);
+ return getValidMinimumShiftAmount(V, DemandedElts, Depth);
}
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
- SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
- if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
- return ValidAmt;
unsigned BitWidth = V.getScalarValueSizeInBits();
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
- if (!BV)
- return nullptr;
- const APInt *MaxShAmt = nullptr;
- for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
- if (!DemandedElts[i])
- continue;
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
- if (!SA)
- return nullptr;
- // Shifting more than the bitwidth is not valid.
- const APInt &ShAmt = SA->getAPIntValue();
- if (ShAmt.uge(BitWidth))
- return nullptr;
- if (MaxShAmt && MaxShAmt->uge(ShAmt))
- continue;
- MaxShAmt = &ShAmt;
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+ const APInt *MaxShAmt = nullptr;
+ for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+ if (!SA) {
+ MaxShAmt = nullptr;
+ break;
+ }
+ // Shifting more than the bitwidth is not valid.
+ const APInt &ShAmt = SA->getAPIntValue();
+ if (ShAmt.uge(BitWidth))
+ return std::nullopt;
+ if (MaxShAmt && MaxShAmt->uge(ShAmt))
+ continue;
+ MaxShAmt = &ShAmt;
+ }
+ if (MaxShAmt)
+ return MaxShAmt->getZExtValue();
}
- return MaxShAmt;
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.getMaxValue().ult(BitWidth))
+ return KnownAmt.getMaxValue().getZExtValue();
+ return std::nullopt;
----------------
goldsteinn wrote:
Its a bit of a mixed bag. KnownBits impl is more complete but has to throw away range information that isn't tied to particular bits. ConstantRange has a less complete impl, but obv represents what we want here better.
https://github.com/llvm/llvm-project/pull/93182
More information about the llvm-commits
mailing list