[clang] [NFC][Clang][AArch64]Refactor implementation of Neon vectors MFloat8… (PR #114804)

via cfe-commits cfe-commits at lists.llvm.org
Mon Nov 4 06:44:49 PST 2024


https://github.com/CarolineConcatto created https://github.com/llvm/llvm-project/pull/114804

…x8 and MFloat8x16

This patch removes the builtins for MFloat8x8 and Mfloat8x16 and build these types the same way the other neon vectors are build. It uses the scalar type(mfloat8).

>From 5d027a374f489e22ba929445b8a701a0461b38b5 Mon Sep 17 00:00:00 2001
From: Caroline Concatto <caroline.concatto at arm.com>
Date: Wed, 30 Oct 2024 16:20:16 +0000
Subject: [PATCH] [NFC][Clang][AArch64]Refactor implementation of Neon vectors 
 MFloat8x8 and MFloat8x16

This patch removes the builtins for MFloat8x8 and Mfloat8x16 and build these types
the same way the other neon vectors are build. It uses the scalar type(mfloat8).
---
 clang/include/clang/AST/Type.h                |  5 +++
 .../clang/Basic/AArch64SVEACLETypes.def       |  2 -
 clang/include/clang/Basic/TargetBuiltins.h    |  4 +-
 clang/include/clang/Basic/arm_neon_incl.td    |  1 +
 clang/lib/AST/ItaniumMangle.cpp               |  2 +
 clang/lib/CodeGen/CGBuiltin.cpp               |  2 +
 clang/lib/CodeGen/CodeGenTypes.cpp            |  5 +++
 clang/lib/CodeGen/Targets/AArch64.cpp         |  7 +---
 clang/lib/Sema/SemaARM.cpp                    |  2 +
 clang/lib/Sema/SemaExpr.cpp                   |  5 +++
 clang/lib/Sema/SemaType.cpp                   |  3 +-
 clang/test/CodeGen/arm-mfp8.c                 |  4 +-
 clang/test/Sema/arm-mfp8.cpp                  | 28 ++++++-------
 clang/utils/TableGen/NeonEmitter.cpp          | 42 ++++++++++++-------
 14 files changed, 70 insertions(+), 42 deletions(-)

diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index ba3161c366f4d9..23413f498b4503 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2521,6 +2521,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFloat32Type() const;
   bool isDoubleType() const;
   bool isBFloat16Type() const;
+  bool isMFloat8Type() const;
   bool isFloat128Type() const;
   bool isIbm128Type() const;
   bool isRealType() const;         // C99 6.2.5p17 (real floating + integer)
@@ -8527,6 +8528,10 @@ inline bool Type::isBFloat16Type() const {
   return isSpecificBuiltinType(BuiltinType::BFloat16);
 }
 
+inline bool Type::isMFloat8Type() const {
+  return isSpecificBuiltinType(BuiltinType::MFloat8);
+}
+
 inline bool Type::isFloat128Type() const {
   return isSpecificBuiltinType(BuiltinType::Float128);
 }
diff --git a/clang/include/clang/Basic/AArch64SVEACLETypes.def b/clang/include/clang/Basic/AArch64SVEACLETypes.def
index 62f6087e962466..1734ed6cf31780 100644
--- a/clang/include/clang/Basic/AArch64SVEACLETypes.def
+++ b/clang/include/clang/Basic/AArch64SVEACLETypes.def
@@ -201,8 +201,6 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
 SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
 
 AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8_t", "__MFloat8_t", MFloat8, MFloat8Ty, 1, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
 
 #undef SVE_VECTOR_TYPE
 #undef SVE_VECTOR_TYPE_BFLOAT
diff --git a/clang/include/clang/Basic/TargetBuiltins.h b/clang/include/clang/Basic/TargetBuiltins.h
index d0f41b17c154f3..c0f9a98b433560 100644
--- a/clang/include/clang/Basic/TargetBuiltins.h
+++ b/clang/include/clang/Basic/TargetBuiltins.h
@@ -198,7 +198,8 @@ namespace clang {
       Float16,
       Float32,
       Float64,
-      BFloat16
+      BFloat16,
+      MFloat8
     };
 
     NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -220,6 +221,7 @@ namespace clang {
       switch (getEltType()) {
       case Int8:
       case Poly8:
+      case MFloat8:
         return 8;
       case Int16:
       case Float16:
diff --git a/clang/include/clang/Basic/arm_neon_incl.td b/clang/include/clang/Basic/arm_neon_incl.td
index b088e0794cdea3..fd800e5a6278e4 100644
--- a/clang/include/clang/Basic/arm_neon_incl.td
+++ b/clang/include/clang/Basic/arm_neon_incl.td
@@ -218,6 +218,7 @@ def OP_UNAVAILABLE : Operation {
 // h: half-float
 // d: double
 // b: bfloat16
+// m: mfloat8
 //
 // Typespec modifiers
 // ------------------
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index b3e46508cf596d..59dd35ffb9c1b3 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3902,6 +3902,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
     return "Float64";
   case BuiltinType::BFloat16:
     return "Bfloat16";
+  case BuiltinType::MFloat8:
+    return "Mfloat8";
   default:
     llvm_unreachable("Unexpected vector element base type");
   }
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 65d7f5c54a1913..8d8231f47b025e 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6513,6 +6513,8 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
   case NeonTypeFlags::Int8:
   case NeonTypeFlags::Poly8:
     return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
+  case NeonTypeFlags::MFloat8:
+    return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
   case NeonTypeFlags::Int16:
   case NeonTypeFlags::Poly16:
     return llvm::FixedVectorType::get(CGF->Int16Ty, V1Ty ? 1 : (4 << IsQuad));
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index 09191a4901f493..500d4dcd2cc392 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -647,6 +647,11 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
   case Type::ExtVector:
   case Type::Vector: {
     const auto *VT = cast<VectorType>(Ty);
+    if (VT->getElementType()->isMFloat8Type()) {
+      ResultType = llvm::FixedVectorType::get(
+          llvm::Type::getInt8Ty(getLLVMContext()), VT->getNumElements());
+      break;
+    }
     // An ext_vector_type of Bool is really a vector of bits.
     llvm::Type *IRElemTy = VT->isExtVectorBoolType()
                                ? llvm::Type::getInt1Ty(getLLVMContext())
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 9320c6ef06efab..7b727a04740680 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -375,10 +375,6 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
         NSRN = std::min(NSRN + 1, 8u);
       else {
         switch (BT->getKind()) {
-        case BuiltinType::MFloat8x8:
-        case BuiltinType::MFloat8x16:
-          NSRN = std::min(NSRN + 1, 8u);
-          break;
         case BuiltinType::SveBool:
         case BuiltinType::SveCount:
           NPRN = std::min(NPRN + 1, 4u);
@@ -620,8 +616,7 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
   // but with the difference that any floating-point type is allowed,
   // including __fp16.
   if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
-    if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
-        BT->getKind() == BuiltinType::MFloat8x8)
+    if (BT->isFloatingPoint())
       return true;
   } else if (const VectorType *VT = Ty->getAs<VectorType>()) {
     if (auto Kind = VT->getVectorKind();
diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp
index c3a6e5ef8a9d44..7ed0654169723b 100644
--- a/clang/lib/Sema/SemaARM.cpp
+++ b/clang/lib/Sema/SemaARM.cpp
@@ -323,6 +323,8 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
   switch (Flags.getEltType()) {
   case NeonTypeFlags::Int8:
     return Flags.isUnsigned() ? Context.UnsignedCharTy : Context.SignedCharTy;
+  case NeonTypeFlags::MFloat8:
+    return Context.MFloat8Ty;
   case NeonTypeFlags::Int16:
     return Flags.isUnsigned() ? Context.UnsignedShortTy : Context.ShortTy;
   case NeonTypeFlags::Int32:
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff6616901016ab..a4f2588651e3c1 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -10156,6 +10156,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
     return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
                                               IsCompAssign);
 
+  // Any operation with MFloat8 type is only possible with C intrinsics
+  if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
+      (RHSVecType && RHSVecType->getElementType()->isMFloat8Type()))
+    return InvalidOperands(Loc, LHS, RHS);
+
   // AltiVec-style "vector bool op vector bool" combinations are allowed
   // for some operators but not others.
   if (!AllowBothBool && LHSVecType &&
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 5d043e6684573c..1ae95c42228ce0 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -8180,7 +8180,8 @@ static bool isPermittedNeonBaseType(QualType &Ty, VectorKind VecKind, Sema &S) {
          BTy->getKind() == BuiltinType::ULongLong ||
          BTy->getKind() == BuiltinType::Float ||
          BTy->getKind() == BuiltinType::Half ||
-         BTy->getKind() == BuiltinType::BFloat16;
+         BTy->getKind() == BuiltinType::BFloat16 ||
+         BTy->getKind() == BuiltinType::MFloat8;
 }
 
 static bool verifyValidIntegerConstantExpr(Sema &S, const ParsedAttr &Attr,
diff --git a/clang/test/CodeGen/arm-mfp8.c b/clang/test/CodeGen/arm-mfp8.c
index 8c817fd5be1c9b..c94410986a7e5c 100644
--- a/clang/test/CodeGen/arm-mfp8.c
+++ b/clang/test/CodeGen/arm-mfp8.c
@@ -15,7 +15,7 @@
 // CHECK-C-NEXT:    [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
 // CHECK-C-NEXT:    ret <16 x i8> [[TMP0]]
 //
-// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_tu14__MFloat8x16_t(
+// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_t14__Mfloat8x16_t(
 // CHECK-CXX-SAME: <16 x i8> [[V:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-CXX-NEXT:  [[ENTRY:.*:]]
 // CHECK-CXX-NEXT:    [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
@@ -35,7 +35,7 @@ mfloat8x16_t test_ret_mfloat8x16_t(mfloat8x16_t v) {
 // CHECK-C-NEXT:    [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
 // CHECK-C-NEXT:    ret <8 x i8> [[TMP0]]
 //
-// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_tu13__MFloat8x8_t(
+// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_t13__Mfloat8x8_t(
 // CHECK-CXX-SAME: <8 x i8> [[V:%.*]]) #[[ATTR0]] {
 // CHECK-CXX-NEXT:  [[ENTRY:.*:]]
 // CHECK-CXX-NEXT:    [[V_ADDR:%.*]] = alloca <8 x i8>, align 8
diff --git a/clang/test/Sema/arm-mfp8.cpp b/clang/test/Sema/arm-mfp8.cpp
index e882c382522c22..5d098c40e17a65 100644
--- a/clang/test/Sema/arm-mfp8.cpp
+++ b/clang/test/Sema/arm-mfp8.cpp
@@ -11,23 +11,22 @@ void test_vector_sve(svmfloat8_t a, svuint8_t c) {
   a / c;  // sve-error {{cannot convert between vector type 'svuint8_t' (aka '__SVUint8_t') and vector type 'svmfloat8_t' (aka '__SVMfloat8_t') as implicit conversion would cause truncation}}
 }
 
-
 #include <arm_neon.h>
 
 void test_vector(mfloat8x8_t a, mfloat8x16_t b, uint8x8_t c) {
-  a + b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a - b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a * b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a / b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-
-  a + c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a - c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a * c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a / c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  c + b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c - b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c * b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c / b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+  a + b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a - b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a * b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a / b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+
+  a + c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a - c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a * c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a / c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  c + b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c - b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c * b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c / b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
 }
 __mfp8 test_static_cast_from_char(char in) {
   return static_cast<__mfp8>(in); // scalar-error {{static_cast from 'char' to '__mfp8' (aka '__MFloat8_t') is not allowed}}
@@ -60,4 +59,3 @@ void test(bool b) {
   u8 = mfp8;   // scalar-error {{assigning to 'char' from incompatible type '__mfp8' (aka '__MFloat8_t')}}
   mfp8 + (b ? u8 : mfp8);  // scalar-error {{incompatible operand types ('char' and '__mfp8' (aka '__MFloat8_t'))}}
 }
-
diff --git a/clang/utils/TableGen/NeonEmitter.cpp b/clang/utils/TableGen/NeonEmitter.cpp
index c6d82646b40de2..de7f8aa1d73769 100644
--- a/clang/utils/TableGen/NeonEmitter.cpp
+++ b/clang/utils/TableGen/NeonEmitter.cpp
@@ -101,7 +101,8 @@ enum EltType {
   Float16,
   Float32,
   Float64,
-  BFloat16
+  BFloat16,
+  MFloat8
 };
 
 } // end namespace NeonTypeFlags
@@ -143,14 +144,7 @@ class Type {
 private:
   TypeSpec TS;
 
-  enum TypeKind {
-    Void,
-    Float,
-    SInt,
-    UInt,
-    Poly,
-    BFloat16
-  };
+  enum TypeKind { Void, Float, SInt, UInt, Poly, BFloat16, MFloat8 };
   TypeKind Kind;
   bool Immediate, Constant, Pointer;
   // ScalarForMangling and NoManglingQ are really not suited to live here as
@@ -203,6 +197,7 @@ class Type {
   bool isLong() const { return isInteger() && ElementBitwidth == 64; }
   bool isVoid() const { return Kind == Void; }
   bool isBFloat16() const { return Kind == BFloat16; }
+  bool isMFloat8() const { return Kind == MFloat8; }
   unsigned getNumElements() const { return Bitwidth / ElementBitwidth; }
   unsigned getSizeInBits() const { return Bitwidth; }
   unsigned getElementSizeInBits() const { return ElementBitwidth; }
@@ -657,6 +652,8 @@ std::string Type::str() const {
     S += "float";
   else if (isBFloat16())
     S += "bfloat";
+  else if (isMFloat8())
+    S += "mfloat";
   else
     S += "int";
 
@@ -699,6 +696,9 @@ std::string Type::builtin_str() const {
   else if (isBFloat16()) {
     assert(ElementBitwidth == 16 && "BFloat16 can only be 16 bits");
     S += "y";
+  } else if (isMFloat8()) {
+    assert(ElementBitwidth == 8 && "BFloat16 can only be 8 bits");
+    S += "m";
   } else
     switch (ElementBitwidth) {
     case 16: S += "h"; break;
@@ -758,6 +758,10 @@ unsigned Type::getNeonEnum() const {
     Base = (unsigned)NeonTypeFlags::BFloat16;
   }
 
+  if (isMFloat8()) {
+    Base = (unsigned)NeonTypeFlags::MFloat8;
+  }
+
   if (Bitwidth == 128)
     Base |= (unsigned)NeonTypeFlags::QuadFlag;
   if (isInteger() && !isSigned())
@@ -779,6 +783,8 @@ Type Type::fromTypedefName(StringRef Name) {
     T.Kind = Poly;
   } else if (Name.consume_front("bfloat")) {
     T.Kind = BFloat16;
+  } else if (Name.consume_front("mfloat")) {
+    T.Kind = MFloat8;
   } else {
     assert(Name.starts_with("int"));
     Name = Name.drop_front(3);
@@ -879,6 +885,10 @@ void Type::applyTypespec(bool &Quad) {
       Kind = BFloat16;
       ElementBitwidth = 16;
       break;
+    case 'm':
+      Kind = MFloat8;
+      ElementBitwidth = 8;
+      break;
     default:
       llvm_unreachable("Unhandled type code!");
     }
@@ -993,6 +1003,9 @@ std::string Intrinsic::getInstTypeCode(Type T, ClassKind CK) const {
   if (T.isBFloat16())
     return "bf16";
 
+  if (T.isMFloat8())
+    return "mfp8";
+
   if (T.isPoly())
     typeCode = 'p';
   else if (T.isInteger())
@@ -1030,7 +1043,7 @@ std::string Intrinsic::getBuiltinTypeStr() {
 
   Type RetT = getReturnType();
   if ((LocalCK == ClassI || LocalCK == ClassW) && RetT.isScalar() &&
-      !RetT.isFloating() && !RetT.isBFloat16())
+      !RetT.isFloating() && !RetT.isBFloat16() && !RetT.isMFloat8())
     RetT.makeInteger(RetT.getElementSizeInBits(), false);
 
   // Since the return value must be one type, return a vector type of the
@@ -2270,7 +2283,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
   for (auto &TS : TDTypeVec) {
     bool IsA64 = false;
     Type T(TS, ".");
-    if (T.isDouble())
+    if (T.isDouble() || T.isMFloat8())
       IsA64 = true;
 
     if (InIfdef && !IsA64) {
@@ -2303,7 +2316,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
     for (auto &TS : TDTypeVec) {
       bool IsA64 = false;
       Type T(TS, ".");
-      if (T.isDouble())
+      if (T.isDouble() || T.isMFloat8())
         IsA64 = true;
 
       if (InIfdef && !IsA64) {
@@ -2589,8 +2602,7 @@ void NeonEmitter::runVectorTypes(raw_ostream &OS) {
 
   OS << "#if defined(__aarch64__) || defined(__arm64ec__)\n";
   OS << "typedef __MFloat8_t __mfp8;\n";
-  OS << "typedef __MFloat8x8_t mfloat8x8_t;\n";
-  OS << "typedef __MFloat8x16_t mfloat8x16_t;\n";
+  OS << "typedef __mfp8 mfloat8_t;\n";
   OS << "typedef double float64_t;\n";
   OS << "#endif\n\n";
 
@@ -2648,7 +2660,7 @@ __arm_set_fpm_lscale2(fpm_t __fpm, uint64_t __scale) {
 
 )";
 
-  emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlhQhfQfdQd", OS);
+  emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlmQmhQhfQfdQd", OS);
 
   emitNeonTypeDefs("bQb", OS);
   OS << "#endif // __ARM_NEON_TYPES_H\n";



More information about the cfe-commits mailing list