[llvm] [VectorCombine] New folding pattern for extract/binop/shuffle chains (PR #145232)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 22 06:39:38 PDT 2025


================
@@ -2910,6 +2911,133 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
   return foldSelectShuffle(*Shuffle, true);
 }
 
+bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
+  auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
+  if (!SVI)
+    return false;
+
+  std::queue<Value *> Worklist;
+  SmallVector<Instruction *> ToEraseFromParent;
+
+  SmallVector<int> ShuffleMask;
+  bool IsShuffleOp = true;
+
+  Worklist.push(SVI);
+  SVI->getShuffleMask(ShuffleMask);
+
+  if (ShuffleMask.size() < 2)
+    return false;
+
+  Instruction *Prev0 = nullptr, *Prev1 = nullptr;
+  Instruction *LastOp = nullptr;
+
+  int MaskHalfPos = ShuffleMask.size() / 2;
+  bool IsFirst = true;
+
+  while (!Worklist.empty()) {
+    Value *V = Worklist.front();
+    Worklist.pop();
+
+    auto *CI = dyn_cast<Instruction>(V);
+    if (!CI)
+      return false;
+
+    if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
+      if (!IsShuffleOp || MaskHalfPos < 1 || (!Prev1 && !IsFirst))
+        return false;
+
+      auto *Op0 = SV->getOperand(0);
+      auto *Op1 = SV->getOperand(1);
+      if (!Op0 || !Op1)
+        return false;
+
+      auto *FVT = dyn_cast<FixedVectorType>(Op1->getType());
+      if (!FVT || !isa<PoisonValue>(Op1))
+        return false;
+
+      SmallVector<int> CurrentMask;
+      SV->getShuffleMask(CurrentMask);
+
+      int64_t MaskSize = CurrentMask.size();
+      for (int MaskPos = 0; MaskPos != MaskSize; ++MaskPos) {
+        if (MaskPos < MaskHalfPos &&
+            CurrentMask[MaskPos] != MaskHalfPos + MaskPos)
+          return false;
+        if (MaskPos >= MaskHalfPos && CurrentMask[MaskPos] != -1)
+          return false;
+      }
+      MaskHalfPos /= 2;
+      Prev0 = SV;
+    } else if (auto *Call = dyn_cast<CallInst>(V)) {
+      if (IsShuffleOp || !Prev0)
+        return false;
+
+      auto *II = dyn_cast<IntrinsicInst>(Call);
+      if (!II)
+        return false;
+
+      switch (II->getIntrinsicID()) {
+      case Intrinsic::umin: {
+        auto *Op0 = Call->getOperand(0);
+        auto *Op1 = Call->getOperand(1);
+        if (!(Op0 == Prev0 && Op1 == Prev1) &&
+            !(Op0 == Prev1 && Op1 == Prev0) && !IsFirst)
+          return false;
+
+        if (!IsFirst)
+          Prev0 = Prev1;
+        else
+          IsFirst = false;
+        Prev1 = Call;
+        break;
+      }
+      default:
+        return false;
+      }
+    } else if (auto *ExtractElement = dyn_cast<ExtractElementInst>(CI)) {
+      if (!IsShuffleOp || !Prev0 || !Prev1 || MaskHalfPos != 0)
+        return false;
+
+      auto *Op0 = ExtractElement->getOperand(0);
+      auto *Op1 = ExtractElement->getOperand(1);
+      if (Op0 != Prev1)
+        return false;
+
+      if (auto *Op1Idx = dyn_cast<ConstantInt>(Op1)) {
+        if (Op1Idx->getValue() != 0)
+          return false;
+      } else {
+        return false;
+      }
+      LastOp = ExtractElement;
+      break;
+    }
+    IsShuffleOp ^= 1;
+    ToEraseFromParent.push_back(CI);
+
+    auto *NextI = CI->getNextNode();
+    if (!NextI)
+      return false;
+    Worklist.push(NextI);
+  }
+
+  if (!LastOp)
+    return false;
+
----------------
RKSimon wrote:

You need to compare costs - for VectorCombine its expected that a fold only occurs if there is a cost benefit.

https://github.com/llvm/llvm-project/pull/145232


More information about the llvm-commits mailing list