[llvm] [AMDGPU][AMDGPULateCodeGenPrepare] Combine scalarized selects back into vector selects (PR #173990)
Pankaj Dwivedi via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 30 08:38:38 PST 2025
================
@@ -551,6 +593,113 @@ bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
return true;
}
+bool AMDGPULateCodeGenPrepare::tryCombineSelectsFromBitcast(BitCastInst &BC) {
+ auto *SrcVecTy = dyn_cast<FixedVectorType>(BC.getSrcTy());
+ auto *DstVecTy = dyn_cast<FixedVectorType>(BC.getDestTy());
+ if (!SrcVecTy || !DstVecTy)
+ return false;
+
+ // Must be: bitcast <N x i32> to <M x i8>
+ if (!SrcVecTy->getElementType()->isIntegerTy(32) ||
+ !DstVecTy->getElementType()->isIntegerTy(8))
+ return false;
+
+ unsigned NumDstElts = DstVecTy->getNumElements();
+ BasicBlock *BB = BC.getParent();
+
+ // Require at least half the elements to have matching selects.
+ // For v16i8 (from v4i32), this means at least 8 selects must match.
+ // This threshold ensures the transformation is profitable.
+ unsigned MinRequired = NumDstElts / 2;
+
+ // Early exit: not enough users to possibly meet the threshold.
+ if (BC.getNumUses() < MinRequired)
+ return false;
+
+ // Group selects by their condition value. Different conditions selecting
+ // from the same bitcast are handled as independent groups, allowing us to
+ // optimize multiple select patterns from a single bitcast.
+ struct SelectGroup {
+ // Map from element index to (select, extractelement) pair.
+ SmallDenseMap<unsigned, std::pair<SelectInst *, ExtractElementInst *>, 16>
+ Selects;
+ // Track the earliest select instruction for correct insertion point.
+ SelectInst *FirstSelect = nullptr;
+ };
+ DenseMap<Value *, SelectGroup> ConditionGroups;
+
+ // Collect all matching select patterns in a single pass.
+ // Pattern: select i1 %cond, i8 (extractelement %bc, idx), i8 0
+ for (User *U : BC.users()) {
+ auto *Ext = dyn_cast<ExtractElementInst>(U);
+ if (!Ext || Ext->getParent() != BB)
+ continue;
+
+ auto *IdxC = dyn_cast<ConstantInt>(Ext->getIndexOperand());
+ if (!IdxC || IdxC->getZExtValue() >= NumDstElts)
+ continue;
+
+ unsigned Idx = IdxC->getZExtValue();
+
+ for (User *EU : Ext->users()) {
+ auto *Sel = dyn_cast<SelectInst>(EU);
+ // Must be: select %cond, %extract, 0 (in same BB)
+ if (!Sel || Sel->getParent() != BB || Sel->getTrueValue() != Ext ||
+ !match(Sel->getFalseValue(), m_Zero()))
+ continue;
+
+ auto &Group = ConditionGroups[Sel->getCondition()];
+ Group.Selects[Idx] = {Sel, Ext};
+
+ // Track earliest select to ensure correct dominance for insertion.
+ if (!Group.FirstSelect || Sel->comesBefore(Group.FirstSelect))
+ Group.FirstSelect = Sel;
+ }
+ }
+
+ bool Changed = false;
+
+ // Process each condition group that meets the threshold.
+ for (auto &[Cond, Group] : ConditionGroups) {
+ if (Group.Selects.size() < MinRequired)
+ continue;
+
+ LLVM_DEBUG(dbgs() << "AMDGPULateCodeGenPrepare: Combining "
+ << Group.Selects.size()
+ << " scalar selects into vector select\n");
+
+ // Insert before the first select to maintain dominance.
+ IRBuilder<> Builder(Group.FirstSelect);
+
+ // Create vector select: select i1 %cond, <N x i32> %src, zeroinitializer
+ Value *VecSel =
+ Builder.CreateSelect(Cond, BC.getOperand(0),
+ Constant::getNullValue(SrcVecTy), "combined.sel");
+
+ // Bitcast the selected vector back to the byte vector type.
+ Value *NewBC = Builder.CreateBitCast(VecSel, DstVecTy, "combined.bc");
+
+ // Replace each scalar select with an extract from the combined result.
+ for (auto &[Idx, Pair] : Group.Selects) {
+ Value *NewExt = Builder.CreateExtractElement(NewBC, Idx);
+ Pair.first->replaceAllUsesWith(NewExt);
+ DeadInsts.emplace_back(Pair.first);
+
+ // Mark the original extract as dead if it has no remaining uses.
+ if (Pair.second->use_empty())
+ DeadInsts.emplace_back(Pair.second);
----------------
PankajDwivedi-25 wrote:
Have added the test for it, in the recent patch.
https://github.com/llvm/llvm-project/pull/173990
More information about the llvm-commits
mailing list