[llvm] [RISCV][TTI] Avoid an infinite recursion issue in getCastInstrCost (PR #110164)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 27 08:24:35 PDT 2024


================
@@ -1163,9 +1163,47 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
       Dst->getScalarSizeInBits() > ST->getELen())
     return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
 
+  int ISD = TLI->InstructionOpcodeToISD(Opcode);
+  assert(ISD && "Invalid opcode");
   std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src);
   std::pair<InstructionCost, MVT> DstLT = getTypeLegalizationCost(Dst);
 
+  // Handle i1 source and dest cases *before* calling logic in BasicTTI.
+  // The shared implementation doesn't model vector widening during legalization
+  // and instead assumes scalarization.  In order to scalarize an <N x i1>
+  // vector, we need to extend/trunc to/from i8.  If we don't special case
+  // this, we can get an infinite recursion cycle.
+  switch (ISD) {
+  default:
+    break;
+  case ISD::SIGN_EXTEND:
+  case ISD::ZERO_EXTEND:
+    if (Src->getScalarSizeInBits() == 1) {
+      // We do not use vsext/vzext to extend from mask vector.
+      // Instead we use the following instructions to extend from mask vector:
+      // vmv.v.i v8, 0
+      // vmerge.vim v8, v8, -1, v0
+      return DstLT.first *
+                 getRISCVInstructionCost({RISCV::VMV_V_I, RISCV::VMERGE_VIM},
+                                         DstLT.second, CostKind) +
+             DstLT.first - 1;
----------------
preames wrote:

Follow up posted here: https://github.com/llvm/llvm-project/pull/110282

https://github.com/llvm/llvm-project/pull/110164


More information about the llvm-commits mailing list