[llvm] [RISCV][TTI] Scale the cost of FP-Int conversion with LMUL (PR #87506)

Shih-Po Hung via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 4 23:48:34 PDT 2024


================
@@ -988,31 +988,106 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
     return Cost;
   }
   case ISD::FP_TO_SINT:
-  case ISD::FP_TO_UINT:
+  case ISD::FP_TO_UINT: {
+    unsigned IsSigned = ISD == ISD::FP_TO_SINT;
+    unsigned FCVT = IsSigned ? RISCV::VFCVT_RTZ_X_F_V : RISCV::VFCVT_RTZ_XU_F_V;
+    unsigned FWCVT =
+        IsSigned ? RISCV::VFWCVT_RTZ_X_F_V : RISCV::VFWCVT_RTZ_XU_F_V;
+    unsigned FNCVT =
+        IsSigned ? RISCV::VFNCVT_RTZ_X_F_W : RISCV::VFNCVT_RTZ_XU_F_W;
+    unsigned SrcEltSize = Src->getScalarSizeInBits();
+    unsigned DstEltSize = Dst->getScalarSizeInBits();
+    if (DstEltSize == 1) {
+      // For fp vector to mask, we use:
+      // vfncvt.rtz.x.f.w v9, v8
+      // vand.vi v8, v9, 1
+      // vmsne.vi v0, v8, 0
+      SrcEltSize /= 2;
+      MVT ElementVT = MVT::getIntegerVT(SrcEltSize);
+      MVT InterimVT = SrcLT.second.changeVectorElementType(ElementVT);
+      return getRISCVInstructionCost(FNCVT, InterimVT, CostKind) +
+             getRISCVInstructionCost({RISCV::VAND_VI, RISCV::VMSNE_VI},
+                                     DstLT.second, CostKind);
+    }
+    if (DstEltSize == SrcEltSize)
+      return getRISCVInstructionCost(FCVT, DstLT.second, CostKind);
+    if (DstEltSize == (2 * SrcEltSize))
+      return getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
+    if (DstEltSize == (4 * SrcEltSize) && (SrcEltSize == 16)) {
+      // Convert f16 to f32 then convert f32 to i64.
+      MVT VecF32VT = DstLT.second.changeVectorElementType(MVT::f32);
+      return getRISCVInstructionCost(RISCV::VFWCVT_F_F_V, VecF32VT, CostKind) +
+             getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
+    }
+    if (DstEltSize < SrcEltSize) {
+      SrcEltSize /= 2;
+      MVT ElementVT = MVT::getIntegerVT(SrcEltSize);
+      MVT InterimVT = DstLT.second.changeVectorElementType(ElementVT);
+      InstructionCost Cost =
+          getRISCVInstructionCost(FNCVT, InterimVT, CostKind);
+      while (DstEltSize < SrcEltSize) {
+        SrcEltSize /= 2;
+        ElementVT = MVT::getIntegerVT(SrcEltSize);
+        InterimVT = DstLT.second.changeVectorElementType(ElementVT);
+        Cost += getRISCVInstructionCost(RISCV::VNSRL_WI, InterimVT, CostKind);
+      }
+      return Cost;
+    }
+    return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
+  }
   case ISD::SINT_TO_FP:
-  case ISD::UINT_TO_FP:
-    if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
-      // The cost of convert from or to mask vector is different from other
-      // cases. We could not use PowDiff to calculate it.
-      // For mask vector to fp, we should use the following instructions:
+  case ISD::UINT_TO_FP: {
+    unsigned IsSigned = ISD == ISD::SINT_TO_FP;
+    unsigned FCVT = IsSigned ? RISCV::VFCVT_F_X_V : RISCV::VFCVT_F_XU_V;
+    unsigned FWCVT = IsSigned ? RISCV::VFWCVT_F_X_V : RISCV::VFWCVT_F_XU_V;
+    unsigned FNCVT = IsSigned ? RISCV::VFNCVT_F_X_W : RISCV::VFNCVT_F_XU_W;
+    unsigned SrcEltSize = Src->getScalarSizeInBits();
+    unsigned DstEltSize = Dst->getScalarSizeInBits();
+
+    if (SrcEltSize == 1) {
+      // For mask vector to fp, we use:
       // vmv.v.i v8, 0
       // vmerge.vim v8, v8, -1, v0
-      // vfcvt.f.x.v v8, v8
+      // vfwcvt.f.x.v v8, v8
+      MVT ElementVT = MVT::getIntegerVT(DstEltSize >> 1);
+      MVT VecHalfVT = DstLT.second.changeVectorElementType(ElementVT);
+      return getRISCVInstructionCost({RISCV::VMV_V_I, RISCV::VMERGE_VIM},
+                                     VecHalfVT, CostKind) +
+             getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
+    }
 
-      // And for fp vector to mask, we use:
-      // vfncvt.rtz.x.f.w v9, v8
-      // vand.vi v8, v9, 1
-      // vmsne.vi v0, v8, 0
-      return 3;
+    if (DstEltSize == SrcEltSize)
+      return getRISCVInstructionCost(FCVT, DstLT.second, CostKind);
+
+    if (DstEltSize == (2 * SrcEltSize))
+      return getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
----------------
arcbbb wrote:

Thanks for catching that! I have re-implemented the cost to handle fp16 for VFHMIN.

https://github.com/llvm/llvm-project/pull/87506


More information about the llvm-commits mailing list