[llvm] ecdf48f - [ARM] Basic bfloat support

Alexandros Lamprineas via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 18 09:26:57 PDT 2020


Author: Alexandros Lamprineas
Date: 2020-06-18T17:26:24+01:00
New Revision: ecdf48f15bd2d1a73ae6ab5b46387b0ebead6e99

URL: https://github.com/llvm/llvm-project/commit/ecdf48f15bd2d1a73ae6ab5b46387b0ebead6e99
DIFF: https://github.com/llvm/llvm-project/commit/ecdf48f15bd2d1a73ae6ab5b46387b0ebead6e99.diff

LOG: [ARM] Basic bfloat support

This patch adds basic support for BFloat in the Arm backend.
For now the code generation relies on fullfp16 being present.

Briefly:
* adds the bfloat scalar and vector types in the necessary register classes,
* adjusts the calling convention to cope with bfloat argument passing and return,
* adds codegen patterns for moves, loads and stores.

It's tested mostly by the intrinsic patches that depend on it (load/store, convert/copy).

The following people contributed to this patch:

 * Alexandros Lamprineas
 * Ties Stuij

Differential Revision: https://reviews.llvm.org/D81373

Added: 
    llvm/test/CodeGen/ARM/bfloat.ll

Modified: 
    llvm/lib/Target/ARM/ARMCallingConv.cpp
    llvm/lib/Target/ARM/ARMCallingConv.td
    llvm/lib/Target/ARM/ARMISelLowering.cpp
    llvm/lib/Target/ARM/ARMInstrFormats.td
    llvm/lib/Target/ARM/ARMInstrNEON.td
    llvm/lib/Target/ARM/ARMInstrVFP.td
    llvm/lib/Target/ARM/ARMRegisterInfo.td
    llvm/lib/Target/ARM/ARMSubtarget.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/ARM/ARMCallingConv.cpp b/llvm/lib/Target/ARM/ARMCallingConv.cpp
index 9868ce4b099b..d98edc268773 100644
--- a/llvm/lib/Target/ARM/ARMCallingConv.cpp
+++ b/llvm/lib/Target/ARM/ARMCallingConv.cpp
@@ -209,14 +209,17 @@ static bool CC_ARM_AAPCS_Custom_Aggregate(unsigned ValNo, MVT ValVT,
     break;
   }
   case MVT::f16:
+  case MVT::bf16:
   case MVT::f32:
     RegList = SRegList;
     break;
   case MVT::v4f16:
+  case MVT::v4bf16:
   case MVT::f64:
     RegList = DRegList;
     break;
   case MVT::v8f16:
+  case MVT::v8bf16:
   case MVT::v2f64:
     RegList = QRegList;
     break;

diff  --git a/llvm/lib/Target/ARM/ARMCallingConv.td b/llvm/lib/Target/ARM/ARMCallingConv.td
index b7a52b0781fd..65fc9d1cf7cd 100644
--- a/llvm/lib/Target/ARM/ARMCallingConv.td
+++ b/llvm/lib/Target/ARM/ARMCallingConv.td
@@ -30,8 +30,8 @@ def CC_ARM_APCS : CallingConv<[
   CCIfSwiftError<CCIfType<[i32], CCAssignToReg<[R8]>>>,
 
   // Handle all vector types as either f64 or v2f64.
-  CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType<f64>>,
-  CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
+  CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType<f64>>,
+  CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
 
   // f64 and v2f64 are passed in adjacent GPRs, possibly split onto the stack
   CCIfType<[f64, v2f64], CCCustom<"CC_ARM_APCS_Custom_f64">>,
@@ -56,8 +56,8 @@ def RetCC_ARM_APCS : CallingConv<[
   CCIfSwiftError<CCIfType<[i32], CCAssignToReg<[R8]>>>,
 
   // Handle all vector types as either f64 or v2f64.
-  CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType<f64>>,
-  CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
+  CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType<f64>>,
+  CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
 
   CCIfType<[f64, v2f64], CCCustom<"RetCC_ARM_APCS_Custom_f64">>,
 
@@ -71,8 +71,8 @@ def RetCC_ARM_APCS : CallingConv<[
 let Entry = 1 in
 def FastCC_ARM_APCS : CallingConv<[
   // Handle all vector types as either f64 or v2f64.
-  CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType<f64>>,
-  CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
+  CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType<f64>>,
+  CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
 
   CCIfType<[v2f64], CCAssignToReg<[Q0, Q1, Q2, Q3]>>,
   CCIfType<[f64], CCAssignToReg<[D0, D1, D2, D3, D4, D5, D6, D7]>>,
@@ -91,8 +91,8 @@ def FastCC_ARM_APCS : CallingConv<[
 let Entry = 1 in
 def RetFastCC_ARM_APCS : CallingConv<[
   // Handle all vector types as either f64 or v2f64.
-  CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType<f64>>,
-  CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
+  CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType<f64>>,
+  CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
 
   CCIfType<[v2f64], CCAssignToReg<[Q0, Q1, Q2, Q3]>>,
   CCIfType<[f64], CCAssignToReg<[D0, D1, D2, D3, D4, D5, D6, D7]>>,
@@ -108,8 +108,8 @@ def RetFastCC_ARM_APCS : CallingConv<[
 let Entry = 1 in
 def CC_ARM_APCS_GHC : CallingConv<[
   // Handle all vector types as either f64 or v2f64.
-  CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType<f64>>,
-  CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
+  CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType<f64>>,
+  CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
 
   CCIfType<[v2f64], CCAssignToReg<[Q4, Q5]>>,
   CCIfType<[f64], CCAssignToReg<[D8, D9, D10, D11]>>,
@@ -139,7 +139,7 @@ def CC_ARM_AAPCS_Common : CallingConv<[
 
   CCIfType<[i32], CCIfAlign<"8", CCAssignToStackWithShadow<4, 8, [R0, R1, R2, R3]>>>,
   CCIfType<[i32], CCAssignToStackWithShadow<4, 4, [R0, R1, R2, R3]>>,
-  CCIfType<[f16, f32], CCAssignToStackWithShadow<4, 4, [Q0, Q1, Q2, Q3]>>,
+  CCIfType<[f16, bf16, f32], CCAssignToStackWithShadow<4, 4, [Q0, Q1, Q2, Q3]>>,
   CCIfType<[f64], CCAssignToStackWithShadow<8, 8, [Q0, Q1, Q2, Q3]>>,
   CCIfType<[v2f64], CCIfAlign<"16",
            CCAssignToStackWithShadow<16, 16, [Q0, Q1, Q2, Q3]>>>,
@@ -165,8 +165,8 @@ def CC_ARM_AAPCS : CallingConv<[
   CCIfNest<CCAssignToReg<[R12]>>,
 
   // Handle all vector types as either f64 or v2f64.
-  CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType<f64>>,
-  CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
+  CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType<f64>>,
+  CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
 
   // Pass SwiftSelf in a callee saved register.
   CCIfSwiftSelf<CCIfType<[i32], CCAssignToReg<[R10]>>>,
@@ -176,15 +176,15 @@ def CC_ARM_AAPCS : CallingConv<[
 
   CCIfType<[f64, v2f64], CCCustom<"CC_ARM_AAPCS_Custom_f64">>,
   CCIfType<[f32], CCBitConvertToType<i32>>,
-  CCIfType<[f16], CCCustom<"CC_ARM_AAPCS_Custom_f16">>,
+  CCIfType<[f16, bf16], CCCustom<"CC_ARM_AAPCS_Custom_f16">>,
   CCDelegateTo<CC_ARM_AAPCS_Common>
 ]>;
 
 let Entry = 1 in
 def RetCC_ARM_AAPCS : CallingConv<[
   // Handle all vector types as either f64 or v2f64.
-  CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType<f64>>,
-  CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
+  CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType<f64>>,
+  CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
 
   // Pass SwiftSelf in a callee saved register.
   CCIfSwiftSelf<CCIfType<[i32], CCAssignToReg<[R10]>>>,
@@ -194,7 +194,7 @@ def RetCC_ARM_AAPCS : CallingConv<[
 
   CCIfType<[f64, v2f64], CCCustom<"RetCC_ARM_AAPCS_Custom_f64">>,
   CCIfType<[f32], CCBitConvertToType<i32>>,
-  CCIfType<[f16], CCCustom<"CC_ARM_AAPCS_Custom_f16">>,
+  CCIfType<[f16, bf16], CCCustom<"CC_ARM_AAPCS_Custom_f16">>,
 
   CCDelegateTo<RetCC_ARM_AAPCS_Common>
 ]>;
@@ -210,8 +210,8 @@ def CC_ARM_AAPCS_VFP : CallingConv<[
   CCIfByVal<CCPassByVal<4, 4>>,
 
   // Handle all vector types as either f64 or v2f64.
-  CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType<f64>>,
-  CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
+  CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType<f64>>,
+  CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
 
   // Pass SwiftSelf in a callee saved register.
   CCIfSwiftSelf<CCIfType<[i32], CCAssignToReg<[R10]>>>,
@@ -226,15 +226,15 @@ def CC_ARM_AAPCS_VFP : CallingConv<[
   CCIfType<[f64], CCAssignToReg<[D0, D1, D2, D3, D4, D5, D6, D7]>>,
   CCIfType<[f32], CCAssignToReg<[S0, S1, S2, S3, S4, S5, S6, S7, S8,
                                  S9, S10, S11, S12, S13, S14, S15]>>,
-  CCIfType<[f16], CCCustom<"CC_ARM_AAPCS_VFP_Custom_f16">>,
+  CCIfType<[f16, bf16], CCCustom<"CC_ARM_AAPCS_VFP_Custom_f16">>,
   CCDelegateTo<CC_ARM_AAPCS_Common>
 ]>;
 
 let Entry = 1 in
 def RetCC_ARM_AAPCS_VFP : CallingConv<[
   // Handle all vector types as either f64 or v2f64.
-  CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType<f64>>,
-  CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
+  CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType<f64>>,
+  CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType<v2f64>>,
 
   // Pass SwiftSelf in a callee saved register.
   CCIfSwiftSelf<CCIfType<[i32], CCAssignToReg<[R10]>>>,
@@ -246,7 +246,7 @@ def RetCC_ARM_AAPCS_VFP : CallingConv<[
   CCIfType<[f64], CCAssignToReg<[D0, D1, D2, D3, D4, D5, D6, D7]>>,
   CCIfType<[f32], CCAssignToReg<[S0, S1, S2, S3, S4, S5, S6, S7, S8,
                                  S9, S10, S11, S12, S13, S14, S15]>>,
-  CCIfType<[f16], CCCustom<"CC_ARM_AAPCS_VFP_Custom_f16">>,
+  CCIfType<[f16, bf16], CCCustom<"CC_ARM_AAPCS_VFP_Custom_f16">>,
   CCDelegateTo<RetCC_ARM_AAPCS_Common>
 ]>;
 

diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 95132ec253a7..a646f63e0839 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -721,6 +721,10 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
 
     setOperationAction(ISD::FMINNUM, MVT::f16, Legal);
     setOperationAction(ISD::FMAXNUM, MVT::f16, Legal);
+
+    // For the time being bfloat is only supported when fullfp16 is present.
+    if (Subtarget->hasBF16())
+      addRegisterClass(MVT::bf16, &ARM::HPRRegClass);
   }
 
   for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
@@ -770,6 +774,11 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
       addQRTypeForNEON(MVT::v8f16);
       addDRTypeForNEON(MVT::v4f16);
     }
+
+    if (Subtarget->hasBF16()) {
+      addQRTypeForNEON(MVT::v8bf16);
+      addDRTypeForNEON(MVT::v4bf16);
+    }
   }
 
   if (Subtarget->hasMVEIntegerOps() || Subtarget->hasNEON()) {
@@ -2077,9 +2086,10 @@ SDValue ARMTargetLowering::LowerCallResult(
     // f16 arguments have their size extended to 4 bytes and passed as if they
     // had been copied to the LSBs of a 32-bit register.
     // For that, it's passed extended to i32 (soft ABI) or to f32 (hard ABI)
-    if (VA.needsCustom() && VA.getValVT() == MVT::f16) {
+    if (VA.needsCustom() &&
+        (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) {
       assert(Subtarget->hasFullFP16() &&
-             "Lowering f16 type return without full fp16 support");
+             "Lowering half precision fp return without full fp16 support");
       Val = DAG.getNode(ISD::BITCAST, dl,
                         MVT::getIntegerVT(VA.getLocVT().getSizeInBits()), Val);
       Val = DAG.getNode(ARMISD::VMOVhr, dl, VA.getValVT(), Val);
@@ -2256,9 +2266,10 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // f16 arguments have their size extended to 4 bytes and passed as if they
     // had been copied to the LSBs of a 32-bit register.
     // For that, it's passed extended to i32 (soft ABI) or to f32 (hard ABI)
-    if (VA.needsCustom() && VA.getValVT() == MVT::f16) {
+    if (VA.needsCustom() &&
+        (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) {
       assert(Subtarget->hasFullFP16() &&
-             "Lowering f16 type argument without full fp16 support");
+             "Lowering half precision fp argument without full fp16 support");
       Arg = DAG.getNode(ARMISD::VMOVrh, dl,
                         MVT::getIntegerVT(VA.getLocVT().getSizeInBits()), Arg);
       Arg = DAG.getNode(ISD::BITCAST, dl, VA.getLocVT(), Arg);
@@ -3005,8 +3016,8 @@ ARMTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     // Guarantee that all emitted copies are
     // stuck together, avoiding something bad.
     Flag = Chain.getValue(1);
-    RetOps.push_back(DAG.getRegister(VA.getLocReg(),
-                                     ReturnF16 ? MVT::f16 : VA.getLocVT()));
+    RetOps.push_back(DAG.getRegister(
+        VA.getLocReg(), ReturnF16 ? Arg.getValueType() : VA.getLocVT()));
   }
   const ARMBaseRegisterInfo *TRI = Subtarget->getRegisterInfo();
   const MCPhysReg *I =
@@ -4139,7 +4150,8 @@ bool ARMTargetLowering::splitValueIntoRegisterParts(
     unsigned NumParts, MVT PartVT, Optional<CallingConv::ID> CC) const {
   bool IsABIRegCopy = CC.hasValue();
   EVT ValueVT = Val.getValueType();
-  if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
+  if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
+      PartVT == MVT::f32) {
     unsigned ValueBits = ValueVT.getSizeInBits();
     unsigned PartBits = PartVT.getSizeInBits();
     Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(ValueBits), Val);
@@ -4155,7 +4167,8 @@ SDValue ARMTargetLowering::joinRegisterPartsIntoValue(
     SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
     MVT PartVT, EVT ValueVT, Optional<CallingConv::ID> CC) const {
   bool IsABIRegCopy = CC.hasValue();
-  if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
+  if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
+      PartVT == MVT::f32) {
     unsigned ValueBits = ValueVT.getSizeInBits();
     unsigned PartBits = PartVT.getSizeInBits();
     SDValue Val = Parts[0];
@@ -4266,14 +4279,15 @@ SDValue ARMTargetLowering::LowerFormalArguments(
       } else {
         const TargetRegisterClass *RC;
 
-
-        if (RegVT == MVT::f16)
+        if (RegVT == MVT::f16 || RegVT == MVT::bf16)
           RC = &ARM::HPRRegClass;
         else if (RegVT == MVT::f32)
           RC = &ARM::SPRRegClass;
-        else if (RegVT == MVT::f64 || RegVT == MVT::v4f16)
+        else if (RegVT == MVT::f64 || RegVT == MVT::v4f16 ||
+                 RegVT == MVT::v4bf16)
           RC = &ARM::DPRRegClass;
-        else if (RegVT == MVT::v2f64 || RegVT == MVT::v8f16)
+        else if (RegVT == MVT::v2f64 || RegVT == MVT::v8f16 ||
+                 RegVT == MVT::v8bf16)
           RC = &ARM::QPRRegClass;
         else if (RegVT == MVT::i32)
           RC = AFI->isThumb1OnlyFunction() ? &ARM::tGPRRegClass
@@ -4316,9 +4330,10 @@ SDValue ARMTargetLowering::LowerFormalArguments(
       // f16 arguments have their size extended to 4 bytes and passed as if they
       // had been copied to the LSBs of a 32-bit register.
       // For that, it's passed extended to i32 (soft ABI) or to f32 (hard ABI)
-      if (VA.needsCustom() && VA.getValVT() == MVT::f16) {
+      if (VA.needsCustom() &&
+          (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) {
         assert(Subtarget->hasFullFP16() &&
-               "Lowering f16 type argument without full fp16 support");
+               "Lowering half precision fp argument without full fp16 support");
         ArgValue = DAG.getNode(ISD::BITCAST, dl,
                                MVT::getIntegerVT(VA.getLocVT().getSizeInBits()),
                                ArgValue);
@@ -5914,18 +5929,18 @@ static SDValue ExpandBITCAST(SDNode *N, SelectionDAG &DAG,
   EVT SrcVT = Op.getValueType();
   EVT DstVT = N->getValueType(0);
 
-  if (SrcVT == MVT::i16 && DstVT == MVT::f16) {
+  if (SrcVT == MVT::i16 && (DstVT == MVT::f16 || DstVT == MVT::bf16)) {
     if (!Subtarget->hasFullFP16())
       return SDValue();
-    // f16 bitcast i16 -> VMOVhr
-    return DAG.getNode(ARMISD::VMOVhr, SDLoc(N), MVT::f16,
+    // (b)f16 bitcast i16 -> VMOVhr
+    return DAG.getNode(ARMISD::VMOVhr, SDLoc(N), DstVT,
                        DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), MVT::i32, Op));
   }
 
-  if (SrcVT == MVT::f16 && DstVT == MVT::i16) {
+  if ((SrcVT == MVT::f16 || SrcVT == MVT::bf16) && DstVT == MVT::i16) {
     if (!Subtarget->hasFullFP16())
       return SDValue();
-    // i16 bitcast f16 -> VMOVrh
+    // i16 bitcast (b)f16 -> VMOVrh
     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), MVT::i16,
                        DAG.getNode(ARMISD::VMOVrh, SDLoc(N), MVT::i32, Op));
   }
@@ -13196,7 +13211,7 @@ static SDValue PerformVMOVhrCombine(SDNode *N, TargetLowering::DAGCombinerInfo &
         Copy->getOpcode() == ISD::CopyFromReg) {
       SDValue Ops[] = {Copy->getOperand(0), Copy->getOperand(1)};
       SDValue NewCopy =
-          DCI.DAG.getNode(ISD::CopyFromReg, SDLoc(N), MVT::f16, Ops);
+          DCI.DAG.getNode(ISD::CopyFromReg, SDLoc(N), N->getValueType(0), Ops);
       return NewCopy;
     }
   }
@@ -13205,8 +13220,9 @@ static SDValue PerformVMOVhrCombine(SDNode *N, TargetLowering::DAGCombinerInfo &
   if (LoadSDNode *LN0 = dyn_cast<LoadSDNode>(Op0)) {
     if (LN0->hasOneUse() && LN0->isUnindexed() &&
         LN0->getMemoryVT() == MVT::i16) {
-      SDValue Load = DCI.DAG.getLoad(MVT::f16, SDLoc(N), LN0->getChain(),
-                                     LN0->getBasePtr(), LN0->getMemOperand());
+      SDValue Load =
+          DCI.DAG.getLoad(N->getValueType(0), SDLoc(N), LN0->getChain(),
+                          LN0->getBasePtr(), LN0->getMemOperand());
       DCI.DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Load.getValue(0));
       DCI.DAG.ReplaceAllUsesOfValueWith(Op0.getValue(1), Load.getValue(1));
       return Load;

diff  --git a/llvm/lib/Target/ARM/ARMInstrFormats.td b/llvm/lib/Target/ARM/ARMInstrFormats.td
index 2cd4e052a440..e13f3437cc7b 100644
--- a/llvm/lib/Target/ARM/ARMInstrFormats.td
+++ b/llvm/lib/Target/ARM/ARMInstrFormats.td
@@ -1128,6 +1128,9 @@ class Thumb2DSPPat<dag pattern, dag result> : Pat<pattern, result> {
 class Thumb2DSPMulPat<dag pattern, dag result> : Pat<pattern, result> {
   list<Predicate> Predicates = [IsThumb2, UseMulOps, HasDSP];
 }
+class FPRegs16Pat<dag pattern, dag result> : Pat<pattern, result> {
+  list<Predicate> Predicates = [HasFPRegs16];
+}
 class FP16Pat<dag pattern, dag result> : Pat<pattern, result> {
   list<Predicate> Predicates = [HasFP16];
 }

diff  --git a/llvm/lib/Target/ARM/ARMInstrNEON.td b/llvm/lib/Target/ARM/ARMInstrNEON.td
index 5dd82f637798..80743210e3a5 100644
--- a/llvm/lib/Target/ARM/ARMInstrNEON.td
+++ b/llvm/lib/Target/ARM/ARMInstrNEON.td
@@ -7395,6 +7395,9 @@ def : Pat<(v2i32 (bitconvert (v2f32 DPR:$src))), (v2i32 DPR:$src)>;
 def : Pat<(v4i16 (bitconvert (v4f16 DPR:$src))), (v4i16  DPR:$src)>;
 def : Pat<(v4f16 (bitconvert (v4i16 DPR:$src))), (v4f16  DPR:$src)>;
 
+def : Pat<(v4i16 (bitconvert (v4bf16 DPR:$src))), (v4i16  DPR:$src)>;
+def : Pat<(v4bf16 (bitconvert (v4i16 DPR:$src))), (v4bf16  DPR:$src)>;
+
 // 128 bit conversions
 def : Pat<(v2f64 (bitconvert (v2i64 QPR:$src))), (v2f64 QPR:$src)>;
 def : Pat<(v2i64 (bitconvert (v2f64 QPR:$src))), (v2i64 QPR:$src)>;
@@ -7404,6 +7407,9 @@ def : Pat<(v4f32 (bitconvert (v4i32 QPR:$src))), (v4f32 QPR:$src)>;
 
 def : Pat<(v8i16 (bitconvert (v8f16 QPR:$src))), (v8i16  QPR:$src)>;
 def : Pat<(v8f16 (bitconvert (v8i16 QPR:$src))), (v8f16  QPR:$src)>;
+
+def : Pat<(v8i16 (bitconvert (v8bf16 QPR:$src))), (v8i16  QPR:$src)>;
+def : Pat<(v8bf16 (bitconvert (v8i16 QPR:$src))), (v8bf16  QPR:$src)>;
 }
 
 let Predicates = [IsLE,HasNEON] in {
@@ -7411,24 +7417,28 @@ let Predicates = [IsLE,HasNEON] in {
   def : Pat<(f64   (bitconvert (v2f32 DPR:$src))), (f64   DPR:$src)>;
   def : Pat<(f64   (bitconvert (v2i32 DPR:$src))), (f64   DPR:$src)>;
   def : Pat<(f64   (bitconvert (v4f16 DPR:$src))), (f64   DPR:$src)>;
+  def : Pat<(f64   (bitconvert (v4bf16 DPR:$src))), (f64   DPR:$src)>;
   def : Pat<(f64   (bitconvert (v4i16 DPR:$src))), (f64   DPR:$src)>;
   def : Pat<(f64   (bitconvert (v8i8  DPR:$src))), (f64   DPR:$src)>;
 
   def : Pat<(v1i64 (bitconvert (v2f32 DPR:$src))), (v1i64 DPR:$src)>;
   def : Pat<(v1i64 (bitconvert (v2i32 DPR:$src))), (v1i64 DPR:$src)>;
   def : Pat<(v1i64 (bitconvert (v4f16 DPR:$src))), (v1i64 DPR:$src)>;
+  def : Pat<(v1i64 (bitconvert (v4bf16 DPR:$src))), (v1i64 DPR:$src)>;
   def : Pat<(v1i64 (bitconvert (v4i16 DPR:$src))), (v1i64 DPR:$src)>;
   def : Pat<(v1i64 (bitconvert (v8i8  DPR:$src))), (v1i64 DPR:$src)>;
 
   def : Pat<(v2f32 (bitconvert (f64   DPR:$src))), (v2f32 DPR:$src)>;
   def : Pat<(v2f32 (bitconvert (v1i64 DPR:$src))), (v2f32 DPR:$src)>;
   def : Pat<(v2f32 (bitconvert (v4f16 DPR:$src))), (v2f32 DPR:$src)>;
+  def : Pat<(v2f32 (bitconvert (v4bf16 DPR:$src))), (v2f32 DPR:$src)>;
   def : Pat<(v2f32 (bitconvert (v4i16 DPR:$src))), (v2f32 DPR:$src)>;
   def : Pat<(v2f32 (bitconvert (v8i8  DPR:$src))), (v2f32 DPR:$src)>;
 
   def : Pat<(v2i32 (bitconvert (f64   DPR:$src))), (v2i32 DPR:$src)>;
   def : Pat<(v2i32 (bitconvert (v1i64 DPR:$src))), (v2i32 DPR:$src)>;
   def : Pat<(v2i32 (bitconvert (v4f16 DPR:$src))), (v2i32 DPR:$src)>;
+  def : Pat<(v2i32 (bitconvert (v4bf16 DPR:$src))), (v2i32 DPR:$src)>;
   def : Pat<(v2i32 (bitconvert (v4i16 DPR:$src))), (v2i32 DPR:$src)>;
   def : Pat<(v2i32 (bitconvert (v8i8  DPR:$src))), (v2i32 DPR:$src)>;
 
@@ -7438,6 +7448,12 @@ let Predicates = [IsLE,HasNEON] in {
   def : Pat<(v4f16 (bitconvert (v2i32 DPR:$src))), (v4f16 DPR:$src)>;
   def : Pat<(v4f16 (bitconvert (v8i8  DPR:$src))), (v4f16 DPR:$src)>;
 
+  def : Pat<(v4bf16 (bitconvert (f64   DPR:$src))), (v4bf16 DPR:$src)>;
+  def : Pat<(v4bf16 (bitconvert (v1i64 DPR:$src))), (v4bf16 DPR:$src)>;
+  def : Pat<(v4bf16 (bitconvert (v2f32 DPR:$src))), (v4bf16 DPR:$src)>;
+  def : Pat<(v4bf16 (bitconvert (v2i32 DPR:$src))), (v4bf16 DPR:$src)>;
+  def : Pat<(v4bf16 (bitconvert (v8i8  DPR:$src))), (v4bf16 DPR:$src)>;
+
   def : Pat<(v4i16 (bitconvert (f64   DPR:$src))), (v4i16 DPR:$src)>;
   def : Pat<(v4i16 (bitconvert (v1i64 DPR:$src))), (v4i16 DPR:$src)>;
   def : Pat<(v4i16 (bitconvert (v2f32 DPR:$src))), (v4i16 DPR:$src)>;
@@ -7449,30 +7465,35 @@ let Predicates = [IsLE,HasNEON] in {
   def : Pat<(v8i8  (bitconvert (v2f32 DPR:$src))), (v8i8  DPR:$src)>;
   def : Pat<(v8i8  (bitconvert (v2i32 DPR:$src))), (v8i8  DPR:$src)>;
   def : Pat<(v8i8  (bitconvert (v4f16 DPR:$src))), (v8i8  DPR:$src)>;
+  def : Pat<(v8i8  (bitconvert (v4bf16 DPR:$src))), (v8i8  DPR:$src)>;
   def : Pat<(v8i8  (bitconvert (v4i16 DPR:$src))), (v8i8  DPR:$src)>;
 
   // 128 bit conversions
   def : Pat<(v2f64 (bitconvert (v4f32 QPR:$src))), (v2f64 QPR:$src)>;
   def : Pat<(v2f64 (bitconvert (v4i32 QPR:$src))), (v2f64 QPR:$src)>;
   def : Pat<(v2f64 (bitconvert (v8f16 QPR:$src))), (v2f64 QPR:$src)>;
+  def : Pat<(v2f64 (bitconvert (v8bf16 QPR:$src))), (v2f64 QPR:$src)>;
   def : Pat<(v2f64 (bitconvert (v8i16 QPR:$src))), (v2f64 QPR:$src)>;
   def : Pat<(v2f64 (bitconvert (v16i8 QPR:$src))), (v2f64 QPR:$src)>;
 
   def : Pat<(v2i64 (bitconvert (v4f32 QPR:$src))), (v2i64 QPR:$src)>;
   def : Pat<(v2i64 (bitconvert (v4i32 QPR:$src))), (v2i64 QPR:$src)>;
   def : Pat<(v2i64 (bitconvert (v8f16 QPR:$src))), (v2i64 QPR:$src)>;
+  def : Pat<(v2i64 (bitconvert (v8bf16 QPR:$src))), (v2i64 QPR:$src)>;
   def : Pat<(v2i64 (bitconvert (v8i16 QPR:$src))), (v2i64 QPR:$src)>;
   def : Pat<(v2i64 (bitconvert (v16i8 QPR:$src))), (v2i64 QPR:$src)>;
 
   def : Pat<(v4f32 (bitconvert (v2f64 QPR:$src))), (v4f32 QPR:$src)>;
   def : Pat<(v4f32 (bitconvert (v2i64 QPR:$src))), (v4f32 QPR:$src)>;
   def : Pat<(v4f32 (bitconvert (v8f16 QPR:$src))), (v4f32 QPR:$src)>;
+  def : Pat<(v4f32 (bitconvert (v8bf16 QPR:$src))), (v4f32 QPR:$src)>;
   def : Pat<(v4f32 (bitconvert (v8i16 QPR:$src))), (v4f32 QPR:$src)>;
   def : Pat<(v4f32 (bitconvert (v16i8 QPR:$src))), (v4f32 QPR:$src)>;
 
   def : Pat<(v4i32 (bitconvert (v2f64 QPR:$src))), (v4i32 QPR:$src)>;
   def : Pat<(v4i32 (bitconvert (v2i64 QPR:$src))), (v4i32 QPR:$src)>;
   def : Pat<(v4i32 (bitconvert (v8f16 QPR:$src))), (v4i32 QPR:$src)>;
+  def : Pat<(v4i32 (bitconvert (v8bf16 QPR:$src))), (v4i32 QPR:$src)>;
   def : Pat<(v4i32 (bitconvert (v8i16 QPR:$src))), (v4i32 QPR:$src)>;
   def : Pat<(v4i32 (bitconvert (v16i8 QPR:$src))), (v4i32 QPR:$src)>;
 
@@ -7482,6 +7503,12 @@ let Predicates = [IsLE,HasNEON] in {
   def : Pat<(v8f16 (bitconvert (v4i32 QPR:$src))), (v8f16 QPR:$src)>;
   def : Pat<(v8f16 (bitconvert (v16i8 QPR:$src))), (v8f16 QPR:$src)>;
 
+  def : Pat<(v8bf16 (bitconvert (v2f64 QPR:$src))), (v8bf16 QPR:$src)>;
+  def : Pat<(v8bf16 (bitconvert (v2i64 QPR:$src))), (v8bf16 QPR:$src)>;
+  def : Pat<(v8bf16 (bitconvert (v4f32 QPR:$src))), (v8bf16 QPR:$src)>;
+  def : Pat<(v8bf16 (bitconvert (v4i32 QPR:$src))), (v8bf16 QPR:$src)>;
+  def : Pat<(v8bf16 (bitconvert (v16i8 QPR:$src))), (v8bf16 QPR:$src)>;
+
   def : Pat<(v8i16 (bitconvert (v2f64 QPR:$src))), (v8i16 QPR:$src)>;
   def : Pat<(v8i16 (bitconvert (v2i64 QPR:$src))), (v8i16 QPR:$src)>;
   def : Pat<(v8i16 (bitconvert (v4f32 QPR:$src))), (v8i16 QPR:$src)>;
@@ -7493,6 +7520,7 @@ let Predicates = [IsLE,HasNEON] in {
   def : Pat<(v16i8 (bitconvert (v4f32 QPR:$src))), (v16i8 QPR:$src)>;
   def : Pat<(v16i8 (bitconvert (v4i32 QPR:$src))), (v16i8 QPR:$src)>;
   def : Pat<(v16i8 (bitconvert (v8f16 QPR:$src))), (v16i8 QPR:$src)>;
+  def : Pat<(v16i8 (bitconvert (v8bf16 QPR:$src))), (v16i8 QPR:$src)>;
   def : Pat<(v16i8 (bitconvert (v8i16 QPR:$src))), (v16i8 QPR:$src)>;
 }
 
@@ -7501,24 +7529,28 @@ let Predicates = [IsBE,HasNEON] in {
   def : Pat<(f64   (bitconvert (v2f32 DPR:$src))), (VREV64d32 DPR:$src)>;
   def : Pat<(f64   (bitconvert (v2i32 DPR:$src))), (VREV64d32 DPR:$src)>;
   def : Pat<(f64   (bitconvert (v4f16 DPR:$src))), (VREV64d16 DPR:$src)>;
+  def : Pat<(f64   (bitconvert (v4bf16 DPR:$src))), (VREV64d16 DPR:$src)>;
   def : Pat<(f64   (bitconvert (v4i16 DPR:$src))), (VREV64d16 DPR:$src)>;
   def : Pat<(f64   (bitconvert (v8i8  DPR:$src))), (VREV64d8  DPR:$src)>;
 
   def : Pat<(v1i64 (bitconvert (v2f32 DPR:$src))), (VREV64d32 DPR:$src)>;
   def : Pat<(v1i64 (bitconvert (v2i32 DPR:$src))), (VREV64d32 DPR:$src)>;
   def : Pat<(v1i64 (bitconvert (v4f16 DPR:$src))), (VREV64d16 DPR:$src)>;
+  def : Pat<(v1i64 (bitconvert (v4bf16 DPR:$src))), (VREV64d16 DPR:$src)>;
   def : Pat<(v1i64 (bitconvert (v4i16 DPR:$src))), (VREV64d16 DPR:$src)>;
   def : Pat<(v1i64 (bitconvert (v8i8  DPR:$src))), (VREV64d8  DPR:$src)>;
 
   def : Pat<(v2f32 (bitconvert (f64   DPR:$src))), (VREV64d32 DPR:$src)>;
   def : Pat<(v2f32 (bitconvert (v1i64 DPR:$src))), (VREV64d32 DPR:$src)>;
   def : Pat<(v2f32 (bitconvert (v4f16 DPR:$src))), (VREV32d16 DPR:$src)>;
+  def : Pat<(v2f32 (bitconvert (v4bf16 DPR:$src))), (VREV32d16 DPR:$src)>;
   def : Pat<(v2f32 (bitconvert (v4i16 DPR:$src))), (VREV32d16 DPR:$src)>;
   def : Pat<(v2f32 (bitconvert (v8i8  DPR:$src))), (VREV32d8  DPR:$src)>;
 
   def : Pat<(v2i32 (bitconvert (f64   DPR:$src))), (VREV64d32 DPR:$src)>;
   def : Pat<(v2i32 (bitconvert (v1i64 DPR:$src))), (VREV64d32 DPR:$src)>;
   def : Pat<(v2i32 (bitconvert (v4f16 DPR:$src))), (VREV32d16 DPR:$src)>;
+  def : Pat<(v2i32 (bitconvert (v4bf16 DPR:$src))), (VREV32d16 DPR:$src)>;
   def : Pat<(v2i32 (bitconvert (v4i16 DPR:$src))), (VREV32d16 DPR:$src)>;
   def : Pat<(v2i32 (bitconvert (v8i8  DPR:$src))), (VREV32d8  DPR:$src)>;
 
@@ -7528,6 +7560,12 @@ let Predicates = [IsBE,HasNEON] in {
   def : Pat<(v4f16 (bitconvert (v2i32 DPR:$src))), (VREV32d16 DPR:$src)>;
   def : Pat<(v4f16 (bitconvert (v8i8  DPR:$src))), (VREV16d8  DPR:$src)>;
 
+  def : Pat<(v4bf16 (bitconvert (f64   DPR:$src))), (VREV64d16 DPR:$src)>;
+  def : Pat<(v4bf16 (bitconvert (v1i64 DPR:$src))), (VREV64d16 DPR:$src)>;
+  def : Pat<(v4bf16 (bitconvert (v2f32 DPR:$src))), (VREV32d16 DPR:$src)>;
+  def : Pat<(v4bf16 (bitconvert (v2i32 DPR:$src))), (VREV32d16 DPR:$src)>;
+  def : Pat<(v4bf16 (bitconvert (v8i8  DPR:$src))), (VREV16d8  DPR:$src)>;
+
   def : Pat<(v4i16 (bitconvert (f64   DPR:$src))), (VREV64d16 DPR:$src)>;
   def : Pat<(v4i16 (bitconvert (v1i64 DPR:$src))), (VREV64d16 DPR:$src)>;
   def : Pat<(v4i16 (bitconvert (v2f32 DPR:$src))), (VREV32d16 DPR:$src)>;
@@ -7539,30 +7577,35 @@ let Predicates = [IsBE,HasNEON] in {
   def : Pat<(v8i8  (bitconvert (v2f32 DPR:$src))), (VREV32d8  DPR:$src)>;
   def : Pat<(v8i8  (bitconvert (v2i32 DPR:$src))), (VREV32d8  DPR:$src)>;
   def : Pat<(v8i8  (bitconvert (v4f16 DPR:$src))), (VREV16d8  DPR:$src)>;
+  def : Pat<(v8i8  (bitconvert (v4bf16 DPR:$src))), (VREV16d8  DPR:$src)>;
   def : Pat<(v8i8  (bitconvert (v4i16 DPR:$src))), (VREV16d8  DPR:$src)>;
 
   // 128 bit conversions
   def : Pat<(v2f64 (bitconvert (v4f32 QPR:$src))), (VREV64q32 QPR:$src)>;
   def : Pat<(v2f64 (bitconvert (v4i32 QPR:$src))), (VREV64q32 QPR:$src)>;
   def : Pat<(v2f64 (bitconvert (v8f16 QPR:$src))), (VREV64q16 QPR:$src)>;
+  def : Pat<(v2f64 (bitconvert (v8bf16 QPR:$src))), (VREV64q16 QPR:$src)>;
   def : Pat<(v2f64 (bitconvert (v8i16 QPR:$src))), (VREV64q16 QPR:$src)>;
   def : Pat<(v2f64 (bitconvert (v16i8 QPR:$src))), (VREV64q8  QPR:$src)>;
 
   def : Pat<(v2i64 (bitconvert (v4f32 QPR:$src))), (VREV64q32 QPR:$src)>;
   def : Pat<(v2i64 (bitconvert (v4i32 QPR:$src))), (VREV64q32 QPR:$src)>;
   def : Pat<(v2i64 (bitconvert (v8f16 QPR:$src))), (VREV64q16 QPR:$src)>;
+  def : Pat<(v2i64 (bitconvert (v8bf16 QPR:$src))), (VREV64q16 QPR:$src)>;
   def : Pat<(v2i64 (bitconvert (v8i16 QPR:$src))), (VREV64q16 QPR:$src)>;
   def : Pat<(v2i64 (bitconvert (v16i8 QPR:$src))), (VREV64q8  QPR:$src)>;
 
   def : Pat<(v4f32 (bitconvert (v2f64 QPR:$src))), (VREV64q32 QPR:$src)>;
   def : Pat<(v4f32 (bitconvert (v2i64 QPR:$src))), (VREV64q32 QPR:$src)>;
   def : Pat<(v4f32 (bitconvert (v8f16 QPR:$src))), (VREV32q16 QPR:$src)>;
+  def : Pat<(v4f32 (bitconvert (v8bf16 QPR:$src))), (VREV32q16 QPR:$src)>;
   def : Pat<(v4f32 (bitconvert (v8i16 QPR:$src))), (VREV32q16 QPR:$src)>;
   def : Pat<(v4f32 (bitconvert (v16i8 QPR:$src))), (VREV32q8  QPR:$src)>;
 
   def : Pat<(v4i32 (bitconvert (v2f64 QPR:$src))), (VREV64q32 QPR:$src)>;
   def : Pat<(v4i32 (bitconvert (v2i64 QPR:$src))), (VREV64q32 QPR:$src)>;
   def : Pat<(v4i32 (bitconvert (v8f16 QPR:$src))), (VREV32q16 QPR:$src)>;
+  def : Pat<(v4i32 (bitconvert (v8bf16 QPR:$src))), (VREV32q16 QPR:$src)>;
   def : Pat<(v4i32 (bitconvert (v8i16 QPR:$src))), (VREV32q16 QPR:$src)>;
   def : Pat<(v4i32 (bitconvert (v16i8 QPR:$src))), (VREV32q8  QPR:$src)>;
 
@@ -7572,6 +7615,12 @@ let Predicates = [IsBE,HasNEON] in {
   def : Pat<(v8f16 (bitconvert (v4i32 QPR:$src))), (VREV32q16 QPR:$src)>;
   def : Pat<(v8f16 (bitconvert (v16i8 QPR:$src))), (VREV16q8  QPR:$src)>;
 
+  def : Pat<(v8bf16 (bitconvert (v2f64 QPR:$src))), (VREV64q16 QPR:$src)>;
+  def : Pat<(v8bf16 (bitconvert (v2i64 QPR:$src))), (VREV64q16 QPR:$src)>;
+  def : Pat<(v8bf16 (bitconvert (v4f32 QPR:$src))), (VREV32q16 QPR:$src)>;
+  def : Pat<(v8bf16 (bitconvert (v4i32 QPR:$src))), (VREV32q16 QPR:$src)>;
+  def : Pat<(v8bf16 (bitconvert (v16i8 QPR:$src))), (VREV16q8  QPR:$src)>;
+
   def : Pat<(v8i16 (bitconvert (v2f64 QPR:$src))), (VREV64q16 QPR:$src)>;
   def : Pat<(v8i16 (bitconvert (v2i64 QPR:$src))), (VREV64q16 QPR:$src)>;
   def : Pat<(v8i16 (bitconvert (v4f32 QPR:$src))), (VREV32q16 QPR:$src)>;
@@ -7583,6 +7632,7 @@ let Predicates = [IsBE,HasNEON] in {
   def : Pat<(v16i8 (bitconvert (v4f32 QPR:$src))), (VREV32q8  QPR:$src)>;
   def : Pat<(v16i8 (bitconvert (v4i32 QPR:$src))), (VREV32q8  QPR:$src)>;
   def : Pat<(v16i8 (bitconvert (v8f16 QPR:$src))), (VREV16q8  QPR:$src)>;
+  def : Pat<(v16i8 (bitconvert (v8bf16 QPR:$src))), (VREV16q8  QPR:$src)>;
   def : Pat<(v16i8 (bitconvert (v8i16 QPR:$src))), (VREV16q8  QPR:$src)>;
 }
 
@@ -7593,12 +7643,12 @@ let Predicates = [HasNEON] in {
   // input and output types are the same, the bitconvert gets elided
   // and we end up generating a nonsense match of nothing.
 
-  foreach VT = [ v16i8, v8i16, v8f16, v4i32, v4f32, v2i64, v2f64 ] in
-    foreach VT2 = [ v16i8, v8i16, v8f16, v4i32, v4f32, v2i64, v2f64 ] in
+  foreach VT = [ v16i8, v8i16, v8f16, v8bf16, v4i32, v4f32, v2i64, v2f64 ] in
+    foreach VT2 = [ v16i8, v8i16, v8f16, v8bf16, v4i32, v4f32, v2i64, v2f64 ] in
       def : Pat<(VT (ARMVectorRegCastImpl (VT2 QPR:$src))), (VT QPR:$src)>;
 
-  foreach VT = [ v8i8, v4i16, v4f16, v2i32, v2f32, v1i64, f64 ] in
-    foreach VT2 = [ v8i8, v4i16, v4f16, v2i32, v2f32, v1i64, f64 ] in
+  foreach VT = [ v8i8, v4i16, v4f16, v4bf16, v2i32, v2f32, v1i64, f64 ] in
+    foreach VT2 = [ v8i8, v4i16, v4f16, v4bf16, v2i32, v2f32, v1i64, f64 ] in
       def : Pat<(VT (ARMVectorRegCastImpl (VT2 DPR:$src))), (VT DPR:$src)>;
 }
 

diff  --git a/llvm/lib/Target/ARM/ARMInstrVFP.td b/llvm/lib/Target/ARM/ARMInstrVFP.td
index 80008c59a56a..5611ddb57541 100644
--- a/llvm/lib/Target/ARM/ARMInstrVFP.td
+++ b/llvm/lib/Target/ARM/ARMInstrVFP.td
@@ -158,11 +158,16 @@ def VLDRS : ASI5<0b1101, 0b01, (outs SPR:$Sd), (ins addrmode5:$addr),
 let isUnpredicable = 1 in
 def VLDRH : AHI5<0b1101, 0b01, (outs HPR:$Sd), (ins addrmode5fp16:$addr),
                  IIC_fpLoad16, "vldr", ".16\t$Sd, $addr",
-                 [(set HPR:$Sd, (alignedload16 addrmode5fp16:$addr))]>,
+                 []>,
             Requires<[HasFPRegs16]>;
 
 } // End of 'let canFoldAsLoad = 1, isReMaterializable = 1 in'
 
+def : FPRegs16Pat<(f16 (alignedload16 addrmode5fp16:$addr)),
+                  (VLDRH addrmode5fp16:$addr)>;
+def : FPRegs16Pat<(bf16 (alignedload16 addrmode5fp16:$addr)),
+                  (VLDRH addrmode5fp16:$addr)>;
+
 def VSTRD : ADI5<0b1101, 0b00, (outs), (ins DPR:$Dd, addrmode5:$addr),
                  IIC_fpStore64, "vstr", "\t$Dd, $addr",
                  [(alignedstore32 (f64 DPR:$Dd), addrmode5:$addr)]>,
@@ -180,9 +185,14 @@ def VSTRS : ASI5<0b1101, 0b00, (outs), (ins SPR:$Sd, addrmode5:$addr),
 let isUnpredicable = 1 in
 def VSTRH : AHI5<0b1101, 0b00, (outs), (ins HPR:$Sd, addrmode5fp16:$addr),
                  IIC_fpStore16, "vstr", ".16\t$Sd, $addr",
-                 [(alignedstore16 HPR:$Sd, addrmode5fp16:$addr)]>,
+                 []>,
             Requires<[HasFPRegs16]>;
 
+def : FPRegs16Pat<(alignedstore16 (f16 HPR:$Sd), addrmode5fp16:$addr),
+                  (VSTRH (f16 HPR:$Sd), addrmode5fp16:$addr)>;
+def : FPRegs16Pat<(alignedstore16 (bf16 HPR:$Sd), addrmode5fp16:$addr),
+                  (VSTRH (bf16 HPR:$Sd), addrmode5fp16:$addr)>;
+
 //===----------------------------------------------------------------------===//
 //  Load / store multiple Instructions.
 //
@@ -1250,7 +1260,7 @@ def VMOVSRR : AVConv5I<0b11000100, 0b1010,
 def VMOVRH : AVConv2I<0b11100001, 0b1001,
                       (outs rGPR:$Rt), (ins HPR:$Sn),
                       IIC_fpMOVSI, "vmov", ".f16\t$Rt, $Sn",
-                      [(set rGPR:$Rt, (arm_vmovrh HPR:$Sn))]>,
+                      []>,
              Requires<[HasFPRegs16]>,
              Sched<[WriteFPMOV]> {
   // Instruction operands.
@@ -1272,7 +1282,7 @@ def VMOVRH : AVConv2I<0b11100001, 0b1001,
 def VMOVHR : AVConv4I<0b11100000, 0b1001,
                       (outs HPR:$Sn), (ins rGPR:$Rt),
                       IIC_fpMOVIS, "vmov", ".f16\t$Sn, $Rt",
-                      [(set HPR:$Sn, (arm_vmovhr rGPR:$Rt))]>,
+                      []>,
              Requires<[HasFPRegs16]>,
              Sched<[WriteFPMOV]> {
   // Instruction operands.
@@ -1290,6 +1300,11 @@ def VMOVHR : AVConv4I<0b11100000, 0b1001,
   let isUnpredicable = 1;
 }
 
+def : FPRegs16Pat<(arm_vmovrh (f16 HPR:$Sn)), (VMOVRH (f16 HPR:$Sn))>;
+def : FPRegs16Pat<(arm_vmovrh (bf16 HPR:$Sn)), (VMOVRH (bf16 HPR:$Sn))>;
+def : FPRegs16Pat<(f16 (arm_vmovhr rGPR:$Rt)), (VMOVHR rGPR:$Rt)>;
+def : FPRegs16Pat<(bf16 (arm_vmovhr rGPR:$Rt)), (VMOVHR rGPR:$Rt)>;
+
 // FMRDH: SPR -> GPR
 // FMRDL: SPR -> GPR
 // FMRRS: SPR -> GPR

diff  --git a/llvm/lib/Target/ARM/ARMRegisterInfo.td b/llvm/lib/Target/ARM/ARMRegisterInfo.td
index 39cdb685c492..a384b0dc757c 100644
--- a/llvm/lib/Target/ARM/ARMRegisterInfo.td
+++ b/llvm/lib/Target/ARM/ARMRegisterInfo.td
@@ -390,7 +390,7 @@ def SPR : RegisterClass<"ARM", [f32], 32, (sequence "S%u", 0, 31)> {
   let DiagnosticString = "operand must be a register in range [s0, s31]";
 }
 
-def HPR : RegisterClass<"ARM", [f16], 32, (sequence "S%u", 0, 31)> {
+def HPR : RegisterClass<"ARM", [f16, bf16], 32, (sequence "S%u", 0, 31)> {
   let AltOrders = [(add (decimate HPR, 2), SPR),
                    (add (decimate HPR, 4),
                         (decimate HPR, 2),
@@ -412,7 +412,7 @@ def SPR_8 : RegisterClass<"ARM", [f32], 32, (sequence "S%u", 0, 15)> {
 // class.
 // ARM requires only word alignment for double. It's more performant if it
 // is double-word alignment though.
-def DPR : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16], 64,
+def DPR : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16, v4bf16], 64,
                         (sequence "D%u", 0, 31)> {
   // Allocate non-VFP2 registers D16-D31 first, and prefer even registers on
   // Darwin platforms.
@@ -433,20 +433,20 @@ def FPWithVPR : RegisterClass<"ARM", [f32], 32, (add SPR, DPR, VPR)> {
 
 // Subset of DPR that are accessible with VFP2 (and so that also have
 // 32-bit SPR subregs).
-def DPR_VFP2 : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16], 64,
+def DPR_VFP2 : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16, v4bf16], 64,
                              (trunc DPR, 16)> {
   let DiagnosticString = "operand must be a register in range [d0, d15]";
 }
 
 // Subset of DPR which can be used as a source of NEON scalars for 16-bit
 // operations
-def DPR_8 : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16], 64,
+def DPR_8 : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16, v4bf16], 64,
                           (trunc DPR, 8)> {
   let DiagnosticString = "operand must be a register in range [d0, d7]";
 }
 
 // Generic 128-bit vector register class.
-def QPR : RegisterClass<"ARM", [v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16], 128,
+def QPR : RegisterClass<"ARM", [v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16, v8bf16], 128,
                         (sequence "Q%u", 0, 15)> {
   // Allocate non-VFP2 aliases Q8-Q15 first.
   let AltOrders = [(rotl QPR, 8), (trunc QPR, 8)];

diff  --git a/llvm/lib/Target/ARM/ARMSubtarget.h b/llvm/lib/Target/ARM/ARMSubtarget.h
index b49b953a84b1..566765205899 100644
--- a/llvm/lib/Target/ARM/ARMSubtarget.h
+++ b/llvm/lib/Target/ARM/ARMSubtarget.h
@@ -702,6 +702,7 @@ class ARMSubtarget : public ARMGenSubtargetInfo {
   bool hasD32() const { return HasD32; }
   bool hasFullFP16() const { return HasFullFP16; }
   bool hasFP16FML() const { return HasFP16FML; }
+  bool hasBF16() const { return HasBF16; }
 
   bool hasFuseAES() const { return HasFuseAES; }
   bool hasFuseLiterals() const { return HasFuseLiterals; }

diff  --git a/llvm/test/CodeGen/ARM/bfloat.ll b/llvm/test/CodeGen/ARM/bfloat.ll
new file mode 100644
index 000000000000..53b7cd6018f6
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/bfloat.ll
@@ -0,0 +1,106 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -float-abi hard -mattr=+bf16,+fullfp16 < %s | FileCheck %s --check-prefix=HARD
+; RUN: llc -float-abi soft -mattr=+bf16,+fullfp16 < %s | FileCheck %s --check-prefix=SOFT
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+target triple = "armv8.6a-arm-none-eabi"
+
+define bfloat @load_scalar_bf(bfloat* %addr) {
+; HARD-LABEL: load_scalar_bf:
+; HARD:       @ %bb.0: @ %entry
+; HARD-NEXT:    vldr.16 s0, [r0]
+; HARD-NEXT:    bx lr
+;
+; SOFT-LABEL: load_scalar_bf:
+; SOFT:       @ %bb.0: @ %entry
+; SOFT-NEXT:    vldr.16 s0, [r0]
+; SOFT-NEXT:    vmov r0, s0
+; SOFT-NEXT:    bx lr
+entry:
+  %0 = load bfloat, bfloat* %addr, align 2
+  ret bfloat %0
+}
+
+define void @store_scalar_bf(bfloat %v, bfloat* %addr) {
+; HARD-LABEL: store_scalar_bf:
+; HARD:       @ %bb.0: @ %entry
+; HARD-NEXT:    vstr.16 s0, [r0]
+; HARD-NEXT:    bx lr
+;
+; SOFT-LABEL: store_scalar_bf:
+; SOFT:       @ %bb.0: @ %entry
+; SOFT-NEXT:    vmov.f16 s0, r0
+; SOFT-NEXT:    vstr.16 s0, [r1]
+; SOFT-NEXT:    bx lr
+entry:
+  store bfloat %v, bfloat* %addr, align 2
+  ret void
+}
+
+define <4 x bfloat> @load_vector4_bf(<4 x bfloat>* %addr) {
+; HARD-LABEL: load_vector4_bf:
+; HARD:       @ %bb.0: @ %entry
+; HARD-NEXT:    vldr d0, [r0]
+; HARD-NEXT:    bx lr
+;
+; SOFT-LABEL: load_vector4_bf:
+; SOFT:       @ %bb.0: @ %entry
+; SOFT-NEXT:    vldr d16, [r0]
+; SOFT-NEXT:    vmov r0, r1, d16
+; SOFT-NEXT:    bx lr
+entry:
+  %0 = load <4 x bfloat>, <4 x bfloat>* %addr, align 8
+  ret <4 x bfloat> %0
+}
+
+define void @store_vector4_bf(<4 x bfloat> %v, <4 x bfloat>* %addr) {
+; HARD-LABEL: store_vector4_bf:
+; HARD:       @ %bb.0: @ %entry
+; HARD-NEXT:    vstr d0, [r0]
+; HARD-NEXT:    bx lr
+;
+; SOFT-LABEL: store_vector4_bf:
+; SOFT:       @ %bb.0: @ %entry
+; SOFT-NEXT:    strd r0, r1, [r2]
+; SOFT-NEXT:    bx lr
+entry:
+  store <4 x bfloat> %v, <4 x bfloat>* %addr, align 8
+  ret void
+}
+
+define <8 x bfloat> @load_vector8_bf(<8 x bfloat>* %addr) {
+; HARD-LABEL: load_vector8_bf:
+; HARD:       @ %bb.0: @ %entry
+; HARD-NEXT:    vld1.64 {d0, d1}, [r0]
+; HARD-NEXT:    bx lr
+;
+; SOFT-LABEL: load_vector8_bf:
+; SOFT:       @ %bb.0: @ %entry
+; SOFT-NEXT:    vld1.64 {d16, d17}, [r0]
+; SOFT-NEXT:    vmov r0, r1, d16
+; SOFT-NEXT:    vmov r2, r3, d17
+; SOFT-NEXT:    bx lr
+entry:
+  %0 = load <8 x bfloat>, <8 x bfloat>* %addr, align 8
+  ret <8 x bfloat> %0
+}
+
+define void @store_vector8_bf(<8 x bfloat> %v, <8 x bfloat>* %addr) {
+; HARD-LABEL: store_vector8_bf:
+; HARD:       @ %bb.0: @ %entry
+; HARD-NEXT:    vst1.64 {d0, d1}, [r0]
+; HARD-NEXT:    bx lr
+;
+; SOFT-LABEL: store_vector8_bf:
+; SOFT:       @ %bb.0: @ %entry
+; SOFT-NEXT:    vmov d17, r2, r3
+; SOFT-NEXT:    ldr r12, [sp]
+; SOFT-NEXT:    vmov d16, r0, r1
+; SOFT-NEXT:    vst1.64 {d16, d17}, [r12]
+; SOFT-NEXT:    bx lr
+entry:
+  store <8 x bfloat> %v, <8 x bfloat>* %addr, align 8
+  ret void
+}
+
+


        


More information about the llvm-commits mailing list