[llvm] [RISCV][TTI] Refactor getCastInstrCost to exit early (PR #86619)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 25 19:38:43 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Shih-Po Hung (arcbbb)

<details>
<summary>Changes</summary>

To reduce the indentation by using early returns, this patch hoist the return for illegal type and non vector type earlier.

It should mostly be an NFC.

---
Full diff: https://github.com/llvm/llvm-project/pull/86619.diff


1 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+65-68) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index f75b3d3caa62f2..65142a03f0a624 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -897,76 +897,73 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
                                                TTI::CastContextHint CCH,
                                                TTI::TargetCostKind CostKind,
                                                const Instruction *I) {
-  if (isa<VectorType>(Dst) && isa<VectorType>(Src)) {
-    // FIXME: Need to compute legalizing cost for illegal types.
-    if (!isTypeLegal(Src) || !isTypeLegal(Dst))
-      return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
-
-    // Skip if element size of Dst or Src is bigger than ELEN.
-    if (Src->getScalarSizeInBits() > ST->getELen() ||
-        Dst->getScalarSizeInBits() > ST->getELen())
-      return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
-
-    int ISD = TLI->InstructionOpcodeToISD(Opcode);
-    assert(ISD && "Invalid opcode");
-
-    // FIXME: Need to consider vsetvli and lmul.
-    int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
-                  (int)Log2_32(Src->getScalarSizeInBits());
-    switch (ISD) {
-    case ISD::SIGN_EXTEND:
-    case ISD::ZERO_EXTEND:
-      if (Src->getScalarSizeInBits() == 1) {
-        // We do not use vsext/vzext to extend from mask vector.
-        // Instead we use the following instructions to extend from mask vector:
-        // vmv.v.i v8, 0
-        // vmerge.vim v8, v8, -1, v0
-        return 2;
-      }
-      return 1;
-    case ISD::TRUNCATE:
-      if (Dst->getScalarSizeInBits() == 1) {
-        // We do not use several vncvt to truncate to mask vector. So we could
-        // not use PowDiff to calculate it.
-        // Instead we use the following instructions to truncate to mask vector:
-        // vand.vi v8, v8, 1
-        // vmsne.vi v0, v8, 0
-        return 2;
-      }
-      [[fallthrough]];
-    case ISD::FP_EXTEND:
-    case ISD::FP_ROUND:
-      // Counts of narrow/widen instructions.
-      return std::abs(PowDiff);
-    case ISD::FP_TO_SINT:
-    case ISD::FP_TO_UINT:
-    case ISD::SINT_TO_FP:
-    case ISD::UINT_TO_FP:
-      if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
-        // The cost of convert from or to mask vector is different from other
-        // cases. We could not use PowDiff to calculate it.
-        // For mask vector to fp, we should use the following instructions:
-        // vmv.v.i v8, 0
-        // vmerge.vim v8, v8, -1, v0
-        // vfcvt.f.x.v v8, v8
-
-        // And for fp vector to mask, we use:
-        // vfncvt.rtz.x.f.w v9, v8
-        // vand.vi v8, v9, 1
-        // vmsne.vi v0, v8, 0
-        return 3;
-      }
-      if (std::abs(PowDiff) <= 1)
-        return 1;
-      // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
-      // so it only need two conversion.
-      if (Src->isIntOrIntVectorTy())
-        return 2;
-      // Counts of narrow/widen instructions.
-      return std::abs(PowDiff);
+  bool IsVectorType = isa<VectorType>(Dst) && isa<VectorType>(Src);
+  bool IsTypeLegal = isTypeLegal(Src) && isTypeLegal(Dst) &&
+                     (Src->getScalarSizeInBits() <= ST->getELen()) &&
+                     (Dst->getScalarSizeInBits() <= ST->getELen());
+
+  // FIXME: Need to compute legalizing cost for illegal types.
+  if (!IsVectorType || !IsTypeLegal)
+    return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
+
+  int ISD = TLI->InstructionOpcodeToISD(Opcode);
+  assert(ISD && "Invalid opcode");
+
+  // FIXME: Need to consider vsetvli and lmul.
+  int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
+                (int)Log2_32(Src->getScalarSizeInBits());
+  switch (ISD) {
+  case ISD::SIGN_EXTEND:
+  case ISD::ZERO_EXTEND:
+    if (Src->getScalarSizeInBits() == 1) {
+      // We do not use vsext/vzext to extend from mask vector.
+      // Instead we use the following instructions to extend from mask vector:
+      // vmv.v.i v8, 0
+      // vmerge.vim v8, v8, -1, v0
+      return 2;
     }
+    return 1;
+  case ISD::TRUNCATE:
+    if (Dst->getScalarSizeInBits() == 1) {
+      // We do not use several vncvt to truncate to mask vector. So we could
+      // not use PowDiff to calculate it.
+      // Instead we use the following instructions to truncate to mask vector:
+      // vand.vi v8, v8, 1
+      // vmsne.vi v0, v8, 0
+      return 2;
+    }
+    [[fallthrough]];
+  case ISD::FP_EXTEND:
+  case ISD::FP_ROUND:
+    // Counts of narrow/widen instructions.
+    return std::abs(PowDiff);
+  case ISD::FP_TO_SINT:
+  case ISD::FP_TO_UINT:
+  case ISD::SINT_TO_FP:
+  case ISD::UINT_TO_FP:
+    if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
+      // The cost of convert from or to mask vector is different from other
+      // cases. We could not use PowDiff to calculate it.
+      // For mask vector to fp, we should use the following instructions:
+      // vmv.v.i v8, 0
+      // vmerge.vim v8, v8, -1, v0
+      // vfcvt.f.x.v v8, v8
+
+      // And for fp vector to mask, we use:
+      // vfncvt.rtz.x.f.w v9, v8
+      // vand.vi v8, v9, 1
+      // vmsne.vi v0, v8, 0
+      return 3;
+    }
+    if (std::abs(PowDiff) <= 1)
+      return 1;
+    // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
+    // so it only need two conversion.
+    if (Src->isIntOrIntVectorTy())
+      return 2;
+    // Counts of narrow/widen instructions.
+    return std::abs(PowDiff);
   }
-  return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
 }
 
 unsigned RISCVTTIImpl::getEstimatedVLFor(VectorType *Ty) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/86619


More information about the llvm-commits mailing list