[clang] [clang] fix half && bfloat16 convert node expr codegen (PR #89051)

via cfe-commits cfe-commits at lists.llvm.org
Wed Sep 4 20:16:27 PDT 2024


JinjinLi868 wrote:

> The vector tests should still be added

sorry. if i remove the change of the vector. i have to remove the testcase. because, for the current code convert between vector type of half and bfloat16, it has a bug. And it will be Assert "Invalid cast!""
`CastInst *CastInst::Create(Instruction::CastOps op, Value *S, Type *Ty,
                           const Twine &Name, InsertPosition InsertBefore) {
  assert(castIsValid(op, S, Ty) && "Invalid cast!");
  // Construct and return the appropriate CastInst subclass
  switch (op) {
  case Trunc:         return new TruncInst         (S, Ty, Name, InsertBefore);
  case ZExt:          return new ZExtInst          (S, Ty, Name, InsertBefore);
  case SExt:          return new SExtInst          (S, Ty, Name, InsertBefore);
  case FPTrunc:       return new FPTruncInst       (S, Ty, Name, InsertBefore);
  case FPExt:         return new FPExtInst         (S, Ty, Name, InsertBefore);

  default:
    llvm_unreachable("Invalid opcode provided");
  }
}`
`CastInst::castIsValid(Instruction::CastOps op, Type *SrcTy, Type *DstTy) {
  if (!SrcTy->isFirstClassType() || !DstTy->isFirstClassType() ||
      SrcTy->isAggregateType() || DstTy->isAggregateType())
    return false;

  // Get the size of the types in bits, and whether we are dealing
  // with vector types, we'll need this later.
  bool SrcIsVec = isa<VectorType>(SrcTy);
  bool DstIsVec = isa<VectorType>(DstTy);
  unsigned SrcScalarBitSize = SrcTy->getScalarSizeInBits();
  unsigned DstScalarBitSize = DstTy->getScalarSizeInBits();

  // If these are vector types, get the lengths of the vectors (using zero for
  // scalar types means that checking that vector lengths match also checks that
  // scalars are not being converted to vectors or vectors to scalars).
  ElementCount SrcEC = SrcIsVec ? cast<VectorType>(SrcTy)->getElementCount()
                                : ElementCount::getFixed(0);
  ElementCount DstEC = DstIsVec ? cast<VectorType>(DstTy)->getElementCount()
                                : ElementCount::getFixed(0);

  // Switch on the opcode provided
  switch (op) {
  default: return false; // This is an input error
  case Instruction::FPExt:
    return SrcTy->isFPOrFPVectorTy() && DstTy->isFPOrFPVectorTy() &&
           SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize;
`
now, for the vector convert between half and bfloat16. it codegen to the FPExt. For the FPExt,  castIsValid() need SrcScalarBitSize < DstScalarBitSize;  But for vector half and bfloat16, the SrcScalarBitSize is equal to DstScalarBitSize.

https://github.com/llvm/llvm-project/pull/89051


More information about the cfe-commits mailing list