[clang] [clang] fix half && bfloat16 convert node expr codegen (PR #89051)
via cfe-commits
cfe-commits at lists.llvm.org
Wed Sep 4 20:16:27 PDT 2024
JinjinLi868 wrote:
> The vector tests should still be added
sorry. if i remove the change of the vector. i have to remove the testcase. because, for the current code convert between vector type of half and bfloat16, it has a bug. And it will be Assert "Invalid cast!""
`CastInst *CastInst::Create(Instruction::CastOps op, Value *S, Type *Ty,
const Twine &Name, InsertPosition InsertBefore) {
assert(castIsValid(op, S, Ty) && "Invalid cast!");
// Construct and return the appropriate CastInst subclass
switch (op) {
case Trunc: return new TruncInst (S, Ty, Name, InsertBefore);
case ZExt: return new ZExtInst (S, Ty, Name, InsertBefore);
case SExt: return new SExtInst (S, Ty, Name, InsertBefore);
case FPTrunc: return new FPTruncInst (S, Ty, Name, InsertBefore);
case FPExt: return new FPExtInst (S, Ty, Name, InsertBefore);
default:
llvm_unreachable("Invalid opcode provided");
}
}`
`CastInst::castIsValid(Instruction::CastOps op, Type *SrcTy, Type *DstTy) {
if (!SrcTy->isFirstClassType() || !DstTy->isFirstClassType() ||
SrcTy->isAggregateType() || DstTy->isAggregateType())
return false;
// Get the size of the types in bits, and whether we are dealing
// with vector types, we'll need this later.
bool SrcIsVec = isa<VectorType>(SrcTy);
bool DstIsVec = isa<VectorType>(DstTy);
unsigned SrcScalarBitSize = SrcTy->getScalarSizeInBits();
unsigned DstScalarBitSize = DstTy->getScalarSizeInBits();
// If these are vector types, get the lengths of the vectors (using zero for
// scalar types means that checking that vector lengths match also checks that
// scalars are not being converted to vectors or vectors to scalars).
ElementCount SrcEC = SrcIsVec ? cast<VectorType>(SrcTy)->getElementCount()
: ElementCount::getFixed(0);
ElementCount DstEC = DstIsVec ? cast<VectorType>(DstTy)->getElementCount()
: ElementCount::getFixed(0);
// Switch on the opcode provided
switch (op) {
default: return false; // This is an input error
case Instruction::FPExt:
return SrcTy->isFPOrFPVectorTy() && DstTy->isFPOrFPVectorTy() &&
SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize;
`
now, for the vector convert between half and bfloat16. it codegen to the FPExt. For the FPExt, castIsValid() need SrcScalarBitSize < DstScalarBitSize; But for vector half and bfloat16, the SrcScalarBitSize is equal to DstScalarBitSize.
https://github.com/llvm/llvm-project/pull/89051
More information about the cfe-commits
mailing list