[llvm] [VectorCombine] Scalarize extracts of ZExt if profitable. (PR #142976)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 5 07:52:15 PDT 2025


================
@@ -1710,6 +1711,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
+bool VectorCombine::scalarizeExtExtract(Instruction &I) {
+  if (!match(&I, m_ZExt(m_Value())))
+    return false;
+
+  // Try to convert a vector zext feeding only extracts to a set of scalar (Src
+  // << ExtIdx *Size) & (Size -1), if profitable.
+  auto *Ext = cast<ZExtInst>(&I);
+  auto *SrcTy = cast<FixedVectorType>(Ext->getOperand(0)->getType());
+  auto *DstTy = cast<FixedVectorType>(Ext->getType());
+
+  if (DL->getTypeSizeInBits(SrcTy) !=
+      DL->getTypeSizeInBits(DstTy->getElementType()))
+    return false;
+
+  InstructionCost VectorCost = TTI.getCastInstrCost(
+      Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
+  unsigned ExtCnt = 0;
+  bool ExtLane0 = false;
+  for (User *U : Ext->users()) {
+    const APInt *Idx;
+    if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
+      return false;
+    if (cast<Instruction>(U)->use_empty())
+      continue;
+    ExtCnt += 1;
+    ExtLane0 |= Idx->isZero();
+    VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
+                                         CostKind, Idx->getZExtValue(), U);
+  }
+
+  Type *ScalarDstTy = DstTy->getElementType();
+  InstructionCost ScalarCost =
+      ExtCnt * TTI.getArithmeticInstrCost(
+                   Instruction::And, ScalarDstTy, CostKind,
+                   {TTI::OK_AnyValue, TTI::OP_None},
+                   {TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
+      (ExtCnt - ExtLane0) *
+          TTI.getArithmeticInstrCost(
+
+              Instruction::LShr, ScalarDstTy, CostKind,
+              {TTI::OK_AnyValue, TTI::OP_None},
+              {TTI::OK_NonUniformConstantValue, TTI::OP_None});
+  if (ScalarCost > VectorCost)
+    return false;
+
+  Value *ScalarV = Ext->getOperand(0);
+  if (!isGuaranteedNotToBePoison(ScalarV, &AC))
+    ScalarV = Builder.CreateFreeze(ScalarV);
+  ScalarV = Builder.CreateBitCast(
+      ScalarV,
+      IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
----------------
RKSimon wrote:

Where do you account for the cost of the bitcast (vec->int transfer)?

https://github.com/llvm/llvm-project/pull/142976


More information about the llvm-commits mailing list