[llvm] 42eba9b - [AArch64][BFloat] basic AArch64 bfloat support
Ties Stuij via llvm-commits
llvm-commits at lists.llvm.org
Wed May 27 07:26:53 PDT 2020
Author: Ties Stuij
Date: 2020-05-27T15:26:40+01:00
New Revision: 42eba9b40b25cceeb3e6d432047c5ef99d4a7b50
URL: https://github.com/llvm/llvm-project/commit/42eba9b40b25cceeb3e6d432047c5ef99d4a7b50
DIFF: https://github.com/llvm/llvm-project/commit/42eba9b40b25cceeb3e6d432047c5ef99d4a7b50.diff
LOG: [AArch64][BFloat] basic AArch64 bfloat support
Summary:
This patch adds the bfloat type to the AArch64 backend:
- adds it as part of the FPR16 register class
- adds bfloat calling conventions
- as f16 is now not the only FPR16 type anymore, we need to constrain a number
of instruction patterns using FPR16Op to help out the TableGen type inferrer
This patch is part of a series implementing the Bfloat16 extension of the
Armv8.6-a architecture, as detailed here:
https://community.arm.com/developer/ip-products/processors/b/processors-ip-blog/posts/arm-architecture-developments-armv8-6-a
The bfloat type, and its properties are specified in the Arm Architecture
Reference Manual:
https://developer.arm.com/docs/ddi0487/latest/arm-architecture-reference-manual-armv8-for-armv8-a-architecture-profile
Reviewers: t.p.northover, c-rhodes, fpetrogalli, sdesmalen, ostannard, LukeGeeson, ab
Reviewed By: fpetrogalli
Subscribers: pbarrio, LukeGeeson, kristof.beyls, hiraditya, danielkiss, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D79709
Added:
Modified:
llvm/lib/Target/AArch64/AArch64CallingConvention.td
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64InstrFormats.td
llvm/lib/Target/AArch64/AArch64InstrInfo.td
llvm/lib/Target/AArch64/AArch64RegisterInfo.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
index 6eb9ba486462..eed87946dab9 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
@@ -33,9 +33,9 @@ def CC_AArch64_AAPCS : CallingConv<[
// Big endian vectors must be passed as if they were 1-element vectors so that
// their lanes are in a consistent order.
- CCIfBigEndian<CCIfType<[v2i32, v2f32, v4i16, v4f16, v8i8],
+ CCIfBigEndian<CCIfType<[v2i32, v2f32, v4i16, v4f16, v4bf16, v8i8],
CCBitConvertToType<f64>>>,
- CCIfBigEndian<CCIfType<[v2i64, v2f64, v4i32, v4f32, v8i16, v8f16, v16i8],
+ CCIfBigEndian<CCIfType<[v2i64, v2f64, v4i32, v4f32, v8i16, v8f16, v8bf16, v16i8],
CCBitConvertToType<f128>>>,
// In AAPCS, an SRet is passed in X8, not X0 like a normal pointer parameter.
@@ -75,10 +75,10 @@ def CC_AArch64_AAPCS : CallingConv<[
CCIfConsecutiveRegs<CCCustom<"CC_AArch64_Custom_Block">>,
CCIfType<[nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv2f16, nxv4f16, nxv8f16,
- nxv2f32, nxv4f32, nxv2f64],
+ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>,
CCIfType<[nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv2f16, nxv4f16, nxv8f16,
- nxv2f32, nxv4f32, nxv2f64],
+ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
CCPassIndirect<i64>>,
CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1],
@@ -102,22 +102,24 @@ def CC_AArch64_AAPCS : CallingConv<[
[W0, W1, W2, W3, W4, W5, W6, W7]>>,
CCIfType<[f16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
+ CCIfType<[bf16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
+ [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
CCIfType<[f32], CCAssignToRegWithShadow<[S0, S1, S2, S3, S4, S5, S6, S7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
CCIfType<[f64], CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
- CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16],
+ CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16, v4bf16],
CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
- CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+ CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
CCAssignToReg<[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
// If more than will fit in registers, pass them on the stack instead.
- CCIfType<[i1, i8, i16, f16], CCAssignToStack<8, 8>>,
+ CCIfType<[i1, i8, i16, f16, bf16], CCAssignToStack<8, 8>>,
CCIfType<[i32, f32], CCAssignToStack<8, 8>>,
- CCIfType<[i64, f64, v1f64, v2f32, v1i64, v2i32, v4i16, v8i8, v4f16],
+ CCIfType<[i64, f64, v1f64, v2f32, v1i64, v2i32, v4i16, v8i8, v4f16, v4bf16],
CCAssignToStack<8, 8>>,
- CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+ CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
CCAssignToStack<16, 16>>
]>;
@@ -132,9 +134,9 @@ def RetCC_AArch64_AAPCS : CallingConv<[
// Big endian vectors must be passed as if they were 1-element vectors so that
// their lanes are in a consistent order.
- CCIfBigEndian<CCIfType<[v2i32, v2f32, v4i16, v4f16, v8i8],
+ CCIfBigEndian<CCIfType<[v2i32, v2f32, v4i16, v4f16, v4bf16, v8i8],
CCBitConvertToType<f64>>>,
- CCIfBigEndian<CCIfType<[v2i64, v2f64, v4i32, v4f32, v8i16, v8f16, v16i8],
+ CCIfBigEndian<CCIfType<[v2i64, v2f64, v4i32, v4f32, v8i16, v8f16, v8bf16, v16i8],
CCBitConvertToType<f128>>>,
CCIfType<[i1, i8, i16], CCPromoteToType<i32>>,
@@ -144,18 +146,20 @@ def RetCC_AArch64_AAPCS : CallingConv<[
[W0, W1, W2, W3, W4, W5, W6, W7]>>,
CCIfType<[f16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
+ CCIfType<[bf16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
+ [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
CCIfType<[f32], CCAssignToRegWithShadow<[S0, S1, S2, S3, S4, S5, S6, S7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
CCIfType<[f64], CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
- CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16],
+ CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16, v4bf16],
CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
- CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+ CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
CCAssignToReg<[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
CCIfType<[nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv2f16, nxv4f16, nxv8f16,
- nxv2f32, nxv4f32, nxv2f64],
+ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>,
CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1],
@@ -165,7 +169,7 @@ def RetCC_AArch64_AAPCS : CallingConv<[
// Vararg functions on windows pass floats in integer registers
let Entry = 1 in
def CC_AArch64_Win64_VarArg : CallingConv<[
- CCIfType<[f16, f32], CCPromoteToType<f64>>,
+ CCIfType<[f16, bf16, f32], CCPromoteToType<f64>>,
CCIfType<[f64], CCBitConvertToType<i64>>,
CCDelegateTo<CC_AArch64_AAPCS>
]>;
@@ -219,19 +223,22 @@ def CC_AArch64_DarwinPCS : CallingConv<[
[W0, W1, W2, W3, W4, W5, W6, W7]>>,
CCIfType<[f16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
+ CCIfType<[bf16], CCAssignToRegWithShadow<[H0, H1, H2, H3, H4, H5, H6, H7],
+ [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
CCIfType<[f32], CCAssignToRegWithShadow<[S0, S1, S2, S3, S4, S5, S6, S7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
CCIfType<[f64], CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
- CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16],
+ CCIfType<[v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16, v4bf16],
CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7],
[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
- CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+ CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
CCAssignToReg<[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>,
// If more than will fit in registers, pass them on the stack instead.
CCIf<"ValVT == MVT::i1 || ValVT == MVT::i8", CCAssignToStack<1, 1>>,
- CCIf<"ValVT == MVT::i16 || ValVT == MVT::f16", CCAssignToStack<2, 2>>,
+ CCIf<"ValVT == MVT::i16 || ValVT == MVT::f16 || ValVT == MVT::bf16",
+ CCAssignToStack<2, 2>>,
CCIfType<[i32, f32], CCAssignToStack<4, 4>>,
// Re-demote pointers to 32-bits so we don't end up storing 64-bit
@@ -239,9 +246,9 @@ def CC_AArch64_DarwinPCS : CallingConv<[
CCIfPtr<CCIfILP32<CCTruncToType<i32>>>,
CCIfPtr<CCIfILP32<CCAssignToStack<4, 4>>>,
- CCIfType<[i64, f64, v1f64, v2f32, v1i64, v2i32, v4i16, v8i8, v4f16],
+ CCIfType<[i64, f64, v1f64, v2f32, v1i64, v2i32, v4i16, v8i8, v4f16, v4bf16],
CCAssignToStack<8, 8>>,
- CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+ CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
CCAssignToStack<16, 16>>
]>;
@@ -255,14 +262,14 @@ def CC_AArch64_DarwinPCS_VarArg : CallingConv<[
// Handle all scalar types as either i64 or f64.
CCIfType<[i8, i16, i32], CCPromoteToType<i64>>,
- CCIfType<[f16, f32], CCPromoteToType<f64>>,
+ CCIfType<[f16, bf16, f32], CCPromoteToType<f64>>,
// Everything is on the stack.
// i128 is split to two i64s, and its stack alignment is 16 bytes.
CCIfType<[i64], CCIfSplit<CCAssignToStack<8, 16>>>,
- CCIfType<[i64, f64, v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16],
+ CCIfType<[i64, f64, v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16, v4bf16],
CCAssignToStack<8, 8>>,
- CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+ CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
CCAssignToStack<16, 16>>
]>;
@@ -275,16 +282,16 @@ def CC_AArch64_DarwinPCS_ILP32_VarArg : CallingConv<[
// Handle all scalar types as either i32 or f32.
CCIfType<[i8, i16], CCPromoteToType<i32>>,
- CCIfType<[f16], CCPromoteToType<f32>>,
+ CCIfType<[f16, bf16], CCPromoteToType<f32>>,
// Everything is on the stack.
// i128 is split to two i64s, and its stack alignment is 16 bytes.
CCIfPtr<CCIfILP32<CCTruncToType<i32>>>,
CCIfType<[i32, f32], CCAssignToStack<4, 4>>,
CCIfType<[i64], CCIfSplit<CCAssignToStack<8, 16>>>,
- CCIfType<[i64, f64, v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16],
+ CCIfType<[i64, f64, v1i64, v2i32, v4i16, v8i8, v1f64, v2f32, v4f16, v4bf16],
CCAssignToStack<8, 8>>,
- CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16],
+ CCIfType<[v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16, v8bf16],
CCAssignToStack<16, 16>>
]>;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5eb9b7463411..187f133669e6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -132,6 +132,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (Subtarget->hasFPARMv8()) {
addRegisterClass(MVT::f16, &AArch64::FPR16RegClass);
+ addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass);
addRegisterClass(MVT::f32, &AArch64::FPR32RegClass);
addRegisterClass(MVT::f64, &AArch64::FPR64RegClass);
addRegisterClass(MVT::f128, &AArch64::FPR128RegClass);
@@ -148,6 +149,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addDRTypeForNEON(MVT::v1i64);
addDRTypeForNEON(MVT::v1f64);
addDRTypeForNEON(MVT::v4f16);
+ addDRTypeForNEON(MVT::v4bf16);
addQRTypeForNEON(MVT::v4f32);
addQRTypeForNEON(MVT::v2f64);
@@ -156,6 +158,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addQRTypeForNEON(MVT::v4i32);
addQRTypeForNEON(MVT::v2i64);
addQRTypeForNEON(MVT::v8f16);
+ addQRTypeForNEON(MVT::v8bf16);
}
if (Subtarget->hasSVE()) {
@@ -174,6 +177,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::nxv2f16, &AArch64::ZPRRegClass);
addRegisterClass(MVT::nxv4f16, &AArch64::ZPRRegClass);
addRegisterClass(MVT::nxv8f16, &AArch64::ZPRRegClass);
+ addRegisterClass(MVT::nxv2bf16, &AArch64::ZPRRegClass);
+ addRegisterClass(MVT::nxv4bf16, &AArch64::ZPRRegClass);
+ addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass);
addRegisterClass(MVT::nxv2f32, &AArch64::ZPRRegClass);
addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass);
addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass);
@@ -3578,6 +3584,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
RC = &AArch64::GPR64RegClass;
else if (RegVT == MVT::f16)
RC = &AArch64::FPR16RegClass;
+ else if (RegVT == MVT::bf16)
+ RC = &AArch64::FPR16RegClass;
else if (RegVT == MVT::f32)
RC = &AArch64::FPR32RegClass;
else if (RegVT == MVT::f64 || RegVT.is64BitVector())
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index a06394a2898d..713bf0bf3cad 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -4447,14 +4447,14 @@ multiclass FPToIntegerUnscaled<bits<2> rmode, bits<3> opcode, string asm,
SDPatternOperator OpN> {
// Unscaled half-precision to 32-bit
def UWHr : BaseFPToIntegerUnscaled<0b11, rmode, opcode, FPR16, GPR32, asm,
- [(set GPR32:$Rd, (OpN FPR16:$Rn))]> {
+ [(set GPR32:$Rd, (OpN (f16 FPR16:$Rn)))]> {
let Inst{31} = 0; // 32-bit GPR flag
let Predicates = [HasFullFP16];
}
// Unscaled half-precision to 64-bit
def UXHr : BaseFPToIntegerUnscaled<0b11, rmode, opcode, FPR16, GPR64, asm,
- [(set GPR64:$Rd, (OpN FPR16:$Rn))]> {
+ [(set GPR64:$Rd, (OpN (f16 FPR16:$Rn)))]> {
let Inst{31} = 1; // 64-bit GPR flag
let Predicates = [HasFullFP16];
}
@@ -4489,7 +4489,7 @@ multiclass FPToIntegerScaled<bits<2> rmode, bits<3> opcode, string asm,
// Scaled half-precision to 32-bit
def SWHri : BaseFPToInteger<0b11, rmode, opcode, FPR16, GPR32,
fixedpoint_f16_i32, asm,
- [(set GPR32:$Rd, (OpN (fmul FPR16:$Rn,
+ [(set GPR32:$Rd, (OpN (fmul (f16 FPR16:$Rn),
fixedpoint_f16_i32:$scale)))]> {
let Inst{31} = 0; // 32-bit GPR flag
let scale{5} = 1;
@@ -4499,7 +4499,7 @@ multiclass FPToIntegerScaled<bits<2> rmode, bits<3> opcode, string asm,
// Scaled half-precision to 64-bit
def SXHri : BaseFPToInteger<0b11, rmode, opcode, FPR16, GPR64,
fixedpoint_f16_i64, asm,
- [(set GPR64:$Rd, (OpN (fmul FPR16:$Rn,
+ [(set GPR64:$Rd, (OpN (fmul (f16 FPR16:$Rn),
fixedpoint_f16_i64:$scale)))]> {
let Inst{31} = 1; // 64-bit GPR flag
let Predicates = [HasFullFP16];
@@ -4615,7 +4615,7 @@ multiclass IntegerToFP<bit isUnsigned, string asm, SDNode node> {
// Scaled
def SWHri: BaseIntegerToFP<isUnsigned, GPR32, FPR16, fixedpoint_f16_i32, asm,
- [(set FPR16:$Rd,
+ [(set (f16 FPR16:$Rd),
(fdiv (node GPR32:$Rn),
fixedpoint_f16_i32:$scale))]> {
let Inst{31} = 0; // 32-bit GPR flag
@@ -4643,7 +4643,7 @@ multiclass IntegerToFP<bit isUnsigned, string asm, SDNode node> {
}
def SXHri: BaseIntegerToFP<isUnsigned, GPR64, FPR16, fixedpoint_f16_i64, asm,
- [(set FPR16:$Rd,
+ [(set (f16 FPR16:$Rd),
(fdiv (node GPR64:$Rn),
fixedpoint_f16_i64:$scale))]> {
let Inst{31} = 1; // 64-bit GPR flag
@@ -4816,7 +4816,7 @@ class BaseFPConversion<bits<2> type, bits<2> opcode, RegisterClass dstType,
multiclass FPConversion<string asm> {
// Double-precision to Half-precision
def HDr : BaseFPConversion<0b01, 0b11, FPR16, FPR64, asm,
- [(set FPR16:$Rd, (any_fpround FPR64:$Rn))]>;
+ [(set (f16 FPR16:$Rd), (any_fpround FPR64:$Rn))]>;
// Double-precision to Single-precision
def SDr : BaseFPConversion<0b01, 0b00, FPR32, FPR64, asm,
@@ -4824,11 +4824,11 @@ multiclass FPConversion<string asm> {
// Half-precision to Double-precision
def DHr : BaseFPConversion<0b11, 0b01, FPR64, FPR16, asm,
- [(set FPR64:$Rd, (fpextend FPR16:$Rn))]>;
+ [(set FPR64:$Rd, (fpextend (f16 FPR16:$Rn)))]>;
// Half-precision to Single-precision
def SHr : BaseFPConversion<0b11, 0b00, FPR32, FPR16, asm,
- [(set FPR32:$Rd, (fpextend FPR16:$Rn))]>;
+ [(set FPR32:$Rd, (fpextend (f16 FPR16:$Rn)))]>;
// Single-precision to Double-precision
def DSr : BaseFPConversion<0b00, 0b01, FPR64, FPR32, asm,
@@ -4836,7 +4836,7 @@ multiclass FPConversion<string asm> {
// Single-precision to Half-precision
def HSr : BaseFPConversion<0b00, 0b11, FPR16, FPR32, asm,
- [(set FPR16:$Rd, (any_fpround FPR32:$Rn))]>;
+ [(set (f16 FPR16:$Rd), (any_fpround FPR32:$Rn))]>;
}
//---
@@ -4938,7 +4938,7 @@ multiclass TwoOperandFPData<bits<4> opcode, string asm,
multiclass TwoOperandFPDataNeg<bits<4> opcode, string asm, SDNode node> {
def Hrr : BaseTwoOperandFPData<opcode, FPR16, asm,
- [(set FPR16:$Rd, (fneg (node FPR16:$Rn, (f16 FPR16:$Rm))))]> {
+ [(set (f16 FPR16:$Rd), (fneg (node (f16 FPR16:$Rn), (f16 FPR16:$Rm))))]> {
let Inst{23-22} = 0b11; // 16-bit size flag
let Predicates = [HasFullFP16];
}
@@ -4980,7 +4980,7 @@ class BaseThreeOperandFPData<bit isNegated, bit isSub,
multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
SDPatternOperator node> {
def Hrrr : BaseThreeOperandFPData<isNegated, isSub, FPR16, asm,
- [(set FPR16:$Rd,
+ [(set (f16 FPR16:$Rd),
(node (f16 FPR16:$Rn), (f16 FPR16:$Rm), (f16 FPR16:$Ra)))]> {
let Inst{23-22} = 0b11; // 16-bit size flag
let Predicates = [HasFullFP16];
@@ -5042,7 +5042,7 @@ multiclass FPComparison<bit signalAllNans, string asm,
SDPatternOperator OpNode = null_frag> {
let Defs = [NZCV] in {
def Hrr : BaseTwoOperandFPComparison<signalAllNans, FPR16, asm,
- [(OpNode FPR16:$Rn, (f16 FPR16:$Rm)), (implicit NZCV)]> {
+ [(OpNode (f16 FPR16:$Rn), (f16 FPR16:$Rm)), (implicit NZCV)]> {
let Inst{23-22} = 0b11;
let Predicates = [HasFullFP16];
}
@@ -6742,7 +6742,7 @@ multiclass SIMDFPThreeScalar<bit U, bit S, bits<3> opc, string asm,
[(set FPR32:$Rd, (OpNode FPR32:$Rn, FPR32:$Rm))]>;
let Predicates = [HasNEON, HasFullFP16] in {
def NAME#16 : BaseSIMDThreeScalar<U, {S,0b10}, {0b00,opc}, FPR16, asm,
- [(set FPR16:$Rd, (OpNode FPR16:$Rn, FPR16:$Rm))]>;
+ [(set (f16 FPR16:$Rd), (OpNode (f16 FPR16:$Rn), (f16 FPR16:$Rm)))]>;
} // Predicates = [HasNEON, HasFullFP16]
}
@@ -6949,7 +6949,7 @@ multiclass SIMDFPTwoScalarCVT<bit U, bit S, bits<5> opc, string asm,
[(set FPR32:$Rd, (OpNode (f32 FPR32:$Rn)))]>;
let Predicates = [HasNEON, HasFullFP16] in {
def v1i16 : BaseSIMDTwoScalar<U, {S,1}, 0b11, opc, FPR16, FPR16, asm,
- [(set FPR16:$Rd, (OpNode (f16 FPR16:$Rn)))]>;
+ [(set (f16 FPR16:$Rd), (OpNode (f16 FPR16:$Rn)))]>;
}
}
@@ -7091,10 +7091,10 @@ multiclass SIMDFPAcrossLanes<bits<5> opcode, bit sz1, string asm,
let Predicates = [HasNEON, HasFullFP16] in {
def v4i16v : BaseSIMDAcrossLanes<0, 0, {sz1, 0}, opcode, FPR16, V64,
asm, ".4h",
- [(set FPR16:$Rd, (intOp (v4f16 V64:$Rn)))]>;
+ [(set (f16 FPR16:$Rd), (intOp (v4f16 V64:$Rn)))]>;
def v8i16v : BaseSIMDAcrossLanes<1, 0, {sz1, 0}, opcode, FPR16, V128,
asm, ".8h",
- [(set FPR16:$Rd, (intOp (v8f16 V128:$Rn)))]>;
+ [(set (f16 FPR16:$Rd), (intOp (v8f16 V128:$Rn)))]>;
} // Predicates = [HasNEON, HasFullFP16]
def v4i32v : BaseSIMDAcrossLanes<1, 1, {sz1, 0}, opcode, FPR32, V128,
asm, ".4s",
@@ -8095,7 +8095,7 @@ multiclass SIMDFPIndexedTiedPatterns<string INST, SDPatternOperator OpNode> {
def : Pat<(v8f16 (OpNode (v8f16 V128:$Rd), (v8f16 V128:$Rn),
(AArch64dup (f16 FPR16Op_lo:$Rm)))),
(!cast<Instruction>(INST # "v8i16_indexed") V128:$Rd, V128:$Rn,
- (SUBREG_TO_REG (i32 0), FPR16Op_lo:$Rm, hsub), (i64 0))>;
+ (SUBREG_TO_REG (i32 0), (f16 FPR16Op_lo:$Rm), hsub), (i64 0))>;
def : Pat<(v4f16 (OpNode (v4f16 V64:$Rd), (v4f16 V64:$Rn),
(AArch64duplane16 (v8f16 V128_lo:$Rm),
@@ -8105,7 +8105,7 @@ multiclass SIMDFPIndexedTiedPatterns<string INST, SDPatternOperator OpNode> {
def : Pat<(v4f16 (OpNode (v4f16 V64:$Rd), (v4f16 V64:$Rn),
(AArch64dup (f16 FPR16Op_lo:$Rm)))),
(!cast<Instruction>(INST # "v4i16_indexed") V64:$Rd, V64:$Rn,
- (SUBREG_TO_REG (i32 0), FPR16Op_lo:$Rm, hsub), (i64 0))>;
+ (SUBREG_TO_REG (i32 0), (f16 FPR16Op_lo:$Rm), hsub), (i64 0))>;
def : Pat<(f16 (OpNode (f16 FPR16:$Rd), (f16 FPR16:$Rn),
(vector_extract (v8f16 V128_lo:$Rm), VectorIndexH:$idx))),
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 03ac81e2462b..07bca441529e 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -2529,7 +2529,7 @@ defm LDURB : LoadUnscaled<0b00, 1, 0b01, FPR8Op, "ldur",
[(set FPR8Op:$Rt,
(load (am_unscaled8 GPR64sp:$Rn, simm9:$offset)))]>;
defm LDURH : LoadUnscaled<0b01, 1, 0b01, FPR16Op, "ldur",
- [(set FPR16Op:$Rt,
+ [(set (f16 FPR16Op:$Rt),
(load (am_unscaled16 GPR64sp:$Rn, simm9:$offset)))]>;
defm LDURS : LoadUnscaled<0b10, 1, 0b01, FPR32Op, "ldur",
[(set (f32 FPR32Op:$Rt),
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index 93b6aa0cdb7f..bd05c56009a1 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -422,18 +422,20 @@ def Q31 : AArch64Reg<31, "q31", [D31], ["v31", ""]>, DwarfRegAlias<B31>;
def FPR8 : RegisterClass<"AArch64", [untyped], 8, (sequence "B%u", 0, 31)> {
let Size = 8;
}
-def FPR16 : RegisterClass<"AArch64", [f16], 16, (sequence "H%u", 0, 31)> {
+def FPR16 : RegisterClass<"AArch64", [f16, bf16], 16, (sequence "H%u", 0, 31)> {
let Size = 16;
}
+
def FPR16_lo : RegisterClass<"AArch64", [f16], 16, (trunc FPR16, 16)> {
let Size = 16;
}
def FPR32 : RegisterClass<"AArch64", [f32, i32], 32,(sequence "S%u", 0, 31)>;
def FPR64 : RegisterClass<"AArch64", [f64, i64, v2f32, v1f64, v8i8, v4i16, v2i32,
- v1i64, v4f16],
- 64, (sequence "D%u", 0, 31)>;
+ v1i64, v4f16, v4bf16],
+ 64, (sequence "D%u", 0, 31)>;
def FPR64_lo : RegisterClass<"AArch64",
- [v8i8, v4i16, v2i32, v1i64, v4f16, v2f32, v1f64],
+ [v8i8, v4i16, v2i32, v1i64, v4f16, v4bf16, v2f32,
+ v1f64],
64, (trunc FPR64, 16)>;
// We don't (yet) have an f128 legal type, so don't use that here. We
@@ -441,13 +443,14 @@ def FPR64_lo : RegisterClass<"AArch64",
// that here.
def FPR128 : RegisterClass<"AArch64",
[v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, f128,
- v8f16],
+ v8f16, v8bf16],
128, (sequence "Q%u", 0, 31)>;
// The lower 16 vector registers. Some instructions can only take registers
// in this range.
def FPR128_lo : RegisterClass<"AArch64",
- [v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16],
+ [v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16,
+ v8bf16],
128, (trunc FPR128, 16)>;
// Pairs, triples, and quads of 64-bit vector registers.
@@ -876,6 +879,7 @@ def PPR3b64 : PPRRegOp<"d", PPRAsmOp3b64, ElementSizeD, PPR_3b>;
class ZPRClass<int lastreg> : RegisterClass<"AArch64",
[nxv16i8, nxv8i16, nxv4i32, nxv2i64,
nxv2f16, nxv4f16, nxv8f16,
+ nxv2bf16, nxv4bf16, nxv8bf16,
nxv2f32, nxv4f32,
nxv2f64],
128, (sequence "Z%u", 0, lastreg)> {
More information about the llvm-commits
mailing list