[llvm] [VectorCombine] New folding pattern for extract/binop/shuffle chains (PR #145232)

Rajveer Singh Bharadwaj via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 30 04:06:30 PDT 2025


================
@@ -2988,6 +2989,240 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
   return foldSelectShuffle(*Shuffle, true);
 }
 
+bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
+  auto *EEI = dyn_cast<ExtractElementInst>(&I);
+  if (!EEI)
+    return false;
+
+  std::queue<Value *> InstWorklist;
+  Value *InitEEV = nullptr;
+
+  unsigned int CommonCallOp = 0, CommonBinOp = 0;
+
+  bool IsFirstCallOrBinInst = true;
+  bool ShouldBeCallOrBinInst = true;
+
+  SmallVector<Value *, 3> PrevVecV(3, nullptr);
+  int64_t ShuffleMaskHalf = -1, ExpectedShuffleMaskHalf = 1;
+  int64_t VecSize = -1;
+
+  Value *VecOp;
+  if (!match(&I, m_ExtractElt(m_Value(VecOp), m_Zero())))
+    return false;
+
+  auto *FVT = dyn_cast<FixedVectorType>(VecOp->getType());
+  if (!FVT)
+    return false;
+
+  VecSize = FVT->getNumElements();
+  if (VecSize < 2 || (VecSize % 2) != 0)
+    return false;
+
+  ShuffleMaskHalf = 1;
+  PrevVecV[2] = VecOp;
+  InitEEV = EEI;
+
+  InstWorklist.push(PrevVecV[2]);
+
+  while (!InstWorklist.empty()) {
+    Value *V = InstWorklist.front();
+    InstWorklist.pop();
+
+    auto *CI = dyn_cast<Instruction>(V);
+    if (!CI)
+      return false;
+
+    if (auto *CallI = dyn_cast<CallInst>(CI)) {
+      if (!ShouldBeCallOrBinInst || !PrevVecV[2])
+        return false;
+
+      if (!IsFirstCallOrBinInst &&
+          any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+        return false;
+
+      if (CallI != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
+        return false;
+      IsFirstCallOrBinInst = false;
+
+      auto *II = dyn_cast<IntrinsicInst>(CallI);
+      if (!II)
+        return false;
+
+      if (!CommonCallOp)
+        CommonCallOp = II->getIntrinsicID();
+      if (II->getIntrinsicID() != CommonCallOp)
+        return false;
+
+      switch (II->getIntrinsicID()) {
+      case Intrinsic::umin:
+      case Intrinsic::umax:
+      case Intrinsic::smin:
+      case Intrinsic::smax: {
+        auto *Op0 = CallI->getOperand(0);
+        auto *Op1 = CallI->getOperand(1);
+        PrevVecV[0] = Op0;
+        PrevVecV[1] = Op1;
+        break;
+      }
+      default:
+        return false;
+      }
+      ShouldBeCallOrBinInst ^= 1;
+
+      if (!isa<ShuffleVectorInst>(PrevVecV[1]))
+        std::swap(PrevVecV[0], PrevVecV[1]);
+      InstWorklist.push(PrevVecV[1]);
+      InstWorklist.push(PrevVecV[0]);
+    } else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
+      if (!ShouldBeCallOrBinInst || !PrevVecV[2])
+        return false;
+
+      if (!IsFirstCallOrBinInst &&
+          any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+        return false;
+
+      if (BinOp != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
+        return false;
+      IsFirstCallOrBinInst = false;
+
+      if (!CommonBinOp)
+        CommonBinOp = CI->getOpcode();
+      if (CI->getOpcode() != CommonBinOp)
+        return false;
+
+      switch (CI->getOpcode()) {
+      case BinaryOperator::Add:
+      case BinaryOperator::Mul:
+      case BinaryOperator::Or:
+      case BinaryOperator::And:
+      case BinaryOperator::Xor: {
+        auto *Op0 = BinOp->getOperand(0);
+        auto *Op1 = BinOp->getOperand(1);
+        PrevVecV[0] = Op0;
+        PrevVecV[1] = Op1;
+        break;
+      }
+      default:
+        return false;
+      }
+      ShouldBeCallOrBinInst ^= 1;
+
+      if (!isa<ShuffleVectorInst>(PrevVecV[1]))
+        std::swap(PrevVecV[0], PrevVecV[1]);
+      InstWorklist.push(PrevVecV[1]);
+      InstWorklist.push(PrevVecV[0]);
+    } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
+      if (ShouldBeCallOrBinInst ||
+          any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
+        return false;
+
+      if (SVInst != PrevVecV[1])
+        return false;
+
+      auto *ShuffleVec = SVInst->getOperand(0);
+      if (!ShuffleVec || ShuffleVec != PrevVecV[0])
+        return false;
+
+      SmallVector<int> CurMask;
+      SVInst->getShuffleMask(CurMask);
+
+      if (ShuffleMaskHalf != ExpectedShuffleMaskHalf)
+        return false;
+      ExpectedShuffleMaskHalf *= 2;
+
+      for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
+        if (Mask < ShuffleMaskHalf && CurMask[Mask] != ShuffleMaskHalf + Mask)
+          return false;
+        if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
+          return false;
+      }
+      ShuffleMaskHalf *= 2;
+      if (ExpectedShuffleMaskHalf == VecSize)
+        break;
+      ShouldBeCallOrBinInst ^= 1;
+    } else {
+      return false;
+    }
+  }
+
+  if (ShouldBeCallOrBinInst)
+    return false;
+
+  assert(VecSize != -1 && ExpectedShuffleMaskHalf == VecSize &&
+         "Expected Match for Vector Size and Mask Half");
+
+  Value *FinalVecV = PrevVecV[0];
+  auto *FinalVecVTy = dyn_cast<FixedVectorType>(FinalVecV->getType());
+
+  if (!InitEEV || !FinalVecV)
+    return false;
+
+  assert(FinalVecVTy && "Expected non-null value for Vector Type");
+
+  Intrinsic::ID ReducedOp = 0;
+  if (CommonCallOp) {
+    switch (CommonCallOp) {
+    case Intrinsic::umin:
+      ReducedOp = Intrinsic::vector_reduce_umin;
+      break;
+    case Intrinsic::umax:
+      ReducedOp = Intrinsic::vector_reduce_umax;
+      break;
+    case Intrinsic::smin:
+      ReducedOp = Intrinsic::vector_reduce_smin;
+      break;
+    case Intrinsic::smax:
+      ReducedOp = Intrinsic::vector_reduce_smax;
+      break;
+    default:
+      return false;
+    }
+  } else if (CommonBinOp) {
+    switch (CommonBinOp) {
+    case BinaryOperator::Add:
+      ReducedOp = Intrinsic::vector_reduce_add;
+      break;
+    case BinaryOperator::Mul:
+      ReducedOp = Intrinsic::vector_reduce_mul;
+      break;
+    case BinaryOperator::Or:
+      ReducedOp = Intrinsic::vector_reduce_or;
+      break;
+    case BinaryOperator::And:
+      ReducedOp = Intrinsic::vector_reduce_and;
+      break;
+    case BinaryOperator::Xor:
+      ReducedOp = Intrinsic::vector_reduce_xor;
+      break;
+    default:
+      return false;
+    }
+  }
+
+  InstructionCost OrigCost = 0;
+  unsigned int NumLevels = Log2_64(VecSize);
+
+  for (unsigned int Level = 0; Level < NumLevels; ++Level) {
+    OrigCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
+                                   FinalVecVTy, FinalVecVTy);
+    OrigCost += TTI.getArithmeticInstrCost(Instruction::ICmp, FinalVecVTy);
----------------
Rajveer100 wrote:

In that case, I think there might be an option to check the target support for reduce, since cost analysis wouldn't tell us that, it would just always be true that the new cost is cheaper?

https://github.com/llvm/llvm-project/pull/145232


More information about the llvm-commits mailing list