[llvm-branch-commits] [llvm] Add DoNotPoisonEltMask to several SimplifyDemanded function in TargetLowering (PR #145903)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jun 26 07:30:54 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Björn Pettersson (bjope)
<details>
<summary>Changes</summary>
This is currently several commits in one PR. Should perhaps be splitted in several pull requests.
---
Patch is 1.15 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145903.diff
112 Files Affected:
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+24-2)
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+4-2)
- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+225-125)
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+117-84)
- (modified) llvm/lib/Target/X86/X86ISelLowering.h (+9-12)
- (modified) llvm/test/CodeGen/AMDGPU/load-constant-i1.ll (+278-324)
- (modified) llvm/test/CodeGen/AMDGPU/load-constant-i16.ll (+154-184)
- (modified) llvm/test/CodeGen/AMDGPU/load-constant-i8.ll (+445-517)
- (modified) llvm/test/CodeGen/AMDGPU/shift-i128.ll (+7-9)
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v2bf16.ll (+15-30)
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v3bf16.ll (+15-30)
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v4bf16.ll (+15-30)
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v2f16.ll (+15-30)
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v3f16.ll (+15-30)
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v4f16.ll (+15-30)
- (modified) llvm/test/CodeGen/ARM/fpclamptosat_vec.ll (+210-234)
- (modified) llvm/test/CodeGen/Thumb2/mve-fpclamptosat_vec.ll (+38-42)
- (modified) llvm/test/CodeGen/Thumb2/mve-gather-ind8-unscaled.ll (+5)
- (modified) llvm/test/CodeGen/Thumb2/mve-laneinterleaving.ll (+41-45)
- (modified) llvm/test/CodeGen/Thumb2/mve-pred-ext.ll (-1)
- (modified) llvm/test/CodeGen/Thumb2/mve-satmul-loops.ll (+95-118)
- (modified) llvm/test/CodeGen/Thumb2/mve-scatter-ind8-unscaled.ll (+7-2)
- (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-addpred.ll (+4-4)
- (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-mlapred.ll (+4-4)
- (modified) llvm/test/CodeGen/Thumb2/mve-vst3.ll (+19-19)
- (modified) llvm/test/CodeGen/X86/avx512-intrinsics-fast-isel.ll (+2-2)
- (modified) llvm/test/CodeGen/X86/avx512-intrinsics-upgrade.ll (+14-32)
- (modified) llvm/test/CodeGen/X86/avx512fp16-mov.ll (+16-16)
- (modified) llvm/test/CodeGen/X86/avx512vl-intrinsics-upgrade.ll (+14-38)
- (modified) llvm/test/CodeGen/X86/avx512vl-vec-masked-cmp.ll (+30-60)
- (modified) llvm/test/CodeGen/X86/bitcast-and-setcc-128.ll (+12-16)
- (modified) llvm/test/CodeGen/X86/bitcast-setcc-128.ll (+5-7)
- (modified) llvm/test/CodeGen/X86/bitcast-vector-bool.ll (-4)
- (modified) llvm/test/CodeGen/X86/buildvec-widen-dotproduct.ll (+15-19)
- (modified) llvm/test/CodeGen/X86/combine-pmuldq.ll (+18-69)
- (modified) llvm/test/CodeGen/X86/combine-sdiv.ll (+19-20)
- (modified) llvm/test/CodeGen/X86/combine-sra.ll (+31-32)
- (modified) llvm/test/CodeGen/X86/combine-udiv.ll (+1-4)
- (modified) llvm/test/CodeGen/X86/extractelement-load.ll (+5-4)
- (modified) llvm/test/CodeGen/X86/f16c-intrinsics-fast-isel.ll (-4)
- (modified) llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll (+45-44)
- (modified) llvm/test/CodeGen/X86/gfni-funnel-shifts.ll (+150-150)
- (modified) llvm/test/CodeGen/X86/half.ll (+3-3)
- (modified) llvm/test/CodeGen/X86/hoist-and-by-const-from-shl-in-eqcmp-zero.ll (+28-20)
- (modified) llvm/test/CodeGen/X86/known-never-zero.ll (+1-2)
- (modified) llvm/test/CodeGen/X86/known-pow2.ll (+6-6)
- (modified) llvm/test/CodeGen/X86/known-signbits-shl.ll (+1-2)
- (modified) llvm/test/CodeGen/X86/known-signbits-vector.ll (+7-23)
- (modified) llvm/test/CodeGen/X86/masked_store.ll (+1-1)
- (modified) llvm/test/CodeGen/X86/movmsk-cmp.ll (+8-35)
- (modified) llvm/test/CodeGen/X86/mulvi32.ll (+7-8)
- (modified) llvm/test/CodeGen/X86/omit-urem-of-power-of-two-or-zero-when-comparing-with-zero.ll (+11-14)
- (modified) llvm/test/CodeGen/X86/pmul.ll (+66-74)
- (modified) llvm/test/CodeGen/X86/pmulh.ll (+2-4)
- (modified) llvm/test/CodeGen/X86/pr107423.ll (+13-13)
- (modified) llvm/test/CodeGen/X86/pr35918.ll (+2-2)
- (modified) llvm/test/CodeGen/X86/pr41619.ll (-2)
- (modified) llvm/test/CodeGen/X86/pr42727.ll (+1-1)
- (modified) llvm/test/CodeGen/X86/pr45563-2.ll (+1-1)
- (modified) llvm/test/CodeGen/X86/pr45833.ll (+1-1)
- (modified) llvm/test/CodeGen/X86/pr77459.ll (+1-1)
- (modified) llvm/test/CodeGen/X86/psubus.ll (+123-132)
- (modified) llvm/test/CodeGen/X86/rotate-extract-vector.ll (-2)
- (modified) llvm/test/CodeGen/X86/sadd_sat_vec.ll (+93-114)
- (modified) llvm/test/CodeGen/X86/sat-add.ll (+9-12)
- (modified) llvm/test/CodeGen/X86/sdiv-exact.ll (+13-17)
- (modified) llvm/test/CodeGen/X86/shrink_vmul.ll (+20-14)
- (modified) llvm/test/CodeGen/X86/srem-seteq-vec-nonsplat.ll (+9-11)
- (modified) llvm/test/CodeGen/X86/sshl_sat_vec.ll (+38-38)
- (modified) llvm/test/CodeGen/X86/ssub_sat_vec.ll (+123-151)
- (modified) llvm/test/CodeGen/X86/test-shrink-bug.ll (+1-1)
- (modified) llvm/test/CodeGen/X86/udiv-exact.ll (+13-17)
- (modified) llvm/test/CodeGen/X86/urem-seteq-illegal-types.ll (+6-4)
- (modified) llvm/test/CodeGen/X86/urem-seteq-vec-nonsplat.ll (+122-172)
- (modified) llvm/test/CodeGen/X86/ushl_sat_vec.ll (+14-15)
- (modified) llvm/test/CodeGen/X86/vec_minmax_sint.ll (+66-92)
- (modified) llvm/test/CodeGen/X86/vec_minmax_uint.ll (+66-92)
- (modified) llvm/test/CodeGen/X86/vec_smulo.ll (+32-38)
- (modified) llvm/test/CodeGen/X86/vec_umulo.ll (+60-70)
- (modified) llvm/test/CodeGen/X86/vector-compare-all_of.ll (+15-21)
- (modified) llvm/test/CodeGen/X86/vector-compare-any_of.ll (+15-21)
- (modified) llvm/test/CodeGen/X86/vector-fshl-128.ll (+139-159)
- (modified) llvm/test/CodeGen/X86/vector-fshl-256.ll (+47-48)
- (modified) llvm/test/CodeGen/X86/vector-fshl-rot-128.ll (+37-56)
- (modified) llvm/test/CodeGen/X86/vector-fshl-rot-256.ll (+14-15)
- (modified) llvm/test/CodeGen/X86/vector-fshr-128.ll (+27-33)
- (modified) llvm/test/CodeGen/X86/vector-fshr-256.ll (+19-20)
- (modified) llvm/test/CodeGen/X86/vector-fshr-rot-128.ll (+28-33)
- (modified) llvm/test/CodeGen/X86/vector-fshr-rot-256.ll (+2-2)
- (modified) llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll (+839-863)
- (modified) llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-8.ll (+214-222)
- (modified) llvm/test/CodeGen/X86/vector-mul.ll (+34-34)
- (modified) llvm/test/CodeGen/X86/vector-pcmp.ll (+1-2)
- (modified) llvm/test/CodeGen/X86/vector-reduce-fmaximum.ll (+81-82)
- (modified) llvm/test/CodeGen/X86/vector-reduce-mul.ll (+57-113)
- (modified) llvm/test/CodeGen/X86/vector-reduce-smax.ll (+69-99)
- (modified) llvm/test/CodeGen/X86/vector-reduce-smin.ll (+65-96)
- (modified) llvm/test/CodeGen/X86/vector-reduce-umax.ll (+69-99)
- (modified) llvm/test/CodeGen/X86/vector-reduce-umin.ll (+65-96)
- (modified) llvm/test/CodeGen/X86/vector-rotate-128.ll (+37-56)
- (modified) llvm/test/CodeGen/X86/vector-rotate-256.ll (+14-15)
- (modified) llvm/test/CodeGen/X86/vector-shift-shl-128.ll (+24-28)
- (modified) llvm/test/CodeGen/X86/vector-shift-shl-256.ll (+20-22)
- (modified) llvm/test/CodeGen/X86/vector-shift-shl-sub128.ll (+48-56)
- (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-avx.ll (+6-6)
- (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-ssse3.ll (+6-3)
- (modified) llvm/test/CodeGen/X86/vector-shuffle-combining.ll (+3-6)
- (modified) llvm/test/CodeGen/X86/vector-trunc-packus.ll (+806-945)
- (modified) llvm/test/CodeGen/X86/vector-trunc-ssat.ll (+609-734)
- (modified) llvm/test/CodeGen/X86/vector-trunc-usat.ll (+368-398)
- (modified) llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll (+5-4)
- (modified) llvm/test/CodeGen/X86/vselect.ll (+18-8)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 727526055e592..191f4cc78fcc5 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4187,6 +4187,16 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
/// More limited version of SimplifyDemandedBits that can be used to "look
/// through" ops that don't contribute to the DemandedBits/DemandedElts -
/// bitwise ops etc.
+ /// Vector elements that aren't demanded can be turned into poison unless the
+ /// corresponding bit in the \p DoNotPoisonEltMask is set.
+ SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
+ const APInt &DemandedElts,
+ const APInt &DoNotPoisonEltMask,
+ SelectionDAG &DAG,
+ unsigned Depth = 0) const;
+
+ /// Helper wrapper around SimplifyMultipleUseDemandedBits, with
+ /// DoNotPoisonEltMask being set to zero.
SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
const APInt &DemandedElts,
SelectionDAG &DAG,
@@ -4202,6 +4212,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
/// bits from only some vector elements.
SDValue SimplifyMultipleUseDemandedVectorElts(SDValue Op,
const APInt &DemandedElts,
+ const APInt &DoNotPoisonEltMask,
SelectionDAG &DAG,
unsigned Depth = 0) const;
@@ -4219,6 +4230,15 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
/// results of this function, because simply replacing TLO.Old
/// with TLO.New will be incorrect when this parameter is true and TLO.Old
/// has multiple uses.
+ /// Vector elements that aren't demanded can be turned into poison unless the
+ /// corresponding bit in \p DoNotPoisonEltMask is set.
+ bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
+ const APInt &DoNotPoisonEltMask,
+ APInt &KnownUndef, APInt &KnownZero,
+ TargetLoweringOpt &TLO, unsigned Depth = 0,
+ bool AssumeSingleUse = false) const;
+ /// Version of SimplifyDemandedVectorElts without the DoNotPoisonEltMask
+ /// argument. All undemanded elements can be turned into poison.
bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
APInt &KnownUndef, APInt &KnownZero,
TargetLoweringOpt &TLO, unsigned Depth = 0,
@@ -4303,8 +4323,9 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
/// (used to simplify the caller). The KnownUndef/Zero elements may only be
/// accurate for those bits in the DemandedMask.
virtual bool SimplifyDemandedVectorEltsForTargetNode(
- SDValue Op, const APInt &DemandedElts, APInt &KnownUndef,
- APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth = 0) const;
+ SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+ APInt &KnownUndef, APInt &KnownZero, TargetLoweringOpt &TLO,
+ unsigned Depth = 0) const;
/// Attempt to simplify any target nodes based on the demanded bits/elts,
/// returning true on success. Otherwise, analyze the
@@ -4323,6 +4344,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
/// bitwise ops etc.
virtual SDValue SimplifyMultipleUseDemandedBitsForTargetNode(
SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+ const APInt &DoNotPoisonEltMask,
SelectionDAG &DAG, unsigned Depth) const;
/// Return true if this function can prove that \p Op is never poison
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c8088d68b6f1b..1f53742b2c6f7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1457,8 +1457,10 @@ bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
bool AssumeSingleUse) {
TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
APInt KnownUndef, KnownZero;
- if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
- TLO, 0, AssumeSingleUse))
+ APInt DoNotPoisonElts = APInt::getZero(DemandedElts.getBitWidth());
+ if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, DoNotPoisonElts,
+ KnownUndef, KnownZero, TLO, 0,
+ AssumeSingleUse))
return false;
// Revisit the node.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 524c97ab3eab8..931154b922924 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -679,7 +679,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
// TODO: Under what circumstances can we create nodes? Constant folding?
SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
- SelectionDAG &DAG, unsigned Depth) const {
+ const APInt &DoNotPoisonEltMask, SelectionDAG &DAG, unsigned Depth) const {
EVT VT = Op.getValueType();
// Limit search depth.
@@ -713,9 +713,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
if (NumSrcEltBits == NumDstEltBits)
if (SDValue V = SimplifyMultipleUseDemandedBits(
- Src, DemandedBits, DemandedElts, DAG, Depth + 1))
+ Src, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG,
+ Depth + 1))
return DAG.getBitcast(DstVT, V);
+ // FIXME: Handle DoNotPoisonEltMask better?
+
if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
unsigned Scale = NumDstEltBits / NumSrcEltBits;
unsigned NumSrcElts = SrcVT.getVectorNumElements();
@@ -731,9 +734,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
// destination element, since recursive calls below may turn not demanded
// elements into poison.
APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+ APInt DoNotPoisonSrcElts =
+ APIntOps::ScaleBitMask(DoNotPoisonEltMask, NumSrcElts);
if (SDValue V = SimplifyMultipleUseDemandedBits(
- Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+ Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+ Depth + 1))
return DAG.getBitcast(DstVT, V);
}
@@ -743,15 +749,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
+ APInt DoNotPoisonSrcElts = APInt::getZero(NumSrcElts);
for (unsigned i = 0; i != NumElts; ++i)
if (DemandedElts[i]) {
unsigned Offset = (i % Scale) * NumDstEltBits;
DemandedSrcBits.insertBits(DemandedBits, Offset);
DemandedSrcElts.setBit(i / Scale);
+ } else if (DoNotPoisonEltMask[i]) {
+ DoNotPoisonSrcElts.setBit(i / Scale);
}
+ // FIXME: Handle DoNotPoisonEltMask better?
+
if (SDValue V = SimplifyMultipleUseDemandedBits(
- Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+ Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+ Depth + 1))
return DAG.getBitcast(DstVT, V);
}
@@ -759,7 +771,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
}
case ISD::FREEZE: {
SDValue N0 = Op.getOperand(0);
- if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
+ if (DAG.isGuaranteedNotToBeUndefOrPoison(N0,
+ DemandedElts | DoNotPoisonEltMask,
/*PoisonOnly=*/false))
return N0;
break;
@@ -815,12 +828,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
case ISD::SHL: {
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (std::optional<uint64_t> MaxSA =
- DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+ if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+ Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
SDValue Op0 = Op.getOperand(0);
unsigned ShAmt = *MaxSA;
- unsigned NumSignBits =
- DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+ unsigned NumSignBits = DAG.ComputeNumSignBits(
+ Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
return Op0;
@@ -830,15 +843,15 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
case ISD::SRL: {
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (std::optional<uint64_t> MaxSA =
- DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+ if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+ Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
SDValue Op0 = Op.getOperand(0);
unsigned ShAmt = *MaxSA;
// Must already be signbits in DemandedBits bounds, and can't demand any
// shifted in zeroes.
if (DemandedBits.countl_zero() >= ShAmt) {
- unsigned NumSignBits =
- DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+ unsigned NumSignBits = DAG.ComputeNumSignBits(
+ Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
if (DemandedBits.countr_zero() >= (BitWidth - NumSignBits))
return Op0;
}
@@ -875,7 +888,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
shouldRemoveRedundantExtend(Op))
return Op0;
// If the input is already sign extended, just drop the extension.
- unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+ // FIXME: Can we skip DoNotPoisonEltMask here?
+ unsigned NumSignBits = DAG.ComputeNumSignBits(
+ Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
if (NumSignBits >= (BitWidth - ExBits + 1))
return Op0;
break;
@@ -891,7 +906,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
SDValue Src = Op.getOperand(0);
EVT SrcVT = Src.getValueType();
EVT DstVT = Op.getValueType();
- if (IsLE && DemandedElts == 1 &&
+ // FIXME: Can we skip DoNotPoisonEltMask here?
+ if (IsLE && (DemandedElts | DoNotPoisonEltMask) == 1 &&
DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
return DAG.getBitcast(DstVT, Src);
@@ -906,8 +922,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
SDValue Vec = Op.getOperand(0);
auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
EVT VecVT = Vec.getValueType();
+ // FIXME: Handle DoNotPoisonEltMask better.
if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
- !DemandedElts[CIdx->getZExtValue()])
+ !DemandedElts[CIdx->getZExtValue()] &&
+ !DoNotPoisonEltMask[CIdx->getZExtValue()])
return Vec;
break;
}
@@ -920,8 +938,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
uint64_t Idx = Op.getConstantOperandVal(2);
unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
+ APInt DoNotPoisonSubElts = DoNotPoisonEltMask.extractBits(NumSubElts, Idx);
// If we don't demand the inserted subvector, return the base vector.
- if (DemandedSubElts == 0)
+ // FIXME: Handle DoNotPoisonEltMask better.
+ if (DemandedSubElts == 0 && DoNotPoisonSubElts == 0)
return Vec;
break;
}
@@ -934,9 +954,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
for (unsigned i = 0; i != NumElts; ++i) {
int M = ShuffleMask[i];
- if (M < 0 || !DemandedElts[i])
+ if (M < 0 || (!DemandedElts[i] && !DoNotPoisonEltMask[i]))
continue;
- AllUndef = false;
+ if (DemandedElts[i])
+ AllUndef = false;
IdentityLHS &= (M == (int)i);
IdentityRHS &= ((M - NumElts) == i);
}
@@ -957,13 +978,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
- Op, DemandedBits, DemandedElts, DAG, Depth))
+ Op, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG, Depth))
return V;
break;
}
return SDValue();
}
+SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
+ SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+ SelectionDAG &DAG, unsigned Depth) const {
+ APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+ return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+ DoNotPoisonEltMask, DAG, Depth);
+}
+
SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG,
unsigned Depth) const {
@@ -974,13 +1003,14 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
- Depth);
+ APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+ return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+ DoNotPoisonEltMask, DAG, Depth);
}
SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
- SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
- unsigned Depth) const {
+ SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+ SelectionDAG &DAG, unsigned Depth) const {
APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
Depth);
@@ -2756,34 +2786,51 @@ bool TargetLowering::SimplifyDemandedBits(
TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
}
}
-
// Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
// Demand the elt/bit if any of the original elts/bits are demanded.
if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
unsigned Scale = BitWidth / NumSrcEltBits;
unsigned NumSrcElts = SrcVT.getVectorNumElements();
APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
+ APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
for (unsigned i = 0; i != Scale; ++i) {
unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
unsigned BitOffset = EltOffset * NumSrcEltBits;
APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
- if (!Sub.isZero())
+ if (!Sub.isZero()) {
DemandedSrcBits |= Sub;
+ for (unsigned j = 0; j != NumElts; ++j)
+ if (DemandedElts[j])
+ DemandedSrcElts.setBit((j * Scale) + i);
+ }
}
- // Need to demand all smaller source elements that maps to a demanded
- // destination element, since recursive calls below may turn not demanded
- // elements into poison.
- APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+ // Need to "semi demand" all smaller source elements that maps to a
+ // demanded destination element, since recursive calls below may turn not
+ // demanded elements into poison. Instead of demanding such elements we
+ // use a special bitmask to indicate that the recursive calls must not
+ // turn such elements into poison.
+ APInt NoPoisonSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
APInt KnownSrcUndef, KnownSrcZero;
- if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
- KnownSrcZero, TLO, Depth + 1))
+ if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, NoPoisonSrcElts,
+ KnownSrcUndef, KnownSrcZero, TLO,
+ Depth + 1))
return true;
KnownBits KnownSrcBits;
- if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
- KnownSrcBits, TLO, Depth + 1))
+ if (SimplifyDemandedBits(Src, DemandedSrcBits,
+ DemandedSrcElts | NoPoisonSrcElts, KnownSrcBits,
+ TLO, Depth + 1))
return true;
+
+ if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
+ if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
+ Src, DemandedSrcBits, DemandedSrcElts, NoPoisonSrcElts, TLO.DAG,
+ Depth + 1)) {
+ SDValue NewOp = TLO.DAG.getBitcast(VT, DemandedSrc);
+ return TLO.CombineTo(Op, NewOp);
+ }
+ }
} else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
// TODO - bigendian once we have test coverage.
unsigned Scale = NumSrcEltBits / BitWidth;
@@ -3090,8 +3137,9 @@ bool TargetLowering::SimplifyDemandedVectorElts(SDValue Op,
!DCI.isBeforeLegalizeOps());
APInt KnownUndef, KnownZero;
- bool Simplified =
- SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
+ APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+ bool Simplified = SimplifyDemandedVectorElts(
+ Op, DemandedElts, DoNotPoisonEltMask, KnownUndef, KnownZero, TLO);
if (Simplified) {
DCI.AddToWorklist(Op.getNode());
DCI.CommitTargetLoweringOpt(TLO);
@@ -3152,6 +3200,16 @@ bool TargetLowering::SimplifyDemandedVectorElts(
SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
bool AssumeSingleUse) const {
+ APInt DoNotPoisonEltMask = APInt::getZero(OriginalDemandedElts.getBitWidth());
+ return SimplifyDemandedVectorElts(Op, OriginalDemandedElts,
+ DoNotPoisonEltMask, KnownUndef, KnownZero,
+ TLO, Depth, AssumeSingleUse);
+}
+
+bool TargetLowering::SimplifyDemandedVectorElts(
+ SDValue Op, const APInt &OriginalDemandedElts,
+ const APInt &DoNotPoisonEltMask, APInt &KnownUndef, APInt &KnownZero,
+ TargetLoweringOpt &TLO, unsigned Depth, bool AssumeSingleUse) const {
EVT VT = Op.getValueType();
unsigned Opcode = Op.getOpcode();
APInt DemandedElts = OriginalDemandedElts;
@@ -3190,6 +3248,7 @@ bool TargetLowering::SimplifyDemandedVectorElts(
if (Depth >= SelectionDAG::MaxRecursionDepth)
return false;
+ APInt DemandedEltsInclDoNotPoison = DemandedElts | DoNotPoisonEltMask;
SDLoc DL(Op);
unsigned EltSizeInBits = VT.getScalarSizeInBits();
bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
@@ -3197,10 +3256,10 @@ bool TargetLowering::SimplifyDemandedVectorElts(
// Helper for demanding the specified elements and all the bits of both binary
// operands.
auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
- SDValue NewOp0 = S...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/145903
More information about the llvm-branch-commits
mailing list