[clang] [AArch64] Refactor implementation of FP8 types (NFC) (PR #123604)
via cfe-commits
cfe-commits at lists.llvm.org
Mon Jan 20 04:34:45 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen
Author: Momchil Velikov (momchil-velikov)
<details>
<summary>Changes</summary>
- The FP8 scalar type (`__mfp8`) was described as a vector type
- The FP8 vector types were described/assumed to have integer element type (the element type ought to be `__mfp8`)
- Add support for `m` type specifier (denoting `__mfp8`) in `DecodeTypeFromStr` and create builtin function prototypes using that specifier, instead of `int8_t`
Supersedes https://github.com/llvm/llvm-project/pull/118969
---
Patch is 59.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123604.diff
22 Files Affected:
- (modified) clang/include/clang/AST/Type.h (+5)
- (modified) clang/include/clang/Basic/AArch64SVEACLETypes.def (+17-20)
- (modified) clang/include/clang/Basic/TargetBuiltins.h (+3-1)
- (modified) clang/lib/AST/ASTContext.cpp (+18-12)
- (modified) clang/lib/AST/ItaniumMangle.cpp (+6-1)
- (modified) clang/lib/AST/Type.cpp (+1-3)
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+1)
- (modified) clang/lib/CodeGen/CGExpr.cpp (+9-2)
- (modified) clang/lib/CodeGen/CodeGenTypes.cpp (+13-5)
- (modified) clang/lib/CodeGen/Targets/AArch64.cpp (+6-8)
- (modified) clang/lib/Sema/SemaARM.cpp (+2)
- (modified) clang/lib/Sema/SemaExpr.cpp (+6-1)
- (modified) clang/lib/Sema/SemaType.cpp (+2-1)
- (added) clang/test/CodeGen/AArch64/builtin-shufflevector-fp8.c (+123)
- (added) clang/test/CodeGen/AArch64/fp8-cast.c (+193)
- (modified) clang/test/CodeGen/arm-mfp8.c (+53-35)
- (modified) clang/test/CodeGenCXX/aarch64-mangle-neon-vectors.cpp (+7)
- (modified) clang/test/CodeGenCXX/mangle-neon-vectors.cpp (+11)
- (added) clang/test/Sema/aarch64-fp8-cast.c (+104)
- (modified) clang/test/Sema/arm-mfp8.cpp (+22-12)
- (modified) clang/utils/TableGen/NeonEmitter.cpp (+3-8)
- (modified) clang/utils/TableGen/SveEmitter.cpp (+2-2)
``````````diff
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 3457d524c63aaa..1d9743520654eb 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2518,6 +2518,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
bool isFloat32Type() const;
bool isDoubleType() const;
bool isBFloat16Type() const;
+ bool isMFloat8Type() const;
bool isFloat128Type() const;
bool isIbm128Type() const;
bool isRealType() const; // C99 6.2.5p17 (real floating + integer)
@@ -8537,6 +8538,10 @@ inline bool Type::isBFloat16Type() const {
return isSpecificBuiltinType(BuiltinType::BFloat16);
}
+inline bool Type::isMFloat8Type() const {
+ return isSpecificBuiltinType(BuiltinType::MFloat8);
+}
+
inline bool Type::isFloat128Type() const {
return isSpecificBuiltinType(BuiltinType::Float128);
}
diff --git a/clang/include/clang/Basic/AArch64SVEACLETypes.def b/clang/include/clang/Basic/AArch64SVEACLETypes.def
index 063cac1f4a58ee..a408bb0c54057c 100644
--- a/clang/include/clang/Basic/AArch64SVEACLETypes.def
+++ b/clang/include/clang/Basic/AArch64SVEACLETypes.def
@@ -57,6 +57,11 @@
// - IsBF true for vector of brain float elements.
//===----------------------------------------------------------------------===//
+#ifndef SVE_SCALAR_TYPE
+#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
+ SVE_TYPE(Name, Id, SingletonId)
+#endif
+
#ifndef SVE_VECTOR_TYPE
#define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
SVE_TYPE(Name, Id, SingletonId)
@@ -72,6 +77,11 @@
SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, false, true)
#endif
+#ifndef SVE_VECTOR_TYPE_MFLOAT
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
+ SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, false, false)
+#endif
+
#ifndef SVE_VECTOR_TYPE_FLOAT
#define SVE_VECTOR_TYPE_FLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, true, false)
@@ -97,16 +107,6 @@
SVE_TYPE(Name, Id, SingletonId)
#endif
-#ifndef AARCH64_VECTOR_TYPE
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
- SVE_TYPE(Name, Id, SingletonId)
-#endif
-
-#ifndef AARCH64_VECTOR_TYPE_MFLOAT
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
- AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)
-#endif
-
//===- Vector point types -----------------------------------------------===//
SVE_VECTOR_TYPE_INT("__SVInt8_t", "__SVInt8_t", SveInt8, SveInt8Ty, 16, 8, 1, true)
@@ -125,8 +125,7 @@ SVE_VECTOR_TYPE_FLOAT("__SVFloat64_t", "__SVFloat64_t", SveFloat64, SveFloat64Ty
SVE_VECTOR_TYPE_BFLOAT("__SVBfloat16_t", "__SVBfloat16_t", SveBFloat16, SveBFloat16Ty, 8, 16, 1)
-// This is a 8 bits opaque type.
-SVE_VECTOR_TYPE_INT("__SVMfloat8_t", "__SVMfloat8_t", SveMFloat8, SveMFloat8Ty, 16, 8, 1, false)
+SVE_VECTOR_TYPE_MFLOAT("__SVMfloat8_t", "__SVMfloat8_t", SveMFloat8, SveMFloat8Ty, 16, 8, 1)
//
// x2
@@ -148,7 +147,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x2_t", "svfloat64x2_t", SveFloat64x2, Sv
SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x2_t", "svbfloat16x2_t", SveBFloat16x2, SveBFloat16x2Ty, 8, 16, 2)
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, SveMFloat8x2Ty, 16, 8, 2, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, SveMFloat8x2Ty, 16, 8, 2)
//
// x3
@@ -170,7 +169,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x3_t", "svfloat64x3_t", SveFloat64x3, Sv
SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x3_t", "svbfloat16x3_t", SveBFloat16x3, SveBFloat16x3Ty, 8, 16, 3)
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, SveMFloat8x3Ty, 16, 8, 3, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, SveMFloat8x3Ty, 16, 8, 3)
//
// x4
@@ -192,7 +191,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x4_t", "svfloat64x4_t", SveFloat64x4, Sv
SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x4_t", "svbfloat16x4_t", SveBFloat16x4, SveBFloat16x4Ty, 8, 16, 4)
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, SveMFloat8x4Ty, 16, 8, 4, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, SveMFloat8x4Ty, 16, 8, 4)
SVE_PREDICATE_TYPE_ALL("__SVBool_t", "__SVBool_t", SveBool, SveBoolTy, 16, 1)
SVE_PREDICATE_TYPE_ALL("__clang_svboolx2_t", "svboolx2_t", SveBoolx2, SveBoolx2Ty, 16, 2)
@@ -200,17 +199,15 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
-AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
+SVE_SCALAR_TYPE("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 8)
#undef SVE_VECTOR_TYPE
+#undef SVE_VECTOR_TYPE_MFLOAT
#undef SVE_VECTOR_TYPE_BFLOAT
#undef SVE_VECTOR_TYPE_FLOAT
#undef SVE_VECTOR_TYPE_INT
#undef SVE_PREDICATE_TYPE
#undef SVE_PREDICATE_TYPE_ALL
#undef SVE_OPAQUE_TYPE
-#undef AARCH64_VECTOR_TYPE_MFLOAT
-#undef AARCH64_VECTOR_TYPE
+#undef SVE_SCALAR_TYPE
#undef SVE_TYPE
diff --git a/clang/include/clang/Basic/TargetBuiltins.h b/clang/include/clang/Basic/TargetBuiltins.h
index 4dc8b24ed8ae6c..83ef015018f1a1 100644
--- a/clang/include/clang/Basic/TargetBuiltins.h
+++ b/clang/include/clang/Basic/TargetBuiltins.h
@@ -208,7 +208,8 @@ namespace clang {
Float16,
Float32,
Float64,
- BFloat16
+ BFloat16,
+ MFloat8
};
NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -230,6 +231,7 @@ namespace clang {
switch (getEltType()) {
case Int8:
case Poly8:
+ case MFloat8:
return 8;
case Int16:
case Float16:
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 155dbcfcaeed31..51764095a68980 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2269,11 +2269,10 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
Width = 0; \
Align = 16; \
break;
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
- ElBits, NF) \
+#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
case BuiltinType::Id: \
- Width = NumEls * ElBits * NF; \
- Align = NumEls * ElBits; \
+ Width = Bits; \
+ Align = Bits; \
break;
#include "clang/Basic/AArch64SVEACLETypes.def"
#define PPC_VECTOR_TYPE(Name, Id, Size) \
@@ -4395,15 +4394,14 @@ ASTContext::getBuiltinVectorTypeInfo(const BuiltinType *Ty) const {
ElBits, NF) \
case BuiltinType::Id: \
return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF};
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
+ ElBits, NF) \
+ case BuiltinType::Id: \
+ return {MFloat8Ty, llvm::ElementCount::getScalable(NumEls), NF};
#define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) \
case BuiltinType::Id: \
return {BoolTy, llvm::ElementCount::getScalable(NumEls), NF};
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
- ElBits, NF) \
- case BuiltinType::Id: \
- return {getIntTypeForBitwidth(ElBits, false), \
- llvm::ElementCount::getFixed(NumEls), NF};
-#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
+#define SVE_TYPE(Name, Id, SingletonId)
#include "clang/Basic/AArch64SVEACLETypes.def"
#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, \
@@ -4465,11 +4463,16 @@ QualType ASTContext::getScalableVectorType(QualType EltTy, unsigned NumElts,
EltTySize == ElBits && NumElts == (NumEls * NF) && NumFields == 1) { \
return SingletonId; \
}
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
+ ElBits, NF) \
+ if (EltTy->isMFloat8Type() && EltTySize == ElBits && \
+ NumElts == (NumEls * NF) && NumFields == 1) { \
+ return SingletonId; \
+ }
#define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) \
if (EltTy->isBooleanType() && NumElts == (NumEls * NF) && NumFields == 1) \
return SingletonId;
-#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)
+#define SVE_TYPE(Name, Id, SingletonId)
#include "clang/Basic/AArch64SVEACLETypes.def"
} else if (Target->hasRISCVVTypes()) {
uint64_t EltTySize = getTypeSize(EltTy);
@@ -12354,6 +12357,9 @@ static QualType DecodeTypeFromStr(const char *&Str, const ASTContext &Context,
case 'p':
Type = Context.getProcessIDType();
break;
+ case 'm':
+ Type = Context.MFloat8Ty;
+ break;
}
// If there are modifiers and if we're allowed to parse them, go for it.
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 1dd936cf4fb518..49089c0ea3c8ac 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3433,7 +3433,7 @@ void CXXNameMangler::mangleType(const BuiltinType *T) {
type_name = MangledName; \
Out << (type_name == Name ? "u" : "") << type_name.size() << type_name; \
break;
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
+#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
case BuiltinType::Id: \
type_name = MangledName; \
Out << (type_name == Name ? "u" : "") << type_name.size() << type_name; \
@@ -3919,6 +3919,9 @@ void CXXNameMangler::mangleNeonVectorType(const VectorType *T) {
case BuiltinType::Float: EltName = "float32_t"; break;
case BuiltinType::Half: EltName = "float16_t"; break;
case BuiltinType::BFloat16: EltName = "bfloat16_t"; break;
+ case BuiltinType::MFloat8:
+ EltName = "mfloat8_t";
+ break;
default:
llvm_unreachable("unexpected Neon vector element type");
}
@@ -3972,6 +3975,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
return "Float64";
case BuiltinType::BFloat16:
return "Bfloat16";
+ case BuiltinType::MFloat8:
+ return "Mfloat8";
default:
llvm_unreachable("Unexpected vector element base type");
}
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index caa0ac858a1bea..fde0746a175705 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -2527,9 +2527,7 @@ bool Type::isSVESizelessBuiltinType() const {
#define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId) \
case BuiltinType::Id: \
return true;
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
- case BuiltinType::Id: \
- return false;
+#define SVE_TYPE(Name, Id, SingletonId)
#include "clang/Basic/AArch64SVEACLETypes.def"
default:
return false;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index b80833fd91884d..5fd4c3e34ebeba 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6819,6 +6819,7 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
switch (TypeFlags.getEltType()) {
case NeonTypeFlags::Int8:
case NeonTypeFlags::Poly8:
+ case NeonTypeFlags::MFloat8:
return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
case NeonTypeFlags::Int16:
case NeonTypeFlags::Poly16:
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 9a9a8c7f6eae09..eeff77db2d2684 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2414,8 +2414,15 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
Vec = Builder.CreateBitCast(Vec, IRVecTy);
// iN --> <N x i1>.
}
- Vec = Builder.CreateInsertElement(Vec, Src.getScalarVal(),
- Dst.getVectorIdx(), "vecins");
+ llvm::Value *SrcVal = Src.getScalarVal();
+ // Allow inserting `<1 x T>` into an `<N x T>`. It can happen with scalar
+ // types which are mapped to vector LLVM IR types (e.g. for implementing
+ // an ABI).
+ if (auto *EltTy = dyn_cast<llvm::FixedVectorType>(SrcVal->getType());
+ EltTy && EltTy->getNumElements() == 1)
+ SrcVal = Builder.CreateBitCast(SrcVal, EltTy->getElementType());
+ Vec = Builder.CreateInsertElement(Vec, SrcVal, Dst.getVectorIdx(),
+ "vecins");
if (IRStoreTy) {
// <N x i1> --> <iN>.
Vec = Builder.CreateBitCast(Vec, IRStoreTy);
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index 09191a4901f493..405242e97e75cb 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -505,15 +505,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
case BuiltinType::Id:
#define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId) \
case BuiltinType::Id:
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
- case BuiltinType::Id:
-#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
+#define SVE_TYPE(Name, Id, SingletonId)
#include "clang/Basic/AArch64SVEACLETypes.def"
{
ASTContext::BuiltinVectorTypeInfo Info =
Context.getBuiltinVectorTypeInfo(cast<BuiltinType>(Ty));
- auto VTy =
- llvm::VectorType::get(ConvertType(Info.ElementType), Info.EC);
+ // The `__mfp8` type maps to `<1 x i8>` which can't be used to build
+ // a <N x i8> vector type, hence bypass the call to `ConvertType` for
+ // the element type and create the vector type directly.
+ auto *EltTy = Info.ElementType->isMFloat8Type()
+ ? llvm::Type::getInt8Ty(getLLVMContext())
+ : ConvertType(Info.ElementType);
+ auto *VTy = llvm::VectorType::get(EltTy, Info.EC);
switch (Info.NumVectors) {
default:
llvm_unreachable("Expected 1, 2, 3 or 4 vectors!");
@@ -529,6 +532,9 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
}
case BuiltinType::SveCount:
return llvm::TargetExtType::get(getLLVMContext(), "aarch64.svcount");
+ case BuiltinType::MFloat8:
+ return llvm::VectorType::get(llvm::Type::getInt8Ty(getLLVMContext()), 1,
+ false);
#define PPC_VECTOR_TYPE(Name, Id, Size) \
case BuiltinType::Id: \
ResultType = \
@@ -650,6 +656,8 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
// An ext_vector_type of Bool is really a vector of bits.
llvm::Type *IRElemTy = VT->isExtVectorBoolType()
? llvm::Type::getInt1Ty(getLLVMContext())
+ : VT->getElementType()->isMFloat8Type()
+ ? llvm::Type::getInt8Ty(getLLVMContext())
: ConvertType(VT->getElementType());
ResultType = llvm::FixedVectorType::get(IRElemTy, VT->getNumElements());
break;
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 7db67ecba07c8f..057199c66f5a10 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -244,6 +244,7 @@ AArch64ABIInfo::convertFixedToScalableVectorType(const VectorType *VT) const {
case BuiltinType::SChar:
case BuiltinType::UChar:
+ case BuiltinType::MFloat8:
return llvm::ScalableVectorType::get(
llvm::Type::getInt8Ty(getVMContext()), 16);
@@ -383,10 +384,6 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
NSRN = std::min(NSRN + 1, 8u);
else {
switch (BT->getKind()) {
- case BuiltinType::MFloat8x8:
- case BuiltinType::MFloat8x16:
- NSRN = std::min(NSRN + 1, 8u);
- break;
case BuiltinType::SveBool:
case BuiltinType::SveCount:
NPRN = std::min(NPRN + 1, 4u);
@@ -629,8 +626,7 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
// but with the difference that any floating-point type is allowed,
// including __fp16.
if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
- if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
- BT->getKind() == BuiltinType::MFloat8x8)
+ if (BT->isFloatingPoint())
return true;
} else if (const VectorType *VT = Ty->getAs<VectorType>()) {
if (auto Kind = VT->getVectorKind();
@@ -781,8 +777,10 @@ bool AArch64ABIInfo::passAsPureScalableType(
NPred += Info.NumVectors;
else
NVec += Info.NumVectors;
- auto VTy = llvm::ScalableVectorType::get(CGT.ConvertType(Info.ElementType),
- Info.EC.getKnownMinValue());
+ llvm::Type *EltTy = Info.ElementType->isMFloat8Type()
+ ? llvm::Type::getInt8Ty(getVMContext())
+ : CGT.ConvertType(Info.ElementType);
+ auto *VTy = llvm::ScalableVectorType::get(EltTy, Info.EC.getKnownMinValue());
if (CoerceToSeq.size() + Info.NumVectors > 12)
return false;
diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp
index db418d80e0e09c..2620bbc97ba02a 100644
--- a/clang/lib/Sema/SemaARM.cpp
+++ b/clang/lib/Sema/SemaARM.cpp
@@ -352,6 +352,8 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
return Context.DoubleTy;
case NeonTypeFlags::BFloat16:
return Context.BFloat16Ty;
+ case NeonTypeFlags::MFloat8:
+ return Context.MFloat8Ty;
}
llvm_unreachable("Invalid NeonTypeFlag!");
}
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ae40895980d90a..2087e8b251d8c6 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -7502,7 +7502,7 @@ static bool breakDownVectorType(QualType type, uint64_t &len,
if (const VectorType *vecType = type->getAs<VectorType>()) {
len = vecType->getNumElements();
eltType = vecType->getElementType();
- assert(eltType->isScalarType());
+ assert(eltType->isScalarType() || eltType->isMFloat8Type());
return true;
}
@@ -10168,6 +10168,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
IsCompAssign);
+ // Any operation with MFloat8 type is only possible with C intrinsics
+ if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
+ (RHSVecType && RHSVecType->getElementType()->isMFloat8Type...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/123604
More information about the cfe-commits
mailing list