[llvm] [VectorCombine] Scalarize extracts of ZExt if profitable. (PR #142976)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 20 14:58:40 PDT 2025


================
@@ -1710,6 +1711,73 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
+bool VectorCombine::scalarizeExtExtract(Instruction &I) {
+  if (!match(&I, m_ZExt(m_Value())))
+    return false;
+
+  // Try to convert a vector zext feeding only extracts to a set of scalar (Src
+  // << ExtIdx *Size) & (Size -1), if profitable.
+  auto *Ext = cast<ZExtInst>(&I);
+  auto *SrcTy = dyn_cast<FixedVectorType>(Ext->getOperand(0)->getType());
+  if (!SrcTy)
+    return false;
+  auto *DstTy = cast<FixedVectorType>(Ext->getType());
+
+  Type *ScalarDstTy = DstTy->getElementType();
+  if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy))
+    return false;
+
+  InstructionCost VectorCost =
+      TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy,
+                           TTI::CastContextHint::None, CostKind, Ext);
+  unsigned ExtCnt = 0;
+  bool ExtLane0 = false;
+  for (User *U : Ext->users()) {
+    const APInt *Idx;
+    if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
+      return false;
+    if (cast<Instruction>(U)->use_empty())
+      continue;
+    ExtCnt += 1;
+    ExtLane0 |= Idx->isZero();
+    VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
+                                         CostKind, Idx->getZExtValue(), U);
+  }
+
+  InstructionCost ScalarCost =
+      ExtCnt * TTI.getArithmeticInstrCost(
+                   Instruction::And, ScalarDstTy, CostKind,
+                   {TTI::OK_AnyValue, TTI::OP_None},
+                   {TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
+      (ExtCnt - ExtLane0) *
+          TTI.getArithmeticInstrCost(
+              Instruction::LShr, ScalarDstTy, CostKind,
+              {TTI::OK_AnyValue, TTI::OP_None},
+              {TTI::OK_NonUniformConstantValue, TTI::OP_None});
+  if (ScalarCost > VectorCost)
+    return false;
+
+  Value *ScalarV = Ext->getOperand(0);
+  if (!isGuaranteedNotToBePoison(ScalarV, &AC))
+    ScalarV = Builder.CreateFreeze(ScalarV);
+  ScalarV = Builder.CreateBitCast(
+      ScalarV,
+      IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
+  unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
+  Value *EltBitMask =
+      ConstantInt::get(ScalarV->getType(), (1ull << SrcEltSizeInBits) - 1);
+  for (auto *U : to_vector(Ext->users())) {
+    auto *Extract = cast<ExtractElementInst>(U);
+    unsigned Idx =
+        cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
+    auto *S = Builder.CreateLShr(
+        ScalarV, ConstantInt::get(ScalarV->getType(), Idx * SrcEltSizeInBits));
----------------
artagnon wrote:

This ConstantInt::get is also unnecessary?

https://github.com/llvm/llvm-project/pull/142976


More information about the llvm-commits mailing list