[llvm] [RISCV][TTI] Refactor getCastInstrCost to exit early (PR #86619)

Shih-Po Hung via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 25 19:38:15 PDT 2024


https://github.com/arcbbb created https://github.com/llvm/llvm-project/pull/86619

To reduce the indentation by using early returns, this patch hoist the return for illegal type and non vector type earlier.

It should mostly be an NFC.

>From cb14c3359c7b34aa02a3c0ad12c490bafbd8094a Mon Sep 17 00:00:00 2001
From: ShihPo Hung <shihpo.hung at sifive.com>
Date: Mon, 25 Mar 2024 19:28:35 -0700
Subject: [PATCH] [RISCV][TTI] Refactor getCastInstrCost to exit early

To reduce the indentation by using early returns, this patch
hoist the return for illegal type and non vector type earlier.

It should mostly be an NFC.
---
 .../Target/RISCV/RISCVTargetTransformInfo.cpp | 133 +++++++++---------
 1 file changed, 65 insertions(+), 68 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index f75b3d3caa62f2..65142a03f0a624 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -897,76 +897,73 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
                                                TTI::CastContextHint CCH,
                                                TTI::TargetCostKind CostKind,
                                                const Instruction *I) {
-  if (isa<VectorType>(Dst) && isa<VectorType>(Src)) {
-    // FIXME: Need to compute legalizing cost for illegal types.
-    if (!isTypeLegal(Src) || !isTypeLegal(Dst))
-      return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
-
-    // Skip if element size of Dst or Src is bigger than ELEN.
-    if (Src->getScalarSizeInBits() > ST->getELen() ||
-        Dst->getScalarSizeInBits() > ST->getELen())
-      return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
-
-    int ISD = TLI->InstructionOpcodeToISD(Opcode);
-    assert(ISD && "Invalid opcode");
-
-    // FIXME: Need to consider vsetvli and lmul.
-    int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
-                  (int)Log2_32(Src->getScalarSizeInBits());
-    switch (ISD) {
-    case ISD::SIGN_EXTEND:
-    case ISD::ZERO_EXTEND:
-      if (Src->getScalarSizeInBits() == 1) {
-        // We do not use vsext/vzext to extend from mask vector.
-        // Instead we use the following instructions to extend from mask vector:
-        // vmv.v.i v8, 0
-        // vmerge.vim v8, v8, -1, v0
-        return 2;
-      }
-      return 1;
-    case ISD::TRUNCATE:
-      if (Dst->getScalarSizeInBits() == 1) {
-        // We do not use several vncvt to truncate to mask vector. So we could
-        // not use PowDiff to calculate it.
-        // Instead we use the following instructions to truncate to mask vector:
-        // vand.vi v8, v8, 1
-        // vmsne.vi v0, v8, 0
-        return 2;
-      }
-      [[fallthrough]];
-    case ISD::FP_EXTEND:
-    case ISD::FP_ROUND:
-      // Counts of narrow/widen instructions.
-      return std::abs(PowDiff);
-    case ISD::FP_TO_SINT:
-    case ISD::FP_TO_UINT:
-    case ISD::SINT_TO_FP:
-    case ISD::UINT_TO_FP:
-      if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
-        // The cost of convert from or to mask vector is different from other
-        // cases. We could not use PowDiff to calculate it.
-        // For mask vector to fp, we should use the following instructions:
-        // vmv.v.i v8, 0
-        // vmerge.vim v8, v8, -1, v0
-        // vfcvt.f.x.v v8, v8
-
-        // And for fp vector to mask, we use:
-        // vfncvt.rtz.x.f.w v9, v8
-        // vand.vi v8, v9, 1
-        // vmsne.vi v0, v8, 0
-        return 3;
-      }
-      if (std::abs(PowDiff) <= 1)
-        return 1;
-      // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
-      // so it only need two conversion.
-      if (Src->isIntOrIntVectorTy())
-        return 2;
-      // Counts of narrow/widen instructions.
-      return std::abs(PowDiff);
+  bool IsVectorType = isa<VectorType>(Dst) && isa<VectorType>(Src);
+  bool IsTypeLegal = isTypeLegal(Src) && isTypeLegal(Dst) &&
+                     (Src->getScalarSizeInBits() <= ST->getELen()) &&
+                     (Dst->getScalarSizeInBits() <= ST->getELen());
+
+  // FIXME: Need to compute legalizing cost for illegal types.
+  if (!IsVectorType || !IsTypeLegal)
+    return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
+
+  int ISD = TLI->InstructionOpcodeToISD(Opcode);
+  assert(ISD && "Invalid opcode");
+
+  // FIXME: Need to consider vsetvli and lmul.
+  int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
+                (int)Log2_32(Src->getScalarSizeInBits());
+  switch (ISD) {
+  case ISD::SIGN_EXTEND:
+  case ISD::ZERO_EXTEND:
+    if (Src->getScalarSizeInBits() == 1) {
+      // We do not use vsext/vzext to extend from mask vector.
+      // Instead we use the following instructions to extend from mask vector:
+      // vmv.v.i v8, 0
+      // vmerge.vim v8, v8, -1, v0
+      return 2;
     }
+    return 1;
+  case ISD::TRUNCATE:
+    if (Dst->getScalarSizeInBits() == 1) {
+      // We do not use several vncvt to truncate to mask vector. So we could
+      // not use PowDiff to calculate it.
+      // Instead we use the following instructions to truncate to mask vector:
+      // vand.vi v8, v8, 1
+      // vmsne.vi v0, v8, 0
+      return 2;
+    }
+    [[fallthrough]];
+  case ISD::FP_EXTEND:
+  case ISD::FP_ROUND:
+    // Counts of narrow/widen instructions.
+    return std::abs(PowDiff);
+  case ISD::FP_TO_SINT:
+  case ISD::FP_TO_UINT:
+  case ISD::SINT_TO_FP:
+  case ISD::UINT_TO_FP:
+    if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
+      // The cost of convert from or to mask vector is different from other
+      // cases. We could not use PowDiff to calculate it.
+      // For mask vector to fp, we should use the following instructions:
+      // vmv.v.i v8, 0
+      // vmerge.vim v8, v8, -1, v0
+      // vfcvt.f.x.v v8, v8
+
+      // And for fp vector to mask, we use:
+      // vfncvt.rtz.x.f.w v9, v8
+      // vand.vi v8, v9, 1
+      // vmsne.vi v0, v8, 0
+      return 3;
+    }
+    if (std::abs(PowDiff) <= 1)
+      return 1;
+    // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
+    // so it only need two conversion.
+    if (Src->isIntOrIntVectorTy())
+      return 2;
+    // Counts of narrow/widen instructions.
+    return std::abs(PowDiff);
   }
-  return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
 }
 
 unsigned RISCVTTIImpl::getEstimatedVLFor(VectorType *Ty) {



More information about the llvm-commits mailing list