[llvm] [RISCV][TTI] Avoid an infinite recursion issue in getCastInstrCost (PR #110164)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 27 01:49:33 PDT 2024


================
@@ -1163,9 +1163,47 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
       Dst->getScalarSizeInBits() > ST->getELen())
     return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
 
+  int ISD = TLI->InstructionOpcodeToISD(Opcode);
+  assert(ISD && "Invalid opcode");
   std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src);
   std::pair<InstructionCost, MVT> DstLT = getTypeLegalizationCost(Dst);
 
+  // Handle i1 source and dest cases *before* calling logic in BasicTTI.
+  // The shared implementation doesn't model vector widening during legalization
+  // and instead assumes scalarization.  In order to scalarize an <N x i1>
+  // vector, we need to extend/trunc to/from i8.  If we don't special case
+  // this, we can get an infinite recursion cycle.
+  switch (ISD) {
+  default:
+    break;
+  case ISD::SIGN_EXTEND:
+  case ISD::ZERO_EXTEND:
+    if (Src->getScalarSizeInBits() == 1) {
+      // We do not use vsext/vzext to extend from mask vector.
+      // Instead we use the following instructions to extend from mask vector:
+      // vmv.v.i v8, 0
+      // vmerge.vim v8, v8, -1, v0
+      return DstLT.first *
+                 getRISCVInstructionCost({RISCV::VMV_V_I, RISCV::VMERGE_VIM},
+                                         DstLT.second, CostKind) +
+             DstLT.first - 1;
----------------
lukel97 wrote:

Is this DstLT.first - 1 the cost of gluing together two or more split ops?

Also a minor observations with this example I tried out, not related to this PR, we only need to cost 1 vmv.v.i

```llvm
define <256 x i8> @f(<256 x i1> %v) {
  %g = sext <256 x i1> %v to <256 x i8>
  ret <256 x i8> %g
}
```

```asm
f:
	vmv1r.v	v16, v8
	li	a0, 128
	vsetvli	zero, a0, e8, m8, ta, ma
	vmv.v.i	v24, 0
	vmerge.vim	v8, v24, -1, v0
	vmv1r.v	v0, v16
	vmerge.vim	v16, v24, -1, v0
        ret
```


https://github.com/llvm/llvm-project/pull/110164


More information about the llvm-commits mailing list