[llvm] [RISCV][TTI] Refactor getCastInstrCost to exit early (PR #86619)

Shih-Po Hung via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 25 23:54:52 PDT 2024


================
@@ -897,76 +897,73 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
                                                TTI::CastContextHint CCH,
                                                TTI::TargetCostKind CostKind,
                                                const Instruction *I) {
-  if (isa<VectorType>(Dst) && isa<VectorType>(Src)) {
-    // FIXME: Need to compute legalizing cost for illegal types.
-    if (!isTypeLegal(Src) || !isTypeLegal(Dst))
-      return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
-
-    // Skip if element size of Dst or Src is bigger than ELEN.
-    if (Src->getScalarSizeInBits() > ST->getELen() ||
-        Dst->getScalarSizeInBits() > ST->getELen())
-      return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
-
-    int ISD = TLI->InstructionOpcodeToISD(Opcode);
-    assert(ISD && "Invalid opcode");
-
-    // FIXME: Need to consider vsetvli and lmul.
-    int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
-                  (int)Log2_32(Src->getScalarSizeInBits());
-    switch (ISD) {
-    case ISD::SIGN_EXTEND:
-    case ISD::ZERO_EXTEND:
-      if (Src->getScalarSizeInBits() == 1) {
-        // We do not use vsext/vzext to extend from mask vector.
-        // Instead we use the following instructions to extend from mask vector:
-        // vmv.v.i v8, 0
-        // vmerge.vim v8, v8, -1, v0
-        return 2;
-      }
-      return 1;
-    case ISD::TRUNCATE:
-      if (Dst->getScalarSizeInBits() == 1) {
-        // We do not use several vncvt to truncate to mask vector. So we could
-        // not use PowDiff to calculate it.
-        // Instead we use the following instructions to truncate to mask vector:
-        // vand.vi v8, v8, 1
-        // vmsne.vi v0, v8, 0
-        return 2;
-      }
-      [[fallthrough]];
-    case ISD::FP_EXTEND:
-    case ISD::FP_ROUND:
-      // Counts of narrow/widen instructions.
-      return std::abs(PowDiff);
-    case ISD::FP_TO_SINT:
-    case ISD::FP_TO_UINT:
-    case ISD::SINT_TO_FP:
-    case ISD::UINT_TO_FP:
-      if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
-        // The cost of convert from or to mask vector is different from other
-        // cases. We could not use PowDiff to calculate it.
-        // For mask vector to fp, we should use the following instructions:
-        // vmv.v.i v8, 0
-        // vmerge.vim v8, v8, -1, v0
-        // vfcvt.f.x.v v8, v8
-
-        // And for fp vector to mask, we use:
-        // vfncvt.rtz.x.f.w v9, v8
-        // vand.vi v8, v9, 1
-        // vmsne.vi v0, v8, 0
-        return 3;
-      }
-      if (std::abs(PowDiff) <= 1)
-        return 1;
-      // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
-      // so it only need two conversion.
-      if (Src->isIntOrIntVectorTy())
-        return 2;
-      // Counts of narrow/widen instructions.
-      return std::abs(PowDiff);
+  bool IsVectorType = isa<VectorType>(Dst) && isa<VectorType>(Src);
+  bool IsTypeLegal = isTypeLegal(Src) && isTypeLegal(Dst) &&
+                     (Src->getScalarSizeInBits() <= ST->getELen()) &&
+                     (Dst->getScalarSizeInBits() <= ST->getELen());
+
+  // FIXME: Need to compute legalizing cost for illegal types.
+  if (!IsVectorType || !IsTypeLegal)
+    return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
+
+  int ISD = TLI->InstructionOpcodeToISD(Opcode);
+  assert(ISD && "Invalid opcode");
+
+  // FIXME: Need to consider vsetvli and lmul.
+  int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
+                (int)Log2_32(Src->getScalarSizeInBits());
+  switch (ISD) {
+  case ISD::SIGN_EXTEND:
+  case ISD::ZERO_EXTEND:
+    if (Src->getScalarSizeInBits() == 1) {
+      // We do not use vsext/vzext to extend from mask vector.
+      // Instead we use the following instructions to extend from mask vector:
+      // vmv.v.i v8, 0
+      // vmerge.vim v8, v8, -1, v0
+      return 2;
     }
+    return 1;
+  case ISD::TRUNCATE:
+    if (Dst->getScalarSizeInBits() == 1) {
+      // We do not use several vncvt to truncate to mask vector. So we could
+      // not use PowDiff to calculate it.
+      // Instead we use the following instructions to truncate to mask vector:
+      // vand.vi v8, v8, 1
+      // vmsne.vi v0, v8, 0
+      return 2;
+    }
+    [[fallthrough]];
+  case ISD::FP_EXTEND:
+  case ISD::FP_ROUND:
+    // Counts of narrow/widen instructions.
+    return std::abs(PowDiff);
+  case ISD::FP_TO_SINT:
+  case ISD::FP_TO_UINT:
+  case ISD::SINT_TO_FP:
+  case ISD::UINT_TO_FP:
+    if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
+      // The cost of convert from or to mask vector is different from other
+      // cases. We could not use PowDiff to calculate it.
+      // For mask vector to fp, we should use the following instructions:
+      // vmv.v.i v8, 0
+      // vmerge.vim v8, v8, -1, v0
+      // vfcvt.f.x.v v8, v8
+
+      // And for fp vector to mask, we use:
+      // vfncvt.rtz.x.f.w v9, v8
+      // vand.vi v8, v9, 1
+      // vmsne.vi v0, v8, 0
+      return 3;
+    }
+    if (std::abs(PowDiff) <= 1)
+      return 1;
+    // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
+    // so it only need two conversion.
+    if (Src->isIntOrIntVectorTy())
+      return 2;
+    // Counts of narrow/widen instructions.
+    return std::abs(PowDiff);
   }
-  return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
 }
----------------
arcbbb wrote:

Sorry for the trouble. I fix it with commit 5dc0c75aabb9811e03cc8025905fed6dc2dd7bda

https://github.com/llvm/llvm-project/pull/86619


More information about the llvm-commits mailing list