[llvm] [RISCV][CostModel] Add cost for fabs/fsqrt of type bf16/f16 (PR #118608)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 2 08:10:21 PST 2025
================
@@ -1035,21 +1036,61 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
}
break;
}
- case Intrinsic::fabs:
+ case Intrinsic::fabs: {
+ auto LT = getTypeLegalizationCost(RetTy);
+ if (ST->hasVInstructions() && LT.second.isVector()) {
+ // lui a0, 8
+ // addi a0, a0, -1
+ // vsetvli a1, zero, e16, m1, ta, ma
+ // vand.vx v8, v8, a0
+ // f16 with zvfhmin and bf16 with zvfhbmin
+ if (LT.second.getVectorElementType() == MVT::bf16 ||
+ (LT.second.getVectorElementType() == MVT::f16 &&
+ !ST->hasVInstructionsF16()))
+ return LT.first * getRISCVInstructionCost(RISCV::VAND_VX, LT.second,
+ CostKind) +
+ 2;
+ else
+ return LT.first *
+ getRISCVInstructionCost(RISCV::VFSGNJX_VV, LT.second, CostKind);
+ }
+ break;
+ }
case Intrinsic::sqrt: {
auto LT = getTypeLegalizationCost(RetTy);
- // TODO: add f16/bf16, bf16 with zvfbfmin && f16 with zvfhmin
+ auto NVT = LT.second;
if (ST->hasVInstructions() && LT.second.isVector()) {
- unsigned Op;
- switch (ICA.getID()) {
- case Intrinsic::fabs:
- Op = RISCV::VFSGNJX_VV;
- break;
- case Intrinsic::sqrt:
- Op = RISCV::VFSQRT_V;
- break;
+ SmallVector<unsigned, 3> Opcodes;
+ // f16 with zvfhmin and bf16 with zvfbfmin and the type of nxv32[b]f16
+ // will be spilt.
+ if (LT.second.getVectorElementType() == MVT::bf16) {
+ if (LT.second == MVT::nxv32bf16) {
+ Opcodes = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFWCVTBF16_F_F_V,
+ RISCV::VFSQRT_V, RISCV::VFSQRT_V,
+ RISCV::VFNCVTBF16_F_F_W, RISCV::VFNCVTBF16_F_F_W};
+ NVT = TLI->getTypeToPromoteTo(ISD::FSQRT,
+ NVT.getHalfNumVectorElementsVT());
+ } else {
+ Opcodes = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFSQRT_V,
+ RISCV::VFNCVTBF16_F_F_W};
+ NVT = TLI->getTypeToPromoteTo(ISD::FSQRT, NVT);
+ }
+ } else if (LT.second.getVectorElementType() == MVT::f16 &&
+ !ST->hasVInstructionsF16()) {
+ if (LT.second == MVT::nxv32f16) {
+ Opcodes = {RISCV::VFWCVT_F_F_V, RISCV::VFWCVT_F_F_V,
+ RISCV::VFSQRT_V, RISCV::VFSQRT_V,
+ RISCV::VFNCVT_F_F_W, RISCV::VFNCVT_F_F_W};
+ NVT = TLI->getTypeToPromoteTo(ISD::FSQRT,
----------------
topperc wrote:
You can't use getTypeToPromoteTo. That only works when LegalizeVectorOps/LegalizeDAG does the promotion via `setOperationAction(ISD::FSQRT, VT, Promote)`. This case uses `setOperationAction(ISD::FSQRT, MVT::nxv32f16, Custom)`. You'll need to hardcode the type has MVT::nxv16f32.
https://github.com/llvm/llvm-project/pull/118608
More information about the llvm-commits
mailing list