[llvm] [VectorCombine] New folding pattern for extract/binop/shuffle chains (PR #145232)
Rajveer Singh Bharadwaj via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 30 01:04:24 PDT 2025
================
@@ -2988,6 +2989,240 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
return foldSelectShuffle(*Shuffle, true);
}
+bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
+ auto *EEI = dyn_cast<ExtractElementInst>(&I);
+ if (!EEI)
+ return false;
+
+ std::queue<Value *> InstWorklist;
+ Value *InitEEV = nullptr;
+
+ unsigned int CommonCallOp = 0, CommonBinOp = 0;
+
+ bool IsFirstCallOrBinInst = true;
+ bool ShouldBeCallOrBinInst = true;
+
+ SmallVector<Value *, 3> PrevVecV(3, nullptr);
+ int64_t ShuffleMaskHalf = -1, ExpectedShuffleMaskHalf = 1;
+ int64_t VecSize = -1;
+
+ Value *VecOp;
+ if (!match(&I, m_ExtractElt(m_Value(VecOp), m_Zero())))
+ return false;
+
+ auto *FVT = dyn_cast<FixedVectorType>(VecOp->getType());
+ if (!FVT)
+ return false;
+
+ VecSize = FVT->getNumElements();
+ if (VecSize < 2 || (VecSize % 2) != 0)
+ return false;
+
+ ShuffleMaskHalf = 1;
+ PrevVecV[2] = VecOp;
+ InitEEV = EEI;
+
+ InstWorklist.push(PrevVecV[2]);
+
+ while (!InstWorklist.empty()) {
+ Value *V = InstWorklist.front();
+ InstWorklist.pop();
+
+ auto *CI = dyn_cast<Instruction>(V);
+ if (!CI)
+ return false;
+
+ if (auto *CallI = dyn_cast<CallInst>(CI)) {
+ if (!ShouldBeCallOrBinInst || !PrevVecV[2])
+ return false;
+
+ if (!IsFirstCallOrBinInst &&
+ any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+ return false;
+
+ if (CallI != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
+ return false;
+ IsFirstCallOrBinInst = false;
+
+ auto *II = dyn_cast<IntrinsicInst>(CallI);
+ if (!II)
+ return false;
+
+ if (!CommonCallOp)
+ CommonCallOp = II->getIntrinsicID();
+ if (II->getIntrinsicID() != CommonCallOp)
+ return false;
+
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::umin:
+ case Intrinsic::umax:
+ case Intrinsic::smin:
+ case Intrinsic::smax: {
+ auto *Op0 = CallI->getOperand(0);
+ auto *Op1 = CallI->getOperand(1);
+ PrevVecV[0] = Op0;
+ PrevVecV[1] = Op1;
+ break;
+ }
+ default:
+ return false;
+ }
+ ShouldBeCallOrBinInst ^= 1;
+
+ if (!isa<ShuffleVectorInst>(PrevVecV[1]))
+ std::swap(PrevVecV[0], PrevVecV[1]);
+ InstWorklist.push(PrevVecV[1]);
+ InstWorklist.push(PrevVecV[0]);
+ } else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
+ if (!ShouldBeCallOrBinInst || !PrevVecV[2])
+ return false;
+
+ if (!IsFirstCallOrBinInst &&
+ any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+ return false;
+
+ if (BinOp != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
+ return false;
+ IsFirstCallOrBinInst = false;
+
+ if (!CommonBinOp)
+ CommonBinOp = CI->getOpcode();
+ if (CI->getOpcode() != CommonBinOp)
+ return false;
+
+ switch (CI->getOpcode()) {
+ case BinaryOperator::Add:
+ case BinaryOperator::Mul:
+ case BinaryOperator::Or:
+ case BinaryOperator::And:
+ case BinaryOperator::Xor: {
+ auto *Op0 = BinOp->getOperand(0);
+ auto *Op1 = BinOp->getOperand(1);
+ PrevVecV[0] = Op0;
+ PrevVecV[1] = Op1;
+ break;
+ }
+ default:
+ return false;
+ }
+ ShouldBeCallOrBinInst ^= 1;
+
+ if (!isa<ShuffleVectorInst>(PrevVecV[1]))
+ std::swap(PrevVecV[0], PrevVecV[1]);
+ InstWorklist.push(PrevVecV[1]);
+ InstWorklist.push(PrevVecV[0]);
+ } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
+ if (ShouldBeCallOrBinInst ||
+ any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+ return false;
+
+ if (SVInst != PrevVecV[1])
+ return false;
+
+ auto *ShuffleVec = SVInst->getOperand(0);
+ if (!ShuffleVec || ShuffleVec != PrevVecV[0])
+ return false;
+
+ SmallVector<int> CurMask;
+ SVInst->getShuffleMask(CurMask);
+
+ if (ShuffleMaskHalf != ExpectedShuffleMaskHalf)
+ return false;
+ ExpectedShuffleMaskHalf *= 2;
+
+ for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
+ if (Mask < ShuffleMaskHalf && CurMask[Mask] != ShuffleMaskHalf + Mask)
+ return false;
+ if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
+ return false;
+ }
+ ShuffleMaskHalf *= 2;
+ if (ExpectedShuffleMaskHalf == VecSize)
+ break;
+ ShouldBeCallOrBinInst ^= 1;
+ } else {
+ return false;
+ }
+ }
+
+ if (ShouldBeCallOrBinInst)
+ return false;
+
+ assert(VecSize != -1 && ExpectedShuffleMaskHalf == VecSize &&
+ "Expected Match for Vector Size and Mask Half");
+
+ Value *FinalVecV = PrevVecV[0];
+ auto *FinalVecVTy = dyn_cast<FixedVectorType>(FinalVecV->getType());
+
+ if (!InitEEV || !FinalVecV)
+ return false;
+
+ assert(FinalVecVTy && "Expected non-null value for Vector Type");
+
+ Intrinsic::ID ReducedOp = 0;
+ if (CommonCallOp) {
+ switch (CommonCallOp) {
+ case Intrinsic::umin:
+ ReducedOp = Intrinsic::vector_reduce_umin;
+ break;
+ case Intrinsic::umax:
+ ReducedOp = Intrinsic::vector_reduce_umax;
+ break;
+ case Intrinsic::smin:
+ ReducedOp = Intrinsic::vector_reduce_smin;
+ break;
+ case Intrinsic::smax:
+ ReducedOp = Intrinsic::vector_reduce_smax;
+ break;
+ default:
+ return false;
+ }
+ } else if (CommonBinOp) {
+ switch (CommonBinOp) {
+ case BinaryOperator::Add:
+ ReducedOp = Intrinsic::vector_reduce_add;
+ break;
+ case BinaryOperator::Mul:
+ ReducedOp = Intrinsic::vector_reduce_mul;
+ break;
+ case BinaryOperator::Or:
+ ReducedOp = Intrinsic::vector_reduce_or;
+ break;
+ case BinaryOperator::And:
+ ReducedOp = Intrinsic::vector_reduce_and;
+ break;
+ case BinaryOperator::Xor:
+ ReducedOp = Intrinsic::vector_reduce_xor;
+ break;
+ default:
+ return false;
+ }
+ }
+
+ InstructionCost OrigCost = 0;
+ unsigned int NumLevels = Log2_64(VecSize);
+
+ for (unsigned int Level = 0; Level < NumLevels; ++Level) {
+ OrigCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
+ FinalVecVTy, FinalVecVTy);
+ OrigCost += TTI.getArithmeticInstrCost(Instruction::ICmp, FinalVecVTy);
----------------
Rajveer100 wrote:
Regarding this, I was actually wondering from the beginning if cost analysis is even worth it, since isn't it always cheaper to replace multiple combo operations with a single one?
In fact, if we consider the original ones like you suggest, it will be even costlier when the sum adds up make it more biased towards the single reduce.
https://github.com/llvm/llvm-project/pull/145232
More information about the llvm-commits
mailing list