[llvm] [LV][VPlan] Add initial support for CSA vectorization (PR #106560)

Elvis Wang via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 12 00:11:58 PDT 2024


================
@@ -2107,6 +2241,223 @@ void VPScalarCastRecipe ::print(raw_ostream &O, const Twine &Indent,
 }
 #endif
 
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPCSAHeaderPHIRecipe::print(raw_ostream &O, const Twine &Indent,
+                                 VPSlotTracker &SlotTracker) const {
+  O << Indent << "EMIT ";
+  printAsOperand(O, SlotTracker);
+  O << " = csa-data-phi ";
+  printOperands(O, SlotTracker);
+}
+#endif
+
+void VPCSAHeaderPHIRecipe::execute(VPTransformState &State) {
+  // PrevBB is this BB
+  IRBuilder<>::InsertPointGuard Guard(State.Builder);
+  State.Builder.SetInsertPoint(State.CFG.PrevBB->getFirstNonPHI());
+
+  Value *InitData = State.get(getVPInitData(), 0);
+  PHINode *DataPhi =
+      State.Builder.CreatePHI(InitData->getType(), 2, "csa.data.phi");
+  BasicBlock *PreheaderBB = State.CFG.getPreheaderBBFor(this);
+  DataPhi->addIncoming(InitData, PreheaderBB);
+  // Note: We didn't add Incoming for the new data since VPCSADataUpdateRecipe
+  // may not have been executed. We let VPCSADataUpdateRecipe::execute add the
+  // incoming operand to DataPhi.
+
+  // Use the same DataPhi for all Parts
+  for (unsigned Part = 0; Part < State.UF; ++Part)
+    State.set(this, DataPhi, Part);
+}
+
+InstructionCost VPCSAHeaderPHIRecipe::computeCost(ElementCount VF,
+                                                  VPCostContext &Ctx) const {
+  if (VF.isScalar())
+    return 0;
+
+  InstructionCost C = 0;
+  auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
+  const TargetTransformInfo &TTI = Ctx.TTI;
+
+  // FIXME: These costs should be moved into VPInstruction::computeCost. We put
+  // them here for now since there is no VPInstruction::computeCost support.
+  // CSAInitMask
+  C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
+  // CSAInitData
+  C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
+  return C;
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPCSADataUpdateRecipe::print(raw_ostream &O, const Twine &Indent,
+                                  VPSlotTracker &SlotTracker) const {
+  O << Indent << "EMIT ";
+  printAsOperand(O, SlotTracker);
+  O << " = csa-data-update ";
+  printOperands(O, SlotTracker);
+}
+#endif
+
+void VPCSADataUpdateRecipe::execute(VPTransformState &State) {
+  for (unsigned Part = 0; Part < State.UF; ++Part) {
+    Value *AnyActive = State.get(getVPAnyActive(), Part, /*NeedsScalar=*/true);
+    Value *DataUpdate = getVPDataPhi() == getVPTrue()
+                            ? State.get(getVPFalse(), Part)
+                            : State.get(getVPTrue(), Part);
+    PHINode *DataPhi = cast<PHINode>(State.get(getVPDataPhi(), Part));
+    // If not the first Part, use the mask from the previous unrolled Part
+    Value *OldData = Part == 0 ? DataPhi : State.get(this, Part - 1);
+    Value *DataSel = State.Builder.CreateSelect(AnyActive, DataUpdate, OldData,
+                                                "csa.data.sel");
+
+    if (Part == State.UF - 1)
+      DataPhi->addIncoming(DataSel, State.CFG.PrevBB);
+    State.set(this, DataSel, Part);
+  }
+}
+
+InstructionCost VPCSADataUpdateRecipe::computeCost(ElementCount VF,
+                                                   VPCostContext &Ctx) const {
+  if (VF.isScalar())
+    return 0;
+
+  InstructionCost C = 0;
+  auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
+  auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
+  constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  const TargetTransformInfo &TTI = Ctx.TTI;
+
+  // Data Update
+  C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
+
+  // FIXME: These costs should be moved into VPInstruction::computeCost. We put
+  // them here for now since they are related to updating the data and there is
+  // no VPInstruction::computeCost support at the moment. CSAInitMask AnyActive
+  C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
+  // vp.reduce.or
+  C += TTI.getArithmeticReductionCost(Instruction::Or, VTy, std::nullopt,
+                                      CostKind);
+  // VPVLSel
+  C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
+  // MaskUpdate
+  C += TTI.getArithmeticInstrCost(Instruction::Select, MaskTy, CostKind);
+  return C;
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void VPCSAExtractScalarRecipe::print(raw_ostream &O, const Twine &Indent,
+                                     VPSlotTracker &SlotTracker) const {
+  O << Indent << "EMIT ";
+  printAsOperand(O, SlotTracker);
+  O << " = CSA-EXTRACT-SCALAR ";
+  printOperands(O, SlotTracker);
+}
+#endif
+
+void VPCSAExtractScalarRecipe::execute(VPTransformState &State) {
+  IRBuilder<>::InsertPointGuard Guard(State.Builder);
+  State.Builder.SetInsertPoint(State.CFG.ExitBB->getFirstNonPHI());
+
+  unsigned LastPart = State.UF - 1;
+  Value *InitScalar = getVPInitScalar()->getLiveInIRValue();
+  Value *MaskSel = State.get(getVPMaskSel(), LastPart);
+  Value *DataSel = State.get(getVPDataSel(), LastPart);
+
+  Value *LastIdx = nullptr;
+  Value *IndexVec = State.Builder.CreateStepVector(
+      VectorType::get(State.Builder.getInt32Ty(), State.VF), "csa.step");
+  Value *NegOne = ConstantInt::get(IndexVec->getType()->getScalarType(), -1);
+  if (usesEVL()) {
+    // A vp.reduce.smax over the IndexVec with the MaskSel as the mask will
+    // give us the last active index into MaskSel, which gives us the correct
+    // index in the data vector to extract from. If no element in the mask
+    // is active, we pick -1. If we pick -1, then we will use the initial scalar
+    // value instead of extracting from the data vector.
+    Value *VL = State.get(getVPCSAVLSel(), LastPart, /*NeedsScalar=*/true);
+    LastIdx = State.Builder.CreateIntrinsic(NegOne->getType(),
+                                            Intrinsic::vp_reduce_smax,
+                                            {NegOne, IndexVec, MaskSel, VL});
+  } else {
+    // Get a vector where the elements are zero when the last active mask is
+    // false and the index in the vector when the mask is true.
+    Value *ActiveLaneIdxs = State.Builder.CreateSelect(
+        MaskSel, IndexVec, ConstantAggregateZero::get(IndexVec->getType()));
+    // Get the last active index in the mask. When no lanes in the mask are
+    // active, vector.umax will have value 0. Take the additional step to set
+    // LastIdx as -1 in this case to avoid the case of lane 0 of the mask being
+    // inactive, which would also cause the reduction to have value 0.
+    Value *MaybeLastIdx = State.Builder.CreateIntMaxReduce(ActiveLaneIdxs);
+    Value *IsLaneZeroActive =
+        State.Builder.CreateExtractElement(MaskSel, static_cast<uint64_t>(0));
+    Value *Zero = ConstantInt::get(MaybeLastIdx->getType(), 0);
+    Value *MaybeLastIdxEQZero = State.Builder.CreateICmpEQ(MaybeLastIdx, Zero);
+    Value *And = State.Builder.CreateAnd(IsLaneZeroActive, MaybeLastIdxEQZero);
+    LastIdx = State.Builder.CreateSelect(And, Zero, NegOne);
+  }
+
+  Value *ExtractFromVec =
+      State.Builder.CreateExtractElement(DataSel, LastIdx, "csa.extract");
+  Value *Zero = ConstantInt::get(LastIdx->getType(), 0);
+  Value *LastIdxGEZero = State.Builder.CreateICmpSGE(LastIdx, Zero);
+  Value *ChooseFromVecOrInit =
+      State.Builder.CreateSelect(LastIdxGEZero, ExtractFromVec, InitScalar);
+  State.set(this, ChooseFromVecOrInit, 0, /*IsScalar=*/true);
+}
+
+InstructionCost
+VPCSAExtractScalarRecipe::computeCost(ElementCount VF,
+                                      VPCostContext &Ctx) const {
+  if (VF.isScalar())
+    return 0;
+
+  InstructionCost C = 0;
+  auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
+  auto *Int32VTy =
+      VectorType::get(IntegerType::getInt32Ty(VTy->getContext()), VF);
+  auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
+  constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  const TargetTransformInfo &TTI = Ctx.TTI;
+
+  // StepVector
+  ArrayRef<Value *> Args;
+  IntrinsicCostAttributes CostAttrs(Intrinsic::stepvector, Int32VTy, Args);
+  C += TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
+  // NegOneSplat
+  C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, Int32VTy);
+  // LastIdx
+  if (usesEVL()) {
+    C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
+                                    CostKind);
+  } else {
+    // ActiveLaneIdxs
+    C += TTI.getArithmeticInstrCost(Instruction::Select,
+                                    MaskTy->getScalarType(), CostKind);
+    // MaybeLastIdx
+    C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
+                                    CostKind);
+    // IsLaneZeroActive
+    C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, MaskTy,
+                                    CostKind);
----------------
ElvisWang123 wrote:

We should use `TTI.getVectorInstrCost()` to get the instruction cost of `ExtractElement` instructions.

https://github.com/llvm/llvm-project/pull/106560


More information about the llvm-commits mailing list