[llvm] [VectorCombine] New folding pattern for extract/binop/shuffle chains (PR #145232)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 22 06:39:40 PDT 2025


================
@@ -2910,6 +2911,133 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
   return foldSelectShuffle(*Shuffle, true);
 }
 
+bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
+  auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
+  if (!SVI)
+    return false;
+
+  std::queue<Value *> Worklist;
+  SmallVector<Instruction *> ToEraseFromParent;
+
+  SmallVector<int> ShuffleMask;
+  bool IsShuffleOp = true;
+
+  Worklist.push(SVI);
+  SVI->getShuffleMask(ShuffleMask);
+
+  if (ShuffleMask.size() < 2)
+    return false;
+
+  Instruction *Prev0 = nullptr, *Prev1 = nullptr;
+  Instruction *LastOp = nullptr;
+
+  int MaskHalfPos = ShuffleMask.size() / 2;
+  bool IsFirst = true;
+
+  while (!Worklist.empty()) {
+    Value *V = Worklist.front();
+    Worklist.pop();
+
+    auto *CI = dyn_cast<Instruction>(V);
+    if (!CI)
+      return false;
+
+    if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
+      if (!IsShuffleOp || MaskHalfPos < 1 || (!Prev1 && !IsFirst))
+        return false;
+
+      auto *Op0 = SV->getOperand(0);
+      auto *Op1 = SV->getOperand(1);
+      if (!Op0 || !Op1)
+        return false;
+
+      auto *FVT = dyn_cast<FixedVectorType>(Op1->getType());
+      if (!FVT || !isa<PoisonValue>(Op1))
+        return false;
+
+      SmallVector<int> CurrentMask;
+      SV->getShuffleMask(CurrentMask);
+
+      int64_t MaskSize = CurrentMask.size();
+      for (int MaskPos = 0; MaskPos != MaskSize; ++MaskPos) {
+        if (MaskPos < MaskHalfPos &&
+            CurrentMask[MaskPos] != MaskHalfPos + MaskPos)
+          return false;
+        if (MaskPos >= MaskHalfPos && CurrentMask[MaskPos] != -1)
+          return false;
+      }
+      MaskHalfPos /= 2;
+      Prev0 = SV;
+    } else if (auto *Call = dyn_cast<CallInst>(V)) {
----------------
RKSimon wrote:

How do you intend to generalise this for smin/umax/smax calls and then add/mul/and/or/xor?

https://github.com/llvm/llvm-project/pull/145232


More information about the llvm-commits mailing list