[llvm] [VectorCombine] New folding pattern for extract/binop/shuffle chains (PR #145232)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Sat Jul 5 00:04:27 PDT 2025
================
@@ -2988,6 +2989,241 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
return foldSelectShuffle(*Shuffle, true);
}
+bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
+ auto *EEI = dyn_cast<ExtractElementInst>(&I);
+ if (!EEI)
+ return false;
+
+ std::queue<Value *> InstWorklist;
+ InstructionCost OrigCost = 0;
+
+ Value *InitEEV = nullptr;
+
+ unsigned int CommonCallOp = 0;
+ Instruction::BinaryOps CommonBinOp = Instruction::BinaryOpsEnd;
+
+ bool IsFirstCallOrBinInst = true;
+ bool ShouldBeCallOrBinInst = true;
+
+ SmallVector<Value *, 3> PrevVecV(3, nullptr);
+ int64_t VecSize = -1;
+
+ Value *VecOp;
+ if (!match(&I, m_ExtractElt(m_Value(VecOp), m_Zero())))
+ return false;
+
+ auto *FVT = dyn_cast<FixedVectorType>(VecOp->getType());
+ if (!FVT)
+ return false;
+
+ VecSize = FVT->getNumElements();
+ if (VecSize < 2)
+ return false;
+
+ unsigned int NumLevels = Log2_64_Ceil(VecSize), VisitedCnt = 0;
+ int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
+
+ for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
+ Cur = (Cur + 1) / 2, --Mask) {
+ if (Cur & 1)
+ ExpectedParityMask |= (1ll << Mask);
+ }
+
+ PrevVecV[2] = VecOp;
+ InitEEV = EEI;
+
+ InstWorklist.push(PrevVecV[2]);
+
+ while (!InstWorklist.empty()) {
+ Value *V = InstWorklist.front();
+ InstWorklist.pop();
+
+ auto *CI = dyn_cast<Instruction>(V);
+ if (!CI)
+ return false;
+
+ if (auto *CallI = dyn_cast<CallInst>(CI)) {
+ if (!ShouldBeCallOrBinInst || !PrevVecV[2])
+ return false;
+
+ if (!IsFirstCallOrBinInst &&
+ any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+ return false;
+
+ if (CallI != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
+ return false;
+ IsFirstCallOrBinInst = false;
+
+ auto *II = dyn_cast<IntrinsicInst>(CallI);
+ if (!II)
+ return false;
----------------
dtcxzyw wrote:
```suggestion
if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
if (!ShouldBeCallOrBinInst || !PrevVecV[2])
return false;
if (!IsFirstCallOrBinInst &&
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
return false;
if (CallI != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
return false;
IsFirstCallOrBinInst = false;
```
https://github.com/llvm/llvm-project/pull/145232
More information about the llvm-commits
mailing list