[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 9 09:56:08 PST 2024
================
@@ -291,6 +291,67 @@ InstructionCost VPRecipeBase::computeCost(ElementCount VF,
llvm_unreachable("subclasses should implement computeCost");
}
+InstructionCost
+VPPartialReductionRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ std::optional<unsigned> Opcode = std::nullopt;
+ VPRecipeBase *BinOpR = getOperand(0)->getDefiningRecipe();
+ if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
+ Opcode = std::make_optional(WidenR->getOpcode());
+
+ VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
+ VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
+
+ auto GetExtendKind = [](VPRecipeBase *R) {
+ auto *WidenCastR = dyn_cast<VPWidenCastRecipe>(R);
+ if (!WidenCastR)
+ return TargetTransformInfo::PR_None;
+ if (WidenCastR->getOpcode() == Instruction::CastOps::ZExt)
+ return TargetTransformInfo::PR_ZeroExtend;
+ if (WidenCastR->getOpcode() == Instruction::CastOps::SExt)
+ return TargetTransformInfo::PR_SignExtend;
+ return TargetTransformInfo::PR_None;
+ };
+
+ auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
+ auto *ExtTy = Ctx.Types.inferScalarType(ExtAR->getOperand(0));
+
+ return Ctx.TTI.getPartialReductionCost(getOpcode(), ExtTy, PhiType, VF,
+ GetExtendKind(ExtAR),
+ GetExtendKind(ExtBR), Opcode);
+}
+
+void VPPartialReductionRecipe::execute(VPTransformState &State) {
+ State.setDebugLocFrom(getDebugLoc());
+ auto &Builder = State.Builder;
+
+ assert(getOpcode() == Instruction::Add &&
+ "Unhandled partial reduction opcode");
+
+ Value *BinOpVal = State.get(getOperand(0));
+ Value *PhiVal = State.get(getOperand(1));
+ assert(PhiVal && BinOpVal && "Phi and Mul must be set");
+
+ Type *RetTy = PhiVal->getType();
+
+ CallInst *V = Builder.CreateIntrinsic(
+ RetTy, Intrinsic::experimental_vector_partial_reduce_add,
+ {PhiVal, BinOpVal}, nullptr, Twine("partial.reduce"));
+
+ State.set(this, V);
+ State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
----------------
fhahn wrote:
This copies over the metadata from a binary instructions on a call instruction, and some of the metadata for binary ops may not be valid for call instructions I think.
Better to drop?
https://github.com/llvm/llvm-project/pull/92418
More information about the llvm-commits
mailing list