[llvm] [VectorCombine] support mismatching extract/insert indices for foldInsExtFNeg (PR #126408)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 27 01:57:18 PST 2025


================
@@ -664,79 +664,89 @@ bool VectorCombine::foldExtractExtract(Instruction &I) {
 /// shuffle.
 bool VectorCombine::foldInsExtFNeg(Instruction &I) {
   // Match an insert (op (extract)) pattern.
-  Value *DestVec;
-  uint64_t Index;
+  Value *DstVec;
+  uint64_t ExtIdx, InsIdx;
   Instruction *FNeg;
-  if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)),
-                             m_ConstantInt(Index))))
+  if (!match(&I, m_InsertElt(m_Value(DstVec), m_OneUse(m_Instruction(FNeg)),
+                             m_ConstantInt(InsIdx))))
     return false;
 
   // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
   Value *SrcVec;
   Instruction *Extract;
   if (!match(FNeg, m_FNeg(m_CombineAnd(
                        m_Instruction(Extract),
-                       m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index))))))
+                       m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx))))))
     return false;
 
-  auto *VecTy = cast<FixedVectorType>(I.getType());
-  auto *ScalarTy = VecTy->getScalarType();
+  auto *DstVecTy = cast<FixedVectorType>(DstVec->getType());
+  auto *DstVecScalarTy = DstVecTy->getScalarType();
   auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
-  if (!SrcVecTy || ScalarTy != SrcVecTy->getScalarType())
+  if (!SrcVecTy || DstVecScalarTy != SrcVecTy->getScalarType())
     return false;
 
   // Ignore bogus insert/extract index.
-  unsigned NumElts = VecTy->getNumElements();
-  if (Index >= NumElts)
+  unsigned NumDstElts = DstVecTy->getNumElements();
+  unsigned NumSrcElts = SrcVecTy->getNumElements();
+  if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
     return false;
 
   // We are inserting the negated element into the same lane that we extracted
   // from. This is equivalent to a select-shuffle that chooses all but the
   // negated element from the destination vector.
-  SmallVector<int> Mask(NumElts);
+  SmallVector<int> Mask(NumDstElts);
   std::iota(Mask.begin(), Mask.end(), 0);
-  Mask[Index] = Index + NumElts;
+  Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
   InstructionCost OldCost =
-      TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy, CostKind) +
-      TTI.getVectorInstrCost(I, VecTy, CostKind, Index);
+      TTI.getArithmeticInstrCost(Instruction::FNeg, DstVecScalarTy, CostKind) +
+      TTI.getVectorInstrCost(I, DstVecTy, CostKind, InsIdx);
 
   // If the extract has one use, it will be eliminated, so count it in the
   // original cost. If it has more than one use, ignore the cost because it will
   // be the same before/after.
   if (Extract->hasOneUse())
-    OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index);
+    OldCost += TTI.getVectorInstrCost(*Extract, SrcVecTy, CostKind, ExtIdx);
 
   InstructionCost NewCost =
-      TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy, CostKind) +
-      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, VecTy, Mask,
+      TTI.getArithmeticInstrCost(Instruction::FNeg, SrcVecTy, CostKind) +
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, DstVecTy, Mask,
                          CostKind);
 
-  bool NeedLenChg = SrcVecTy->getNumElements() != NumElts;
+  bool NeedLenChg = SrcVecTy->getNumElements() != NumDstElts;
   // If the lengths of the two vectors are not equal,
   // we need to add a length-change vector. Add this cost.
   SmallVector<int> SrcMask;
   if (NeedLenChg) {
-    SrcMask.assign(NumElts, PoisonMaskElem);
-    SrcMask[Index] = Index;
+    SrcMask.assign(NumDstElts, PoisonMaskElem);
+    SrcMask[(ExtIdx % NumDstElts)] = ExtIdx;
     NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
                                   SrcVecTy, SrcMask, CostKind);
   }
 
+  LLVM_DEBUG(dbgs() << "Found an insertion of (extract)fneg : " << I
+                    << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
+                    << "\n");
   if (NewCost > OldCost)
     return false;
 
-  Value *NewShuf;
-  // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index
+  Value *NewShuf, *LenChgShuf = nullptr;
+  // insertelt DstVec, (fneg (extractelt SrcVec, Index)), Index
   Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
   if (NeedLenChg) {
-    // shuffle DestVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
-    Value *LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask);
-    NewShuf = Builder.CreateShuffleVector(DestVec, LenChgShuf, Mask);
+    // shuffle DstVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
+    LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask);
+    NewShuf = Builder.CreateShuffleVector(DstVec, LenChgShuf, Mask);
+    Worklist.pushValue(LenChgShuf);
   } else {
-    // shuffle DestVec, (fneg SrcVec), Mask
-    NewShuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask);
+    // shuffle DstVec, (fneg SrcVec), Mask
+    NewShuf = Builder.CreateShuffleVector(DstVec, VecFNeg, Mask);
   }
 
+  if (LenChgShuf)
+    Worklist.pushValue(LenChgShuf);
----------------
RKSimon wrote:

duplicate?

https://github.com/llvm/llvm-project/pull/126408


More information about the llvm-commits mailing list