[llvm] [VectorCombine] New folding pattern for extract/binop/shuffle chains (PR #145232)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Sat Jul 5 00:04:27 PDT 2025


================
@@ -2988,6 +2989,241 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
   return foldSelectShuffle(*Shuffle, true);
 }
 
+bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
+  auto *EEI = dyn_cast<ExtractElementInst>(&I);
+  if (!EEI)
+    return false;
+
+  std::queue<Value *> InstWorklist;
+  InstructionCost OrigCost = 0;
+
+  Value *InitEEV = nullptr;
+
+  unsigned int CommonCallOp = 0;
+  Instruction::BinaryOps CommonBinOp = Instruction::BinaryOpsEnd;
+
+  bool IsFirstCallOrBinInst = true;
+  bool ShouldBeCallOrBinInst = true;
+
+  SmallVector<Value *, 3> PrevVecV(3, nullptr);
+  int64_t VecSize = -1;
+
+  Value *VecOp;
+  if (!match(&I, m_ExtractElt(m_Value(VecOp), m_Zero())))
+    return false;
+
+  auto *FVT = dyn_cast<FixedVectorType>(VecOp->getType());
+  if (!FVT)
+    return false;
+
+  VecSize = FVT->getNumElements();
+  if (VecSize < 2)
+    return false;
+
+  unsigned int NumLevels = Log2_64_Ceil(VecSize), VisitedCnt = 0;
+  int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
+
+  for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
+       Cur = (Cur + 1) / 2, --Mask) {
+    if (Cur & 1)
+      ExpectedParityMask |= (1ll << Mask);
+  }
+
+  PrevVecV[2] = VecOp;
+  InitEEV = EEI;
+
+  InstWorklist.push(PrevVecV[2]);
+
+  while (!InstWorklist.empty()) {
+    Value *V = InstWorklist.front();
+    InstWorklist.pop();
+
+    auto *CI = dyn_cast<Instruction>(V);
+    if (!CI)
+      return false;
+
+    if (auto *CallI = dyn_cast<CallInst>(CI)) {
+      if (!ShouldBeCallOrBinInst || !PrevVecV[2])
+        return false;
+
+      if (!IsFirstCallOrBinInst &&
+          any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+        return false;
+
+      if (CallI != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
+        return false;
+      IsFirstCallOrBinInst = false;
+
+      auto *II = dyn_cast<IntrinsicInst>(CallI);
+      if (!II)
+        return false;
----------------
dtcxzyw wrote:

```suggestion
    if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
      if (!ShouldBeCallOrBinInst || !PrevVecV[2])
        return false;

      if (!IsFirstCallOrBinInst &&
          any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
        return false;

      if (CallI != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
        return false;
      IsFirstCallOrBinInst = false;
```

https://github.com/llvm/llvm-project/pull/145232


More information about the llvm-commits mailing list