[llvm] [LV][VPlan] Add initial support for CSA vectorization (PR #106560)
Ramkumar Ramachandra via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 30 13:22:37 PDT 2024
================
@@ -1998,6 +2119,223 @@ void VPScalarCastRecipe ::print(raw_ostream &O, const Twine &Indent,
}
#endif
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPCSAHeaderPHIRecipe::print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const {
+ O << Indent << "EMIT ";
+ printAsOperand(O, SlotTracker);
+ O << " = csa-data-phi ";
+ printOperands(O, SlotTracker);
+}
+#endif
+
+void VPCSAHeaderPHIRecipe::execute(VPTransformState &State) {
+ // PrevBB is this BB
+ IRBuilder<>::InsertPointGuard Guard(State.Builder);
+ State.Builder.SetInsertPoint(State.CFG.PrevBB->getFirstNonPHI());
+
+ Value *InitData = State.get(getVPInitData(), 0);
+ PHINode *DataPhi =
+ State.Builder.CreatePHI(InitData->getType(), 2, "csa.data.phi");
+ BasicBlock *PreheaderBB = State.CFG.getPreheaderBBFor(this);
+ DataPhi->addIncoming(InitData, PreheaderBB);
+ // Note: We didn't add Incoming for the new data since VPCSADataUpdateRecipe
+ // may not have been executed. We let VPCSADataUpdateRecipe::execute add the
+ // incoming operand to DataPhi.
+
+ // Use the same DataPhi for all Parts
+ for (unsigned Part = 0; Part < State.UF; ++Part)
+ State.set(this, DataPhi, Part);
+}
+
+InstructionCost VPCSAHeaderPHIRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ if (VF.isScalar())
+ return 0;
+
+ InstructionCost C = 0;
+ auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
+ const TargetTransformInfo &TTI = Ctx.TTI;
+
+ // FIXME: These costs should be moved into VPInstruction::computeCost. We put
+ // them here for now since there is no VPInstruction::computeCost support.
+ // CSAInitMask
+ C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
+ // CSAInitData
+ C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
+ return C;
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPCSADataUpdateRecipe::print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const {
+ O << Indent << "EMIT ";
+ printAsOperand(O, SlotTracker);
+ O << " = csa-data-update ";
+ printOperands(O, SlotTracker);
+}
+#endif
+
+void VPCSADataUpdateRecipe::execute(VPTransformState &State) {
+ for (unsigned Part = 0; Part < State.UF; ++Part) {
+ Value *AnyActive = State.get(getVPAnyActive(), Part, /*NeedsScalar=*/true);
+ Value *DataUpdate = getVPDataPhi() == getVPTrue()
+ ? State.get(getVPFalse(), Part)
+ : State.get(getVPTrue(), Part);
+ PHINode *DataPhi = cast<PHINode>(State.get(getVPDataPhi(), Part));
+ // If not the first Part, use the mask from the previous unrolled Part
+ Value *OldData = Part == 0 ? DataPhi : State.get(this, Part - 1);
+ Value *DataSel = State.Builder.CreateSelect(AnyActive, DataUpdate, OldData,
+ "csa.data.sel");
+
+ if (Part == State.UF - 1)
+ DataPhi->addIncoming(DataSel, State.CFG.PrevBB);
+ State.set(this, DataSel, Part);
+ }
+}
+
+InstructionCost VPCSADataUpdateRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ if (VF.isScalar())
+ return 0;
+
+ InstructionCost C = 0;
+ auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
+ auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ const TargetTransformInfo &TTI = Ctx.TTI;
+
+ // Data Update
+ C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
+
+ // FIXME: These costs should be moved into VPInstruction::computeCost. We put
+ // them here for now since they are related to updating the data and there is
+ // no VPInstruction::computeCost support at the moment. CSAInitMask AnyActive
+ C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
+ // vp.reduce.or
+ C += TTI.getArithmeticReductionCost(Instruction::Or, VTy, std::nullopt,
+ CostKind);
+ // VPVLSel
+ C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
+ // MaskUpdate
+ C += TTI.getArithmeticInstrCost(Instruction::Select, MaskTy, CostKind);
+ return C;
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPCSAExtractScalarRecipe::print(raw_ostream &O, const Twine &Indent,
+ VPSlotTracker &SlotTracker) const {
+ O << Indent << "EMIT ";
+ printAsOperand(O, SlotTracker);
+ O << " = CSA-EXTRACT-SCALAR ";
+ printOperands(O, SlotTracker);
+}
+#endif
+
+void VPCSAExtractScalarRecipe::execute(VPTransformState &State) {
+ IRBuilder<>::InsertPointGuard Guard(State.Builder);
+ State.Builder.SetInsertPoint(State.CFG.ExitBB->getFirstNonPHI());
+
+ unsigned LastPart = State.UF - 1;
+ Value *InitScalar = getVPInitScalar()->getLiveInIRValue();
+ Value *MaskSel = State.get(getVPMaskSel(), LastPart);
+ Value *DataSel = State.get(getVPDataSel(), LastPart);
+
+ Value *LastIdx = nullptr;
+ Value *IndexVec = State.Builder.CreateStepVector(
+ VectorType::get(State.Builder.getInt32Ty(), State.VF), "csa.step");
+ Value *NegOne = ConstantInt::get(IndexVec->getType()->getScalarType(), -1);
+ if (usesEVL()) {
+ // A vp.reduce.smax over the IndexVec with the MaskSel as the mask will
+ // give us the last active index into MaskSel, which gives us the correct
+ // index in the data vector to extract from. If no element in the mask
+ // is active, we pick -1. If we pick -1, then we will use the initial scalar
+ // value instead of extracting from the data vector.
+ Value *VL = State.get(getVPCSAVLSel(), LastPart, /*NeedsScalar=*/true);
+ LastIdx = State.Builder.CreateIntrinsic(NegOne->getType(),
+ Intrinsic::vp_reduce_smax,
+ {NegOne, IndexVec, MaskSel, VL});
+ } else {
+ // Get a vector where the elements are zero when the last active mask is
+ // false and the index in the vector when the mask is true.
+ Value *ActiveLaneIdxs = State.Builder.CreateSelect(
+ MaskSel, IndexVec, ConstantAggregateZero::get(IndexVec->getType()));
+ // Get the last active index in the mask. When no lanes in the mask are
+ // active, vector.umax will have value 0. Take the additional step to set
+ // LastIdx as -1 in this case to avoid the case of lane 0 of the mask being
+ // inactive, which would also cause the reduction to have value 0.
+ Value *MaybeLastIdx = State.Builder.CreateIntMaxReduce(ActiveLaneIdxs);
+ Value *IsLaneZeroActive =
+ State.Builder.CreateExtractElement(MaskSel, (uint64_t)0);
----------------
artagnon wrote:
Prefer `static_cast` over C-style cast.
https://github.com/llvm/llvm-project/pull/106560
More information about the llvm-commits
mailing list