[llvm] [DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (PR #93182)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 23 05:23:50 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-powerpc
Author: Simon Pilgrim (RKSimon)
<details>
<summary>Changes</summary>
The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as 32-bit x86, Thumb2 MVE).
This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available.
Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc.
However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element.
This addresses feedback on #<!-- -->92096
---
Full diff: https://github.com/llvm/llvm-project/pull/93182.diff
5 Files Affected:
- (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+29-27)
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+89-69)
- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+13-13)
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
- (modified) llvm/test/CodeGen/PowerPC/pr44183.ll (+3-4)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 96a6270690468..95afbeb5dd6ec 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2159,36 +2159,38 @@ class SelectionDAG {
/// splatted value it will return SDValue().
SDValue getSplatValue(SDValue V, bool LegalTypes = false);
- /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+ /// If a SHL/SRA/SRL node \p V has an uniform shift amount
/// that is less than the element bit-width of the shift node, return it.
- const APInt *getValidShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const;
+ std::optional<uint64_t> getValidShiftAmount(SDValue V,
+ const APInt &DemandedElts,
+ unsigned Depth = 0) const;
- /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+ /// If a SHL/SRA/SRL node \p V has an uniform shift amount
/// that is less than the element bit-width of the shift node, return it.
- const APInt *getValidShiftAmountConstant(SDValue V) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the minimum value.
- const APInt *
- getValidMinimumShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the minimum value.
- const APInt *
- getValidMinimumShiftAmountConstant(SDValue V) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the maximum value.
- const APInt *
- getValidMaximumShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the maximum value.
- const APInt *
- getValidMaximumShiftAmountConstant(SDValue V) const;
+ std::optional<uint64_t> getValidShiftAmount(SDValue V,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the minimum possible value.
+ std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+ const APInt &DemandedElts,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the minimum possible value.
+ std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the maximum possible value.
+ std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+ const APInt &DemandedElts,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the maximum possible value.
+ std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+ unsigned Depth = 0) const;
/// Match a binop + shuffle pyramid that represents a horizontal reduction
/// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b05649c6ce955..b71b496c0aa84 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
return SDValue();
}
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
@@ -3020,91 +3020,111 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
// Shifting more than the bitwidth is not valid.
const APInt &ShAmt = SA->getAPIntValue();
if (ShAmt.ult(BitWidth))
- return &ShAmt;
+ return ShAmt.getZExtValue();
+ } else {
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
+ return KnownAmt.getConstant().getZExtValue();
}
- return nullptr;
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidShiftAmountConstant(V, DemandedElts);
+ return getValidShiftAmount(V, DemandedElts, Depth);
}
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
- SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
- if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
- return ValidAmt;
unsigned BitWidth = V.getScalarValueSizeInBits();
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
- if (!BV)
- return nullptr;
- const APInt *MinShAmt = nullptr;
- for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
- if (!DemandedElts[i])
- continue;
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
- if (!SA)
- return nullptr;
- // Shifting more than the bitwidth is not valid.
- const APInt &ShAmt = SA->getAPIntValue();
- if (ShAmt.uge(BitWidth))
- return nullptr;
- if (MinShAmt && MinShAmt->ule(ShAmt))
- continue;
- MinShAmt = &ShAmt;
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+ const APInt *MinShAmt = nullptr;
+ for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+ if (!SA) {
+ MinShAmt = nullptr;
+ break;
+ }
+ // Shifting more than the bitwidth is not valid.
+ const APInt &ShAmt = SA->getAPIntValue();
+ if (ShAmt.uge(BitWidth))
+ return std::nullopt;
+ if (MinShAmt && MinShAmt->ule(ShAmt))
+ continue;
+ MinShAmt = &ShAmt;
+ }
+ if (MinShAmt)
+ return MinShAmt->getZExtValue();
}
- return MinShAmt;
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.getMaxValue().ult(BitWidth))
+ return KnownAmt.getMinValue().getZExtValue();
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidMinimumShiftAmountConstant(V, DemandedElts);
+ return getValidMinimumShiftAmount(V, DemandedElts, Depth);
}
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
- SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
- if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
- return ValidAmt;
unsigned BitWidth = V.getScalarValueSizeInBits();
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
- if (!BV)
- return nullptr;
- const APInt *MaxShAmt = nullptr;
- for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
- if (!DemandedElts[i])
- continue;
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
- if (!SA)
- return nullptr;
- // Shifting more than the bitwidth is not valid.
- const APInt &ShAmt = SA->getAPIntValue();
- if (ShAmt.uge(BitWidth))
- return nullptr;
- if (MaxShAmt && MaxShAmt->uge(ShAmt))
- continue;
- MaxShAmt = &ShAmt;
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+ const APInt *MaxShAmt = nullptr;
+ for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+ if (!SA) {
+ MaxShAmt = nullptr;
+ break;
+ }
+ // Shifting more than the bitwidth is not valid.
+ const APInt &ShAmt = SA->getAPIntValue();
+ if (ShAmt.uge(BitWidth))
+ return std::nullopt;
+ if (MaxShAmt && MaxShAmt->uge(ShAmt))
+ continue;
+ MaxShAmt = &ShAmt;
+ }
+ if (MaxShAmt)
+ return MaxShAmt->getZExtValue();
}
- return MaxShAmt;
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.getMaxValue().ult(BitWidth))
+ return KnownAmt.getMaxValue().getZExtValue();
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidMaximumShiftAmountConstant(V, DemandedElts);
+ return getValidMaximumShiftAmount(V, DemandedElts, Depth);
}
/// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3589,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
// Minimum shift low bits are known zero.
- if (const APInt *ShMinAmt =
- getValidMinimumShiftAmountConstant(Op, DemandedElts))
- Known.Zero.setLowBits(ShMinAmt->getZExtValue());
+ if (std::optional<uint64_t> ShMinAmt =
+ getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+ Known.Zero.setLowBits(*ShMinAmt);
break;
}
case ISD::SRL:
@@ -3581,9 +3601,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Op->getFlags().hasExact());
// Minimum shift high bits are known zero.
- if (const APInt *ShMinAmt =
- getValidMinimumShiftAmountConstant(Op, DemandedElts))
- Known.Zero.setHighBits(ShMinAmt->getZExtValue());
+ if (std::optional<uint64_t> ShMinAmt =
+ getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+ Known.Zero.setHighBits(*ShMinAmt);
break;
case ISD::SRA:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4607,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRA:
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
// SRA X, C -> adds C sign bits.
- if (const APInt *ShAmt =
- getValidMinimumShiftAmountConstant(Op, DemandedElts))
- Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
+ if (std::optional<uint64_t> ShAmt =
+ getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+ Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
return Tmp;
case ISD::SHL:
- if (const APInt *ShAmt =
- getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+ if (std::optional<uint64_t> ShAmt =
+ getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
- if (ShAmt->ult(Tmp))
- return Tmp - ShAmt->getZExtValue();
+ if (*ShAmt < Tmp)
+ return Tmp - *ShAmt;
}
break;
case ISD::AND:
@@ -5270,7 +5290,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
case ISD::SRL:
case ISD::SRA:
// If the max shift amount isn't in range, then the shift can create poison.
- return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
+ return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
case ISD::SCALAR_TO_VECTOR:
// Check if we demand any upper (undef) elements.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 37c72339fe295..dfcd5439b8a9d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -796,10 +796,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
case ISD::SHL: {
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (const APInt *MaxSA =
- DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+ if (std::optional<uint64_t> MaxSA =
+ DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
SDValue Op0 = Op.getOperand(0);
- unsigned ShAmt = MaxSA->getZExtValue();
+ unsigned ShAmt = *MaxSA;
unsigned NumSignBits =
DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1789,9 +1789,9 @@ bool TargetLowering::SimplifyDemandedBits(
// TODO - support non-uniform vector amounts.
if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
InnerOp.hasOneUse()) {
- if (const APInt *SA2 =
- TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
- unsigned InnerShAmt = SA2->getZExtValue();
+ if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+ InnerOp, DemandedElts, Depth + 1)) {
+ unsigned InnerShAmt = *SA2;
if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
DemandedBits.getActiveBits() <=
(InnerBits - InnerShAmt + ShAmt) &&
@@ -1918,9 +1918,9 @@ bool TargetLowering::SimplifyDemandedBits(
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (const APInt *MaxSA =
- TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
- unsigned ShAmt = MaxSA->getZExtValue();
+ if (std::optional<uint64_t> MaxSA =
+ TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
+ unsigned ShAmt = *MaxSA;
unsigned NumSignBits =
TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -2598,11 +2598,11 @@ bool TargetLowering::SimplifyDemandedBits(
break;
if (Src.getNode()->hasOneUse()) {
- const APInt *ShAmtC =
- TLO.DAG.getValidShiftAmountConstant(Src, DemandedElts);
- if (!ShAmtC || ShAmtC->uge(BitWidth))
+ std::optional<uint64_t> ShAmtC =
+ TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 1);
+ if (!ShAmtC || *ShAmtC >= BitWidth)
break;
- uint64_t ShVal = ShAmtC->getZExtValue();
+ uint64_t ShVal = *ShAmtC;
APInt HighBits =
APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 215cbc308e43d..fd99f0e345d14 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20490,7 +20490,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
// the truncation then we can use PACKSS by converting the srl to a sra.
// SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
if (In.getOpcode() == ISD::SRL && In->hasOneUse())
- if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(In)) {
+ if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
if (*ShAmt == MinSignBits) {
PackOpcode = X86ISD::PACKSS;
return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
diff --git a/llvm/test/CodeGen/PowerPC/pr44183.ll b/llvm/test/CodeGen/PowerPC/pr44183.ll
index 4d2c81c35b7fe..dc3e129922971 100644
--- a/llvm/test/CodeGen/PowerPC/pr44183.ll
+++ b/llvm/test/CodeGen/PowerPC/pr44183.ll
@@ -12,13 +12,12 @@ define void @_ZN1m1nEv(ptr %this) local_unnamed_addr nounwind align 2 {
; CHECK-NEXT: mflr r0
; CHECK-NEXT: std r30, -16(r1) # 8-byte Folded Spill
; CHECK-NEXT: stdu r1, -48(r1)
-; CHECK-NEXT: std r0, 64(r1)
; CHECK-NEXT: mr r30, r3
-; CHECK-NEXT: ld r3, 8(r3)
+; CHECK-NEXT: std r0, 64(r1)
+; CHECK-NEXT: lwz r3, 8(r3)
; CHECK-NEXT: lwz r4, 36(r30)
-; CHECK-NEXT: rldicl r3, r3, 60, 4
+; CHECK-NEXT: rlwinm r3, r3, 27, 0, 0
; CHECK-NEXT: clrlwi r4, r4, 31
-; CHECK-NEXT: slwi r3, r3, 31
; CHECK-NEXT: rlwimi r4, r3, 0, 0, 0
; CHECK-NEXT: bl _ZN1llsE1d
; CHECK-NEXT: nop
``````````
</details>
https://github.com/llvm/llvm-project/pull/93182
More information about the llvm-commits
mailing list