[llvm] [VectorCombine] Scalarize extracts of ZExt if profitable. (PR #142976)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 30 07:10:23 PDT 2025


================
@@ -1770,6 +1771,73 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
+bool VectorCombine::scalarizeExtExtract(Instruction &I) {
+  auto *Ext = dyn_cast<ZExtInst>(&I);
+  if (!Ext)
+    return false;
+
+  // Try to convert a vector zext feeding only extracts to a set of scalar
+  //   (Src << ExtIdx *Size) & (Size -1)
+  // if profitable   .
+  auto *SrcTy = dyn_cast<FixedVectorType>(Ext->getOperand(0)->getType());
+  if (!SrcTy)
+    return false;
+  auto *DstTy = cast<FixedVectorType>(Ext->getType());
+
+  Type *ScalarDstTy = DstTy->getElementType();
+  if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy))
+    return false;
+
+  InstructionCost VectorCost =
+      TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy,
+                           TTI::CastContextHint::None, CostKind, Ext);
+  unsigned ExtCnt = 0;
+  bool ExtLane0 = false;
+  for (User *U : Ext->users()) {
+    const APInt *Idx;
+    if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
+      return false;
+    if (cast<Instruction>(U)->use_empty())
+      continue;
+    ExtCnt += 1;
+    ExtLane0 |= Idx->isZero();
+    VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
+                                         CostKind, Idx->getZExtValue(), U);
+  }
+
+  InstructionCost ScalarCost =
+      ExtCnt * TTI.getArithmeticInstrCost(
+                   Instruction::And, ScalarDstTy, CostKind,
+                   {TTI::OK_AnyValue, TTI::OP_None},
+                   {TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
+      (ExtCnt - ExtLane0) *
+          TTI.getArithmeticInstrCost(
+              Instruction::LShr, ScalarDstTy, CostKind,
+              {TTI::OK_AnyValue, TTI::OP_None},
+              {TTI::OK_NonUniformConstantValue, TTI::OP_None});
+  if (ScalarCost > VectorCost)
+    return false;
+
+  Value *ScalarV = Ext->getOperand(0);
+  if (!isGuaranteedNotToBePoison(ScalarV, &AC, dyn_cast<Instruction>(ScalarV),
+                                 &DT))
+    ScalarV = Builder.CreateFreeze(ScalarV);
+  ScalarV = Builder.CreateBitCast(
+      ScalarV,
+      IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
+  unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
+  unsigned EltBitMask = (1ull << SrcEltSizeInBits) - 1;
+  for (User *U : Ext->users()) {
+    auto *Extract = cast<ExtractElementInst>(U);
+    uint64_t Idx =
+        cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
+    Value *S = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits);
+    Value *A = Builder.CreateAnd(S, EltBitMask);
+    U->replaceAllUsesWith(A);
----------------
artagnon wrote:

```suggestion
    Value *LShr = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits);
    Value *And = Builder.CreateAnd(LShr, EltBitMask);
    U->replaceAllUsesWith(And);
```

https://github.com/llvm/llvm-project/pull/142976


More information about the llvm-commits mailing list