[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)

Nicholas Guy via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 20 06:03:05 PST 2024


================
@@ -291,6 +291,53 @@ InstructionCost VPRecipeBase::computeCost(ElementCount VF,
   llvm_unreachable("subclasses should implement computeCost");
 }
 
+InstructionCost
+VPPartialReductionRecipe::computeCost(ElementCount VF,
+                                      VPCostContext &Ctx) const {
+  auto *BinOp = cast<BinaryOperator>(getOperand(0)->getUnderlyingValue());
+  auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe());
+  auto *Phi = cast<PHINode>(PhiR->getUnderlyingValue());
+  auto *ExtA = cast<Instruction>(BinOp->getOperand(0));
+  auto *ExtB = cast<Instruction>(BinOp->getOperand(1));
+  Value *A = ExtA->getOperand(0);
+  return Ctx.TTI.getPartialReductionCost(
+      Opcode, A->getType(), Phi->getType(), VF,
+      TargetTransformInfo::getPartialReductionExtendKind(ExtA),
+      TargetTransformInfo::getPartialReductionExtendKind(ExtB),
+      std::make_optional(BinOp->getOpcode()));
+}
+
+void VPPartialReductionRecipe::execute(VPTransformState &State) {
+  State.setDebugLocFrom(getDebugLoc());
+  auto &Builder = State.Builder;
+
+  assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
+
+  Value *BinOpVal = State.get(getOperand(0));
+  Value *PhiVal = State.get(getOperand(1));
+  assert(PhiVal && BinOpVal && "Phi and Mul must be set");
+
+  Type *RetTy = PhiVal->getType();
+
+  CallInst *V = Builder.CreateIntrinsic(
+      RetTy, Intrinsic::experimental_vector_partial_reduce_add,
+      {PhiVal, BinOpVal}, nullptr, Twine("partial.reduce"));
+
+  State.set(this, V);
+  State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                     VPSlotTracker &SlotTracker) const {
+  O << Indent << "PARTIAL-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = " << Instruction::getOpcodeName(Opcode);
+  printFlags(O);
----------------
NickGuy-Arm wrote:

> Does the recipe support flags?
I don't think it does, no. With this PR as-is, the call to `setFlags` (introduced in `93fc7af`) does nothing as the operation type of the recipe is `OperationType::Other`. This would however change with my suggestion of setting the underlying value via another constructor, so that would be something to consider (either adding proper support for the flags, or removing the `setFlags` and `printFlags` calls).

Additionally; Unless I'm missing something, `CallInst` doesn't support any of the flags that could be propagated via `setFlags`, do correct me on that if I'm wrong though.

https://github.com/llvm/llvm-project/pull/92418


More information about the llvm-commits mailing list