[clang] [NFC][Clang][AArch64]Refactor implementation of Neon vectors MFloat8… (PR #114804)

via cfe-commits cfe-commits at lists.llvm.org
Thu Nov 21 01:38:23 PST 2024


https://github.com/CarolineConcatto updated https://github.com/llvm/llvm-project/pull/114804

>From fc97e36c08964c42debda9ad1a289d15b5327d23 Mon Sep 17 00:00:00 2001
From: Caroline Concatto <caroline.concatto at arm.com>
Date: Wed, 30 Oct 2024 16:20:16 +0000
Subject: [PATCH 1/4] [NFC][Clang][AArch64]Refactor implementation of Neon
 vectors  MFloat8x8 and MFloat8x16

This patch removes the builtins for MFloat8x8 and Mfloat8x16 and build these types
the same way the other neon vectors are build. It uses the scalar type(mfloat8).
---
 clang/include/clang/AST/Type.h                |  5 +++
 .../clang/Basic/AArch64SVEACLETypes.def       |  2 -
 clang/include/clang/Basic/TargetBuiltins.h    |  4 +-
 clang/include/clang/Basic/arm_neon_incl.td    |  1 +
 clang/lib/AST/ItaniumMangle.cpp               |  2 +
 clang/lib/CodeGen/CGBuiltin.cpp               |  1 +
 clang/lib/CodeGen/CodeGenTypes.cpp            |  5 +++
 clang/lib/CodeGen/Targets/AArch64.cpp         |  7 +---
 clang/lib/Sema/SemaARM.cpp                    |  2 +
 clang/lib/Sema/SemaExpr.cpp                   |  5 +++
 clang/lib/Sema/SemaType.cpp                   |  3 +-
 clang/test/CodeGen/arm-mfp8.c                 |  4 +-
 clang/test/Sema/arm-mfp8.cpp                  | 27 ++++++------
 clang/utils/TableGen/NeonEmitter.cpp          | 41 ++++++++++++-------
 14 files changed, 68 insertions(+), 41 deletions(-)

diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 8979129017163b..8e76d19495a6a6 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2534,6 +2534,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFloat32Type() const;
   bool isDoubleType() const;
   bool isBFloat16Type() const;
+  bool isMFloat8Type() const;
   bool isFloat128Type() const;
   bool isIbm128Type() const;
   bool isRealType() const;         // C99 6.2.5p17 (real floating + integer)
@@ -8542,6 +8543,10 @@ inline bool Type::isBFloat16Type() const {
   return isSpecificBuiltinType(BuiltinType::BFloat16);
 }
 
+inline bool Type::isMFloat8Type() const {
+  return isSpecificBuiltinType(BuiltinType::MFloat8);
+}
+
 inline bool Type::isFloat128Type() const {
   return isSpecificBuiltinType(BuiltinType::Float128);
 }
diff --git a/clang/include/clang/Basic/AArch64SVEACLETypes.def b/clang/include/clang/Basic/AArch64SVEACLETypes.def
index 063cac1f4a58ee..2dd2754e778d60 100644
--- a/clang/include/clang/Basic/AArch64SVEACLETypes.def
+++ b/clang/include/clang/Basic/AArch64SVEACLETypes.def
@@ -201,8 +201,6 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
 SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
 
 AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
 
 #undef SVE_VECTOR_TYPE
 #undef SVE_VECTOR_TYPE_BFLOAT
diff --git a/clang/include/clang/Basic/TargetBuiltins.h b/clang/include/clang/Basic/TargetBuiltins.h
index 89ebf5758a5b55..86432601115d44 100644
--- a/clang/include/clang/Basic/TargetBuiltins.h
+++ b/clang/include/clang/Basic/TargetBuiltins.h
@@ -200,7 +200,8 @@ namespace clang {
       Float16,
       Float32,
       Float64,
-      BFloat16
+      BFloat16,
+      MFloat8
     };
 
     NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -222,6 +223,7 @@ namespace clang {
       switch (getEltType()) {
       case Int8:
       case Poly8:
+      case MFloat8:
         return 8;
       case Int16:
       case Float16:
diff --git a/clang/include/clang/Basic/arm_neon_incl.td b/clang/include/clang/Basic/arm_neon_incl.td
index b088e0794cdea3..fd800e5a6278e4 100644
--- a/clang/include/clang/Basic/arm_neon_incl.td
+++ b/clang/include/clang/Basic/arm_neon_incl.td
@@ -218,6 +218,7 @@ def OP_UNAVAILABLE : Operation {
 // h: half-float
 // d: double
 // b: bfloat16
+// m: mfloat8
 //
 // Typespec modifiers
 // ------------------
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 14bc260d0245fb..2858202426f1ee 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3901,6 +3901,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
     return "Float64";
   case BuiltinType::BFloat16:
     return "Bfloat16";
+  case BuiltinType::MFloat8:
+    return "Mfloat8";
   default:
     llvm_unreachable("Unexpected vector element base type");
   }
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e03cb6e8c3c861..278c3b17110740 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6594,6 +6594,7 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
   switch (TypeFlags.getEltType()) {
   case NeonTypeFlags::Int8:
   case NeonTypeFlags::Poly8:
+  case NeonTypeFlags::MFloat8:
     return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
   case NeonTypeFlags::Int16:
   case NeonTypeFlags::Poly16:
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index 09191a4901f493..500d4dcd2cc392 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -647,6 +647,11 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
   case Type::ExtVector:
   case Type::Vector: {
     const auto *VT = cast<VectorType>(Ty);
+    if (VT->getElementType()->isMFloat8Type()) {
+      ResultType = llvm::FixedVectorType::get(
+          llvm::Type::getInt8Ty(getLLVMContext()), VT->getNumElements());
+      break;
+    }
     // An ext_vector_type of Bool is really a vector of bits.
     llvm::Type *IRElemTy = VT->isExtVectorBoolType()
                                ? llvm::Type::getInt1Ty(getLLVMContext())
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 9320c6ef06efab..7b727a04740680 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -375,10 +375,6 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
         NSRN = std::min(NSRN + 1, 8u);
       else {
         switch (BT->getKind()) {
-        case BuiltinType::MFloat8x8:
-        case BuiltinType::MFloat8x16:
-          NSRN = std::min(NSRN + 1, 8u);
-          break;
         case BuiltinType::SveBool:
         case BuiltinType::SveCount:
           NPRN = std::min(NPRN + 1, 4u);
@@ -620,8 +616,7 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
   // but with the difference that any floating-point type is allowed,
   // including __fp16.
   if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
-    if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
-        BT->getKind() == BuiltinType::MFloat8x8)
+    if (BT->isFloatingPoint())
       return true;
   } else if (const VectorType *VT = Ty->getAs<VectorType>()) {
     if (auto Kind = VT->getVectorKind();
diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp
index c3a6e5ef8a9d44..7ed0654169723b 100644
--- a/clang/lib/Sema/SemaARM.cpp
+++ b/clang/lib/Sema/SemaARM.cpp
@@ -323,6 +323,8 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
   switch (Flags.getEltType()) {
   case NeonTypeFlags::Int8:
     return Flags.isUnsigned() ? Context.UnsignedCharTy : Context.SignedCharTy;
+  case NeonTypeFlags::MFloat8:
+    return Context.MFloat8Ty;
   case NeonTypeFlags::Int16:
     return Flags.isUnsigned() ? Context.UnsignedShortTy : Context.ShortTy;
   case NeonTypeFlags::Int32:
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index c9de7811adfe0a..88a13b02971693 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -10199,6 +10199,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
     return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
                                               IsCompAssign);
 
+  // Any operation with MFloat8 type is only possible with C intrinsics
+  if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
+      (RHSVecType && RHSVecType->getElementType()->isMFloat8Type()))
+    return InvalidOperands(Loc, LHS, RHS);
+
   // AltiVec-style "vector bool op vector bool" combinations are allowed
   // for some operators but not others.
   if (!AllowBothBool && LHSVecType &&
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 4fac71ba095668..c3d0b1d3161b39 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -8240,7 +8240,8 @@ static bool isPermittedNeonBaseType(QualType &Ty, VectorKind VecKind, Sema &S) {
          BTy->getKind() == BuiltinType::ULongLong ||
          BTy->getKind() == BuiltinType::Float ||
          BTy->getKind() == BuiltinType::Half ||
-         BTy->getKind() == BuiltinType::BFloat16;
+         BTy->getKind() == BuiltinType::BFloat16 ||
+         BTy->getKind() == BuiltinType::MFloat8;
 }
 
 static bool verifyValidIntegerConstantExpr(Sema &S, const ParsedAttr &Attr,
diff --git a/clang/test/CodeGen/arm-mfp8.c b/clang/test/CodeGen/arm-mfp8.c
index bf91066335a25c..8c3cc18cc6453e 100644
--- a/clang/test/CodeGen/arm-mfp8.c
+++ b/clang/test/CodeGen/arm-mfp8.c
@@ -15,7 +15,7 @@
 // CHECK-C-NEXT:    [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
 // CHECK-C-NEXT:    ret <16 x i8> [[TMP0]]
 //
-// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_tu14__MFloat8x16_t(
+// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_t14__Mfloat8x16_t(
 // CHECK-CXX-SAME: <16 x i8> [[V:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-CXX-NEXT:  [[ENTRY:.*:]]
 // CHECK-CXX-NEXT:    [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
@@ -35,7 +35,7 @@ mfloat8x16_t test_ret_mfloat8x16_t(mfloat8x16_t v) {
 // CHECK-C-NEXT:    [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
 // CHECK-C-NEXT:    ret <8 x i8> [[TMP0]]
 //
-// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_tu13__MFloat8x8_t(
+// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_t13__Mfloat8x8_t(
 // CHECK-CXX-SAME: <8 x i8> [[V:%.*]]) #[[ATTR0]] {
 // CHECK-CXX-NEXT:  [[ENTRY:.*:]]
 // CHECK-CXX-NEXT:    [[V_ADDR:%.*]] = alloca <8 x i8>, align 8
diff --git a/clang/test/Sema/arm-mfp8.cpp b/clang/test/Sema/arm-mfp8.cpp
index be5bc9bb71dbd2..a85bf7006bb9f8 100644
--- a/clang/test/Sema/arm-mfp8.cpp
+++ b/clang/test/Sema/arm-mfp8.cpp
@@ -44,21 +44,20 @@ void test_vector_sve(svmfloat8_t a, svuint8_t c) {
   a / c;  // sve-error {{cannot convert between vector type 'svuint8_t' (aka '__SVUint8_t') and vector type 'svmfloat8_t' (aka '__SVMfloat8_t') as implicit conversion would cause truncation}}
 }
 
-
 #include <arm_neon.h>
 
 void test_vector(mfloat8x8_t a, mfloat8x16_t b, uint8x8_t c) {
-  a + b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a - b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a * b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a / b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-
-  a + c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a - c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a * c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a / c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  c + b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c - b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c * b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c / b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+  a + b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a - b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a * b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a / b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+
+  a + c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a - c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a * c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a / c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  c + b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c - b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c * b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c / b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
 }
diff --git a/clang/utils/TableGen/NeonEmitter.cpp b/clang/utils/TableGen/NeonEmitter.cpp
index 86b8e6b5543c76..2d143e96efe692 100644
--- a/clang/utils/TableGen/NeonEmitter.cpp
+++ b/clang/utils/TableGen/NeonEmitter.cpp
@@ -101,7 +101,8 @@ enum EltType {
   Float16,
   Float32,
   Float64,
-  BFloat16
+  BFloat16,
+  MFloat8
 };
 
 } // end namespace NeonTypeFlags
@@ -143,14 +144,7 @@ class Type {
 private:
   TypeSpec TS;
 
-  enum TypeKind {
-    Void,
-    Float,
-    SInt,
-    UInt,
-    Poly,
-    BFloat16
-  };
+  enum TypeKind { Void, Float, SInt, UInt, Poly, BFloat16, MFloat8 };
   TypeKind Kind;
   bool Immediate, Constant, Pointer;
   // ScalarForMangling and NoManglingQ are really not suited to live here as
@@ -203,6 +197,7 @@ class Type {
   bool isLong() const { return isInteger() && ElementBitwidth == 64; }
   bool isVoid() const { return Kind == Void; }
   bool isBFloat16() const { return Kind == BFloat16; }
+  bool isMFloat8() const { return Kind == MFloat8; }
   unsigned getNumElements() const { return Bitwidth / ElementBitwidth; }
   unsigned getSizeInBits() const { return Bitwidth; }
   unsigned getElementSizeInBits() const { return ElementBitwidth; }
@@ -657,6 +652,8 @@ std::string Type::str() const {
     S += "float";
   else if (isBFloat16())
     S += "bfloat";
+  else if (isMFloat8())
+    S += "mfloat";
   else
     S += "int";
 
@@ -699,6 +696,9 @@ std::string Type::builtin_str() const {
   else if (isBFloat16()) {
     assert(ElementBitwidth == 16 && "BFloat16 can only be 16 bits");
     S += "y";
+  } else if (isMFloat8()) {
+    assert(ElementBitwidth == 8 && "MFloat8 can only be 8 bits");
+    S += "m";
   } else
     switch (ElementBitwidth) {
     case 16: S += "h"; break;
@@ -758,6 +758,10 @@ unsigned Type::getNeonEnum() const {
     Base = (unsigned)NeonTypeFlags::BFloat16;
   }
 
+  if (isMFloat8()) {
+    Base = (unsigned)NeonTypeFlags::MFloat8;
+  }
+
   if (Bitwidth == 128)
     Base |= (unsigned)NeonTypeFlags::QuadFlag;
   if (isInteger() && !isSigned())
@@ -779,6 +783,8 @@ Type Type::fromTypedefName(StringRef Name) {
     T.Kind = Poly;
   } else if (Name.consume_front("bfloat")) {
     T.Kind = BFloat16;
+  } else if (Name.consume_front("mfloat")) {
+    T.Kind = MFloat8;
   } else {
     assert(Name.starts_with("int"));
     Name = Name.drop_front(3);
@@ -879,6 +885,10 @@ void Type::applyTypespec(bool &Quad) {
       Kind = BFloat16;
       ElementBitwidth = 16;
       break;
+    case 'm':
+      Kind = MFloat8;
+      ElementBitwidth = 8;
+      break;
     default:
       llvm_unreachable("Unhandled type code!");
     }
@@ -993,6 +1003,9 @@ std::string Intrinsic::getInstTypeCode(Type T, ClassKind CK) const {
   if (T.isBFloat16())
     return "bf16";
 
+  if (T.isMFloat8())
+    return "mfp8";
+
   if (T.isPoly())
     typeCode = 'p';
   else if (T.isInteger())
@@ -1030,7 +1043,7 @@ std::string Intrinsic::getBuiltinTypeStr() {
 
   Type RetT = getReturnType();
   if ((LocalCK == ClassI || LocalCK == ClassW) && RetT.isScalar() &&
-      !RetT.isFloating() && !RetT.isBFloat16())
+      !RetT.isFloating() && !RetT.isBFloat16() && !RetT.isMFloat8())
     RetT.makeInteger(RetT.getElementSizeInBits(), false);
 
   // Since the return value must be one type, return a vector type of the
@@ -2270,7 +2283,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
   for (auto &TS : TDTypeVec) {
     bool IsA64 = false;
     Type T(TS, ".");
-    if (T.isDouble())
+    if (T.isDouble() || T.isMFloat8())
       IsA64 = true;
 
     if (InIfdef && !IsA64) {
@@ -2303,7 +2316,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
     for (auto &TS : TDTypeVec) {
       bool IsA64 = false;
       Type T(TS, ".");
-      if (T.isDouble())
+      if (T.isDouble() || T.isMFloat8())
         IsA64 = true;
 
       if (InIfdef && !IsA64) {
@@ -2589,8 +2602,6 @@ void NeonEmitter::runVectorTypes(raw_ostream &OS) {
 
   OS << "#if defined(__aarch64__) || defined(__arm64ec__)\n";
   OS << "typedef __mfp8 mfloat8_t;\n";
-  OS << "typedef __MFloat8x8_t mfloat8x8_t;\n";
-  OS << "typedef __MFloat8x16_t mfloat8x16_t;\n";
   OS << "typedef double float64_t;\n";
   OS << "#endif\n\n";
 
@@ -2648,7 +2659,7 @@ __arm_set_fpm_lscale2(fpm_t __fpm, uint64_t __scale) {
 
 )";
 
-  emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlhQhfQfdQd", OS);
+  emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlmQmhQhfQfdQd", OS);
 
   emitNeonTypeDefs("bQb", OS);
   OS << "#endif // __ARM_NEON_TYPES_H\n";

>From 20aeccce6fa679422d41dbc92a1423ba4f5a4d70 Mon Sep 17 00:00:00 2001
From: Caroline Concatto <caroline.concatto at arm.com>
Date: Mon, 18 Nov 2024 13:20:17 +0000
Subject: [PATCH 2/4] Keep neon builtins

---
 clang/include/clang/AST/Type.h                |  5 ----
 .../clang/Basic/AArch64SVEACLETypes.def       |  2 ++
 clang/include/clang/Basic/TargetBuiltins.h    |  4 +--
 clang/lib/AST/ItaniumMangle.cpp               |  2 --
 clang/lib/CodeGen/CGBuiltin.cpp               |  1 -
 clang/lib/CodeGen/CodeGenTypes.cpp            |  5 ----
 clang/lib/CodeGen/Targets/AArch64.cpp         |  7 ++++-
 clang/lib/Sema/SemaARM.cpp                    |  2 --
 clang/lib/Sema/SemaExpr.cpp                   |  5 ----
 clang/lib/Sema/SemaType.cpp                   |  3 +--
 clang/test/CodeGen/arm-mfp8.c                 |  4 +--
 clang/test/Sema/arm-mfp8.cpp                  | 26 +++++++++----------
 clang/utils/TableGen/NeonEmitter.cpp          | 15 ++++++++---
 13 files changed, 36 insertions(+), 45 deletions(-)

diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 8e76d19495a6a6..8979129017163b 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2534,7 +2534,6 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFloat32Type() const;
   bool isDoubleType() const;
   bool isBFloat16Type() const;
-  bool isMFloat8Type() const;
   bool isFloat128Type() const;
   bool isIbm128Type() const;
   bool isRealType() const;         // C99 6.2.5p17 (real floating + integer)
@@ -8543,10 +8542,6 @@ inline bool Type::isBFloat16Type() const {
   return isSpecificBuiltinType(BuiltinType::BFloat16);
 }
 
-inline bool Type::isMFloat8Type() const {
-  return isSpecificBuiltinType(BuiltinType::MFloat8);
-}
-
 inline bool Type::isFloat128Type() const {
   return isSpecificBuiltinType(BuiltinType::Float128);
 }
diff --git a/clang/include/clang/Basic/AArch64SVEACLETypes.def b/clang/include/clang/Basic/AArch64SVEACLETypes.def
index 2dd2754e778d60..063cac1f4a58ee 100644
--- a/clang/include/clang/Basic/AArch64SVEACLETypes.def
+++ b/clang/include/clang/Basic/AArch64SVEACLETypes.def
@@ -201,6 +201,8 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
 SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
 
 AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
+AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
+AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
 
 #undef SVE_VECTOR_TYPE
 #undef SVE_VECTOR_TYPE_BFLOAT
diff --git a/clang/include/clang/Basic/TargetBuiltins.h b/clang/include/clang/Basic/TargetBuiltins.h
index 86432601115d44..89ebf5758a5b55 100644
--- a/clang/include/clang/Basic/TargetBuiltins.h
+++ b/clang/include/clang/Basic/TargetBuiltins.h
@@ -200,8 +200,7 @@ namespace clang {
       Float16,
       Float32,
       Float64,
-      BFloat16,
-      MFloat8
+      BFloat16
     };
 
     NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -223,7 +222,6 @@ namespace clang {
       switch (getEltType()) {
       case Int8:
       case Poly8:
-      case MFloat8:
         return 8;
       case Int16:
       case Float16:
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 2858202426f1ee..14bc260d0245fb 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3901,8 +3901,6 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
     return "Float64";
   case BuiltinType::BFloat16:
     return "Bfloat16";
-  case BuiltinType::MFloat8:
-    return "Mfloat8";
   default:
     llvm_unreachable("Unexpected vector element base type");
   }
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 278c3b17110740..e03cb6e8c3c861 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6594,7 +6594,6 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
   switch (TypeFlags.getEltType()) {
   case NeonTypeFlags::Int8:
   case NeonTypeFlags::Poly8:
-  case NeonTypeFlags::MFloat8:
     return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
   case NeonTypeFlags::Int16:
   case NeonTypeFlags::Poly16:
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index 500d4dcd2cc392..09191a4901f493 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -647,11 +647,6 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
   case Type::ExtVector:
   case Type::Vector: {
     const auto *VT = cast<VectorType>(Ty);
-    if (VT->getElementType()->isMFloat8Type()) {
-      ResultType = llvm::FixedVectorType::get(
-          llvm::Type::getInt8Ty(getLLVMContext()), VT->getNumElements());
-      break;
-    }
     // An ext_vector_type of Bool is really a vector of bits.
     llvm::Type *IRElemTy = VT->isExtVectorBoolType()
                                ? llvm::Type::getInt1Ty(getLLVMContext())
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 7b727a04740680..9320c6ef06efab 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -375,6 +375,10 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
         NSRN = std::min(NSRN + 1, 8u);
       else {
         switch (BT->getKind()) {
+        case BuiltinType::MFloat8x8:
+        case BuiltinType::MFloat8x16:
+          NSRN = std::min(NSRN + 1, 8u);
+          break;
         case BuiltinType::SveBool:
         case BuiltinType::SveCount:
           NPRN = std::min(NPRN + 1, 4u);
@@ -616,7 +620,8 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
   // but with the difference that any floating-point type is allowed,
   // including __fp16.
   if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
-    if (BT->isFloatingPoint())
+    if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
+        BT->getKind() == BuiltinType::MFloat8x8)
       return true;
   } else if (const VectorType *VT = Ty->getAs<VectorType>()) {
     if (auto Kind = VT->getVectorKind();
diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp
index 7ed0654169723b..c3a6e5ef8a9d44 100644
--- a/clang/lib/Sema/SemaARM.cpp
+++ b/clang/lib/Sema/SemaARM.cpp
@@ -323,8 +323,6 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
   switch (Flags.getEltType()) {
   case NeonTypeFlags::Int8:
     return Flags.isUnsigned() ? Context.UnsignedCharTy : Context.SignedCharTy;
-  case NeonTypeFlags::MFloat8:
-    return Context.MFloat8Ty;
   case NeonTypeFlags::Int16:
     return Flags.isUnsigned() ? Context.UnsignedShortTy : Context.ShortTy;
   case NeonTypeFlags::Int32:
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 88a13b02971693..c9de7811adfe0a 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -10199,11 +10199,6 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
     return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
                                               IsCompAssign);
 
-  // Any operation with MFloat8 type is only possible with C intrinsics
-  if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
-      (RHSVecType && RHSVecType->getElementType()->isMFloat8Type()))
-    return InvalidOperands(Loc, LHS, RHS);
-
   // AltiVec-style "vector bool op vector bool" combinations are allowed
   // for some operators but not others.
   if (!AllowBothBool && LHSVecType &&
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index c3d0b1d3161b39..4fac71ba095668 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -8240,8 +8240,7 @@ static bool isPermittedNeonBaseType(QualType &Ty, VectorKind VecKind, Sema &S) {
          BTy->getKind() == BuiltinType::ULongLong ||
          BTy->getKind() == BuiltinType::Float ||
          BTy->getKind() == BuiltinType::Half ||
-         BTy->getKind() == BuiltinType::BFloat16 ||
-         BTy->getKind() == BuiltinType::MFloat8;
+         BTy->getKind() == BuiltinType::BFloat16;
 }
 
 static bool verifyValidIntegerConstantExpr(Sema &S, const ParsedAttr &Attr,
diff --git a/clang/test/CodeGen/arm-mfp8.c b/clang/test/CodeGen/arm-mfp8.c
index 8c3cc18cc6453e..bf91066335a25c 100644
--- a/clang/test/CodeGen/arm-mfp8.c
+++ b/clang/test/CodeGen/arm-mfp8.c
@@ -15,7 +15,7 @@
 // CHECK-C-NEXT:    [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
 // CHECK-C-NEXT:    ret <16 x i8> [[TMP0]]
 //
-// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_t14__Mfloat8x16_t(
+// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_tu14__MFloat8x16_t(
 // CHECK-CXX-SAME: <16 x i8> [[V:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-CXX-NEXT:  [[ENTRY:.*:]]
 // CHECK-CXX-NEXT:    [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
@@ -35,7 +35,7 @@ mfloat8x16_t test_ret_mfloat8x16_t(mfloat8x16_t v) {
 // CHECK-C-NEXT:    [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
 // CHECK-C-NEXT:    ret <8 x i8> [[TMP0]]
 //
-// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_t13__Mfloat8x8_t(
+// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_tu13__MFloat8x8_t(
 // CHECK-CXX-SAME: <8 x i8> [[V:%.*]]) #[[ATTR0]] {
 // CHECK-CXX-NEXT:  [[ENTRY:.*:]]
 // CHECK-CXX-NEXT:    [[V_ADDR:%.*]] = alloca <8 x i8>, align 8
diff --git a/clang/test/Sema/arm-mfp8.cpp b/clang/test/Sema/arm-mfp8.cpp
index a85bf7006bb9f8..96f9938b52423c 100644
--- a/clang/test/Sema/arm-mfp8.cpp
+++ b/clang/test/Sema/arm-mfp8.cpp
@@ -47,17 +47,17 @@ void test_vector_sve(svmfloat8_t a, svuint8_t c) {
 #include <arm_neon.h>
 
 void test_vector(mfloat8x8_t a, mfloat8x16_t b, uint8x8_t c) {
-  a + b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
-  a - b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
-  a * b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
-  a / b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
-
-  a + c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a - c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a * c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a / c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  c + b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
-  c - b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
-  c * b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
-  c / b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a + b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+  a - b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+  a * b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+  a / b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+
+  a + c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a - c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a * c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a / c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  c + b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+  c - b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+  c * b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+  c / b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
 }
diff --git a/clang/utils/TableGen/NeonEmitter.cpp b/clang/utils/TableGen/NeonEmitter.cpp
index 2d143e96efe692..ddf11cda2ef6eb 100644
--- a/clang/utils/TableGen/NeonEmitter.cpp
+++ b/clang/utils/TableGen/NeonEmitter.cpp
@@ -102,7 +102,7 @@ enum EltType {
   Float32,
   Float64,
   BFloat16,
-  MFloat8
+  MFloat8 // Not used by Sema or CodeGen in Clang
 };
 
 } // end namespace NeonTypeFlags
@@ -2295,15 +2295,22 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
       InIfdef = true;
     }
 
-    if (T.isPoly())
+    if (T.isMFloat8())
+      OS << "typedef __MFloat8x";
+    else if (T.isPoly())
       OS << "typedef __attribute__((neon_polyvector_type(";
     else
       OS << "typedef __attribute__((neon_vector_type(";
 
     Type T2 = T;
     T2.makeScalar();
-    OS << T.getNumElements() << "))) ";
-    OS << T2.str();
+    OS << T.getNumElements();
+    if (T.isMFloat8())
+      OS << "_t ";
+    else {
+      OS << "))) ";
+      OS << T2.str();
+    }
     OS << " " << T.str() << ";\n";
   }
   if (InIfdef)

>From 74464306f7e5072dc74efb0d38e6e362f262661a Mon Sep 17 00:00:00 2001
From: CarolineConcatto <caroline.concatto at arm.com>
Date: Wed, 20 Nov 2024 16:49:18 +0000
Subject: [PATCH 3/4] Update arm-mfp8.cpp

---
 clang/test/Sema/arm-mfp8.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/clang/test/Sema/arm-mfp8.cpp b/clang/test/Sema/arm-mfp8.cpp
index 96f9938b52423c..be5bc9bb71dbd2 100644
--- a/clang/test/Sema/arm-mfp8.cpp
+++ b/clang/test/Sema/arm-mfp8.cpp
@@ -44,6 +44,7 @@ void test_vector_sve(svmfloat8_t a, svuint8_t c) {
   a / c;  // sve-error {{cannot convert between vector type 'svuint8_t' (aka '__SVUint8_t') and vector type 'svmfloat8_t' (aka '__SVMfloat8_t') as implicit conversion would cause truncation}}
 }
 
+
 #include <arm_neon.h>
 
 void test_vector(mfloat8x8_t a, mfloat8x16_t b, uint8x8_t c) {

>From fd26e8ef5d2533ee07a529e812a8288fa9db1349 Mon Sep 17 00:00:00 2001
From: CarolineConcatto <caroline.concatto at arm.com>
Date: Wed, 20 Nov 2024 16:51:52 +0000
Subject: [PATCH 4/4] Update NeonEmitter.cpp

---
 clang/utils/TableGen/NeonEmitter.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/clang/utils/TableGen/NeonEmitter.cpp b/clang/utils/TableGen/NeonEmitter.cpp
index ddf11cda2ef6eb..193af2ff2a2625 100644
--- a/clang/utils/TableGen/NeonEmitter.cpp
+++ b/clang/utils/TableGen/NeonEmitter.cpp
@@ -2307,10 +2307,8 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
     OS << T.getNumElements();
     if (T.isMFloat8())
       OS << "_t ";
-    else {
-      OS << "))) ";
-      OS << T2.str();
-    }
+    else
+      OS << "))) " << T2.str();
     OS << " " << T.str() << ";\n";
   }
   if (InIfdef)



More information about the cfe-commits mailing list