[llvm] 42eba9b - [AArch64][BFloat] basic AArch64 bfloat support

Ties Stuij via llvm-commits llvm-commits at lists.llvm.org
Wed May 27 07:26:53 PDT 2020


Author: Ties Stuij
Date: 2020-05-27T15:26:40+01:00
New Revision: 42eba9b40b25cceeb3e6d432047c5ef99d4a7b50

URL: https://github.com/llvm/llvm-project/commit/42eba9b40b25cceeb3e6d432047c5ef99d4a7b50
DIFF: https://github.com/llvm/llvm-project/commit/42eba9b40b25cceeb3e6d432047c5ef99d4a7b50.diff

LOG: [AArch64][BFloat] basic AArch64 bfloat support

Summary:
This patch adds the bfloat type to the AArch64 backend:
- adds it as part of the FPR16 register class
- adds bfloat calling conventions
- as f16 is now not the only FPR16 type anymore, we need to constrain a number
  of instruction patterns using FPR16Op to help out the TableGen type inferrer

This patch is part of a series implementing the Bfloat16 extension of the
Armv8.6-a architecture, as detailed here:

https://community.arm.com/developer/ip-products/processors/b/processors-ip-blog/posts/arm-architecture-developments-armv8-6-a

The bfloat type, and its properties are specified in the Arm Architecture
Reference Manual:

https://developer.arm.com/docs/ddi0487/latest/arm-architecture-reference-manual-armv8-for-armv8-a-architecture-profile

Reviewers: t.p.northover, c-rhodes, fpetrogalli, sdesmalen, ostannard, LukeGeeson, ab

Reviewed By: fpetrogalli

Subscribers: pbarrio, LukeGeeson, kristof.beyls, hiraditya, danielkiss, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79709

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64CallingConvention.td
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrFormats.td
    llvm/lib/Target/AArch64/AArch64InstrInfo.td
    llvm/lib/Target/AArch64/AArch64RegisterInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
index 6eb9ba486462..eed87946dab9 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
@@ -33,9 +33,9 @@ def CC_AArch64_AAPCS : CallingConv<[
 
   // Big endian vectors must be passed as if they were 1-element vectors so that
   // their lanes are in a consistent order.
-  CCIfBigEndian<CCIfType<[v2i32, v2f32, v4i16, v4f16, v8i8],
+  CCIfBigEndian<CCIfType<[v2i32, v2f32, v4i16, v4f16, v4bf16, v8i8],
                          CCBitConvertToType<f64>>>,
-  CCIfBigEndian<CCIfType<[v2i64, v2f64, v4i32, v4f32, v8i16, v8f16, v16i8],
+  CCIfBigEndian<CCIfType<[v2i64, v2f64, v4i32, v4f32, v8i16, v8f16, v8bf16, v16i8],
                          CCBitConvertToType<f128>>>,
 
   // In AAPCS, an SRet is passed in X8, not X0 like a normal pointer parameter.
@@ -75,10 +75,10 @@ def CC_AArch64_AAPCS : CallingConv<[
   CCIfConsecutiveRegs<CCCustom<"CC_AArch64_Custom_Block">>,
 
   CCIfType<[nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv2f16, nxv4f16, nxv8f16,
-            nxv2f32, nxv4f32, nxv2f64],
+            nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
            CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>,
   CCIfType<[nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv2f16, nxv4f16, nxv8f16,
-            nxv2f32, nxv4f32, nxv2f64],
+            nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
            CCPassIndirect<i64>>,
 
   CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1],
@@ -102,22 +102,24 @@ def CC_AArch64_AAPCS : CallingConv<[
                                           [W0, W1, W2, W3, W4, W5, W6, W7]>>,
   CCIfType<[f16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
+  CCIfType<[bf16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
+                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
   CCIfType<[f32], CCAssignToRegWithShadow<[S0, S1, S2, S3, S4, S5, S6, S7],
                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
   CCIfType<[f64], CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
-  CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16],
+  CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16, v4bf16],
            CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
                                    [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
-  CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+  CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
            CCAssignToReg<[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
 
   // If more than will fit in registers, pass them on the stack instead.
-  CCIfType<[i1, i8, i16, f16], CCAssignToStack<8, 8>>,
+  CCIfType<[i1, i8, i16, f16, bf16], CCAssignToStack<8, 8>>,
   CCIfType<[i32, f32], CCAssignToStack<8, 8>>,
-  CCIfType<[i64, f64, v1f64, v2f32, v1i64, v2i32, v4i16, v8i8, v4f16],
+  CCIfType<[i64, f64, v1f64, v2f32, v1i64, v2i32, v4i16, v8i8, v4f16, v4bf16],
            CCAssignToStack<8, 8>>,
-  CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+  CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
            CCAssignToStack<16, 16>>
 ]>;
 
@@ -132,9 +134,9 @@ def RetCC_AArch64_AAPCS : CallingConv<[
 
   // Big endian vectors must be passed as if they were 1-element vectors so that
   // their lanes are in a consistent order.
-  CCIfBigEndian<CCIfType<[v2i32, v2f32, v4i16, v4f16, v8i8],
+  CCIfBigEndian<CCIfType<[v2i32, v2f32, v4i16, v4f16, v4bf16, v8i8],
                          CCBitConvertToType<f64>>>,
-  CCIfBigEndian<CCIfType<[v2i64, v2f64, v4i32, v4f32, v8i16, v8f16, v16i8],
+  CCIfBigEndian<CCIfType<[v2i64, v2f64, v4i32, v4f32, v8i16, v8f16, v8bf16, v16i8],
                          CCBitConvertToType<f128>>>,
 
   CCIfType<[i1, i8, i16], CCPromoteToType<i32>>,
@@ -144,18 +146,20 @@ def RetCC_AArch64_AAPCS : CallingConv<[
                                           [W0, W1, W2, W3, W4, W5, W6, W7]>>,
   CCIfType<[f16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
+  CCIfType<[bf16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
+                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
   CCIfType<[f32], CCAssignToRegWithShadow<[S0, S1, S2, S3, S4, S5, S6, S7],
                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
   CCIfType<[f64], CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
-  CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16],
+  CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16, v4bf16],
       CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
                               [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
-  CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+  CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
       CCAssignToReg<[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
 
   CCIfType<[nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv2f16, nxv4f16, nxv8f16,
-            nxv2f32, nxv4f32, nxv2f64],
+            nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
            CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>,
 
   CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1],
@@ -165,7 +169,7 @@ def RetCC_AArch64_AAPCS : CallingConv<[
 // Vararg functions on windows pass floats in integer registers
 let Entry = 1 in
 def CC_AArch64_Win64_VarArg : CallingConv<[
-  CCIfType<[f16, f32], CCPromoteToType<f64>>,
+  CCIfType<[f16, bf16, f32], CCPromoteToType<f64>>,
   CCIfType<[f64], CCBitConvertToType<i64>>,
   CCDelegateTo<CC_AArch64_AAPCS>
 ]>;
@@ -219,19 +223,22 @@ def CC_AArch64_DarwinPCS : CallingConv<[
                                           [W0, W1, W2, W3, W4, W5, W6, W7]>>,
   CCIfType<[f16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
+  CCIfType<[bf16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
+                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
   CCIfType<[f32], CCAssignToRegWithShadow<[S0, S1, S2, S3, S4, S5, S6, S7],
                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
   CCIfType<[f64], CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
                                           [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
-  CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16],
+  CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16, v4bf16],
            CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
                                    [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
-  CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+  CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
            CCAssignToReg<[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
 
   // If more than will fit in registers, pass them on the stack instead.
   CCIf<"ValVT == MVT::i1 || ValVT == MVT::i8", CCAssignToStack<1, 1>>,
-  CCIf<"ValVT == MVT::i16 || ValVT == MVT::f16", CCAssignToStack<2, 2>>,
+  CCIf<"ValVT == MVT::i16 || ValVT == MVT::f16 || ValVT == MVT::bf16",
+  CCAssignToStack<2, 2>>,
   CCIfType<[i32, f32], CCAssignToStack<4, 4>>,
 
   // Re-demote pointers to 32-bits so we don't end up storing 64-bit
@@ -239,9 +246,9 @@ def CC_AArch64_DarwinPCS : CallingConv<[
   CCIfPtr<CCIfILP32<CCTruncToType<i32>>>,
   CCIfPtr<CCIfILP32<CCAssignToStack<4, 4>>>,
 
-  CCIfType<[i64, f64, v1f64, v2f32, v1i64, v2i32, v4i16, v8i8, v4f16],
+  CCIfType<[i64, f64, v1f64, v2f32, v1i64, v2i32, v4i16, v8i8, v4f16, v4bf16],
            CCAssignToStack<8, 8>>,
-  CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+  CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
            CCAssignToStack<16, 16>>
 ]>;
 
@@ -255,14 +262,14 @@ def CC_AArch64_DarwinPCS_VarArg : CallingConv<[
 
   // Handle all scalar types as either i64 or f64.
   CCIfType<[i8, i16, i32], CCPromoteToType<i64>>,
-  CCIfType<[f16, f32],     CCPromoteToType<f64>>,
+  CCIfType<[f16, bf16, f32], CCPromoteToType<f64>>,
 
   // Everything is on the stack.
   // i128 is split to two i64s, and its stack alignment is 16 bytes.
   CCIfType<[i64], CCIfSplit<CCAssignToStack<8, 16>>>,
-  CCIfType<[i64, f64, v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16],
+  CCIfType<[i64, f64, v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16, v4bf16],
            CCAssignToStack<8, 8>>,
-  CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+  CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
            CCAssignToStack<16, 16>>
 ]>;
 
@@ -275,16 +282,16 @@ def CC_AArch64_DarwinPCS_ILP32_VarArg : CallingConv<[
 
   // Handle all scalar types as either i32 or f32.
   CCIfType<[i8, i16], CCPromoteToType<i32>>,
-  CCIfType<[f16],     CCPromoteToType<f32>>,
+  CCIfType<[f16, bf16], CCPromoteToType<f32>>,
 
   // Everything is on the stack.
   // i128 is split to two i64s, and its stack alignment is 16 bytes.
   CCIfPtr<CCIfILP32<CCTruncToType<i32>>>,
   CCIfType<[i32, f32], CCAssignToStack<4, 4>>,
   CCIfType<[i64], CCIfSplit<CCAssignToStack<8, 16>>>,
-  CCIfType<[i64, f64, v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16],
+  CCIfType<[i64, f64, v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16, v4bf16],
            CCAssignToStack<8, 8>>,
-  CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+  CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
            CCAssignToStack<16, 16>>
 ]>;
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5eb9b7463411..187f133669e6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -132,6 +132,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
   if (Subtarget->hasFPARMv8()) {
     addRegisterClass(MVT::f16, &AArch64::FPR16RegClass);
+    addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass);
     addRegisterClass(MVT::f32, &AArch64::FPR32RegClass);
     addRegisterClass(MVT::f64, &AArch64::FPR64RegClass);
     addRegisterClass(MVT::f128, &AArch64::FPR128RegClass);
@@ -148,6 +149,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     addDRTypeForNEON(MVT::v1i64);
     addDRTypeForNEON(MVT::v1f64);
     addDRTypeForNEON(MVT::v4f16);
+    addDRTypeForNEON(MVT::v4bf16);
 
     addQRTypeForNEON(MVT::v4f32);
     addQRTypeForNEON(MVT::v2f64);
@@ -156,6 +158,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     addQRTypeForNEON(MVT::v4i32);
     addQRTypeForNEON(MVT::v2i64);
     addQRTypeForNEON(MVT::v8f16);
+    addQRTypeForNEON(MVT::v8bf16);
   }
 
   if (Subtarget->hasSVE()) {
@@ -174,6 +177,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     addRegisterClass(MVT::nxv2f16, &AArch64::ZPRRegClass);
     addRegisterClass(MVT::nxv4f16, &AArch64::ZPRRegClass);
     addRegisterClass(MVT::nxv8f16, &AArch64::ZPRRegClass);
+    addRegisterClass(MVT::nxv2bf16, &AArch64::ZPRRegClass);
+    addRegisterClass(MVT::nxv4bf16, &AArch64::ZPRRegClass);
+    addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass);
     addRegisterClass(MVT::nxv2f32, &AArch64::ZPRRegClass);
     addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass);
     addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass);
@@ -3578,6 +3584,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
         RC = &AArch64::GPR64RegClass;
       else if (RegVT == MVT::f16)
         RC = &AArch64::FPR16RegClass;
+      else if (RegVT == MVT::bf16)
+        RC = &AArch64::FPR16RegClass;
       else if (RegVT == MVT::f32)
         RC = &AArch64::FPR32RegClass;
       else if (RegVT == MVT::f64 || RegVT.is64BitVector())

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index a06394a2898d..713bf0bf3cad 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -4447,14 +4447,14 @@ multiclass FPToIntegerUnscaled<bits<2> rmode, bits<3> opcode, string asm,
            SDPatternOperator OpN> {
   // Unscaled half-precision to 32-bit
   def UWHr : BaseFPToIntegerUnscaled<0b11, rmode, opcode, FPR16, GPR32, asm,
-                                     [(set GPR32:$Rd, (OpN FPR16:$Rn))]> {
+                                     [(set GPR32:$Rd, (OpN (f16 FPR16:$Rn)))]> {
     let Inst{31} = 0; // 32-bit GPR flag
     let Predicates = [HasFullFP16];
   }
 
   // Unscaled half-precision to 64-bit
   def UXHr : BaseFPToIntegerUnscaled<0b11, rmode, opcode, FPR16, GPR64, asm,
-                                     [(set GPR64:$Rd, (OpN FPR16:$Rn))]> {
+                                     [(set GPR64:$Rd, (OpN (f16 FPR16:$Rn)))]> {
     let Inst{31} = 1; // 64-bit GPR flag
     let Predicates = [HasFullFP16];
   }
@@ -4489,7 +4489,7 @@ multiclass FPToIntegerScaled<bits<2> rmode, bits<3> opcode, string asm,
   // Scaled half-precision to 32-bit
   def SWHri : BaseFPToInteger<0b11, rmode, opcode, FPR16, GPR32,
                               fixedpoint_f16_i32, asm,
-              [(set GPR32:$Rd, (OpN (fmul FPR16:$Rn,
+              [(set GPR32:$Rd, (OpN (fmul (f16 FPR16:$Rn),
                                           fixedpoint_f16_i32:$scale)))]> {
     let Inst{31} = 0; // 32-bit GPR flag
     let scale{5} = 1;
@@ -4499,7 +4499,7 @@ multiclass FPToIntegerScaled<bits<2> rmode, bits<3> opcode, string asm,
   // Scaled half-precision to 64-bit
   def SXHri : BaseFPToInteger<0b11, rmode, opcode, FPR16, GPR64,
                               fixedpoint_f16_i64, asm,
-              [(set GPR64:$Rd, (OpN (fmul FPR16:$Rn,
+              [(set GPR64:$Rd, (OpN (fmul (f16 FPR16:$Rn),
                                           fixedpoint_f16_i64:$scale)))]> {
     let Inst{31} = 1; // 64-bit GPR flag
     let Predicates = [HasFullFP16];
@@ -4615,7 +4615,7 @@ multiclass IntegerToFP<bit isUnsigned, string asm, SDNode node> {
 
   // Scaled
   def SWHri: BaseIntegerToFP<isUnsigned, GPR32, FPR16, fixedpoint_f16_i32, asm,
-                             [(set FPR16:$Rd,
+                             [(set (f16 FPR16:$Rd),
                                    (fdiv (node GPR32:$Rn),
                                          fixedpoint_f16_i32:$scale))]> {
     let Inst{31} = 0; // 32-bit GPR flag
@@ -4643,7 +4643,7 @@ multiclass IntegerToFP<bit isUnsigned, string asm, SDNode node> {
   }
 
   def SXHri: BaseIntegerToFP<isUnsigned, GPR64, FPR16, fixedpoint_f16_i64, asm,
-                             [(set FPR16:$Rd,
+                             [(set (f16 FPR16:$Rd),
                                    (fdiv (node GPR64:$Rn),
                                          fixedpoint_f16_i64:$scale))]> {
     let Inst{31} = 1; // 64-bit GPR flag
@@ -4816,7 +4816,7 @@ class BaseFPConversion<bits<2> type, bits<2> opcode, RegisterClass dstType,
 multiclass FPConversion<string asm> {
   // Double-precision to Half-precision
   def HDr : BaseFPConversion<0b01, 0b11, FPR16, FPR64, asm,
-                             [(set FPR16:$Rd, (any_fpround FPR64:$Rn))]>;
+                             [(set (f16 FPR16:$Rd), (any_fpround FPR64:$Rn))]>;
 
   // Double-precision to Single-precision
   def SDr : BaseFPConversion<0b01, 0b00, FPR32, FPR64, asm,
@@ -4824,11 +4824,11 @@ multiclass FPConversion<string asm> {
 
   // Half-precision to Double-precision
   def DHr : BaseFPConversion<0b11, 0b01, FPR64, FPR16, asm,
-                             [(set FPR64:$Rd, (fpextend FPR16:$Rn))]>;
+                             [(set FPR64:$Rd, (fpextend (f16 FPR16:$Rn)))]>;
 
   // Half-precision to Single-precision
   def SHr : BaseFPConversion<0b11, 0b00, FPR32, FPR16, asm,
-                             [(set FPR32:$Rd, (fpextend FPR16:$Rn))]>;
+                             [(set FPR32:$Rd, (fpextend (f16 FPR16:$Rn)))]>;
 
   // Single-precision to Double-precision
   def DSr : BaseFPConversion<0b00, 0b01, FPR64, FPR32, asm,
@@ -4836,7 +4836,7 @@ multiclass FPConversion<string asm> {
 
   // Single-precision to Half-precision
   def HSr : BaseFPConversion<0b00, 0b11, FPR16, FPR32, asm,
-                             [(set FPR16:$Rd, (any_fpround FPR32:$Rn))]>;
+                             [(set (f16 FPR16:$Rd), (any_fpround FPR32:$Rn))]>;
 }
 
 //---
@@ -4938,7 +4938,7 @@ multiclass TwoOperandFPData<bits<4> opcode, string asm,
 
 multiclass TwoOperandFPDataNeg<bits<4> opcode, string asm, SDNode node> {
   def Hrr : BaseTwoOperandFPData<opcode, FPR16, asm,
-                  [(set FPR16:$Rd, (fneg (node FPR16:$Rn, (f16 FPR16:$Rm))))]> {
+                  [(set (f16 FPR16:$Rd), (fneg (node (f16 FPR16:$Rn), (f16 FPR16:$Rm))))]> {
     let Inst{23-22} = 0b11; // 16-bit size flag
     let Predicates = [HasFullFP16];
   }
@@ -4980,7 +4980,7 @@ class BaseThreeOperandFPData<bit isNegated, bit isSub,
 multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
                               SDPatternOperator node> {
   def Hrrr : BaseThreeOperandFPData<isNegated, isSub, FPR16, asm,
-            [(set FPR16:$Rd,
+            [(set (f16 FPR16:$Rd),
                   (node (f16 FPR16:$Rn), (f16 FPR16:$Rm), (f16 FPR16:$Ra)))]> {
     let Inst{23-22} = 0b11; // 16-bit size flag
     let Predicates = [HasFullFP16];
@@ -5042,7 +5042,7 @@ multiclass FPComparison<bit signalAllNans, string asm,
                         SDPatternOperator OpNode = null_frag> {
   let Defs = [NZCV] in {
   def Hrr : BaseTwoOperandFPComparison<signalAllNans, FPR16, asm,
-      [(OpNode FPR16:$Rn, (f16 FPR16:$Rm)), (implicit NZCV)]> {
+      [(OpNode (f16 FPR16:$Rn), (f16 FPR16:$Rm)), (implicit NZCV)]> {
     let Inst{23-22} = 0b11;
     let Predicates = [HasFullFP16];
   }
@@ -6742,7 +6742,7 @@ multiclass SIMDFPThreeScalar<bit U, bit S, bits<3> opc, string asm,
       [(set FPR32:$Rd, (OpNode FPR32:$Rn, FPR32:$Rm))]>;
     let Predicates = [HasNEON, HasFullFP16] in {
     def NAME#16 : BaseSIMDThreeScalar<U, {S,0b10}, {0b00,opc}, FPR16, asm,
-      [(set FPR16:$Rd, (OpNode FPR16:$Rn, FPR16:$Rm))]>;
+      [(set (f16 FPR16:$Rd), (OpNode (f16 FPR16:$Rn), (f16 FPR16:$Rm)))]>;
     } // Predicates = [HasNEON, HasFullFP16]
   }
 
@@ -6949,7 +6949,7 @@ multiclass SIMDFPTwoScalarCVT<bit U, bit S, bits<5> opc, string asm,
                                 [(set FPR32:$Rd, (OpNode (f32 FPR32:$Rn)))]>;
   let Predicates = [HasNEON, HasFullFP16] in {
   def v1i16 : BaseSIMDTwoScalar<U, {S,1}, 0b11, opc, FPR16, FPR16, asm,
-                                [(set FPR16:$Rd, (OpNode (f16 FPR16:$Rn)))]>;
+                                [(set (f16 FPR16:$Rd), (OpNode (f16 FPR16:$Rn)))]>;
   }
 }
 
@@ -7091,10 +7091,10 @@ multiclass SIMDFPAcrossLanes<bits<5> opcode, bit sz1, string asm,
   let Predicates = [HasNEON, HasFullFP16] in {
   def v4i16v : BaseSIMDAcrossLanes<0, 0, {sz1, 0}, opcode, FPR16, V64,
                                    asm, ".4h",
-        [(set FPR16:$Rd, (intOp (v4f16 V64:$Rn)))]>;
+        [(set (f16 FPR16:$Rd), (intOp (v4f16 V64:$Rn)))]>;
   def v8i16v : BaseSIMDAcrossLanes<1, 0, {sz1, 0}, opcode, FPR16, V128,
                                    asm, ".8h",
-        [(set FPR16:$Rd, (intOp (v8f16 V128:$Rn)))]>;
+        [(set (f16 FPR16:$Rd), (intOp (v8f16 V128:$Rn)))]>;
   } // Predicates = [HasNEON, HasFullFP16]
   def v4i32v : BaseSIMDAcrossLanes<1, 1, {sz1, 0}, opcode, FPR32, V128,
                                    asm, ".4s",
@@ -8095,7 +8095,7 @@ multiclass SIMDFPIndexedTiedPatterns<string INST, SDPatternOperator OpNode> {
   def : Pat<(v8f16 (OpNode (v8f16 V128:$Rd), (v8f16 V128:$Rn),
                            (AArch64dup (f16 FPR16Op_lo:$Rm)))),
             (!cast<Instruction>(INST # "v8i16_indexed") V128:$Rd, V128:$Rn,
-                (SUBREG_TO_REG (i32 0), FPR16Op_lo:$Rm, hsub), (i64 0))>;
+                (SUBREG_TO_REG (i32 0), (f16 FPR16Op_lo:$Rm), hsub), (i64 0))>;
 
   def : Pat<(v4f16 (OpNode (v4f16 V64:$Rd), (v4f16 V64:$Rn),
                            (AArch64duplane16 (v8f16 V128_lo:$Rm),
@@ -8105,7 +8105,7 @@ multiclass SIMDFPIndexedTiedPatterns<string INST, SDPatternOperator OpNode> {
   def : Pat<(v4f16 (OpNode (v4f16 V64:$Rd), (v4f16 V64:$Rn),
                            (AArch64dup (f16 FPR16Op_lo:$Rm)))),
             (!cast<Instruction>(INST # "v4i16_indexed") V64:$Rd, V64:$Rn,
-                (SUBREG_TO_REG (i32 0), FPR16Op_lo:$Rm, hsub), (i64 0))>;
+                (SUBREG_TO_REG (i32 0), (f16 FPR16Op_lo:$Rm), hsub), (i64 0))>;
 
   def : Pat<(f16 (OpNode (f16 FPR16:$Rd), (f16 FPR16:$Rn),
                          (vector_extract (v8f16 V128_lo:$Rm), VectorIndexH:$idx))),

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 03ac81e2462b..07bca441529e 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -2529,7 +2529,7 @@ defm LDURB : LoadUnscaled<0b00, 1, 0b01, FPR8Op, "ldur",
                     [(set FPR8Op:$Rt,
                           (load (am_unscaled8 GPR64sp:$Rn, simm9:$offset)))]>;
 defm LDURH : LoadUnscaled<0b01, 1, 0b01, FPR16Op, "ldur",
-                    [(set FPR16Op:$Rt,
+                    [(set (f16 FPR16Op:$Rt),
                           (load (am_unscaled16 GPR64sp:$Rn, simm9:$offset)))]>;
 defm LDURS : LoadUnscaled<0b10, 1, 0b01, FPR32Op, "ldur",
                     [(set (f32 FPR32Op:$Rt),

diff  --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index 93b6aa0cdb7f..bd05c56009a1 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -422,18 +422,20 @@ def Q31   : AArch64Reg<31, "q31", [D31], ["v31", ""]>, DwarfRegAlias<B31>;
 def FPR8  : RegisterClass<"AArch64", [untyped], 8, (sequence "B%u", 0, 31)> {
   let Size = 8;
 }
-def FPR16 : RegisterClass<"AArch64", [f16], 16, (sequence "H%u", 0, 31)> {
+def FPR16 : RegisterClass<"AArch64", [f16, bf16], 16, (sequence "H%u", 0, 31)> {
   let Size = 16;
 }
+
 def FPR16_lo : RegisterClass<"AArch64", [f16], 16, (trunc FPR16, 16)> {
   let Size = 16;
 }
 def FPR32 : RegisterClass<"AArch64", [f32, i32], 32,(sequence "S%u", 0, 31)>;
 def FPR64 : RegisterClass<"AArch64", [f64, i64, v2f32, v1f64, v8i8, v4i16, v2i32,
-                                    v1i64, v4f16],
-                                    64, (sequence "D%u", 0, 31)>;
+                                      v1i64, v4f16, v4bf16],
+                                     64, (sequence "D%u", 0, 31)>;
 def FPR64_lo : RegisterClass<"AArch64",
-                             [v8i8, v4i16, v2i32, v1i64, v4f16, v2f32, v1f64],
+                             [v8i8, v4i16, v2i32, v1i64, v4f16, v4bf16, v2f32,
+                              v1f64],
                              64, (trunc FPR64, 16)>;
 
 // We don't (yet) have an f128 legal type, so don't use that here. We
@@ -441,13 +443,14 @@ def FPR64_lo : RegisterClass<"AArch64",
 // that here.
 def FPR128 : RegisterClass<"AArch64",
                            [v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, f128,
-                            v8f16],
+                            v8f16, v8bf16],
                            128, (sequence "Q%u", 0, 31)>;
 
 // The lower 16 vector registers.  Some instructions can only take registers
 // in this range.
 def FPR128_lo : RegisterClass<"AArch64",
-                              [v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16],
+                              [v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16,
+                               v8bf16],
                               128, (trunc FPR128, 16)>;
 
 // Pairs, triples, and quads of 64-bit vector registers.
@@ -876,6 +879,7 @@ def PPR3b64  : PPRRegOp<"d", PPRAsmOp3b64,  ElementSizeD, PPR_3b>;
 class ZPRClass<int lastreg> : RegisterClass<"AArch64",
                                             [nxv16i8, nxv8i16, nxv4i32, nxv2i64,
                                              nxv2f16, nxv4f16, nxv8f16,
+                                             nxv2bf16, nxv4bf16, nxv8bf16,
                                              nxv2f32, nxv4f32,
                                              nxv2f64],
                                             128, (sequence "Z%u", 0, lastreg)> {


        


More information about the llvm-commits mailing list