[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 18 09:14:20 PST 2024


================
@@ -291,6 +291,53 @@ InstructionCost VPRecipeBase::computeCost(ElementCount VF,
   llvm_unreachable("subclasses should implement computeCost");
 }
 
+InstructionCost
+VPPartialReductionRecipe::computeCost(ElementCount VF,
+                                      VPCostContext &Ctx) const {
+  auto *BinOp = cast<BinaryOperator>(getOperand(0)->getUnderlyingValue());
+  auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe());
+  auto *Phi = cast<PHINode>(PhiR->getUnderlyingValue());
+  auto *ExtA = cast<Instruction>(BinOp->getOperand(0));
+  auto *ExtB = cast<Instruction>(BinOp->getOperand(1));
+  Value *A = ExtA->getOperand(0);
+  return Ctx.TTI.getPartialReductionCost(
+      Opcode, A->getType(), Phi->getType(), VF,
+      TargetTransformInfo::getPartialReductionExtendKind(ExtA),
+      TargetTransformInfo::getPartialReductionExtendKind(ExtB),
+      std::make_optional(BinOp->getOpcode()));
+}
+
+void VPPartialReductionRecipe::execute(VPTransformState &State) {
+  State.setDebugLocFrom(getDebugLoc());
+  auto &Builder = State.Builder;
+
+  assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
+
+  Value *BinOpVal = State.get(getOperand(0));
+  Value *PhiVal = State.get(getOperand(1));
+  assert(PhiVal && BinOpVal && "Phi and Mul must be set");
+
+  Type *RetTy = PhiVal->getType();
+
+  CallInst *V = Builder.CreateIntrinsic(
+      RetTy, Intrinsic::experimental_vector_partial_reduce_add,
+      {PhiVal, BinOpVal}, nullptr, Twine("partial.reduce"));
+
+  State.set(this, V);
+  State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                     VPSlotTracker &SlotTracker) const {
+  O << Indent << "PARTIAL-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = " << Instruction::getOpcodeName(Opcode);
+  printFlags(O);
----------------
SamTebbs33 wrote:

Done.

https://github.com/llvm/llvm-project/pull/92418


More information about the llvm-commits mailing list