[llvm] [RISCV][TTI] Scale the cost of the sext/zext with LMUL (PR #86617)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 26 22:37:28 PDT 2024


================
@@ -906,23 +906,33 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
   if (!IsVectorType || !IsTypeLegal)
     return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
 
+  std::pair<InstructionCost, MVT> DstLT = getTypeLegalizationCost(Dst);
+
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
-  // FIXME: Need to consider vsetvli and lmul.
   int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
                 (int)Log2_32(Src->getScalarSizeInBits());
   switch (ISD) {
   case ISD::SIGN_EXTEND:
-  case ISD::ZERO_EXTEND:
-    if (Src->getScalarSizeInBits() == 1) {
+  case ISD::ZERO_EXTEND: {
+    const unsigned SrcEltSize = Src->getScalarSizeInBits();
+    if (SrcEltSize == 1) {
       // We do not use vsext/vzext to extend from mask vector.
       // Instead we use the following instructions to extend from mask vector:
       // vmv.v.i v8, 0
       // vmerge.vim v8, v8, -1, v0
-      return 2;
+      return getRISCVInstructionCost({RISCV::VMV_V_I, RISCV::VMERGE_VIM},
+                                     DstLT.second, CostKind);
     }
-    return 1;
+    if (PowDiff > 3)
+      return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
+    unsigned SExtOp[] = {RISCV::VSEXT_VF2, RISCV::VSEXT_VF4, RISCV::VSEXT_VF8};
+    unsigned ZExtOp[] = {RISCV::VZEXT_VF2, RISCV::VZEXT_VF4, RISCV::VZEXT_VF8};
+    unsigned Op =
+        (ISD == ISD::SIGN_EXTEND) ? SExtOp[PowDiff - 1] : ZExtOp[PowDiff - 1];
----------------
lukel97 wrote:

I guess something is calling this with a sign/zero extend to the same type, i.e. PowDiff == 0? Probably best just to defensively check for it then

https://github.com/llvm/llvm-project/pull/86617


More information about the llvm-commits mailing list