[llvm] [InstCombine] Pull extract through broadcast (PR #143380)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 10 08:02:47 PDT 2025
================
@@ -542,27 +542,40 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
}
}
} else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
- // If this is extracting an element from a shufflevector, figure out where
- // it came from and extract from the appropriate input element instead.
- // Restrict the following transformation to fixed-length vector.
- if (isa<FixedVectorType>(SVI->getType()) && isa<ConstantInt>(Index)) {
- int SrcIdx =
- SVI->getMaskValue(cast<ConstantInt>(Index)->getZExtValue());
- Value *Src;
- unsigned LHSWidth = cast<FixedVectorType>(SVI->getOperand(0)->getType())
- ->getNumElements();
-
- if (SrcIdx < 0)
- return replaceInstUsesWith(EI, PoisonValue::get(EI.getType()));
- if (SrcIdx < (int)LHSWidth)
- Src = SVI->getOperand(0);
- else {
- SrcIdx -= LHSWidth;
- Src = SVI->getOperand(1);
+ int SplatIndex = getSplatIndex(SVI->getShuffleMask());
+ // We know such a splat must be reading from the first operand, even
+ // in the case of scalable vectors (vscale is always > 0).
+ if (SplatIndex == 0)
+ return ExtractElementInst::Create(SVI->getOperand(0), Builder.getInt64(0));
+ // Restrict the non-zero index case to fixed-length vectors
+ if (isa<FixedVectorType>(SVI->getType())) {
+
+ // getSplatIndex doesn't distinguish between the all-poison splat and
+ // a non-splat mask. However, if Index is -1, we still want to propagate
+ // that poison value.
+ int SrcIdx = -2;
----------------
agorenstein-nvidia wrote:
I'm not entirely happy with this. I've also considered either a tiny helper function like `bool test(int* outParam)`, or carrying along `bool ValidSrcIdx`, or rewriting `getSplatIndex` to be able to meaningfully return the all-poison splat indicator, and this seemed the best option.
https://github.com/llvm/llvm-project/pull/143380
More information about the llvm-commits
mailing list