[llvm] [AMDGPU][AMDGPULateCodeGenPrepare] Combine scalarized selects back into vector selects (PR #173990)

Pankaj Dwivedi via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 30 08:38:38 PST 2025


================
@@ -551,6 +593,113 @@ bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
   return true;
 }
 
+bool AMDGPULateCodeGenPrepare::tryCombineSelectsFromBitcast(BitCastInst &BC) {
+  auto *SrcVecTy = dyn_cast<FixedVectorType>(BC.getSrcTy());
+  auto *DstVecTy = dyn_cast<FixedVectorType>(BC.getDestTy());
+  if (!SrcVecTy || !DstVecTy)
+    return false;
+
+  // Must be: bitcast <N x i32> to <M x i8>
+  if (!SrcVecTy->getElementType()->isIntegerTy(32) ||
+      !DstVecTy->getElementType()->isIntegerTy(8))
+    return false;
+
+  unsigned NumDstElts = DstVecTy->getNumElements();
+  BasicBlock *BB = BC.getParent();
+
+  // Require at least half the elements to have matching selects.
+  // For v16i8 (from v4i32), this means at least 8 selects must match.
+  // This threshold ensures the transformation is profitable.
+  unsigned MinRequired = NumDstElts / 2;
+
+  // Early exit: not enough users to possibly meet the threshold.
+  if (BC.getNumUses() < MinRequired)
+    return false;
+
+  // Group selects by their condition value. Different conditions selecting
+  // from the same bitcast are handled as independent groups, allowing us to
+  // optimize multiple select patterns from a single bitcast.
+  struct SelectGroup {
+    // Map from element index to (select, extractelement) pair.
+    SmallDenseMap<unsigned, std::pair<SelectInst *, ExtractElementInst *>, 16>
+        Selects;
+    // Track the earliest select instruction for correct insertion point.
+    SelectInst *FirstSelect = nullptr;
+  };
+  DenseMap<Value *, SelectGroup> ConditionGroups;
+
+  // Collect all matching select patterns in a single pass.
+  // Pattern: select i1 %cond, i8 (extractelement %bc, idx), i8 0
+  for (User *U : BC.users()) {
+    auto *Ext = dyn_cast<ExtractElementInst>(U);
+    if (!Ext || Ext->getParent() != BB)
+      continue;
+
+    auto *IdxC = dyn_cast<ConstantInt>(Ext->getIndexOperand());
+    if (!IdxC || IdxC->getZExtValue() >= NumDstElts)
+      continue;
+
+    unsigned Idx = IdxC->getZExtValue();
+
+    for (User *EU : Ext->users()) {
+      auto *Sel = dyn_cast<SelectInst>(EU);
+      // Must be: select %cond, %extract, 0 (in same BB)
+      if (!Sel || Sel->getParent() != BB || Sel->getTrueValue() != Ext ||
+          !match(Sel->getFalseValue(), m_Zero()))
+        continue;
+
+      auto &Group = ConditionGroups[Sel->getCondition()];
+      Group.Selects[Idx] = {Sel, Ext};
+
+      // Track earliest select to ensure correct dominance for insertion.
+      if (!Group.FirstSelect || Sel->comesBefore(Group.FirstSelect))
+        Group.FirstSelect = Sel;
+    }
+  }
+
+  bool Changed = false;
+
+  // Process each condition group that meets the threshold.
+  for (auto &[Cond, Group] : ConditionGroups) {
+    if (Group.Selects.size() < MinRequired)
+      continue;
+
+    LLVM_DEBUG(dbgs() << "AMDGPULateCodeGenPrepare: Combining "
+                      << Group.Selects.size()
+                      << " scalar selects into vector select\n");
+
+    // Insert before the first select to maintain dominance.
+    IRBuilder<> Builder(Group.FirstSelect);
+
+    // Create vector select: select i1 %cond, <N x i32> %src, zeroinitializer
+    Value *VecSel =
+        Builder.CreateSelect(Cond, BC.getOperand(0),
+                             Constant::getNullValue(SrcVecTy), "combined.sel");
+
+    // Bitcast the selected vector back to the byte vector type.
+    Value *NewBC = Builder.CreateBitCast(VecSel, DstVecTy, "combined.bc");
+
+    // Replace each scalar select with an extract from the combined result.
+    for (auto &[Idx, Pair] : Group.Selects) {
+      Value *NewExt = Builder.CreateExtractElement(NewBC, Idx);
+      Pair.first->replaceAllUsesWith(NewExt);
+      DeadInsts.emplace_back(Pair.first);
+
+      // Mark the original extract as dead if it has no remaining uses.
+      if (Pair.second->use_empty())
+        DeadInsts.emplace_back(Pair.second);
----------------
PankajDwivedi-25 wrote:

Have added the test for it, in the recent patch.

https://github.com/llvm/llvm-project/pull/173990


More information about the llvm-commits mailing list