[llvm] [RISCV][TTI] Refactor getCastInstrCost to exit early (PR #86619)
Shih-Po Hung via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 25 23:54:52 PDT 2024
================
@@ -897,76 +897,73 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
TTI::CastContextHint CCH,
TTI::TargetCostKind CostKind,
const Instruction *I) {
- if (isa<VectorType>(Dst) && isa<VectorType>(Src)) {
- // FIXME: Need to compute legalizing cost for illegal types.
- if (!isTypeLegal(Src) || !isTypeLegal(Dst))
- return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
-
- // Skip if element size of Dst or Src is bigger than ELEN.
- if (Src->getScalarSizeInBits() > ST->getELen() ||
- Dst->getScalarSizeInBits() > ST->getELen())
- return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
-
- int ISD = TLI->InstructionOpcodeToISD(Opcode);
- assert(ISD && "Invalid opcode");
-
- // FIXME: Need to consider vsetvli and lmul.
- int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
- (int)Log2_32(Src->getScalarSizeInBits());
- switch (ISD) {
- case ISD::SIGN_EXTEND:
- case ISD::ZERO_EXTEND:
- if (Src->getScalarSizeInBits() == 1) {
- // We do not use vsext/vzext to extend from mask vector.
- // Instead we use the following instructions to extend from mask vector:
- // vmv.v.i v8, 0
- // vmerge.vim v8, v8, -1, v0
- return 2;
- }
- return 1;
- case ISD::TRUNCATE:
- if (Dst->getScalarSizeInBits() == 1) {
- // We do not use several vncvt to truncate to mask vector. So we could
- // not use PowDiff to calculate it.
- // Instead we use the following instructions to truncate to mask vector:
- // vand.vi v8, v8, 1
- // vmsne.vi v0, v8, 0
- return 2;
- }
- [[fallthrough]];
- case ISD::FP_EXTEND:
- case ISD::FP_ROUND:
- // Counts of narrow/widen instructions.
- return std::abs(PowDiff);
- case ISD::FP_TO_SINT:
- case ISD::FP_TO_UINT:
- case ISD::SINT_TO_FP:
- case ISD::UINT_TO_FP:
- if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
- // The cost of convert from or to mask vector is different from other
- // cases. We could not use PowDiff to calculate it.
- // For mask vector to fp, we should use the following instructions:
- // vmv.v.i v8, 0
- // vmerge.vim v8, v8, -1, v0
- // vfcvt.f.x.v v8, v8
-
- // And for fp vector to mask, we use:
- // vfncvt.rtz.x.f.w v9, v8
- // vand.vi v8, v9, 1
- // vmsne.vi v0, v8, 0
- return 3;
- }
- if (std::abs(PowDiff) <= 1)
- return 1;
- // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
- // so it only need two conversion.
- if (Src->isIntOrIntVectorTy())
- return 2;
- // Counts of narrow/widen instructions.
- return std::abs(PowDiff);
+ bool IsVectorType = isa<VectorType>(Dst) && isa<VectorType>(Src);
+ bool IsTypeLegal = isTypeLegal(Src) && isTypeLegal(Dst) &&
+ (Src->getScalarSizeInBits() <= ST->getELen()) &&
+ (Dst->getScalarSizeInBits() <= ST->getELen());
+
+ // FIXME: Need to compute legalizing cost for illegal types.
+ if (!IsVectorType || !IsTypeLegal)
+ return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
+
+ int ISD = TLI->InstructionOpcodeToISD(Opcode);
+ assert(ISD && "Invalid opcode");
+
+ // FIXME: Need to consider vsetvli and lmul.
+ int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
+ (int)Log2_32(Src->getScalarSizeInBits());
+ switch (ISD) {
+ case ISD::SIGN_EXTEND:
+ case ISD::ZERO_EXTEND:
+ if (Src->getScalarSizeInBits() == 1) {
+ // We do not use vsext/vzext to extend from mask vector.
+ // Instead we use the following instructions to extend from mask vector:
+ // vmv.v.i v8, 0
+ // vmerge.vim v8, v8, -1, v0
+ return 2;
}
+ return 1;
+ case ISD::TRUNCATE:
+ if (Dst->getScalarSizeInBits() == 1) {
+ // We do not use several vncvt to truncate to mask vector. So we could
+ // not use PowDiff to calculate it.
+ // Instead we use the following instructions to truncate to mask vector:
+ // vand.vi v8, v8, 1
+ // vmsne.vi v0, v8, 0
+ return 2;
+ }
+ [[fallthrough]];
+ case ISD::FP_EXTEND:
+ case ISD::FP_ROUND:
+ // Counts of narrow/widen instructions.
+ return std::abs(PowDiff);
+ case ISD::FP_TO_SINT:
+ case ISD::FP_TO_UINT:
+ case ISD::SINT_TO_FP:
+ case ISD::UINT_TO_FP:
+ if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
+ // The cost of convert from or to mask vector is different from other
+ // cases. We could not use PowDiff to calculate it.
+ // For mask vector to fp, we should use the following instructions:
+ // vmv.v.i v8, 0
+ // vmerge.vim v8, v8, -1, v0
+ // vfcvt.f.x.v v8, v8
+
+ // And for fp vector to mask, we use:
+ // vfncvt.rtz.x.f.w v9, v8
+ // vand.vi v8, v9, 1
+ // vmsne.vi v0, v8, 0
+ return 3;
+ }
+ if (std::abs(PowDiff) <= 1)
+ return 1;
+ // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
+ // so it only need two conversion.
+ if (Src->isIntOrIntVectorTy())
+ return 2;
+ // Counts of narrow/widen instructions.
+ return std::abs(PowDiff);
}
- return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
}
----------------
arcbbb wrote:
Sorry for the trouble. I fix it with commit 5dc0c75aabb9811e03cc8025905fed6dc2dd7bda
https://github.com/llvm/llvm-project/pull/86619
More information about the llvm-commits
mailing list