[llvm] 3dd6750 - [AArch64] Add more complete support for BF16
David Majnemer via llvm-commits
llvm-commits at lists.llvm.org
Sun Mar 3 14:55:41 PST 2024
Author: David Majnemer
Date: 2024-03-03T22:39:50Z
New Revision: 3dd6750027cd168fce5fd9894b0bac0739652cf5
URL: https://github.com/llvm/llvm-project/commit/3dd6750027cd168fce5fd9894b0bac0739652cf5
DIFF: https://github.com/llvm/llvm-project/commit/3dd6750027cd168fce5fd9894b0bac0739652cf5.diff
LOG: [AArch64] Add more complete support for BF16
We can use a small amount of integer arithmetic to round FP32 to BF16
and extend BF16 to FP32.
While a number of operations still require promotion, this can be
reduced for some rather simple operations like abs, copysign, fneg but
these can be done in a follow-up.
A few neat optimizations are implemented:
- round-inexact-to-odd is used for F64 to BF16 rounding.
- quieting signaling NaNs for f32 -> bf16 tries to detect if a prior
operation makes it unnecessary.
Added:
Modified:
llvm/include/llvm/CodeGen/TargetLowering.h
llvm/include/llvm/CodeGen/ValueTypes.h
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64InstrFormats.td
llvm/lib/Target/AArch64/AArch64InstrInfo.td
llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll
llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir
llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll
llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
llvm/test/CodeGen/AArch64/round-fptosi-sat-scalar.ll
llvm/test/CodeGen/AArch64/vector-fcopysign.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index f2e00aab8d5da2..0438abc7c3061a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1573,13 +1573,14 @@ class TargetLoweringBase {
assert((VT.isInteger() || VT.isFloatingPoint()) &&
"Cannot autopromote this type, add it with AddPromotedToType.");
+ uint64_t VTBits = VT.getScalarSizeInBits();
MVT NVT = VT;
do {
NVT = (MVT::SimpleValueType)(NVT.SimpleTy+1);
assert(NVT.isInteger() == VT.isInteger() && NVT != MVT::isVoid &&
"Didn't find type to promote to!");
- } while (!isTypeLegal(NVT) ||
- getOperationAction(Op, NVT) == Promote);
+ } while (VTBits >= NVT.getScalarSizeInBits() || !isTypeLegal(NVT) ||
+ getOperationAction(Op, NVT) == Promote);
return NVT;
}
diff --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h
index 1e0356ee69ff52..b66c66d1bfc45a 100644
--- a/llvm/include/llvm/CodeGen/ValueTypes.h
+++ b/llvm/include/llvm/CodeGen/ValueTypes.h
@@ -107,6 +107,13 @@ namespace llvm {
return changeExtendedVectorElementType(EltVT);
}
+ /// Return a VT for a type whose attributes match ourselves with the
+ /// exception of the element type that is chosen by the caller.
+ EVT changeElementType(EVT EltVT) const {
+ EltVT = EltVT.getScalarType();
+ return isVector() ? changeVectorElementType(EltVT) : EltVT;
+ }
+
/// Return the type converted to an equivalently sized integer or vector
/// with integer element type. Similar to changeVectorElementTypeToInteger,
/// but also handles scalars.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7f80e877cb2406..475c73c3588dbc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -368,8 +368,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addDRTypeForNEON(MVT::v1i64);
addDRTypeForNEON(MVT::v1f64);
addDRTypeForNEON(MVT::v4f16);
- if (Subtarget->hasBF16())
- addDRTypeForNEON(MVT::v4bf16);
+ addDRTypeForNEON(MVT::v4bf16);
addQRTypeForNEON(MVT::v4f32);
addQRTypeForNEON(MVT::v2f64);
@@ -378,8 +377,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addQRTypeForNEON(MVT::v4i32);
addQRTypeForNEON(MVT::v2i64);
addQRTypeForNEON(MVT::v8f16);
- if (Subtarget->hasBF16())
- addQRTypeForNEON(MVT::v8bf16);
+ addQRTypeForNEON(MVT::v8bf16);
}
if (Subtarget->hasSVEorSME()) {
@@ -403,11 +401,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass);
addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass);
- if (Subtarget->hasBF16()) {
- addRegisterClass(MVT::nxv2bf16, &AArch64::ZPRRegClass);
- addRegisterClass(MVT::nxv4bf16, &AArch64::ZPRRegClass);
- addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass);
- }
+ addRegisterClass(MVT::nxv2bf16, &AArch64::ZPRRegClass);
+ addRegisterClass(MVT::nxv4bf16, &AArch64::ZPRRegClass);
+ addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass);
if (Subtarget->useSVEForFixedLengthVectors()) {
for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
@@ -437,9 +433,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::GlobalTLSAddress, MVT::i64, Custom);
setOperationAction(ISD::SETCC, MVT::i32, Custom);
setOperationAction(ISD::SETCC, MVT::i64, Custom);
+ setOperationAction(ISD::SETCC, MVT::bf16, Custom);
setOperationAction(ISD::SETCC, MVT::f16, Custom);
setOperationAction(ISD::SETCC, MVT::f32, Custom);
setOperationAction(ISD::SETCC, MVT::f64, Custom);
+ setOperationAction(ISD::STRICT_FSETCC, MVT::bf16, Custom);
setOperationAction(ISD::STRICT_FSETCC, MVT::f16, Custom);
setOperationAction(ISD::STRICT_FSETCC, MVT::f32, Custom);
setOperationAction(ISD::STRICT_FSETCC, MVT::f64, Custom);
@@ -463,7 +461,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
- setOperationAction(ISD::SELECT_CC, MVT::bf16, Expand);
+ setOperationAction(ISD::SELECT_CC, MVT::bf16, Custom);
setOperationAction(ISD::SELECT_CC, MVT::f32, Custom);
setOperationAction(ISD::SELECT_CC, MVT::f64, Custom);
setOperationAction(ISD::BR_JT, MVT::Other, Custom);
@@ -539,12 +537,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i32, Custom);
setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i64, Custom);
setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i128, Custom);
- if (Subtarget->hasFPARMv8())
+ if (Subtarget->hasFPARMv8()) {
setOperationAction(ISD::FP_ROUND, MVT::f16, Custom);
+ setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
+ }
setOperationAction(ISD::FP_ROUND, MVT::f32, Custom);
setOperationAction(ISD::FP_ROUND, MVT::f64, Custom);
- if (Subtarget->hasFPARMv8())
+ if (Subtarget->hasFPARMv8()) {
setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Custom);
+ setOperationAction(ISD::STRICT_FP_ROUND, MVT::bf16, Custom);
+ }
setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Custom);
setOperationAction(ISD::STRICT_FP_ROUND, MVT::f64, Custom);
@@ -678,6 +680,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Custom);
else
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Promote);
+ setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
@@ -690,9 +693,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(Op, MVT::f16, Promote);
setOperationAction(Op, MVT::v4f16, Expand);
setOperationAction(Op, MVT::v8f16, Expand);
+ setOperationAction(Op, MVT::bf16, Promote);
+ setOperationAction(Op, MVT::v4bf16, Expand);
+ setOperationAction(Op, MVT::v8bf16, Expand);
}
- if (!Subtarget->hasFullFP16()) {
+ auto LegalizeNarrowFP = [this](MVT ScalarVT) {
for (auto Op :
{ISD::SETCC, ISD::SELECT_CC,
ISD::BR_CC, ISD::FADD, ISD::FSUB,
@@ -708,60 +714,69 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::STRICT_FROUND, ISD::STRICT_FTRUNC, ISD::STRICT_FROUNDEVEN,
ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM, ISD::STRICT_FMINIMUM,
ISD::STRICT_FMAXIMUM})
- setOperationAction(Op, MVT::f16, Promote);
+ setOperationAction(Op, ScalarVT, Promote);
// Round-to-integer need custom lowering for fp16, as Promote doesn't work
// because the result type is integer.
for (auto Op : {ISD::LROUND, ISD::LLROUND, ISD::LRINT, ISD::LLRINT,
ISD::STRICT_LROUND, ISD::STRICT_LLROUND, ISD::STRICT_LRINT,
ISD::STRICT_LLRINT})
- setOperationAction(Op, MVT::f16, Custom);
+ setOperationAction(Op, ScalarVT, Custom);
// promote v4f16 to v4f32 when that is known to be safe.
- setOperationPromotedToType(ISD::FADD, MVT::v4f16, MVT::v4f32);
- setOperationPromotedToType(ISD::FSUB, MVT::v4f16, MVT::v4f32);
- setOperationPromotedToType(ISD::FMUL, MVT::v4f16, MVT::v4f32);
- setOperationPromotedToType(ISD::FDIV, MVT::v4f16, MVT::v4f32);
-
- setOperationAction(ISD::FABS, MVT::v4f16, Expand);
- setOperationAction(ISD::FNEG, MVT::v4f16, Expand);
- setOperationAction(ISD::FROUND, MVT::v4f16, Expand);
- setOperationAction(ISD::FROUNDEVEN, MVT::v4f16, Expand);
- setOperationAction(ISD::FMA, MVT::v4f16, Expand);
- setOperationAction(ISD::SETCC, MVT::v4f16, Custom);
- setOperationAction(ISD::BR_CC, MVT::v4f16, Expand);
- setOperationAction(ISD::SELECT, MVT::v4f16, Expand);
- setOperationAction(ISD::SELECT_CC, MVT::v4f16, Expand);
- setOperationAction(ISD::FTRUNC, MVT::v4f16, Expand);
- setOperationAction(ISD::FCOPYSIGN, MVT::v4f16, Expand);
- setOperationAction(ISD::FFLOOR, MVT::v4f16, Expand);
- setOperationAction(ISD::FCEIL, MVT::v4f16, Expand);
- setOperationAction(ISD::FRINT, MVT::v4f16, Expand);
- setOperationAction(ISD::FNEARBYINT, MVT::v4f16, Expand);
- setOperationAction(ISD::FSQRT, MVT::v4f16, Expand);
-
- setOperationAction(ISD::FABS, MVT::v8f16, Expand);
- setOperationAction(ISD::FADD, MVT::v8f16, Expand);
- setOperationAction(ISD::FCEIL, MVT::v8f16, Expand);
- setOperationAction(ISD::FCOPYSIGN, MVT::v8f16, Expand);
- setOperationAction(ISD::FDIV, MVT::v8f16, Expand);
- setOperationAction(ISD::FFLOOR, MVT::v8f16, Expand);
- setOperationAction(ISD::FMA, MVT::v8f16, Expand);
- setOperationAction(ISD::FMUL, MVT::v8f16, Expand);
- setOperationAction(ISD::FNEARBYINT, MVT::v8f16, Expand);
- setOperationAction(ISD::FNEG, MVT::v8f16, Expand);
- setOperationAction(ISD::FROUND, MVT::v8f16, Expand);
- setOperationAction(ISD::FROUNDEVEN, MVT::v8f16, Expand);
- setOperationAction(ISD::FRINT, MVT::v8f16, Expand);
- setOperationAction(ISD::FSQRT, MVT::v8f16, Expand);
- setOperationAction(ISD::FSUB, MVT::v8f16, Expand);
- setOperationAction(ISD::FTRUNC, MVT::v8f16, Expand);
- setOperationAction(ISD::SETCC, MVT::v8f16, Expand);
- setOperationAction(ISD::BR_CC, MVT::v8f16, Expand);
- setOperationAction(ISD::SELECT, MVT::v8f16, Expand);
- setOperationAction(ISD::SELECT_CC, MVT::v8f16, Expand);
- setOperationAction(ISD::FP_EXTEND, MVT::v8f16, Expand);
+ auto V4Narrow = MVT::getVectorVT(ScalarVT, 4);
+ setOperationPromotedToType(ISD::FADD, V4Narrow, MVT::v4f32);
+ setOperationPromotedToType(ISD::FSUB, V4Narrow, MVT::v4f32);
+ setOperationPromotedToType(ISD::FMUL, V4Narrow, MVT::v4f32);
+ setOperationPromotedToType(ISD::FDIV, V4Narrow, MVT::v4f32);
+
+ setOperationAction(ISD::FABS, V4Narrow, Expand);
+ setOperationAction(ISD::FNEG, V4Narrow, Expand);
+ setOperationAction(ISD::FROUND, V4Narrow, Expand);
+ setOperationAction(ISD::FROUNDEVEN, V4Narrow, Expand);
+ setOperationAction(ISD::FMA, V4Narrow, Expand);
+ setOperationAction(ISD::SETCC, V4Narrow, Custom);
+ setOperationAction(ISD::BR_CC, V4Narrow, Expand);
+ setOperationAction(ISD::SELECT, V4Narrow, Expand);
+ setOperationAction(ISD::SELECT_CC, V4Narrow, Expand);
+ setOperationAction(ISD::FTRUNC, V4Narrow, Expand);
+ setOperationAction(ISD::FCOPYSIGN, V4Narrow, Expand);
+ setOperationAction(ISD::FFLOOR, V4Narrow, Expand);
+ setOperationAction(ISD::FCEIL, V4Narrow, Expand);
+ setOperationAction(ISD::FRINT, V4Narrow, Expand);
+ setOperationAction(ISD::FNEARBYINT, V4Narrow, Expand);
+ setOperationAction(ISD::FSQRT, V4Narrow, Expand);
+
+ auto V8Narrow = MVT::getVectorVT(ScalarVT, 8);
+ setOperationAction(ISD::FABS, V8Narrow, Expand);
+ setOperationAction(ISD::FADD, V8Narrow, Expand);
+ setOperationAction(ISD::FCEIL, V8Narrow, Expand);
+ setOperationAction(ISD::FCOPYSIGN, V8Narrow, Expand);
+ setOperationAction(ISD::FDIV, V8Narrow, Expand);
+ setOperationAction(ISD::FFLOOR, V8Narrow, Expand);
+ setOperationAction(ISD::FMA, V8Narrow, Expand);
+ setOperationAction(ISD::FMUL, V8Narrow, Expand);
+ setOperationAction(ISD::FNEARBYINT, V8Narrow, Expand);
+ setOperationAction(ISD::FNEG, V8Narrow, Expand);
+ setOperationAction(ISD::FROUND, V8Narrow, Expand);
+ setOperationAction(ISD::FROUNDEVEN, V8Narrow, Expand);
+ setOperationAction(ISD::FRINT, V8Narrow, Expand);
+ setOperationAction(ISD::FSQRT, V8Narrow, Expand);
+ setOperationAction(ISD::FSUB, V8Narrow, Expand);
+ setOperationAction(ISD::FTRUNC, V8Narrow, Expand);
+ setOperationAction(ISD::SETCC, V8Narrow, Expand);
+ setOperationAction(ISD::BR_CC, V8Narrow, Expand);
+ setOperationAction(ISD::SELECT, V8Narrow, Expand);
+ setOperationAction(ISD::SELECT_CC, V8Narrow, Expand);
+ setOperationAction(ISD::FP_EXTEND, V8Narrow, Expand);
+ };
+
+ if (!Subtarget->hasFullFP16()) {
+ LegalizeNarrowFP(MVT::f16);
}
+ LegalizeNarrowFP(MVT::bf16);
+ setOperationAction(ISD::FP_ROUND, MVT::v4f32, Custom);
+ setOperationAction(ISD::FP_ROUND, MVT::v4bf16, Custom);
// AArch64 has implementations of a lot of rounding-like FP operations.
for (auto Op :
@@ -886,6 +901,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::STORE, MVT::v32i8, Custom);
setOperationAction(ISD::STORE, MVT::v16i16, Custom);
setOperationAction(ISD::STORE, MVT::v16f16, Custom);
+ setOperationAction(ISD::STORE, MVT::v16bf16, Custom);
setOperationAction(ISD::STORE, MVT::v8i32, Custom);
setOperationAction(ISD::STORE, MVT::v8f32, Custom);
setOperationAction(ISD::STORE, MVT::v4f64, Custom);
@@ -897,6 +913,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::LOAD, MVT::v32i8, Custom);
setOperationAction(ISD::LOAD, MVT::v16i16, Custom);
setOperationAction(ISD::LOAD, MVT::v16f16, Custom);
+ setOperationAction(ISD::LOAD, MVT::v16bf16, Custom);
setOperationAction(ISD::LOAD, MVT::v8i32, Custom);
setOperationAction(ISD::LOAD, MVT::v8f32, Custom);
setOperationAction(ISD::LOAD, MVT::v4f64, Custom);
@@ -931,6 +948,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// AArch64 does not have floating-point extending loads, i1 sign-extending
// load, floating-point truncating stores, or v2i32->v2i16 truncating store.
for (MVT VT : MVT::fp_valuetypes()) {
+ setLoadExtAction(ISD::EXTLOAD, VT, MVT::bf16, Expand);
setLoadExtAction(ISD::EXTLOAD, VT, MVT::f16, Expand);
setLoadExtAction(ISD::EXTLOAD, VT, MVT::f32, Expand);
setLoadExtAction(ISD::EXTLOAD, VT, MVT::f64, Expand);
@@ -939,13 +957,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (MVT VT : MVT::integer_valuetypes())
setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Expand);
- setTruncStoreAction(MVT::f32, MVT::f16, Expand);
- setTruncStoreAction(MVT::f64, MVT::f32, Expand);
- setTruncStoreAction(MVT::f64, MVT::f16, Expand);
- setTruncStoreAction(MVT::f128, MVT::f80, Expand);
- setTruncStoreAction(MVT::f128, MVT::f64, Expand);
- setTruncStoreAction(MVT::f128, MVT::f32, Expand);
- setTruncStoreAction(MVT::f128, MVT::f16, Expand);
+ for (MVT WideVT : MVT::fp_valuetypes()) {
+ for (MVT NarrowVT : MVT::fp_valuetypes()) {
+ if (WideVT.getScalarSizeInBits() > NarrowVT.getScalarSizeInBits()) {
+ setTruncStoreAction(WideVT, NarrowVT, Expand);
+ }
+ }
+ }
if (Subtarget->hasFPARMv8()) {
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
@@ -1553,7 +1571,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
}
- if (!Subtarget->isNeonAvailable()) {
+ if (!Subtarget->isNeonAvailable()) {// TODO(majnemer)
setTruncStoreAction(MVT::v2f32, MVT::v2f16, Custom);
setTruncStoreAction(MVT::v4f32, MVT::v4f16, Custom);
setTruncStoreAction(MVT::v8f32, MVT::v8f16, Custom);
@@ -1652,6 +1670,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FLDEXP, MVT::f64, Custom);
setOperationAction(ISD::FLDEXP, MVT::f32, Custom);
setOperationAction(ISD::FLDEXP, MVT::f16, Custom);
+ setOperationAction(ISD::FLDEXP, MVT::bf16, Custom);
}
PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive();
@@ -2451,6 +2470,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::FCMP)
MAKE_CASE(AArch64ISD::STRICT_FCMP)
MAKE_CASE(AArch64ISD::STRICT_FCMPE)
+ MAKE_CASE(AArch64ISD::FCVTXN)
MAKE_CASE(AArch64ISD::SME_ZA_LDR)
MAKE_CASE(AArch64ISD::SME_ZA_STR)
MAKE_CASE(AArch64ISD::DUP)
@@ -3181,7 +3201,7 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &dl,
const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
- if (VT == MVT::f16 && !FullFP16) {
+ if ((VT == MVT::f16 && !FullFP16) || VT == MVT::bf16) {
LHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
{Chain, LHS});
RHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
@@ -3201,7 +3221,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
if (VT.isFloatingPoint()) {
assert(VT != MVT::f128);
- if (VT == MVT::f16 && !FullFP16) {
+ if ((VT == MVT::f16 && !FullFP16) || VT == MVT::bf16) {
LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
VT = MVT::f32;
@@ -3309,7 +3329,8 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
if (LHS.getValueType().isFloatingPoint()) {
assert(LHS.getValueType() != MVT::f128);
- if (LHS.getValueType() == MVT::f16 && !FullFP16) {
+ if ((LHS.getValueType() == MVT::f16 && !FullFP16) ||
+ LHS.getValueType() == MVT::bf16) {
LHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, LHS);
RHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, RHS);
}
@@ -4009,16 +4030,73 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
SelectionDAG &DAG) const {
- if (Op.getValueType().isScalableVector())
+ EVT VT = Op.getValueType();
+ if (VT.isScalableVector())
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
bool IsStrict = Op->isStrictFPOpcode();
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
EVT SrcVT = SrcVal.getValueType();
+ bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;
if (useSVEForFixedLengthVectorVT(SrcVT, !Subtarget->isNeonAvailable()))
return LowerFixedLengthFPRoundToSVE(Op, DAG);
+ // Expand cases where the result type is BF16 but we don't have hardware
+ // instructions to lower it.
+ if (VT.getScalarType() == MVT::bf16 &&
+ !((Subtarget->hasNEON() || Subtarget->hasSME()) &&
+ Subtarget->hasBF16())) {
+ SDLoc dl(Op);
+ SDValue Narrow = SrcVal;
+ SDValue NaN;
+ EVT I32 = SrcVT.changeElementType(MVT::i32);
+ EVT F32 = SrcVT.changeElementType(MVT::f32);
+ if (SrcVT.getScalarType() == MVT::f32) {
+ bool NeverSNaN = DAG.isKnownNeverSNaN(Narrow);
+ Narrow = DAG.getNode(ISD::BITCAST, dl, I32, Narrow);
+ if (!NeverSNaN) {
+ // Set the quiet bit.
+ NaN = DAG.getNode(ISD::OR, dl, I32, Narrow,
+ DAG.getConstant(0x400000, dl, I32));
+ }
+ } else if (SrcVT.getScalarType() == MVT::f64) {
+ Narrow = DAG.getNode(AArch64ISD::FCVTXN, dl, F32, Narrow);
+ Narrow = DAG.getNode(ISD::BITCAST, dl, I32, Narrow);
+ } else {
+ return SDValue();
+ }
+ if (!Trunc) {
+ SDValue One = DAG.getConstant(1, dl, I32);
+ SDValue Lsb = DAG.getNode(ISD::SRL, dl, I32, Narrow,
+ DAG.getShiftAmountConstant(16, I32, dl));
+ Lsb = DAG.getNode(ISD::AND, dl, I32, Lsb, One);
+ SDValue RoundingBias =
+ DAG.getNode(ISD::ADD, dl, I32, DAG.getConstant(0x7fff, dl, I32), Lsb);
+ Narrow = DAG.getNode(ISD::ADD, dl, I32, Narrow, RoundingBias);
+ }
+
+ // Don't round if we had a NaN, we don't want to turn 0x7fffffff into
+ // 0x80000000.
+ if (NaN) {
+ SDValue IsNaN = DAG.getSetCC(
+ dl, getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT),
+ SrcVal, SrcVal, ISD::SETUO);
+ Narrow = DAG.getSelect(dl, I32, IsNaN, NaN, Narrow);
+ }
+
+ // Now that we have rounded, shift the bits into position.
+ Narrow = DAG.getNode(ISD::SRL, dl, I32, Narrow,
+ DAG.getShiftAmountConstant(16, I32, dl));
+ if (VT.isVector()) {
+ EVT I16 = I32.changeVectorElementType(MVT::i16);
+ Narrow = DAG.getNode(ISD::TRUNCATE, dl, I16, Narrow);
+ return DAG.getNode(ISD::BITCAST, dl, VT, Narrow);
+ }
+ Narrow = DAG.getNode(ISD::BITCAST, dl, F32, Narrow);
+ return DAG.getTargetExtractSubreg(AArch64::hsub, dl, VT, Narrow);
+ }
+
if (SrcVT != MVT::f128) {
// Expand cases where the input is a vector bigger than NEON.
if (useSVEForFixedLengthVectorVT(SrcVT))
@@ -4054,8 +4132,8 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
unsigned NumElts = InVT.getVectorNumElements();
// f16 conversions are promoted to f32 when full fp16 is not supported.
- if (InVT.getVectorElementType() == MVT::f16 &&
- !Subtarget->hasFullFP16()) {
+ if ((InVT.getVectorElementType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
+ InVT.getVectorElementType() == MVT::bf16) {
MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
SDLoc dl(Op);
if (IsStrict) {
@@ -4128,7 +4206,8 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
return LowerVectorFP_TO_INT(Op, DAG);
// f16 conversions are promoted to f32 when full fp16 is not supported.
- if (SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
+ if ((SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
+ SrcVal.getValueType() == MVT::bf16) {
SDLoc dl(Op);
if (IsStrict) {
SDValue Ext =
@@ -4175,15 +4254,16 @@ AArch64TargetLowering::LowerVectorFP_TO_INT_SAT(SDValue Op,
EVT SrcElementVT = SrcVT.getVectorElementType();
// In the absence of FP16 support, promote f16 to f32 and saturate the result.
- if (SrcElementVT == MVT::f16 &&
- (!Subtarget->hasFullFP16() || DstElementWidth > 16)) {
+ if ((SrcElementVT == MVT::f16 &&
+ (!Subtarget->hasFullFP16() || DstElementWidth > 16)) ||
+ SrcElementVT == MVT::bf16) {
MVT F32VT = MVT::getVectorVT(MVT::f32, SrcVT.getVectorNumElements());
SrcVal = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), F32VT, SrcVal);
SrcVT = F32VT;
SrcElementVT = MVT::f32;
SrcElementWidth = 32;
} else if (SrcElementVT != MVT::f64 && SrcElementVT != MVT::f32 &&
- SrcElementVT != MVT::f16)
+ SrcElementVT != MVT::f16 && SrcElementVT != MVT::bf16)
return SDValue();
SDLoc DL(Op);
@@ -4236,10 +4316,11 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
assert(SatWidth <= DstWidth && "Saturation width cannot exceed result width");
// In the absence of FP16 support, promote f16 to f32 and saturate the result.
- if (SrcVT == MVT::f16 && !Subtarget->hasFullFP16()) {
+ if ((SrcVT == MVT::f16 && !Subtarget->hasFullFP16()) || SrcVT == MVT::bf16) {
SrcVal = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, SrcVal);
SrcVT = MVT::f32;
- } else if (SrcVT != MVT::f64 && SrcVT != MVT::f32 && SrcVT != MVT::f16)
+ } else if (SrcVT != MVT::f64 && SrcVT != MVT::f32 && SrcVT != MVT::f16 &&
+ SrcVT != MVT::bf16)
return SDValue();
SDLoc DL(Op);
@@ -4358,17 +4439,17 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
// f16 conversions are promoted to f32 when full fp16 is not supported.
- if (Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
+ if ((Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) || Op.getValueType() == MVT::bf16) {
SDLoc dl(Op);
if (IsStrict) {
SDValue Val = DAG.getNode(Op.getOpcode(), dl, {MVT::f32, MVT::Other},
{Op.getOperand(0), SrcVal});
return DAG.getNode(
- ISD::STRICT_FP_ROUND, dl, {MVT::f16, MVT::Other},
+ ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
{Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
}
return DAG.getNode(
- ISD::FP_ROUND, dl, MVT::f16,
+ ISD::FP_ROUND, dl, Op.getValueType(),
DAG.getNode(Op.getOpcode(), dl, MVT::f32, SrcVal),
DAG.getIntPtrConstant(0, dl));
}
@@ -6100,6 +6181,7 @@ static SDValue LowerFLDEXP(SDValue Op, SelectionDAG &DAG) {
switch (Op.getSimpleValueType().SimpleTy) {
default:
return SDValue();
+ case MVT::bf16:
case MVT::f16:
X = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, X);
[[fallthrough]];
@@ -6416,7 +6498,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::LLROUND:
case ISD::LRINT:
case ISD::LLRINT: {
- assert(Op.getOperand(0).getValueType() == MVT::f16 &&
+ assert((Op.getOperand(0).getValueType() == MVT::f16 ||
+ Op.getOperand(0).getValueType() == MVT::bf16) &&
"Expected custom lowering of rounding operations only for f16");
SDLoc DL(Op);
SDValue Ext = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0));
@@ -6426,7 +6509,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::STRICT_LLROUND:
case ISD::STRICT_LRINT:
case ISD::STRICT_LLRINT: {
- assert(Op.getOperand(1).getValueType() == MVT::f16 &&
+ assert((Op.getOperand(1).getValueType() == MVT::f16 ||
+ Op.getOperand(1).getValueType() == MVT::bf16) &&
"Expected custom lowering of rounding operations only for f16");
SDLoc DL(Op);
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {MVT::f32, MVT::Other},
@@ -9459,8 +9543,8 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
}
// Now we know we're dealing with FP values.
- assert(LHS.getValueType() == MVT::f16 || LHS.getValueType() == MVT::f32 ||
- LHS.getValueType() == MVT::f64);
+ assert(LHS.getValueType() == MVT::bf16 || LHS.getValueType() == MVT::f16 ||
+ LHS.getValueType() == MVT::f32 || LHS.getValueType() == MVT::f64);
// If that fails, we'll need to perform an FCMP + CSEL sequence. Go ahead
// and do the comparison.
@@ -9547,7 +9631,8 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
}
// Also handle f16, for which we need to do a f32 comparison.
- if (LHS.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
+ if ((LHS.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
+ LHS.getValueType() == MVT::bf16) {
LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
}
@@ -10306,6 +10391,7 @@ bool AArch64TargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
IsLegal = AArch64_AM::getFP64Imm(ImmInt) != -1 || Imm.isPosZero();
else if (VT == MVT::f32)
IsLegal = AArch64_AM::getFP32Imm(ImmInt) != -1 || Imm.isPosZero();
+ // TODO(majnemer): double check this...
else if (VT == MVT::f16 || VT == MVT::bf16)
IsLegal =
(Subtarget->hasFullFP16() && AArch64_AM::getFP16Imm(ImmInt) != -1) ||
@@ -14161,11 +14247,30 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
}
+ // Lower isnan(x) | isnan(never-nan) to x != x.
+ // Lower !isnan(x) & !isnan(never-nan) to x == x.
+ if (CC == ISD::SETUO || CC == ISD::SETO) {
+ bool OneNaN = false;
+ if (LHS == RHS) {
+ OneNaN = true;
+ } else if (DAG.isKnownNeverNaN(RHS)) {
+ OneNaN = true;
+ RHS = LHS;
+ } else if (DAG.isKnownNeverNaN(LHS)) {
+ OneNaN = true;
+ LHS = RHS;
+ }
+ if (OneNaN) {
+ CC = CC == ISD::SETUO ? ISD::SETUNE : ISD::SETOEQ;
+ }
+ }
+
const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
// Make v4f16 (only) fcmp operations utilise vector instructions
// v8f16 support will be a litle more complicated
- if (!FullFP16 && LHS.getValueType().getVectorElementType() == MVT::f16) {
+ if ((!FullFP16 && LHS.getValueType().getVectorElementType() == MVT::f16) ||
+ LHS.getValueType().getVectorElementType() == MVT::bf16) {
if (LHS.getValueType().getVectorNumElements() == 4) {
LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, LHS);
RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, RHS);
@@ -14177,7 +14282,8 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
}
assert((!FullFP16 && LHS.getValueType().getVectorElementType() != MVT::f16) ||
- LHS.getValueType().getVectorElementType() != MVT::f128);
+ LHS.getValueType().getVectorElementType() != MVT::bf16 ||
+ LHS.getValueType().getVectorElementType() != MVT::f128);
// Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
// clean. Some of them require two branches to implement.
@@ -24684,7 +24790,8 @@ static void ReplaceAddWithADDP(SDNode *N, SmallVectorImpl<SDValue> &Results,
if (!VT.is256BitVector() ||
(VT.getScalarType().isFloatingPoint() &&
!N->getFlags().hasAllowReassociation()) ||
- (VT.getScalarType() == MVT::f16 && !Subtarget->hasFullFP16()))
+ (VT.getScalarType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
+ VT.getScalarType() == MVT::bf16)
return;
SDValue X = N->getOperand(0);
@@ -25736,6 +25843,8 @@ bool AArch64TargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT,
// legalize.
if (FPVT == MVT::v8f16 && !Subtarget->hasFullFP16())
return false;
+ if (FPVT == MVT::v8bf16)
+ return false;
return TargetLowering::shouldConvertFpToSat(Op, FPVT, VT);
}
@@ -25928,6 +26037,8 @@ static EVT getContainerForFixedLengthVector(SelectionDAG &DAG, EVT VT) {
return EVT(MVT::nxv4i32);
case MVT::i64:
return EVT(MVT::nxv2i64);
+ case MVT::bf16:
+ return EVT(MVT::nxv8bf16);
case MVT::f16:
return EVT(MVT::nxv8f16);
case MVT::f32:
@@ -25967,6 +26078,7 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
break;
case MVT::i16:
case MVT::f16:
+ case MVT::bf16:
MaskVT = MVT::nxv8i1;
break;
case MVT::i32:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index c1fe76c07cba87..68341c199e0a2a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -249,6 +249,9 @@ enum NodeType : unsigned {
FCMLEz,
FCMLTz,
+ // Round wide FP to narrow FP with inexact results to odd.
+ FCVTXN,
+
// Vector across-lanes addition
// Only the lower result lane is defined.
SADDV,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 7f8856db6c6e61..091db559a33708 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -7547,7 +7547,7 @@ class BaseSIMDCmpTwoScalar<bit U, bits<2> size, bits<2> size2, bits<5> opcode,
let mayRaiseFPException = 1, Uses = [FPCR] in
class SIMDInexactCvtTwoScalar<bits<5> opcode, string asm>
: I<(outs FPR32:$Rd), (ins FPR64:$Rn), asm, "\t$Rd, $Rn", "",
- [(set (f32 FPR32:$Rd), (int_aarch64_sisd_fcvtxn (f64 FPR64:$Rn)))]>,
+ [(set (f32 FPR32:$Rd), (AArch64fcvtxn (f64 FPR64:$Rn)))]>,
Sched<[WriteVd]> {
bits<5> Rd;
bits<5> Rn;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 52137c1f4065bc..c153bb3f0145bc 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -756,6 +756,11 @@ def AArch64fcmgtz: SDNode<"AArch64ISD::FCMGTz", SDT_AArch64fcmpz>;
def AArch64fcmlez: SDNode<"AArch64ISD::FCMLEz", SDT_AArch64fcmpz>;
def AArch64fcmltz: SDNode<"AArch64ISD::FCMLTz", SDT_AArch64fcmpz>;
+def AArch64fcvtxn_n: SDNode<"AArch64ISD::FCVTXN", SDTFPRoundOp>;
+def AArch64fcvtxn: PatFrags<(ops node:$Rn),
+ [(f32 (int_aarch64_sisd_fcvtxn (f64 node:$Rn))),
+ (f32 (AArch64fcvtxn_n (f64 node:$Rn)))]>;
+
def AArch64bici: SDNode<"AArch64ISD::BICi", SDT_AArch64vecimm>;
def AArch64orri: SDNode<"AArch64ISD::ORRi", SDT_AArch64vecimm>;
@@ -1276,6 +1281,9 @@ def BFMLALTIdx : SIMDBF16MLALIndex<1, "bfmlalt", int_aarch64_neon_bfmlalt>;
def BFCVTN : SIMD_BFCVTN;
def BFCVTN2 : SIMD_BFCVTN2;
+def : Pat<(v4bf16 (any_fpround (v4f32 V128:$Rn))),
+ (EXTRACT_SUBREG (BFCVTN V128:$Rn), dsub)>;
+
// Vector-scalar BFDOT:
// The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit
// register (the instruction uses a single 32-bit lane from it), so the pattern
@@ -1296,6 +1304,8 @@ def : Pat<(v2f32 (int_aarch64_neon_bfdot
let Predicates = [HasNEONorSME, HasBF16] in {
def BFCVT : BF16ToSinglePrecision<"bfcvt">;
+// Round FP32 to BF16.
+def : Pat<(bf16 (any_fpround (f32 FPR32:$Rn))), (BFCVT $Rn)>;
}
// ARMv8.6A AArch64 matrix multiplication
@@ -4648,6 +4658,22 @@ let Predicates = [HasFullFP16] in {
//===----------------------------------------------------------------------===//
defm FCVT : FPConversion<"fcvt">;
+// Helper to get bf16 into fp32.
+def cvt_bf16_to_fp32 :
+ OutPatFrag<(ops node:$Rn),
+ (f32 (COPY_TO_REGCLASS
+ (i32 (UBFMWri
+ (i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)),
+ node:$Rn, hsub), GPR32)),
+ (i64 (i32shift_a (i64 16))),
+ (i64 (i32shift_b (i64 16))))),
+ FPR32))>;
+// Pattern for bf16 -> fp32.
+def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))),
+ (cvt_bf16_to_fp32 FPR16:$Rn)>;
+// Pattern for bf16 -> fp64.
+def : Pat<(f64 (any_fpextend (bf16 FPR16:$Rn))),
+ (FCVTDSr (f32 (cvt_bf16_to_fp32 FPR16:$Rn)))>;
//===----------------------------------------------------------------------===//
// Floating point single operand instructions.
@@ -5002,6 +5028,9 @@ defm FCVTNU : SIMDTwoVectorFPToInt<1,0,0b11010, "fcvtnu",int_aarch64_neon_fcvtnu
defm FCVTN : SIMDFPNarrowTwoVector<0, 0, 0b10110, "fcvtn">;
def : Pat<(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn))),
(FCVTNv4i16 V128:$Rn)>;
+//def : Pat<(concat_vectors V64:$Rd,
+// (v4bf16 (any_fpround (v4f32 V128:$Rn)))),
+// (FCVTNv8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
def : Pat<(concat_vectors V64:$Rd,
(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn)))),
(FCVTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
@@ -5686,6 +5715,11 @@ defm USQADD : SIMDTwoScalarBHSDTied< 1, 0b00011, "usqadd",
def : Pat<(v1i64 (AArch64vashr (v1i64 V64:$Rn), (i32 63))),
(CMLTv1i64rz V64:$Rn)>;
+// Round FP64 to BF16.
+let Predicates = [HasNEONorSME, HasBF16] in
+def : Pat<(bf16 (any_fpround (f64 FPR64:$Rn))),
+ (BFCVT (FCVTXNv1i64 $Rn))>;
+
def : Pat<(v1i64 (int_aarch64_neon_fcvtas (v1f64 FPR64:$Rn))),
(FCVTASv1i64 FPR64:$Rn)>;
def : Pat<(v1i64 (int_aarch64_neon_fcvtau (v1f64 FPR64:$Rn))),
@@ -7698,6 +7732,9 @@ def : Pat<(v4i32 (anyext (v4i16 V64:$Rn))), (USHLLv4i16_shift V64:$Rn, (i32 0))>
def : Pat<(v2i64 (sext (v2i32 V64:$Rn))), (SSHLLv2i32_shift V64:$Rn, (i32 0))>;
def : Pat<(v2i64 (zext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
def : Pat<(v2i64 (anyext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
+// Vector bf16 -> fp32 is implemented morally as a zext + shift.
+def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))),
+ (USHLLv4i16_shift V64:$Rn, (i32 16))>;
// Also match an extend from the upper half of a 128 bit source register.
def : Pat<(v8i16 (anyext (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn)) ))),
(USHLLv16i8_shift V128:$Rn, (i32 0))>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
index 9bc5815ae05371..5a8031641ae099 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
@@ -1022,13 +1022,17 @@ void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
bool Invert = false;
AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
- if (Pred == CmpInst::Predicate::FCMP_ORD && IsZero) {
+ if ((Pred == CmpInst::Predicate::FCMP_ORD ||
+ Pred == CmpInst::Predicate::FCMP_UNO) &&
+ IsZero) {
// The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
// NaN, so equivalent to a == a and doesn't need the two comparisons an
// "ord" normally would.
+ // Similarly, "fcmp uno %a, 0" is the canonical check that LHS is NaN and is
+ // thus equivalent to a != a.
RHS = LHS;
IsZero = false;
- CC = AArch64CC::EQ;
+ CC = Pred == CmpInst::Predicate::FCMP_ORD ? AArch64CC::EQ : AArch64CC::NE;
} else
changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert);
diff --git a/llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll b/llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll
index 954a836e44f283..a68c21f7943432 100644
--- a/llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll
+++ b/llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll
@@ -13,7 +13,7 @@ define void @strict_fp_reductions() {
; CHECK-NEXT: Cost Model: Found an estimated cost of 28 for instruction: %fadd_v8f32 = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v2f64 = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f64 = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
-; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
+; CHECK-NEXT: Cost Model: Found an estimated cost of 18 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 20 for instruction: %fadd_v4f128 = call fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
@@ -24,7 +24,7 @@ define void @strict_fp_reductions() {
; FP16-NEXT: Cost Model: Found an estimated cost of 28 for instruction: %fadd_v8f32 = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v2f64 = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f64 = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
-; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
+; FP16-NEXT: Cost Model: Found an estimated cost of 18 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 20 for instruction: %fadd_v4f128 = call fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
@@ -72,7 +72,7 @@ define void @fast_fp_reductions() {
; CHECK-NEXT: Cost Model: Found an estimated cost of 5 for instruction: %fadd_v4f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %fadd_v7f64 = call fast double @llvm.vector.reduce.fadd.v7f64(double 0.000000e+00, <7 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 15 for instruction: %fadd_v9f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v9f64(double 0.000000e+00, <9 x double> undef)
-; CHECK-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
+; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f128 = call reassoc fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
@@ -95,7 +95,7 @@ define void @fast_fp_reductions() {
; FP16-NEXT: Cost Model: Found an estimated cost of 5 for instruction: %fadd_v4f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %fadd_v7f64 = call fast double @llvm.vector.reduce.fadd.v7f64(double 0.000000e+00, <7 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 15 for instruction: %fadd_v9f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v9f64(double 0.000000e+00, <9 x double> undef)
-; FP16-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
+; FP16-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f128 = call reassoc fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir b/llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir
index 8f01c009c4967a..1f5fb892df5820 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir
@@ -321,18 +321,15 @@ body: |
bb.0:
liveins: $q0, $q1
- ; Should be inverted. Needs two compares.
; CHECK-LABEL: name: uno_zero
; CHECK: liveins: $q0, $q1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: %lhs:_(<2 x s64>) = COPY $q0
- ; CHECK-NEXT: [[FCMGEZ:%[0-9]+]]:_(<2 x s64>) = G_FCMGEZ %lhs
- ; CHECK-NEXT: [[FCMLTZ:%[0-9]+]]:_(<2 x s64>) = G_FCMLTZ %lhs
- ; CHECK-NEXT: [[OR:%[0-9]+]]:_(<2 x s64>) = G_OR [[FCMLTZ]], [[FCMGEZ]]
+ ; CHECK-NEXT: [[FCMEQ:%[0-9]+]]:_(<2 x s64>) = G_FCMEQ %lhs, %lhs(<2 x s64>)
; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 -1
; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<2 x s64>) = G_BUILD_VECTOR [[C]](s64), [[C]](s64)
- ; CHECK-NEXT: [[XOR:%[0-9]+]]:_(<2 x s64>) = G_XOR [[OR]], [[BUILD_VECTOR]]
+ ; CHECK-NEXT: [[XOR:%[0-9]+]]:_(<2 x s64>) = G_XOR [[FCMEQ]], [[BUILD_VECTOR]]
; CHECK-NEXT: $q0 = COPY [[XOR]](<2 x s64>)
; CHECK-NEXT: RET_ReallyLR implicit $q0
%lhs:_(<2 x s64>) = COPY $q0
diff --git a/llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll b/llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll
index a949eaac5cfa29..adde5429a6d93d 100644
--- a/llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll
+++ b/llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll
@@ -187,10 +187,7 @@ entry:
define <8 x bfloat> @insertzero_v4bf16(<4 x bfloat> %a) {
; CHECK-LABEL: insertzero_v4bf16:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: movi d4, #0000000000000000
-; CHECK-NEXT: movi d5, #0000000000000000
-; CHECK-NEXT: movi d6, #0000000000000000
-; CHECK-NEXT: movi d7, #0000000000000000
+; CHECK-NEXT: fmov d0, d0
; CHECK-NEXT: ret
entry:
%shuffle.i = shufflevector <4 x bfloat> %a, <4 x bfloat> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
diff --git a/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll b/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
index 765c81e26e13ca..632b6b32625022 100644
--- a/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
@@ -3057,55 +3057,34 @@ define <2 x i64> @fcmonez2xdouble(<2 x double> %A) {
ret <2 x i64> %tmp4
}
-; ORD with zero = OLT | OGE
+; ORD A, zero = EQ A, A
define <2 x i32> @fcmordz2xfloat(<2 x float> %A) {
-; CHECK-SD-LABEL: fcmordz2xfloat:
-; CHECK-SD: // %bb.0:
-; CHECK-SD-NEXT: fcmge v1.2s, v0.2s, #0.0
-; CHECK-SD-NEXT: fcmlt v0.2s, v0.2s, #0.0
-; CHECK-SD-NEXT: orr v0.8b, v0.8b, v1.8b
-; CHECK-SD-NEXT: ret
-;
-; CHECK-GI-LABEL: fcmordz2xfloat:
-; CHECK-GI: // %bb.0:
-; CHECK-GI-NEXT: fcmeq v0.2s, v0.2s, v0.2s
-; CHECK-GI-NEXT: ret
+; CHECK-LABEL: fcmordz2xfloat:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fcmeq v0.2s, v0.2s, v0.2s
+; CHECK-NEXT: ret
%tmp3 = fcmp ord <2 x float> %A, zeroinitializer
%tmp4 = sext <2 x i1> %tmp3 to <2 x i32>
ret <2 x i32> %tmp4
}
-; ORD with zero = OLT | OGE
+; ORD A, zero = EQ A, A
define <4 x i32> @fcmordz4xfloat(<4 x float> %A) {
-; CHECK-SD-LABEL: fcmordz4xfloat:
-; CHECK-SD: // %bb.0:
-; CHECK-SD-NEXT: fcmge v1.4s, v0.4s, #0.0
-; CHECK-SD-NEXT: fcmlt v0.4s, v0.4s, #0.0
-; CHECK-SD-NEXT: orr v0.16b, v0.16b, v1.16b
-; CHECK-SD-NEXT: ret
-;
-; CHECK-GI-LABEL: fcmordz4xfloat:
-; CHECK-GI: // %bb.0:
-; CHECK-GI-NEXT: fcmeq v0.4s, v0.4s, v0.4s
-; CHECK-GI-NEXT: ret
+; CHECK-LABEL: fcmordz4xfloat:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-NEXT: ret
%tmp3 = fcmp ord <4 x float> %A, zeroinitializer
%tmp4 = sext <4 x i1> %tmp3 to <4 x i32>
ret <4 x i32> %tmp4
}
-; ORD with zero = OLT | OGE
+; ORD A, zero = EQ A, A
define <2 x i64> @fcmordz2xdouble(<2 x double> %A) {
-; CHECK-SD-LABEL: fcmordz2xdouble:
-; CHECK-SD: // %bb.0:
-; CHECK-SD-NEXT: fcmge v1.2d, v0.2d, #0.0
-; CHECK-SD-NEXT: fcmlt v0.2d, v0.2d, #0.0
-; CHECK-SD-NEXT: orr v0.16b, v0.16b, v1.16b
-; CHECK-SD-NEXT: ret
-;
-; CHECK-GI-LABEL: fcmordz2xdouble:
-; CHECK-GI: // %bb.0:
-; CHECK-GI-NEXT: fcmeq v0.2d, v0.2d, v0.2d
-; CHECK-GI-NEXT: ret
+; CHECK-LABEL: fcmordz2xdouble:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fcmeq v0.2d, v0.2d, v0.2d
+; CHECK-NEXT: ret
%tmp3 = fcmp ord <2 x double> %A, zeroinitializer
%tmp4 = sext <2 x i1> %tmp3 to <2 x i64>
ret <2 x i64> %tmp4
@@ -3331,13 +3310,11 @@ define <2 x i64> @fcmunez2xdouble(<2 x double> %A) {
ret <2 x i64> %tmp4
}
-; UNO with zero = !ORD = !(OLT | OGE)
+; UNO A, zero = !(ORD A, zero) = !(EQ A, A)
define <2 x i32> @fcmunoz2xfloat(<2 x float> %A) {
; CHECK-LABEL: fcmunoz2xfloat:
; CHECK: // %bb.0:
-; CHECK-NEXT: fcmge v1.2s, v0.2s, #0.0
-; CHECK-NEXT: fcmlt v0.2s, v0.2s, #0.0
-; CHECK-NEXT: orr v0.8b, v0.8b, v1.8b
+; CHECK-NEXT: fcmeq v0.2s, v0.2s, v0.2s
; CHECK-NEXT: mvn v0.8b, v0.8b
; CHECK-NEXT: ret
%tmp3 = fcmp uno <2 x float> %A, zeroinitializer
@@ -3345,13 +3322,11 @@ define <2 x i32> @fcmunoz2xfloat(<2 x float> %A) {
ret <2 x i32> %tmp4
}
-; UNO with zero = !ORD = !(OLT | OGE)
+; UNO A, zero = !(ORD A, zero) = !(EQ A, A)
define <4 x i32> @fcmunoz4xfloat(<4 x float> %A) {
; CHECK-LABEL: fcmunoz4xfloat:
; CHECK: // %bb.0:
-; CHECK-NEXT: fcmge v1.4s, v0.4s, #0.0
-; CHECK-NEXT: fcmlt v0.4s, v0.4s, #0.0
-; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: fcmeq v0.4s, v0.4s, v0.4s
; CHECK-NEXT: mvn v0.16b, v0.16b
; CHECK-NEXT: ret
%tmp3 = fcmp uno <4 x float> %A, zeroinitializer
@@ -3359,13 +3334,11 @@ define <4 x i32> @fcmunoz4xfloat(<4 x float> %A) {
ret <4 x i32> %tmp4
}
-; UNO with zero = !ORD = !(OLT | OGE)
+; UNO A, zero = !(ORD A, zero) = !(EQ A, A)
define <2 x i64> @fcmunoz2xdouble(<2 x double> %A) {
; CHECK-LABEL: fcmunoz2xdouble:
; CHECK: // %bb.0:
-; CHECK-NEXT: fcmge v1.2d, v0.2d, #0.0
-; CHECK-NEXT: fcmlt v0.2d, v0.2d, #0.0
-; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: fcmeq v0.2d, v0.2d, v0.2d
; CHECK-NEXT: mvn v0.16b, v0.16b
; CHECK-NEXT: ret
%tmp3 = fcmp uno <2 x double> %A, zeroinitializer
@@ -4133,51 +4106,30 @@ define <2 x i64> @fcmonez2xdouble_fast(<2 x double> %A) {
}
define <2 x i32> @fcmordz2xfloat_fast(<2 x float> %A) {
-; CHECK-SD-LABEL: fcmordz2xfloat_fast:
-; CHECK-SD: // %bb.0:
-; CHECK-SD-NEXT: fcmge v1.2s, v0.2s, #0.0
-; CHECK-SD-NEXT: fcmlt v0.2s, v0.2s, #0.0
-; CHECK-SD-NEXT: orr v0.8b, v0.8b, v1.8b
-; CHECK-SD-NEXT: ret
-;
-; CHECK-GI-LABEL: fcmordz2xfloat_fast:
-; CHECK-GI: // %bb.0:
-; CHECK-GI-NEXT: fcmeq v0.2s, v0.2s, v0.2s
-; CHECK-GI-NEXT: ret
+; CHECK-LABEL: fcmordz2xfloat_fast:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fcmeq v0.2s, v0.2s, v0.2s
+; CHECK-NEXT: ret
%tmp3 = fcmp fast ord <2 x float> %A, zeroinitializer
%tmp4 = sext <2 x i1> %tmp3 to <2 x i32>
ret <2 x i32> %tmp4
}
define <4 x i32> @fcmordz4xfloat_fast(<4 x float> %A) {
-; CHECK-SD-LABEL: fcmordz4xfloat_fast:
-; CHECK-SD: // %bb.0:
-; CHECK-SD-NEXT: fcmge v1.4s, v0.4s, #0.0
-; CHECK-SD-NEXT: fcmlt v0.4s, v0.4s, #0.0
-; CHECK-SD-NEXT: orr v0.16b, v0.16b, v1.16b
-; CHECK-SD-NEXT: ret
-;
-; CHECK-GI-LABEL: fcmordz4xfloat_fast:
-; CHECK-GI: // %bb.0:
-; CHECK-GI-NEXT: fcmeq v0.4s, v0.4s, v0.4s
-; CHECK-GI-NEXT: ret
+; CHECK-LABEL: fcmordz4xfloat_fast:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-NEXT: ret
%tmp3 = fcmp fast ord <4 x float> %A, zeroinitializer
%tmp4 = sext <4 x i1> %tmp3 to <4 x i32>
ret <4 x i32> %tmp4
}
define <2 x i64> @fcmordz2xdouble_fast(<2 x double> %A) {
-; CHECK-SD-LABEL: fcmordz2xdouble_fast:
-; CHECK-SD: // %bb.0:
-; CHECK-SD-NEXT: fcmge v1.2d, v0.2d, #0.0
-; CHECK-SD-NEXT: fcmlt v0.2d, v0.2d, #0.0
-; CHECK-SD-NEXT: orr v0.16b, v0.16b, v1.16b
-; CHECK-SD-NEXT: ret
-;
-; CHECK-GI-LABEL: fcmordz2xdouble_fast:
-; CHECK-GI: // %bb.0:
-; CHECK-GI-NEXT: fcmeq v0.2d, v0.2d, v0.2d
-; CHECK-GI-NEXT: ret
+; CHECK-LABEL: fcmordz2xdouble_fast:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fcmeq v0.2d, v0.2d, v0.2d
+; CHECK-NEXT: ret
%tmp3 = fcmp fast ord <2 x double> %A, zeroinitializer
%tmp4 = sext <2 x i1> %tmp3 to <2 x i64>
ret <2 x i64> %tmp4
@@ -4466,9 +4418,7 @@ define <2 x i64> @fcmunez2xdouble_fast(<2 x double> %A) {
define <2 x i32> @fcmunoz2xfloat_fast(<2 x float> %A) {
; CHECK-LABEL: fcmunoz2xfloat_fast:
; CHECK: // %bb.0:
-; CHECK-NEXT: fcmge v1.2s, v0.2s, #0.0
-; CHECK-NEXT: fcmlt v0.2s, v0.2s, #0.0
-; CHECK-NEXT: orr v0.8b, v0.8b, v1.8b
+; CHECK-NEXT: fcmeq v0.2s, v0.2s, v0.2s
; CHECK-NEXT: mvn v0.8b, v0.8b
; CHECK-NEXT: ret
%tmp3 = fcmp fast uno <2 x float> %A, zeroinitializer
@@ -4479,9 +4429,7 @@ define <2 x i32> @fcmunoz2xfloat_fast(<2 x float> %A) {
define <4 x i32> @fcmunoz4xfloat_fast(<4 x float> %A) {
; CHECK-LABEL: fcmunoz4xfloat_fast:
; CHECK: // %bb.0:
-; CHECK-NEXT: fcmge v1.4s, v0.4s, #0.0
-; CHECK-NEXT: fcmlt v0.4s, v0.4s, #0.0
-; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: fcmeq v0.4s, v0.4s, v0.4s
; CHECK-NEXT: mvn v0.16b, v0.16b
; CHECK-NEXT: ret
%tmp3 = fcmp fast uno <4 x float> %A, zeroinitializer
@@ -4492,9 +4440,7 @@ define <4 x i32> @fcmunoz4xfloat_fast(<4 x float> %A) {
define <2 x i64> @fcmunoz2xdouble_fast(<2 x double> %A) {
; CHECK-LABEL: fcmunoz2xdouble_fast:
; CHECK: // %bb.0:
-; CHECK-NEXT: fcmge v1.2d, v0.2d, #0.0
-; CHECK-NEXT: fcmlt v0.2d, v0.2d, #0.0
-; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT: fcmeq v0.2d, v0.2d, v0.2d
; CHECK-NEXT: mvn v0.16b, v0.16b
; CHECK-NEXT: ret
%tmp3 = fcmp fast uno <2 x double> %A, zeroinitializer
diff --git a/llvm/test/CodeGen/AArch64/round-fptosi-sat-scalar.ll b/llvm/test/CodeGen/AArch64/round-fptosi-sat-scalar.ll
index 4f23df1d168197..ec7548e1e65410 100644
--- a/llvm/test/CodeGen/AArch64/round-fptosi-sat-scalar.ll
+++ b/llvm/test/CodeGen/AArch64/round-fptosi-sat-scalar.ll
@@ -4,6 +4,54 @@
; Round towards minus infinity (fcvtms).
+define i32 @testmswbf(bfloat %a) {
+; CHECK-LABEL: testmswbf:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: // kill: def $h0 killed $h0 def $s0
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: mov w8, #32767 // =0x7fff
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: fmov s0, w9
+; CHECK-NEXT: frintm s0, s0
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: ubfx w10, w9, #16, #1
+; CHECK-NEXT: add w8, w9, w8
+; CHECK-NEXT: add w8, w10, w8
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: fcvtzs w0, s0
+; CHECK-NEXT: ret
+entry:
+ %r = call bfloat @llvm.floor.bf16(bfloat %a) nounwind readnone
+ %i = call i32 @llvm.fptosi.sat.i32.bf16(bfloat %r)
+ ret i32 %i
+}
+
+define i64 @testmsxbf(bfloat %a) {
+; CHECK-LABEL: testmsxbf:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: // kill: def $h0 killed $h0 def $s0
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: mov w8, #32767 // =0x7fff
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: fmov s0, w9
+; CHECK-NEXT: frintm s0, s0
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: ubfx w10, w9, #16, #1
+; CHECK-NEXT: add w8, w9, w8
+; CHECK-NEXT: add w8, w10, w8
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: fcvtzs x0, s0
+; CHECK-NEXT: ret
+entry:
+ %r = call bfloat @llvm.floor.bf16(bfloat %a) nounwind readnone
+ %i = call i64 @llvm.fptosi.sat.i64.bf16(bfloat %r)
+ ret i64 %i
+}
+
define i32 @testmswh(half %a) {
; CHECK-CVT-LABEL: testmswh:
; CHECK-CVT: // %bb.0: // %entry
@@ -90,6 +138,54 @@ entry:
; Round towards plus infinity (fcvtps).
+define i32 @testpswbf(bfloat %a) {
+; CHECK-LABEL: testpswbf:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: // kill: def $h0 killed $h0 def $s0
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: mov w8, #32767 // =0x7fff
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: fmov s0, w9
+; CHECK-NEXT: frintp s0, s0
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: ubfx w10, w9, #16, #1
+; CHECK-NEXT: add w8, w9, w8
+; CHECK-NEXT: add w8, w10, w8
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: fcvtzs w0, s0
+; CHECK-NEXT: ret
+entry:
+ %r = call bfloat @llvm.ceil.bf16(bfloat %a) nounwind readnone
+ %i = call i32 @llvm.fptosi.sat.i32.bf16(bfloat %r)
+ ret i32 %i
+}
+
+define i64 @testpsxbf(bfloat %a) {
+; CHECK-LABEL: testpsxbf:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: // kill: def $h0 killed $h0 def $s0
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: mov w8, #32767 // =0x7fff
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: fmov s0, w9
+; CHECK-NEXT: frintp s0, s0
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: ubfx w10, w9, #16, #1
+; CHECK-NEXT: add w8, w9, w8
+; CHECK-NEXT: add w8, w10, w8
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: fcvtzs x0, s0
+; CHECK-NEXT: ret
+entry:
+ %r = call bfloat @llvm.ceil.bf16(bfloat %a) nounwind readnone
+ %i = call i64 @llvm.fptosi.sat.i64.bf16(bfloat %r)
+ ret i64 %i
+}
+
define i32 @testpswh(half %a) {
; CHECK-CVT-LABEL: testpswh:
; CHECK-CVT: // %bb.0: // %entry
@@ -346,6 +442,8 @@ entry:
ret i64 %i
}
+declare i32 @llvm.fptosi.sat.i32.bf16 (bfloat)
+declare i64 @llvm.fptosi.sat.i64.bf16 (bfloat)
declare i32 @llvm.fptosi.sat.i32.f16 (half)
declare i64 @llvm.fptosi.sat.i64.f16 (half)
declare i32 @llvm.fptosi.sat.i32.f32 (float)
@@ -353,6 +451,10 @@ declare i64 @llvm.fptosi.sat.i64.f32 (float)
declare i32 @llvm.fptosi.sat.i32.f64 (double)
declare i64 @llvm.fptosi.sat.i64.f64 (double)
+declare bfloat @llvm.floor.bf16(bfloat) nounwind readnone
+declare bfloat @llvm.ceil.bf16(bfloat) nounwind readnone
+declare bfloat @llvm.trunc.bf16(bfloat) nounwind readnone
+declare bfloat @llvm.round.bf16(bfloat) nounwind readnone
declare half @llvm.floor.f16(half) nounwind readnone
declare half @llvm.ceil.f16(half) nounwind readnone
declare half @llvm.trunc.f16(half) nounwind readnone
diff --git a/llvm/test/CodeGen/AArch64/vector-fcopysign.ll b/llvm/test/CodeGen/AArch64/vector-fcopysign.ll
index c7134508883b11..d01ca881545c02 100644
--- a/llvm/test/CodeGen/AArch64/vector-fcopysign.ll
+++ b/llvm/test/CodeGen/AArch64/vector-fcopysign.ll
@@ -477,4 +477,443 @@ define <8 x half> @test_copysign_v8f16_v8f32(<8 x half> %a, <8 x float> %b) #0 {
declare <8 x half> @llvm.copysign.v8f16(<8 x half> %a, <8 x half> %b) #0
+;============ v4bf16
+
+define <4 x bfloat> @test_copysign_v4bf16_v4bf16(<4 x bfloat> %a, <4 x bfloat> %b) #0 {
+; CHECK-LABEL: test_copysign_v4bf16_v4bf16:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: ; kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: ; kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT: mov h3, v1[1]
+; CHECK-NEXT: mov h4, v0[1]
+; CHECK-NEXT: fmov w8, s1
+; CHECK-NEXT: mov h5, v1[2]
+; CHECK-NEXT: mov h6, v0[2]
+; CHECK-NEXT: fmov w11, s0
+; CHECK-NEXT: mvni.4s v2, #128, lsl #24
+; CHECK-NEXT: mov h1, v1[3]
+; CHECK-NEXT: mov h0, v0[3]
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: fmov w9, s3
+; CHECK-NEXT: lsl w11, w11, #16
+; CHECK-NEXT: fmov w10, s4
+; CHECK-NEXT: fmov s7, w8
+; CHECK-NEXT: fmov w8, s5
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: fmov s3, w9
+; CHECK-NEXT: fmov s4, w10
+; CHECK-NEXT: fmov w9, s6
+; CHECK-NEXT: fmov w10, s1
+; CHECK-NEXT: bit.16b v3, v4, v2
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: fmov s4, w11
+; CHECK-NEXT: fmov w11, s0
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: fmov s1, w9
+; CHECK-NEXT: bif.16b v4, v7, v2
+; CHECK-NEXT: fmov w8, s3
+; CHECK-NEXT: lsl w11, w11, #16
+; CHECK-NEXT: bif.16b v1, v0, v2
+; CHECK-NEXT: fmov s5, w11
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fmov w9, s4
+; CHECK-NEXT: fmov s4, w10
+; CHECK-NEXT: fmov s3, w8
+; CHECK-NEXT: fmov w8, s1
+; CHECK-NEXT: mov.16b v1, v2
+; CHECK-NEXT: lsr w9, w9, #16
+; CHECK-NEXT: bsl.16b v1, v5, v4
+; CHECK-NEXT: fmov s0, w9
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fmov s2, w8
+; CHECK-NEXT: mov.h v0[1], v3[0]
+; CHECK-NEXT: fmov w8, s1
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: mov.h v0[2], v2[0]
+; CHECK-NEXT: fmov s1, w8
+; CHECK-NEXT: mov.h v0[3], v1[0]
+; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT: ret
+ %r = call <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %a, <4 x bfloat> %b)
+ ret <4 x bfloat> %r
+}
+
+define <4 x bfloat> @test_copysign_v4bf16_v4f32(<4 x bfloat> %a, <4 x float> %b) #0 {
+; CHECK-LABEL: test_copysign_v4bf16_v4f32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: movi.4s v2, #127, msl #8
+; CHECK-NEXT: movi.4s v3, #1
+; CHECK-NEXT: ; kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT: ushr.4s v4, v1, #16
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: mov h5, v0[2]
+; CHECK-NEXT: mov h6, v0[3]
+; CHECK-NEXT: add.4s v2, v1, v2
+; CHECK-NEXT: and.16b v3, v4, v3
+; CHECK-NEXT: fcmeq.4s v4, v1, v1
+; CHECK-NEXT: orr.4s v1, #64, lsl #16
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: add.4s v2, v3, v2
+; CHECK-NEXT: mov h3, v0[1]
+; CHECK-NEXT: bit.16b v1, v2, v4
+; CHECK-NEXT: fmov w8, s3
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: shrn.4h v2, v1, #16
+; CHECK-NEXT: mvni.4s v1, #128, lsl #24
+; CHECK-NEXT: fmov s3, w8
+; CHECK-NEXT: fmov w8, s5
+; CHECK-NEXT: fmov s5, w9
+; CHECK-NEXT: mov h4, v2[1]
+; CHECK-NEXT: mov h0, v2[2]
+; CHECK-NEXT: fmov w11, s2
+; CHECK-NEXT: mov h2, v2[3]
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: lsl w11, w11, #16
+; CHECK-NEXT: fmov w10, s4
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: fmov w8, s2
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: fmov s4, w10
+; CHECK-NEXT: fmov w10, s6
+; CHECK-NEXT: fmov s2, w9
+; CHECK-NEXT: bif.16b v3, v4, v1
+; CHECK-NEXT: fmov s4, w11
+; CHECK-NEXT: bit.16b v2, v0, v1
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: bit.16b v4, v5, v1
+; CHECK-NEXT: fmov s5, w8
+; CHECK-NEXT: fmov w9, s3
+; CHECK-NEXT: fmov w8, s2
+; CHECK-NEXT: fmov w11, s4
+; CHECK-NEXT: fmov s4, w10
+; CHECK-NEXT: lsr w9, w9, #16
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fmov s3, w9
+; CHECK-NEXT: lsr w11, w11, #16
+; CHECK-NEXT: bsl.16b v1, v4, v5
+; CHECK-NEXT: fmov s2, w8
+; CHECK-NEXT: fmov s0, w11
+; CHECK-NEXT: fmov w8, s1
+; CHECK-NEXT: mov.h v0[1], v3[0]
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: mov.h v0[2], v2[0]
+; CHECK-NEXT: fmov s1, w8
+; CHECK-NEXT: mov.h v0[3], v1[0]
+; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT: ret
+ %tmp0 = fptrunc <4 x float> %b to <4 x bfloat>
+ %r = call <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %a, <4 x bfloat> %tmp0)
+ ret <4 x bfloat> %r
+}
+
+define <4 x bfloat> @test_copysign_v4bf16_v4f64(<4 x bfloat> %a, <4 x double> %b) #0 {
+; CHECK-LABEL: test_copysign_v4bf16_v4f64:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: ; kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT: fmov w8, s0
+; CHECK-NEXT: mov h4, v0[1]
+; CHECK-NEXT: mov h5, v0[2]
+; CHECK-NEXT: mov d3, v1[1]
+; CHECK-NEXT: fcvt s1, d1
+; CHECK-NEXT: mov h0, v0[3]
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: fmov w9, s4
+; CHECK-NEXT: mvni.4s v4, #128, lsl #24
+; CHECK-NEXT: fmov s6, w8
+; CHECK-NEXT: fmov w8, s5
+; CHECK-NEXT: fcvt s3, d3
+; CHECK-NEXT: fmov w10, s0
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: bit.16b v1, v6, v4
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: mov d6, v2[1]
+; CHECK-NEXT: fmov s7, w9
+; CHECK-NEXT: fcvt s2, d2
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: fmov s5, w8
+; CHECK-NEXT: fmov w8, s1
+; CHECK-NEXT: mov.16b v1, v4
+; CHECK-NEXT: bit.16b v3, v7, v4
+; CHECK-NEXT: bsl.16b v1, v5, v2
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fcvt s2, d6
+; CHECK-NEXT: fmov w9, s3
+; CHECK-NEXT: fmov s5, w10
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: fmov w8, s1
+; CHECK-NEXT: mov.16b v1, v4
+; CHECK-NEXT: lsr w9, w9, #16
+; CHECK-NEXT: fmov s3, w9
+; CHECK-NEXT: bsl.16b v1, v5, v2
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: mov.h v0[1], v3[0]
+; CHECK-NEXT: fmov s2, w8
+; CHECK-NEXT: fmov w8, s1
+; CHECK-NEXT: mov.h v0[2], v2[0]
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fmov s1, w8
+; CHECK-NEXT: mov.h v0[3], v1[0]
+; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT: ret
+ %tmp0 = fptrunc <4 x double> %b to <4 x bfloat>
+ %r = call <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %a, <4 x bfloat> %tmp0)
+ ret <4 x bfloat> %r
+}
+
+declare <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %a, <4 x bfloat> %b) #0
+
+;============ v8bf16
+
+define <8 x bfloat> @test_copysign_v8bf16_v8bf16(<8 x bfloat> %a, <8 x bfloat> %b) #0 {
+; CHECK-LABEL: test_copysign_v8bf16_v8bf16:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: fmov w8, s1
+; CHECK-NEXT: mov h2, v1[1]
+; CHECK-NEXT: mov h4, v0[1]
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: mov h6, v1[2]
+; CHECK-NEXT: mov h7, v0[2]
+; CHECK-NEXT: mvni.4s v3, #128, lsl #24
+; CHECK-NEXT: mov h5, v1[3]
+; CHECK-NEXT: mov h16, v0[3]
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: mov h17, v1[4]
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: fmov w10, s4
+; CHECK-NEXT: mov h4, v0[4]
+; CHECK-NEXT: fmov s18, w8
+; CHECK-NEXT: fmov w8, s2
+; CHECK-NEXT: fmov w11, s7
+; CHECK-NEXT: fmov s2, w9
+; CHECK-NEXT: lsl w9, w10, #16
+; CHECK-NEXT: fmov w10, s6
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: fmov s7, w9
+; CHECK-NEXT: bif.16b v2, v18, v3
+; CHECK-NEXT: lsl w9, w11, #16
+; CHECK-NEXT: fmov s6, w8
+; CHECK-NEXT: lsl w8, w10, #16
+; CHECK-NEXT: fmov w10, s5
+; CHECK-NEXT: fmov w11, s16
+; CHECK-NEXT: fmov s16, w9
+; CHECK-NEXT: mov h18, v0[5]
+; CHECK-NEXT: fmov s5, w8
+; CHECK-NEXT: bit.16b v6, v7, v3
+; CHECK-NEXT: fmov w8, s2
+; CHECK-NEXT: lsl w9, w10, #16
+; CHECK-NEXT: lsl w10, w11, #16
+; CHECK-NEXT: mov h7, v1[5]
+; CHECK-NEXT: bit.16b v5, v16, v3
+; CHECK-NEXT: fmov s16, w10
+; CHECK-NEXT: fmov w10, s4
+; CHECK-NEXT: mov.16b v4, v3
+; CHECK-NEXT: fmov w11, s6
+; CHECK-NEXT: fmov s6, w9
+; CHECK-NEXT: fmov w9, s17
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: lsr w11, w11, #16
+; CHECK-NEXT: fmov s2, w8
+; CHECK-NEXT: lsl w8, w9, #16
+; CHECK-NEXT: bsl.16b v4, v16, v6
+; CHECK-NEXT: lsl w9, w10, #16
+; CHECK-NEXT: fmov w10, s5
+; CHECK-NEXT: fmov s6, w11
+; CHECK-NEXT: fmov s5, w8
+; CHECK-NEXT: lsr w8, w10, #16
+; CHECK-NEXT: fmov w10, s7
+; CHECK-NEXT: mov.h v2[1], v6[0]
+; CHECK-NEXT: fmov s6, w9
+; CHECK-NEXT: fmov w9, s18
+; CHECK-NEXT: fmov s7, w8
+; CHECK-NEXT: fmov w8, s4
+; CHECK-NEXT: mov h4, v1[6]
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: mov h1, v1[7]
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: bit.16b v5, v6, v3
+; CHECK-NEXT: mov h6, v0[6]
+; CHECK-NEXT: mov.h v2[2], v7[0]
+; CHECK-NEXT: fmov s7, w10
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fmov s16, w9
+; CHECK-NEXT: fmov w9, s4
+; CHECK-NEXT: mov h0, v0[7]
+; CHECK-NEXT: fmov w10, s6
+; CHECK-NEXT: bit.16b v7, v16, v3
+; CHECK-NEXT: fmov s16, w8
+; CHECK-NEXT: fmov w8, s5
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: mov.h v2[3], v16[0]
+; CHECK-NEXT: fmov s5, w9
+; CHECK-NEXT: fmov w9, s1
+; CHECK-NEXT: fmov s4, w8
+; CHECK-NEXT: fmov w8, s7
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: mov.h v2[4], v4[0]
+; CHECK-NEXT: fmov s4, w10
+; CHECK-NEXT: fmov w10, s0
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fmov s1, w9
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: bif.16b v4, v5, v3
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: fmov s5, w10
+; CHECK-NEXT: mov.h v2[5], v0[0]
+; CHECK-NEXT: mov.16b v0, v3
+; CHECK-NEXT: fmov w8, s4
+; CHECK-NEXT: bsl.16b v0, v5, v1
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fmov s1, w8
+; CHECK-NEXT: fmov w8, s0
+; CHECK-NEXT: mov.h v2[6], v1[0]
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: mov.h v2[7], v0[0]
+; CHECK-NEXT: mov.16b v0, v2
+; CHECK-NEXT: ret
+ %r = call <8 x bfloat> @llvm.copysign.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b)
+ ret <8 x bfloat> %r
+}
+
+define <8 x bfloat> @test_copysign_v8bf16_v8f32(<8 x bfloat> %a, <8 x float> %b) #0 {
+; CHECK-LABEL: test_copysign_v8bf16_v8f32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: movi.4s v3, #127, msl #8
+; CHECK-NEXT: movi.4s v4, #1
+; CHECK-NEXT: ushr.4s v5, v1, #16
+; CHECK-NEXT: fcmeq.4s v7, v1, v1
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: add.4s v6, v1, v3
+; CHECK-NEXT: and.16b v5, v5, v4
+; CHECK-NEXT: orr.4s v1, #64, lsl #16
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: add.4s v5, v5, v6
+; CHECK-NEXT: ushr.4s v6, v2, #16
+; CHECK-NEXT: and.16b v4, v6, v4
+; CHECK-NEXT: mov h6, v0[2]
+; CHECK-NEXT: bit.16b v1, v5, v7
+; CHECK-NEXT: add.4s v7, v2, v3
+; CHECK-NEXT: mov h5, v0[1]
+; CHECK-NEXT: fcmeq.4s v3, v2, v2
+; CHECK-NEXT: orr.4s v2, #64, lsl #16
+; CHECK-NEXT: shrn.4h v1, v1, #16
+; CHECK-NEXT: add.4s v4, v4, v7
+; CHECK-NEXT: fmov w8, s5
+; CHECK-NEXT: mov h7, v0[3]
+; CHECK-NEXT: mov h5, v0[4]
+; CHECK-NEXT: mov h16, v1[1]
+; CHECK-NEXT: fmov w10, s1
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: bsl.16b v3, v4, v2
+; CHECK-NEXT: mov h4, v1[2]
+; CHECK-NEXT: mov h17, v1[3]
+; CHECK-NEXT: mvni.4s v2, #128, lsl #24
+; CHECK-NEXT: fmov s1, w9
+; CHECK-NEXT: fmov w9, s6
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: fmov s6, w8
+; CHECK-NEXT: fmov w8, s7
+; CHECK-NEXT: fmov w11, s16
+; CHECK-NEXT: fmov s7, w10
+; CHECK-NEXT: fmov w10, s4
+; CHECK-NEXT: mov.16b v4, v2
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: shrn.4h v3, v3, #16
+; CHECK-NEXT: lsl w11, w11, #16
+; CHECK-NEXT: bif.16b v1, v7, v2
+; CHECK-NEXT: fmov s16, w8
+; CHECK-NEXT: fmov s7, w11
+; CHECK-NEXT: bsl.16b v4, v6, v7
+; CHECK-NEXT: fmov s7, w9
+; CHECK-NEXT: lsl w9, w10, #16
+; CHECK-NEXT: fmov w10, s17
+; CHECK-NEXT: mov h6, v0[5]
+; CHECK-NEXT: lsl w8, w10, #16
+; CHECK-NEXT: fmov w10, s1
+; CHECK-NEXT: fmov s1, w9
+; CHECK-NEXT: lsr w9, w10, #16
+; CHECK-NEXT: fmov w10, s4
+; CHECK-NEXT: fmov s4, w8
+; CHECK-NEXT: bif.16b v7, v1, v2
+; CHECK-NEXT: fmov w8, s5
+; CHECK-NEXT: mov h5, v3[1]
+; CHECK-NEXT: fmov s1, w9
+; CHECK-NEXT: fmov w9, s3
+; CHECK-NEXT: lsr w10, w10, #16
+; CHECK-NEXT: bit.16b v4, v16, v2
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: fmov s16, w10
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: fmov w10, s7
+; CHECK-NEXT: mov h7, v0[6]
+; CHECK-NEXT: mov h0, v0[7]
+; CHECK-NEXT: mov.h v1[1], v16[0]
+; CHECK-NEXT: fmov s16, w8
+; CHECK-NEXT: fmov w8, s6
+; CHECK-NEXT: fmov s6, w9
+; CHECK-NEXT: fmov w9, s5
+; CHECK-NEXT: lsr w10, w10, #16
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: bit.16b v6, v16, v2
+; CHECK-NEXT: fmov s16, w10
+; CHECK-NEXT: fmov w10, s4
+; CHECK-NEXT: fmov s5, w8
+; CHECK-NEXT: fmov w8, s7
+; CHECK-NEXT: fmov s7, w9
+; CHECK-NEXT: mov h4, v3[2]
+; CHECK-NEXT: mov h3, v3[3]
+; CHECK-NEXT: mov.h v1[2], v16[0]
+; CHECK-NEXT: lsr w10, w10, #16
+; CHECK-NEXT: fmov w9, s6
+; CHECK-NEXT: lsl w8, w8, #16
+; CHECK-NEXT: bif.16b v5, v7, v2
+; CHECK-NEXT: fmov s16, w10
+; CHECK-NEXT: fmov w10, s4
+; CHECK-NEXT: fmov s4, w8
+; CHECK-NEXT: lsr w9, w9, #16
+; CHECK-NEXT: mov.h v1[3], v16[0]
+; CHECK-NEXT: fmov w8, s5
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: fmov s6, w9
+; CHECK-NEXT: fmov w9, s0
+; CHECK-NEXT: fmov s5, w10
+; CHECK-NEXT: fmov w10, s3
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: mov.h v1[4], v6[0]
+; CHECK-NEXT: lsl w9, w9, #16
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: bif.16b v4, v5, v2
+; CHECK-NEXT: lsl w10, w10, #16
+; CHECK-NEXT: fmov s3, w9
+; CHECK-NEXT: fmov s5, w10
+; CHECK-NEXT: mov.h v1[5], v0[0]
+; CHECK-NEXT: mov.16b v0, v2
+; CHECK-NEXT: fmov w8, s4
+; CHECK-NEXT: bsl.16b v0, v3, v5
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fmov s2, w8
+; CHECK-NEXT: fmov w8, s0
+; CHECK-NEXT: mov.h v1[6], v2[0]
+; CHECK-NEXT: lsr w8, w8, #16
+; CHECK-NEXT: fmov s0, w8
+; CHECK-NEXT: mov.h v1[7], v0[0]
+; CHECK-NEXT: mov.16b v0, v1
+; CHECK-NEXT: ret
+ %tmp0 = fptrunc <8 x float> %b to <8 x bfloat>
+ %r = call <8 x bfloat> @llvm.copysign.v8bf16(<8 x bfloat> %a, <8 x bfloat> %tmp0)
+ ret <8 x bfloat> %r
+}
+
+declare <8 x bfloat> @llvm.copysign.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b) #0
+
attributes #0 = { nounwind }
More information about the llvm-commits
mailing list