[llvm] [VectorCombine] New folding pattern for extract/binop/shuffle chains (PR #145232)

Rajveer Singh Bharadwaj via llvm-commits llvm-commits at lists.llvm.org
Sat Jul 19 10:15:40 PDT 2025


================
@@ -2988,6 +2989,305 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
   return foldSelectShuffle(*Shuffle, true);
 }
 
+/// For a given chain of patterns of the following form:
+///
+/// ```
+///   %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
+///
+///   %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
+///   ty1> %1)
+///     OR
+///   %2 = add/mul/or/and/xor <n x ty1> %0, %1
+///
+///   %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
+///   ...
+///   ...
+///   %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
+///   3), <n x ty1> %(i - 2)
+///     OR
+///   %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
+///
+///   %(i) = extractelement <n x ty1> %(i - 1), 0
+/// ```
+///
+/// Where:
+///    `mask` follows a partition pattern:
+///
+/// Ex:
+///    [n = 8, p = poison]
+///
+///    4 5 6 7 | p p p p
+///    2 3 | p p p p p p
+///    1 | p p p p p p p
+///
+///    For powers of 2, there's a consistent pattern, but for other cases
+///    the parity of the current half value at each step decides the
+///    next partition half (see `ExpectedParityMask` for more logical details
+///    in generalising this).
+///
+/// Ex:
+///    [n = 6]
+///
+///    3 4 5 | p p p
+///    1 2 | p p p p
+///    1 | p p p p p
+bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
+  // Going bottom-up for the pattern.
+  auto *EEI = dyn_cast<ExtractElementInst>(&I);
+  if (!EEI)
+    return false;
+
+  std::queue<Value *> InstWorklist;
+  InstructionCost OrigCost = 0;
+
+  Value *InitEEV = nullptr;
+
+  // Common instruction operation after each shuffle op.
+  unsigned int CommonCallOp = 0;
+  Instruction::BinaryOps CommonBinOp = Instruction::BinaryOpsEnd;
+
+  bool IsFirstCallOrBinInst = true;
+  bool ShouldBeCallOrBinInst = true;
+
+  // This stores the last used instructions for shuffle/common op.
+  //
+  // PrevVecV[2] stores the first vector from extract element instruction,
+  // while PrevVecV[0] / PrevVecV[1] store the last two simultaneous
+  // instructions from either shuffle/common op.
+  SmallVector<Value *, 3> PrevVecV(3, nullptr);
+
+  Value *VecOp;
+  if (!match(&I, m_ExtractElt(m_Value(VecOp), m_Zero())))
+    return false;
+
+  auto *FVT = dyn_cast<FixedVectorType>(VecOp->getType());
+  if (!FVT)
+    return false;
+
+  int64_t VecSize = FVT->getNumElements();
+  if (VecSize < 2)
+    return false;
+
+  // Number of levels would be ~log2(n), considering we always partition
+  // by half for this fold pattern.
+  unsigned int NumLevels = Log2_64_Ceil(VecSize), VisitedCnt = 0;
+  int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
+
+  // This is how we generalise for all element sizes.
+  // At each step, if vector size is odd, we need non-poison
+  // values to cover the dominant half so we don't miss out on any element.
+  //
+  // This mask will help us retrieve this as we go from bottom to top:
+  //
+  // Mask Set -> N = N * 2 - 1
+  // Mask Unset -> N = N * 2
+  for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
+       Cur = (Cur + 1) / 2, --Mask) {
+    if (Cur & 1)
+      ExpectedParityMask |= (1ll << Mask);
+  }
+
+  PrevVecV[2] = VecOp;
+  InitEEV = EEI;
+
+  InstWorklist.push(PrevVecV[2]);
+
+  while (!InstWorklist.empty()) {
+    Value *V = InstWorklist.front();
+    InstWorklist.pop();
+
+    auto *CI = dyn_cast<Instruction>(V);
+    if (!CI)
+      return false;
+
+    if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
+      if (!ShouldBeCallOrBinInst || !PrevVecV[2])
+        return false;
+
+      if (!IsFirstCallOrBinInst &&
+          any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+        return false;
+
+      // For the first found call/bin op, the vector has to come from the
+      // extract element op.
+      if (II != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
+        return false;
+      IsFirstCallOrBinInst = false;
+
+      if (!CommonCallOp)
+        CommonCallOp = II->getIntrinsicID();
+      if (II->getIntrinsicID() != CommonCallOp)
+        return false;
+
+      switch (II->getIntrinsicID()) {
+      case Intrinsic::umin:
+      case Intrinsic::umax:
+      case Intrinsic::smin:
+      case Intrinsic::smax: {
+        auto *Op0 = II->getOperand(0);
+        auto *Op1 = II->getOperand(1);
+        PrevVecV[0] = Op0;
+        PrevVecV[1] = Op1;
+        break;
+      }
+      default:
+        return false;
+      }
+      ShouldBeCallOrBinInst ^= 1;
+
+      IntrinsicCostAttributes ICA(
+          CommonCallOp, II->getType(),
+          {PrevVecV[0]->getType(), PrevVecV[1]->getType()});
+      OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
+
+      // We may need a swap here since it can be (a, b) or (b, a)
+      // and accordinly change as we go up.
+      if (!isa<ShuffleVectorInst>(PrevVecV[1]))
+        std::swap(PrevVecV[0], PrevVecV[1]);
+      InstWorklist.push(PrevVecV[1]);
+      InstWorklist.push(PrevVecV[0]);
+    } else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
+      // Similar logic for bin ops.
+
+      if (!ShouldBeCallOrBinInst || !PrevVecV[2])
+        return false;
+
+      if (!IsFirstCallOrBinInst &&
+          any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+        return false;
+
+      if (BinOp != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
+        return false;
+      IsFirstCallOrBinInst = false;
+
+      if (CommonBinOp == Instruction::BinaryOpsEnd)
+        CommonBinOp = BinOp->getOpcode();
+
+      if (BinOp->getOpcode() != CommonBinOp)
+        return false;
+
+      switch (CommonBinOp) {
+      case BinaryOperator::Add:
+      case BinaryOperator::Mul:
+      case BinaryOperator::Or:
+      case BinaryOperator::And:
+      case BinaryOperator::Xor: {
+        auto *Op0 = BinOp->getOperand(0);
+        auto *Op1 = BinOp->getOperand(1);
+        PrevVecV[0] = Op0;
+        PrevVecV[1] = Op1;
+        break;
+      }
+      default:
+        return false;
+      }
+      ShouldBeCallOrBinInst ^= 1;
+
+      OrigCost +=
+          TTI.getArithmeticInstrCost(CommonBinOp, BinOp->getType(), CostKind);
+
+      if (!isa<ShuffleVectorInst>(PrevVecV[1]))
+        std::swap(PrevVecV[0], PrevVecV[1]);
+      InstWorklist.push(PrevVecV[1]);
+      InstWorklist.push(PrevVecV[0]);
+    } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
+      // We shouldn't have any null values in the previous vectors,
+      // is so, there was a mismatch in pattern.
+      if (ShouldBeCallOrBinInst ||
+          any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+        return false;
+
+      if (SVInst != PrevVecV[1])
+        return false;
+
+      auto *ShuffleVec = SVInst->getOperand(0);
+      if (!ShuffleVec || ShuffleVec != PrevVecV[0])
+        return false;
+
+      if (!isa<PoisonValue>(SVInst->getOperand(1)))
+        return false;
+
+      ArrayRef<int> CurMask = SVInst->getShuffleMask();
+
+      // Subtract the parity mask when checking the condition.
+      for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
+        if (Mask < ShuffleMaskHalf &&
+            CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
+          return false;
+        if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
+          return false;
+      }
+
----------------
Rajveer100 wrote:

Could you give an example of your logic with respect to the already added test cases, i.e, non powers-of-2?

https://github.com/llvm/llvm-project/pull/145232


More information about the llvm-commits mailing list