[llvm-branch-commits] [llvm] Add DoNotPoisonEltMask to several SimplifyDemanded function in TargetLowering (PR #145903)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jun 26 07:30:54 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Björn Pettersson (bjope)

<details>
<summary>Changes</summary>

This is currently several commits in one PR. Should perhaps be splitted in several pull requests.

---

Patch is 1.15 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145903.diff


112 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+24-2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+4-2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+225-125) 
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+117-84) 
- (modified) llvm/lib/Target/X86/X86ISelLowering.h (+9-12) 
- (modified) llvm/test/CodeGen/AMDGPU/load-constant-i1.ll (+278-324) 
- (modified) llvm/test/CodeGen/AMDGPU/load-constant-i16.ll (+154-184) 
- (modified) llvm/test/CodeGen/AMDGPU/load-constant-i8.ll (+445-517) 
- (modified) llvm/test/CodeGen/AMDGPU/shift-i128.ll (+7-9) 
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v2bf16.ll (+15-30) 
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v3bf16.ll (+15-30) 
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v4bf16.ll (+15-30) 
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v2f16.ll (+15-30) 
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v3f16.ll (+15-30) 
- (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v4f16.ll (+15-30) 
- (modified) llvm/test/CodeGen/ARM/fpclamptosat_vec.ll (+210-234) 
- (modified) llvm/test/CodeGen/Thumb2/mve-fpclamptosat_vec.ll (+38-42) 
- (modified) llvm/test/CodeGen/Thumb2/mve-gather-ind8-unscaled.ll (+5) 
- (modified) llvm/test/CodeGen/Thumb2/mve-laneinterleaving.ll (+41-45) 
- (modified) llvm/test/CodeGen/Thumb2/mve-pred-ext.ll (-1) 
- (modified) llvm/test/CodeGen/Thumb2/mve-satmul-loops.ll (+95-118) 
- (modified) llvm/test/CodeGen/Thumb2/mve-scatter-ind8-unscaled.ll (+7-2) 
- (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-addpred.ll (+4-4) 
- (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-mlapred.ll (+4-4) 
- (modified) llvm/test/CodeGen/Thumb2/mve-vst3.ll (+19-19) 
- (modified) llvm/test/CodeGen/X86/avx512-intrinsics-fast-isel.ll (+2-2) 
- (modified) llvm/test/CodeGen/X86/avx512-intrinsics-upgrade.ll (+14-32) 
- (modified) llvm/test/CodeGen/X86/avx512fp16-mov.ll (+16-16) 
- (modified) llvm/test/CodeGen/X86/avx512vl-intrinsics-upgrade.ll (+14-38) 
- (modified) llvm/test/CodeGen/X86/avx512vl-vec-masked-cmp.ll (+30-60) 
- (modified) llvm/test/CodeGen/X86/bitcast-and-setcc-128.ll (+12-16) 
- (modified) llvm/test/CodeGen/X86/bitcast-setcc-128.ll (+5-7) 
- (modified) llvm/test/CodeGen/X86/bitcast-vector-bool.ll (-4) 
- (modified) llvm/test/CodeGen/X86/buildvec-widen-dotproduct.ll (+15-19) 
- (modified) llvm/test/CodeGen/X86/combine-pmuldq.ll (+18-69) 
- (modified) llvm/test/CodeGen/X86/combine-sdiv.ll (+19-20) 
- (modified) llvm/test/CodeGen/X86/combine-sra.ll (+31-32) 
- (modified) llvm/test/CodeGen/X86/combine-udiv.ll (+1-4) 
- (modified) llvm/test/CodeGen/X86/extractelement-load.ll (+5-4) 
- (modified) llvm/test/CodeGen/X86/f16c-intrinsics-fast-isel.ll (-4) 
- (modified) llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll (+45-44) 
- (modified) llvm/test/CodeGen/X86/gfni-funnel-shifts.ll (+150-150) 
- (modified) llvm/test/CodeGen/X86/half.ll (+3-3) 
- (modified) llvm/test/CodeGen/X86/hoist-and-by-const-from-shl-in-eqcmp-zero.ll (+28-20) 
- (modified) llvm/test/CodeGen/X86/known-never-zero.ll (+1-2) 
- (modified) llvm/test/CodeGen/X86/known-pow2.ll (+6-6) 
- (modified) llvm/test/CodeGen/X86/known-signbits-shl.ll (+1-2) 
- (modified) llvm/test/CodeGen/X86/known-signbits-vector.ll (+7-23) 
- (modified) llvm/test/CodeGen/X86/masked_store.ll (+1-1) 
- (modified) llvm/test/CodeGen/X86/movmsk-cmp.ll (+8-35) 
- (modified) llvm/test/CodeGen/X86/mulvi32.ll (+7-8) 
- (modified) llvm/test/CodeGen/X86/omit-urem-of-power-of-two-or-zero-when-comparing-with-zero.ll (+11-14) 
- (modified) llvm/test/CodeGen/X86/pmul.ll (+66-74) 
- (modified) llvm/test/CodeGen/X86/pmulh.ll (+2-4) 
- (modified) llvm/test/CodeGen/X86/pr107423.ll (+13-13) 
- (modified) llvm/test/CodeGen/X86/pr35918.ll (+2-2) 
- (modified) llvm/test/CodeGen/X86/pr41619.ll (-2) 
- (modified) llvm/test/CodeGen/X86/pr42727.ll (+1-1) 
- (modified) llvm/test/CodeGen/X86/pr45563-2.ll (+1-1) 
- (modified) llvm/test/CodeGen/X86/pr45833.ll (+1-1) 
- (modified) llvm/test/CodeGen/X86/pr77459.ll (+1-1) 
- (modified) llvm/test/CodeGen/X86/psubus.ll (+123-132) 
- (modified) llvm/test/CodeGen/X86/rotate-extract-vector.ll (-2) 
- (modified) llvm/test/CodeGen/X86/sadd_sat_vec.ll (+93-114) 
- (modified) llvm/test/CodeGen/X86/sat-add.ll (+9-12) 
- (modified) llvm/test/CodeGen/X86/sdiv-exact.ll (+13-17) 
- (modified) llvm/test/CodeGen/X86/shrink_vmul.ll (+20-14) 
- (modified) llvm/test/CodeGen/X86/srem-seteq-vec-nonsplat.ll (+9-11) 
- (modified) llvm/test/CodeGen/X86/sshl_sat_vec.ll (+38-38) 
- (modified) llvm/test/CodeGen/X86/ssub_sat_vec.ll (+123-151) 
- (modified) llvm/test/CodeGen/X86/test-shrink-bug.ll (+1-1) 
- (modified) llvm/test/CodeGen/X86/udiv-exact.ll (+13-17) 
- (modified) llvm/test/CodeGen/X86/urem-seteq-illegal-types.ll (+6-4) 
- (modified) llvm/test/CodeGen/X86/urem-seteq-vec-nonsplat.ll (+122-172) 
- (modified) llvm/test/CodeGen/X86/ushl_sat_vec.ll (+14-15) 
- (modified) llvm/test/CodeGen/X86/vec_minmax_sint.ll (+66-92) 
- (modified) llvm/test/CodeGen/X86/vec_minmax_uint.ll (+66-92) 
- (modified) llvm/test/CodeGen/X86/vec_smulo.ll (+32-38) 
- (modified) llvm/test/CodeGen/X86/vec_umulo.ll (+60-70) 
- (modified) llvm/test/CodeGen/X86/vector-compare-all_of.ll (+15-21) 
- (modified) llvm/test/CodeGen/X86/vector-compare-any_of.ll (+15-21) 
- (modified) llvm/test/CodeGen/X86/vector-fshl-128.ll (+139-159) 
- (modified) llvm/test/CodeGen/X86/vector-fshl-256.ll (+47-48) 
- (modified) llvm/test/CodeGen/X86/vector-fshl-rot-128.ll (+37-56) 
- (modified) llvm/test/CodeGen/X86/vector-fshl-rot-256.ll (+14-15) 
- (modified) llvm/test/CodeGen/X86/vector-fshr-128.ll (+27-33) 
- (modified) llvm/test/CodeGen/X86/vector-fshr-256.ll (+19-20) 
- (modified) llvm/test/CodeGen/X86/vector-fshr-rot-128.ll (+28-33) 
- (modified) llvm/test/CodeGen/X86/vector-fshr-rot-256.ll (+2-2) 
- (modified) llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll (+839-863) 
- (modified) llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-8.ll (+214-222) 
- (modified) llvm/test/CodeGen/X86/vector-mul.ll (+34-34) 
- (modified) llvm/test/CodeGen/X86/vector-pcmp.ll (+1-2) 
- (modified) llvm/test/CodeGen/X86/vector-reduce-fmaximum.ll (+81-82) 
- (modified) llvm/test/CodeGen/X86/vector-reduce-mul.ll (+57-113) 
- (modified) llvm/test/CodeGen/X86/vector-reduce-smax.ll (+69-99) 
- (modified) llvm/test/CodeGen/X86/vector-reduce-smin.ll (+65-96) 
- (modified) llvm/test/CodeGen/X86/vector-reduce-umax.ll (+69-99) 
- (modified) llvm/test/CodeGen/X86/vector-reduce-umin.ll (+65-96) 
- (modified) llvm/test/CodeGen/X86/vector-rotate-128.ll (+37-56) 
- (modified) llvm/test/CodeGen/X86/vector-rotate-256.ll (+14-15) 
- (modified) llvm/test/CodeGen/X86/vector-shift-shl-128.ll (+24-28) 
- (modified) llvm/test/CodeGen/X86/vector-shift-shl-256.ll (+20-22) 
- (modified) llvm/test/CodeGen/X86/vector-shift-shl-sub128.ll (+48-56) 
- (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-avx.ll (+6-6) 
- (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-ssse3.ll (+6-3) 
- (modified) llvm/test/CodeGen/X86/vector-shuffle-combining.ll (+3-6) 
- (modified) llvm/test/CodeGen/X86/vector-trunc-packus.ll (+806-945) 
- (modified) llvm/test/CodeGen/X86/vector-trunc-ssat.ll (+609-734) 
- (modified) llvm/test/CodeGen/X86/vector-trunc-usat.ll (+368-398) 
- (modified) llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll (+5-4) 
- (modified) llvm/test/CodeGen/X86/vselect.ll (+18-8) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 727526055e592..191f4cc78fcc5 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4187,6 +4187,16 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// More limited version of SimplifyDemandedBits that can be used to "look
   /// through" ops that don't contribute to the DemandedBits/DemandedElts -
   /// bitwise ops etc.
+  /// Vector elements that aren't demanded can be turned into poison unless the
+  /// corresponding bit in the \p DoNotPoisonEltMask is set.
+  SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
+                                          const APInt &DemandedElts,
+                                          const APInt &DoNotPoisonEltMask,
+                                          SelectionDAG &DAG,
+                                          unsigned Depth = 0) const;
+
+  /// Helper wrapper around SimplifyMultipleUseDemandedBits, with
+  /// DoNotPoisonEltMask being set to zero.
   SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
                                           const APInt &DemandedElts,
                                           SelectionDAG &DAG,
@@ -4202,6 +4212,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// bits from only some vector elements.
   SDValue SimplifyMultipleUseDemandedVectorElts(SDValue Op,
                                                 const APInt &DemandedElts,
+                                                const APInt &DoNotPoisonEltMask,
                                                 SelectionDAG &DAG,
                                                 unsigned Depth = 0) const;
 
@@ -4219,6 +4230,15 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   ///    results of this function, because simply replacing TLO.Old
   ///    with TLO.New will be incorrect when this parameter is true and TLO.Old
   ///    has multiple uses.
+  /// Vector elements that aren't demanded can be turned into poison unless the
+  /// corresponding bit in \p DoNotPoisonEltMask is set.
+  bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
+                                  const APInt &DoNotPoisonEltMask,
+                                  APInt &KnownUndef, APInt &KnownZero,
+                                  TargetLoweringOpt &TLO, unsigned Depth = 0,
+                                  bool AssumeSingleUse = false) const;
+  /// Version of SimplifyDemandedVectorElts without the DoNotPoisonEltMask
+  /// argument. All undemanded elements can be turned into poison.
   bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
                                   APInt &KnownUndef, APInt &KnownZero,
                                   TargetLoweringOpt &TLO, unsigned Depth = 0,
@@ -4303,8 +4323,9 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// (used to simplify the caller). The KnownUndef/Zero elements may only be
   /// accurate for those bits in the DemandedMask.
   virtual bool SimplifyDemandedVectorEltsForTargetNode(
-      SDValue Op, const APInt &DemandedElts, APInt &KnownUndef,
-      APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth = 0) const;
+      SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+      APInt &KnownUndef, APInt &KnownZero, TargetLoweringOpt &TLO,
+      unsigned Depth = 0) const;
 
   /// Attempt to simplify any target nodes based on the demanded bits/elts,
   /// returning true on success. Otherwise, analyze the
@@ -4323,6 +4344,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// bitwise ops etc.
   virtual SDValue SimplifyMultipleUseDemandedBitsForTargetNode(
       SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+      const APInt &DoNotPoisonEltMask,
       SelectionDAG &DAG, unsigned Depth) const;
 
   /// Return true if this function can prove that \p Op is never poison
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c8088d68b6f1b..1f53742b2c6f7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1457,8 +1457,10 @@ bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
                                              bool AssumeSingleUse) {
   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
   APInt KnownUndef, KnownZero;
-  if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
-                                      TLO, 0, AssumeSingleUse))
+  APInt DoNotPoisonElts = APInt::getZero(DemandedElts.getBitWidth());
+  if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, DoNotPoisonElts,
+                                      KnownUndef, KnownZero, TLO, 0,
+                                      AssumeSingleUse))
     return false;
 
   // Revisit the node.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 524c97ab3eab8..931154b922924 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -679,7 +679,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
 // TODO: Under what circumstances can we create nodes? Constant folding?
 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
-    SelectionDAG &DAG, unsigned Depth) const {
+    const APInt &DoNotPoisonEltMask, SelectionDAG &DAG, unsigned Depth) const {
   EVT VT = Op.getValueType();
 
   // Limit search depth.
@@ -713,9 +713,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
     if (NumSrcEltBits == NumDstEltBits)
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedBits, DemandedElts, DAG, Depth + 1))
+              Src, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
 
+    // FIXME: Handle DoNotPoisonEltMask better?
+
     if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
       unsigned Scale = NumDstEltBits / NumSrcEltBits;
       unsigned NumSrcElts = SrcVT.getVectorNumElements();
@@ -731,9 +734,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
       // destination element, since recursive calls below may turn not demanded
       // elements into poison.
       APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+      APInt DoNotPoisonSrcElts =
+          APIntOps::ScaleBitMask(DoNotPoisonEltMask, NumSrcElts);
 
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+              Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
     }
 
@@ -743,15 +749,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
       unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
       APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
+      APInt DoNotPoisonSrcElts = APInt::getZero(NumSrcElts);
       for (unsigned i = 0; i != NumElts; ++i)
         if (DemandedElts[i]) {
           unsigned Offset = (i % Scale) * NumDstEltBits;
           DemandedSrcBits.insertBits(DemandedBits, Offset);
           DemandedSrcElts.setBit(i / Scale);
+        } else if (DoNotPoisonEltMask[i]) {
+          DoNotPoisonSrcElts.setBit(i / Scale);
         }
 
+      // FIXME: Handle DoNotPoisonEltMask better?
+
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+              Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
     }
 
@@ -759,7 +771,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   }
   case ISD::FREEZE: {
     SDValue N0 = Op.getOperand(0);
-    if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
+    if (DAG.isGuaranteedNotToBeUndefOrPoison(N0,
+                                             DemandedElts | DoNotPoisonEltMask,
                                              /*PoisonOnly=*/false))
       return N0;
     break;
@@ -815,12 +828,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (std::optional<uint64_t> MaxSA =
-            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+    if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+            Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
       unsigned ShAmt = *MaxSA;
-      unsigned NumSignBits =
-          DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+      unsigned NumSignBits = DAG.ComputeNumSignBits(
+          Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
       if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
         return Op0;
@@ -830,15 +843,15 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SRL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (std::optional<uint64_t> MaxSA =
-            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+    if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+            Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
       unsigned ShAmt = *MaxSA;
       // Must already be signbits in DemandedBits bounds, and can't demand any
       // shifted in zeroes.
       if (DemandedBits.countl_zero() >= ShAmt) {
-        unsigned NumSignBits =
-            DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+        unsigned NumSignBits = DAG.ComputeNumSignBits(
+            Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
         if (DemandedBits.countr_zero() >= (BitWidth - NumSignBits))
           return Op0;
       }
@@ -875,7 +888,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
         shouldRemoveRedundantExtend(Op))
       return Op0;
     // If the input is already sign extended, just drop the extension.
-    unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+    // FIXME: Can we skip DoNotPoisonEltMask here?
+    unsigned NumSignBits = DAG.ComputeNumSignBits(
+        Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
     if (NumSignBits >= (BitWidth - ExBits + 1))
       return Op0;
     break;
@@ -891,7 +906,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     EVT DstVT = Op.getValueType();
-    if (IsLE && DemandedElts == 1 &&
+    // FIXME: Can we skip DoNotPoisonEltMask here?
+    if (IsLE && (DemandedElts | DoNotPoisonEltMask) == 1 &&
         DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
         DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
       return DAG.getBitcast(DstVT, Src);
@@ -906,8 +922,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Vec = Op.getOperand(0);
     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
     EVT VecVT = Vec.getValueType();
+    // FIXME: Handle DoNotPoisonEltMask better.
     if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
-        !DemandedElts[CIdx->getZExtValue()])
+        !DemandedElts[CIdx->getZExtValue()] &&
+        !DoNotPoisonEltMask[CIdx->getZExtValue()])
       return Vec;
     break;
   }
@@ -920,8 +938,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     uint64_t Idx = Op.getConstantOperandVal(2);
     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
+    APInt DoNotPoisonSubElts = DoNotPoisonEltMask.extractBits(NumSubElts, Idx);
     // If we don't demand the inserted subvector, return the base vector.
-    if (DemandedSubElts == 0)
+    // FIXME: Handle DoNotPoisonEltMask better.
+    if (DemandedSubElts == 0 && DoNotPoisonSubElts == 0)
       return Vec;
     break;
   }
@@ -934,9 +954,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
     for (unsigned i = 0; i != NumElts; ++i) {
       int M = ShuffleMask[i];
-      if (M < 0 || !DemandedElts[i])
+      if (M < 0 || (!DemandedElts[i] && !DoNotPoisonEltMask[i]))
         continue;
-      AllUndef = false;
+      if (DemandedElts[i])
+        AllUndef = false;
       IdentityLHS &= (M == (int)i);
       IdentityRHS &= ((M - NumElts) == i);
     }
@@ -957,13 +978,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
 
     if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
       if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
-              Op, DemandedBits, DemandedElts, DAG, Depth))
+              Op, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG, Depth))
         return V;
     break;
   }
   return SDValue();
 }
 
+SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
+    SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+    SelectionDAG &DAG, unsigned Depth) const {
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+                                         DoNotPoisonEltMask, DAG, Depth);
+}
+
 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG,
     unsigned Depth) const {
@@ -974,13 +1003,14 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
-                                         Depth);
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+                                         DoNotPoisonEltMask, DAG, Depth);
 }
 
 SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
-    SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
-    unsigned Depth) const {
+    SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+    SelectionDAG &DAG, unsigned Depth) const {
   APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
   return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
                                          Depth);
@@ -2756,34 +2786,51 @@ bool TargetLowering::SimplifyDemandedBits(
                              TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
       }
     }
-
     // Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
     // Demand the elt/bit if any of the original elts/bits are demanded.
     if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
       unsigned Scale = BitWidth / NumSrcEltBits;
       unsigned NumSrcElts = SrcVT.getVectorNumElements();
       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
+      APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
       for (unsigned i = 0; i != Scale; ++i) {
         unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
         unsigned BitOffset = EltOffset * NumSrcEltBits;
         APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
-        if (!Sub.isZero())
+        if (!Sub.isZero()) {
           DemandedSrcBits |= Sub;
+          for (unsigned j = 0; j != NumElts; ++j)
+            if (DemandedElts[j])
+              DemandedSrcElts.setBit((j * Scale) + i);
+        }
       }
-      // Need to demand all smaller source elements that maps to a demanded
-      // destination element, since recursive calls below may turn not demanded
-      // elements into poison.
-      APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+      // Need to "semi demand" all smaller source elements that maps to a
+      // demanded destination element, since recursive calls below may turn not
+      // demanded elements into poison. Instead of demanding such elements we
+      // use a special bitmask to indicate that the recursive calls must not
+      // turn such elements into poison.
+      APInt NoPoisonSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
 
       APInt KnownSrcUndef, KnownSrcZero;
-      if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
-                                     KnownSrcZero, TLO, Depth + 1))
+      if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, NoPoisonSrcElts,
+                                     KnownSrcUndef, KnownSrcZero, TLO,
+                                     Depth + 1))
         return true;
 
       KnownBits KnownSrcBits;
-      if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
-                               KnownSrcBits, TLO, Depth + 1))
+      if (SimplifyDemandedBits(Src, DemandedSrcBits,
+                               DemandedSrcElts | NoPoisonSrcElts, KnownSrcBits,
+                               TLO, Depth + 1))
         return true;
+
+      if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
+        if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
+                Src, DemandedSrcBits, DemandedSrcElts, NoPoisonSrcElts, TLO.DAG,
+                Depth + 1)) {
+          SDValue NewOp = TLO.DAG.getBitcast(VT, DemandedSrc);
+          return TLO.CombineTo(Op, NewOp);
+        }
+      }
     } else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
       // TODO - bigendian once we have test coverage.
       unsigned Scale = NumSrcEltBits / BitWidth;
@@ -3090,8 +3137,9 @@ bool TargetLowering::SimplifyDemandedVectorElts(SDValue Op,
                         !DCI.isBeforeLegalizeOps());
 
   APInt KnownUndef, KnownZero;
-  bool Simplified =
-      SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  bool Simplified = SimplifyDemandedVectorElts(
+      Op, DemandedElts, DoNotPoisonEltMask, KnownUndef, KnownZero, TLO);
   if (Simplified) {
     DCI.AddToWorklist(Op.getNode());
     DCI.CommitTargetLoweringOpt(TLO);
@@ -3152,6 +3200,16 @@ bool TargetLowering::SimplifyDemandedVectorElts(
     SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
     APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
     bool AssumeSingleUse) const {
+  APInt DoNotPoisonEltMask = APInt::getZero(OriginalDemandedElts.getBitWidth());
+  return SimplifyDemandedVectorElts(Op, OriginalDemandedElts,
+                                    DoNotPoisonEltMask, KnownUndef, KnownZero,
+                                    TLO, Depth, AssumeSingleUse);
+}
+
+bool TargetLowering::SimplifyDemandedVectorElts(
+    SDValue Op, const APInt &OriginalDemandedElts,
+    const APInt &DoNotPoisonEltMask, APInt &KnownUndef, APInt &KnownZero,
+    TargetLoweringOpt &TLO, unsigned Depth, bool AssumeSingleUse) const {
   EVT VT = Op.getValueType();
   unsigned Opcode = Op.getOpcode();
   APInt DemandedElts = OriginalDemandedElts;
@@ -3190,6 +3248,7 @@ bool TargetLowering::SimplifyDemandedVectorElts(
   if (Depth >= SelectionDAG::MaxRecursionDepth)
     return false;
 
+  APInt DemandedEltsInclDoNotPoison = DemandedElts | DoNotPoisonEltMask;
   SDLoc DL(Op);
   unsigned EltSizeInBits = VT.getScalarSizeInBits();
   bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
@@ -3197,10 +3256,10 @@ bool TargetLowering::SimplifyDemandedVectorElts(
   // Helper for demanding the specified elements and all the bits of both binary
   // operands.
   auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
-    SDValue NewOp0 = S...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/145903


More information about the llvm-branch-commits mailing list