[llvm] [RISCV][CostModel] Add cost for fabs/fsqrt of type bf16/f16 (PR #118608)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 10 00:55:28 PST 2025


================
@@ -1035,21 +1036,66 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
     }
     break;
   }
-  case Intrinsic::fabs:
+  case Intrinsic::fabs: {
+    auto LT = getTypeLegalizationCost(RetTy);
+    if (ST->hasVInstructions() && LT.second.isVector()) {
+      // lui a0, 8
+      // addi a0, a0, -1
+      // vsetvli a1, zero, e16, m1, ta, ma
+      // vand.vx v8, v8, a0
+      // f16 with zvfhmin and bf16 with zvfhbmin
+      if (LT.second.getVectorElementType() == MVT::bf16 ||
+          (LT.second.getVectorElementType() == MVT::f16 &&
+           !ST->hasVInstructionsF16()))
+        return LT.first * getRISCVInstructionCost(RISCV::VAND_VX, LT.second,
+                                                  CostKind) +
+               2;
+      else
+        return LT.first *
+               getRISCVInstructionCost(RISCV::VFSGNJX_VV, LT.second, CostKind);
+    }
+    break;
+  }
   case Intrinsic::sqrt: {
     auto LT = getTypeLegalizationCost(RetTy);
-    // TODO: add f16/bf16, bf16 with zvfbfmin && f16 with zvfhmin
     if (ST->hasVInstructions() && LT.second.isVector()) {
-      unsigned Op;
-      switch (ICA.getID()) {
-      case Intrinsic::fabs:
-        Op = RISCV::VFSGNJX_VV;
-        break;
-      case Intrinsic::sqrt:
-        Op = RISCV::VFSQRT_V;
-        break;
+      SmallVector<unsigned, 4> ConvOp;
+      SmallVector<unsigned, 2> FsqrtOp;
+      MVT ConvType = LT.second;
+      MVT FsqrtType = LT.second;
+      // f16 with zvfhmin and bf16 with zvfbfmin and the type of nxv32[b]f16
+      // will be spilt.
+      if (LT.second.getVectorElementType() == MVT::bf16) {
+        if (LT.second == MVT::nxv32bf16) {
+          ConvOp = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFWCVTBF16_F_F_V,
+                    RISCV::VFNCVTBF16_F_F_W, RISCV::VFNCVTBF16_F_F_W};
+          FsqrtOp = {RISCV::VFSQRT_V, RISCV::VFSQRT_V};
+          ConvType = MVT::nxv16f16;
+          FsqrtType = MVT::nxv16f32;
----------------
LiqinWeng wrote:

the ConvOp have been spilt, so neednt  * 2 

https://github.com/llvm/llvm-project/pull/118608


More information about the llvm-commits mailing list