[llvm] [VectorCombine] Handle shuffle of selects (PR #128032)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 27 02:17:34 PST 2025


================
@@ -1899,6 +1900,74 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
   return true;
 }
 
+/// Try to convert,
+/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
+/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
+bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
+  ArrayRef<int> Mask;
+  Value *C1, *T1, *F1, *C2, *T2, *F2;
+  if (!match(&I, m_Shuffle(
+                     m_OneUse(m_Select(m_Value(C1), m_Value(T1), m_Value(F1))),
+                     m_OneUse(m_Select(m_Value(C2), m_Value(T2), m_Value(F2))),
+                     m_Mask(Mask))))
+    return false;
+
+  auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
+  auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
+  auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
+  if (!C1VecTy || !C2VecTy)
+    return false;
+
+  // SelectInsts must have the same FMF.
+  auto *Select0 = cast<Instruction>(I.getOperand(0));
+  if (auto *SI0FOp = dyn_cast<FPMathOperator>(Select0))
+    if (auto *SI1FOp = dyn_cast<FPMathOperator>((I.getOperand(1))))
+      if (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())
+        return false;
+
+  auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
+  auto SelOp = Instruction::Select;
+  InstructionCost OldCost = TTI.getCmpSelInstrCost(
+      SelOp, T1->getType(), C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
+  OldCost += TTI.getCmpSelInstrCost(SelOp, T2->getType(), C2VecTy,
+                                    CmpInst::BAD_ICMP_PREDICATE, CostKind);
+  OldCost += TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr,
+                                {I.getOperand(0), I.getOperand(1)}, &I);
+
+  auto *C1C2VecTy = cast<FixedVectorType>(
+      toVectorTy(Type::getInt1Ty(I.getContext()), DstVecTy->getNumElements()));
+  InstructionCost NewCost =
+      TTI.getShuffleCost(SK, C1C2VecTy, Mask, CostKind, 0, nullptr, {C1, C2});
+  NewCost +=
+      TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr, {T1, T2});
+  NewCost +=
+      TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr, {F1, F2});
+  NewCost += TTI.getCmpSelInstrCost(SelOp, DstVecTy, DstVecTy,
+                                    CmpInst::BAD_ICMP_PREDICATE, CostKind);
+
+  LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
+                    << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
+                    << "\n");
+  if (NewCost > OldCost)
+    return false;
+
+  Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
+  Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
+  Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
+  Value *NewSel = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
+
+  // We presuppose that the SelectInsts have the same FMF.
+  if (isa<FPMathOperator>(NewSel))
+    cast<Instruction>(NewSel)->setFastMathFlags(Select0->getFastMathFlags());
+
+  Worklist.pushValue(ShuffleCmp);
+  Worklist.pushValue(ShuffleTrue);
+  Worklist.pushValue(ShuffleFalse);
+  Worklist.pushValue(NewSel);
----------------
RKSimon wrote:

No need to pushValue(NewSel) - replaceValue should handle this for us

https://github.com/llvm/llvm-project/pull/128032


More information about the llvm-commits mailing list