[llvm] 3dd6750 - [AArch64] Add more complete support for BF16

David Majnemer via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 3 14:55:41 PST 2024


Author: David Majnemer
Date: 2024-03-03T22:39:50Z
New Revision: 3dd6750027cd168fce5fd9894b0bac0739652cf5

URL: https://github.com/llvm/llvm-project/commit/3dd6750027cd168fce5fd9894b0bac0739652cf5
DIFF: https://github.com/llvm/llvm-project/commit/3dd6750027cd168fce5fd9894b0bac0739652cf5.diff

LOG: [AArch64] Add more complete support for BF16

We can use a small amount of integer arithmetic to round FP32 to BF16
and extend BF16 to FP32.

While a number of operations still require promotion, this can be
reduced for some rather simple operations like abs, copysign, fneg but
these can be done in a follow-up.

A few neat optimizations are implemented:
- round-inexact-to-odd is used for F64 to BF16 rounding.
- quieting signaling NaNs for f32 -> bf16 tries to detect if a prior
  operation makes it unnecessary.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/include/llvm/CodeGen/ValueTypes.h
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64InstrFormats.td
    llvm/lib/Target/AArch64/AArch64InstrInfo.td
    llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
    llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll
    llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir
    llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll
    llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
    llvm/test/CodeGen/AArch64/round-fptosi-sat-scalar.ll
    llvm/test/CodeGen/AArch64/vector-fcopysign.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index f2e00aab8d5da2..0438abc7c3061a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1573,13 +1573,14 @@ class TargetLoweringBase {
     assert((VT.isInteger() || VT.isFloatingPoint()) &&
            "Cannot autopromote this type, add it with AddPromotedToType.");
 
+    uint64_t VTBits = VT.getScalarSizeInBits();
     MVT NVT = VT;
     do {
       NVT = (MVT::SimpleValueType)(NVT.SimpleTy+1);
       assert(NVT.isInteger() == VT.isInteger() && NVT != MVT::isVoid &&
              "Didn't find type to promote to!");
-    } while (!isTypeLegal(NVT) ||
-              getOperationAction(Op, NVT) == Promote);
+    } while (VTBits >= NVT.getScalarSizeInBits() || !isTypeLegal(NVT) ||
+             getOperationAction(Op, NVT) == Promote);
     return NVT;
   }
 

diff  --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h
index 1e0356ee69ff52..b66c66d1bfc45a 100644
--- a/llvm/include/llvm/CodeGen/ValueTypes.h
+++ b/llvm/include/llvm/CodeGen/ValueTypes.h
@@ -107,6 +107,13 @@ namespace llvm {
       return changeExtendedVectorElementType(EltVT);
     }
 
+    /// Return a VT for a type whose attributes match ourselves with the
+    /// exception of the element type that is chosen by the caller.
+    EVT changeElementType(EVT EltVT) const {
+      EltVT = EltVT.getScalarType();
+      return isVector() ? changeVectorElementType(EltVT) : EltVT;
+    }
+
     /// Return the type converted to an equivalently sized integer or vector
     /// with integer element type. Similar to changeVectorElementTypeToInteger,
     /// but also handles scalars.

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7f80e877cb2406..475c73c3588dbc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -368,8 +368,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     addDRTypeForNEON(MVT::v1i64);
     addDRTypeForNEON(MVT::v1f64);
     addDRTypeForNEON(MVT::v4f16);
-    if (Subtarget->hasBF16())
-      addDRTypeForNEON(MVT::v4bf16);
+    addDRTypeForNEON(MVT::v4bf16);
 
     addQRTypeForNEON(MVT::v4f32);
     addQRTypeForNEON(MVT::v2f64);
@@ -378,8 +377,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     addQRTypeForNEON(MVT::v4i32);
     addQRTypeForNEON(MVT::v2i64);
     addQRTypeForNEON(MVT::v8f16);
-    if (Subtarget->hasBF16())
-      addQRTypeForNEON(MVT::v8bf16);
+    addQRTypeForNEON(MVT::v8bf16);
   }
 
   if (Subtarget->hasSVEorSME()) {
@@ -403,11 +401,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass);
     addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass);
 
-    if (Subtarget->hasBF16()) {
-      addRegisterClass(MVT::nxv2bf16, &AArch64::ZPRRegClass);
-      addRegisterClass(MVT::nxv4bf16, &AArch64::ZPRRegClass);
-      addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass);
-    }
+    addRegisterClass(MVT::nxv2bf16, &AArch64::ZPRRegClass);
+    addRegisterClass(MVT::nxv4bf16, &AArch64::ZPRRegClass);
+    addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass);
 
     if (Subtarget->useSVEForFixedLengthVectors()) {
       for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
@@ -437,9 +433,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::GlobalTLSAddress, MVT::i64, Custom);
   setOperationAction(ISD::SETCC, MVT::i32, Custom);
   setOperationAction(ISD::SETCC, MVT::i64, Custom);
+  setOperationAction(ISD::SETCC, MVT::bf16, Custom);
   setOperationAction(ISD::SETCC, MVT::f16, Custom);
   setOperationAction(ISD::SETCC, MVT::f32, Custom);
   setOperationAction(ISD::SETCC, MVT::f64, Custom);
+  setOperationAction(ISD::STRICT_FSETCC, MVT::bf16, Custom);
   setOperationAction(ISD::STRICT_FSETCC, MVT::f16, Custom);
   setOperationAction(ISD::STRICT_FSETCC, MVT::f32, Custom);
   setOperationAction(ISD::STRICT_FSETCC, MVT::f64, Custom);
@@ -463,7 +461,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
-  setOperationAction(ISD::SELECT_CC, MVT::bf16, Expand);
+  setOperationAction(ISD::SELECT_CC, MVT::bf16, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::f32, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::f64, Custom);
   setOperationAction(ISD::BR_JT, MVT::Other, Custom);
@@ -539,12 +537,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i32, Custom);
   setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i64, Custom);
   setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i128, Custom);
-  if (Subtarget->hasFPARMv8())
+  if (Subtarget->hasFPARMv8()) {
     setOperationAction(ISD::FP_ROUND, MVT::f16, Custom);
+    setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
+  }
   setOperationAction(ISD::FP_ROUND, MVT::f32, Custom);
   setOperationAction(ISD::FP_ROUND, MVT::f64, Custom);
-  if (Subtarget->hasFPARMv8())
+  if (Subtarget->hasFPARMv8()) {
     setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Custom);
+    setOperationAction(ISD::STRICT_FP_ROUND, MVT::bf16, Custom);
+  }
   setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Custom);
   setOperationAction(ISD::STRICT_FP_ROUND, MVT::f64, Custom);
 
@@ -678,6 +680,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::FCOPYSIGN, MVT::f16, Custom);
   else
     setOperationAction(ISD::FCOPYSIGN, MVT::f16, Promote);
+  setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
 
   for (auto Op : {ISD::FREM,        ISD::FPOW,         ISD::FPOWI,
                   ISD::FCOS,        ISD::FSIN,         ISD::FSINCOS,
@@ -690,9 +693,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setOperationAction(Op, MVT::f16, Promote);
     setOperationAction(Op, MVT::v4f16, Expand);
     setOperationAction(Op, MVT::v8f16, Expand);
+    setOperationAction(Op, MVT::bf16, Promote);
+    setOperationAction(Op, MVT::v4bf16, Expand);
+    setOperationAction(Op, MVT::v8bf16, Expand);
   }
 
-  if (!Subtarget->hasFullFP16()) {
+  auto LegalizeNarrowFP = [this](MVT ScalarVT) {
     for (auto Op :
          {ISD::SETCC,          ISD::SELECT_CC,
           ISD::BR_CC,          ISD::FADD,           ISD::FSUB,
@@ -708,60 +714,69 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
           ISD::STRICT_FROUND,  ISD::STRICT_FTRUNC,  ISD::STRICT_FROUNDEVEN,
           ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM, ISD::STRICT_FMINIMUM,
           ISD::STRICT_FMAXIMUM})
-      setOperationAction(Op, MVT::f16, Promote);
+      setOperationAction(Op, ScalarVT, Promote);
 
     // Round-to-integer need custom lowering for fp16, as Promote doesn't work
     // because the result type is integer.
     for (auto Op : {ISD::LROUND, ISD::LLROUND, ISD::LRINT, ISD::LLRINT,
                     ISD::STRICT_LROUND, ISD::STRICT_LLROUND, ISD::STRICT_LRINT,
                     ISD::STRICT_LLRINT})
-      setOperationAction(Op, MVT::f16, Custom);
+      setOperationAction(Op, ScalarVT, Custom);
 
     // promote v4f16 to v4f32 when that is known to be safe.
-    setOperationPromotedToType(ISD::FADD, MVT::v4f16, MVT::v4f32);
-    setOperationPromotedToType(ISD::FSUB, MVT::v4f16, MVT::v4f32);
-    setOperationPromotedToType(ISD::FMUL, MVT::v4f16, MVT::v4f32);
-    setOperationPromotedToType(ISD::FDIV, MVT::v4f16, MVT::v4f32);
-
-    setOperationAction(ISD::FABS,        MVT::v4f16, Expand);
-    setOperationAction(ISD::FNEG,        MVT::v4f16, Expand);
-    setOperationAction(ISD::FROUND,      MVT::v4f16, Expand);
-    setOperationAction(ISD::FROUNDEVEN,  MVT::v4f16, Expand);
-    setOperationAction(ISD::FMA,         MVT::v4f16, Expand);
-    setOperationAction(ISD::SETCC,       MVT::v4f16, Custom);
-    setOperationAction(ISD::BR_CC,       MVT::v4f16, Expand);
-    setOperationAction(ISD::SELECT,      MVT::v4f16, Expand);
-    setOperationAction(ISD::SELECT_CC,   MVT::v4f16, Expand);
-    setOperationAction(ISD::FTRUNC,      MVT::v4f16, Expand);
-    setOperationAction(ISD::FCOPYSIGN,   MVT::v4f16, Expand);
-    setOperationAction(ISD::FFLOOR,      MVT::v4f16, Expand);
-    setOperationAction(ISD::FCEIL,       MVT::v4f16, Expand);
-    setOperationAction(ISD::FRINT,       MVT::v4f16, Expand);
-    setOperationAction(ISD::FNEARBYINT,  MVT::v4f16, Expand);
-    setOperationAction(ISD::FSQRT,       MVT::v4f16, Expand);
-
-    setOperationAction(ISD::FABS,        MVT::v8f16, Expand);
-    setOperationAction(ISD::FADD,        MVT::v8f16, Expand);
-    setOperationAction(ISD::FCEIL,       MVT::v8f16, Expand);
-    setOperationAction(ISD::FCOPYSIGN,   MVT::v8f16, Expand);
-    setOperationAction(ISD::FDIV,        MVT::v8f16, Expand);
-    setOperationAction(ISD::FFLOOR,      MVT::v8f16, Expand);
-    setOperationAction(ISD::FMA,         MVT::v8f16, Expand);
-    setOperationAction(ISD::FMUL,        MVT::v8f16, Expand);
-    setOperationAction(ISD::FNEARBYINT,  MVT::v8f16, Expand);
-    setOperationAction(ISD::FNEG,        MVT::v8f16, Expand);
-    setOperationAction(ISD::FROUND,      MVT::v8f16, Expand);
-    setOperationAction(ISD::FROUNDEVEN,  MVT::v8f16, Expand);
-    setOperationAction(ISD::FRINT,       MVT::v8f16, Expand);
-    setOperationAction(ISD::FSQRT,       MVT::v8f16, Expand);
-    setOperationAction(ISD::FSUB,        MVT::v8f16, Expand);
-    setOperationAction(ISD::FTRUNC,      MVT::v8f16, Expand);
-    setOperationAction(ISD::SETCC,       MVT::v8f16, Expand);
-    setOperationAction(ISD::BR_CC,       MVT::v8f16, Expand);
-    setOperationAction(ISD::SELECT,      MVT::v8f16, Expand);
-    setOperationAction(ISD::SELECT_CC,   MVT::v8f16, Expand);
-    setOperationAction(ISD::FP_EXTEND,   MVT::v8f16, Expand);
+    auto V4Narrow = MVT::getVectorVT(ScalarVT, 4);
+    setOperationPromotedToType(ISD::FADD, V4Narrow, MVT::v4f32);
+    setOperationPromotedToType(ISD::FSUB, V4Narrow, MVT::v4f32);
+    setOperationPromotedToType(ISD::FMUL, V4Narrow, MVT::v4f32);
+    setOperationPromotedToType(ISD::FDIV, V4Narrow, MVT::v4f32);
+
+    setOperationAction(ISD::FABS,        V4Narrow, Expand);
+    setOperationAction(ISD::FNEG,        V4Narrow, Expand);
+    setOperationAction(ISD::FROUND,      V4Narrow, Expand);
+    setOperationAction(ISD::FROUNDEVEN,  V4Narrow, Expand);
+    setOperationAction(ISD::FMA,         V4Narrow, Expand);
+    setOperationAction(ISD::SETCC,       V4Narrow, Custom);
+    setOperationAction(ISD::BR_CC,       V4Narrow, Expand);
+    setOperationAction(ISD::SELECT,      V4Narrow, Expand);
+    setOperationAction(ISD::SELECT_CC,   V4Narrow, Expand);
+    setOperationAction(ISD::FTRUNC,      V4Narrow, Expand);
+    setOperationAction(ISD::FCOPYSIGN,   V4Narrow, Expand);
+    setOperationAction(ISD::FFLOOR,      V4Narrow, Expand);
+    setOperationAction(ISD::FCEIL,       V4Narrow, Expand);
+    setOperationAction(ISD::FRINT,       V4Narrow, Expand);
+    setOperationAction(ISD::FNEARBYINT,  V4Narrow, Expand);
+    setOperationAction(ISD::FSQRT,       V4Narrow, Expand);
+
+    auto V8Narrow = MVT::getVectorVT(ScalarVT, 8);
+    setOperationAction(ISD::FABS,        V8Narrow, Expand);
+    setOperationAction(ISD::FADD,        V8Narrow, Expand);
+    setOperationAction(ISD::FCEIL,       V8Narrow, Expand);
+    setOperationAction(ISD::FCOPYSIGN,   V8Narrow, Expand);
+    setOperationAction(ISD::FDIV,        V8Narrow, Expand);
+    setOperationAction(ISD::FFLOOR,      V8Narrow, Expand);
+    setOperationAction(ISD::FMA,         V8Narrow, Expand);
+    setOperationAction(ISD::FMUL,        V8Narrow, Expand);
+    setOperationAction(ISD::FNEARBYINT,  V8Narrow, Expand);
+    setOperationAction(ISD::FNEG,        V8Narrow, Expand);
+    setOperationAction(ISD::FROUND,      V8Narrow, Expand);
+    setOperationAction(ISD::FROUNDEVEN,  V8Narrow, Expand);
+    setOperationAction(ISD::FRINT,       V8Narrow, Expand);
+    setOperationAction(ISD::FSQRT,       V8Narrow, Expand);
+    setOperationAction(ISD::FSUB,        V8Narrow, Expand);
+    setOperationAction(ISD::FTRUNC,      V8Narrow, Expand);
+    setOperationAction(ISD::SETCC,       V8Narrow, Expand);
+    setOperationAction(ISD::BR_CC,       V8Narrow, Expand);
+    setOperationAction(ISD::SELECT,      V8Narrow, Expand);
+    setOperationAction(ISD::SELECT_CC,   V8Narrow, Expand);
+    setOperationAction(ISD::FP_EXTEND,   V8Narrow, Expand);
+  };
+
+  if (!Subtarget->hasFullFP16()) {
+    LegalizeNarrowFP(MVT::f16);
   }
+  LegalizeNarrowFP(MVT::bf16);
+  setOperationAction(ISD::FP_ROUND, MVT::v4f32, Custom);
+  setOperationAction(ISD::FP_ROUND, MVT::v4bf16, Custom);
 
   // AArch64 has implementations of a lot of rounding-like FP operations.
   for (auto Op :
@@ -886,6 +901,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::STORE, MVT::v32i8, Custom);
   setOperationAction(ISD::STORE, MVT::v16i16, Custom);
   setOperationAction(ISD::STORE, MVT::v16f16, Custom);
+  setOperationAction(ISD::STORE, MVT::v16bf16, Custom);
   setOperationAction(ISD::STORE, MVT::v8i32, Custom);
   setOperationAction(ISD::STORE, MVT::v8f32, Custom);
   setOperationAction(ISD::STORE, MVT::v4f64, Custom);
@@ -897,6 +913,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::LOAD, MVT::v32i8, Custom);
   setOperationAction(ISD::LOAD, MVT::v16i16, Custom);
   setOperationAction(ISD::LOAD, MVT::v16f16, Custom);
+  setOperationAction(ISD::LOAD, MVT::v16bf16, Custom);
   setOperationAction(ISD::LOAD, MVT::v8i32, Custom);
   setOperationAction(ISD::LOAD, MVT::v8f32, Custom);
   setOperationAction(ISD::LOAD, MVT::v4f64, Custom);
@@ -931,6 +948,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   // AArch64 does not have floating-point extending loads, i1 sign-extending
   // load, floating-point truncating stores, or v2i32->v2i16 truncating store.
   for (MVT VT : MVT::fp_valuetypes()) {
+    setLoadExtAction(ISD::EXTLOAD, VT, MVT::bf16, Expand);
     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f16, Expand);
     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f32, Expand);
     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f64, Expand);
@@ -939,13 +957,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   for (MVT VT : MVT::integer_valuetypes())
     setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Expand);
 
-  setTruncStoreAction(MVT::f32, MVT::f16, Expand);
-  setTruncStoreAction(MVT::f64, MVT::f32, Expand);
-  setTruncStoreAction(MVT::f64, MVT::f16, Expand);
-  setTruncStoreAction(MVT::f128, MVT::f80, Expand);
-  setTruncStoreAction(MVT::f128, MVT::f64, Expand);
-  setTruncStoreAction(MVT::f128, MVT::f32, Expand);
-  setTruncStoreAction(MVT::f128, MVT::f16, Expand);
+  for (MVT WideVT : MVT::fp_valuetypes()) {
+    for (MVT NarrowVT : MVT::fp_valuetypes()) {
+      if (WideVT.getScalarSizeInBits() > NarrowVT.getScalarSizeInBits()) {
+        setTruncStoreAction(WideVT, NarrowVT, Expand);
+      }
+    }
+  }
 
   if (Subtarget->hasFPARMv8()) {
     setOperationAction(ISD::BITCAST, MVT::i16, Custom);
@@ -1553,7 +1571,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
     }
 
-    if (!Subtarget->isNeonAvailable()) {
+    if (!Subtarget->isNeonAvailable()) {// TODO(majnemer)
       setTruncStoreAction(MVT::v2f32, MVT::v2f16, Custom);
       setTruncStoreAction(MVT::v4f32, MVT::v4f16, Custom);
       setTruncStoreAction(MVT::v8f32, MVT::v8f16, Custom);
@@ -1652,6 +1670,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::FLDEXP, MVT::f64, Custom);
     setOperationAction(ISD::FLDEXP, MVT::f32, Custom);
     setOperationAction(ISD::FLDEXP, MVT::f16, Custom);
+    setOperationAction(ISD::FLDEXP, MVT::bf16, Custom);
   }
 
   PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive();
@@ -2451,6 +2470,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::FCMP)
     MAKE_CASE(AArch64ISD::STRICT_FCMP)
     MAKE_CASE(AArch64ISD::STRICT_FCMPE)
+    MAKE_CASE(AArch64ISD::FCVTXN)
     MAKE_CASE(AArch64ISD::SME_ZA_LDR)
     MAKE_CASE(AArch64ISD::SME_ZA_STR)
     MAKE_CASE(AArch64ISD::DUP)
@@ -3181,7 +3201,7 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &dl,
 
   const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
 
-  if (VT == MVT::f16 && !FullFP16) {
+  if ((VT == MVT::f16 && !FullFP16) || VT == MVT::bf16) {
     LHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
                       {Chain, LHS});
     RHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
@@ -3201,7 +3221,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
 
   if (VT.isFloatingPoint()) {
     assert(VT != MVT::f128);
-    if (VT == MVT::f16 && !FullFP16) {
+    if ((VT == MVT::f16 && !FullFP16) || VT == MVT::bf16) {
       LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
       RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
       VT = MVT::f32;
@@ -3309,7 +3329,8 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
 
   if (LHS.getValueType().isFloatingPoint()) {
     assert(LHS.getValueType() != MVT::f128);
-    if (LHS.getValueType() == MVT::f16 && !FullFP16) {
+    if ((LHS.getValueType() == MVT::f16 && !FullFP16) ||
+        LHS.getValueType() == MVT::bf16) {
       LHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, LHS);
       RHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, RHS);
     }
@@ -4009,16 +4030,73 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
 
 SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
                                              SelectionDAG &DAG) const {
-  if (Op.getValueType().isScalableVector())
+  EVT VT = Op.getValueType();
+  if (VT.isScalableVector())
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
 
   bool IsStrict = Op->isStrictFPOpcode();
   SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
   EVT SrcVT = SrcVal.getValueType();
+  bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;
 
   if (useSVEForFixedLengthVectorVT(SrcVT, !Subtarget->isNeonAvailable()))
     return LowerFixedLengthFPRoundToSVE(Op, DAG);
 
+  // Expand cases where the result type is BF16 but we don't have hardware
+  // instructions to lower it.
+  if (VT.getScalarType() == MVT::bf16 &&
+      !((Subtarget->hasNEON() || Subtarget->hasSME()) &&
+        Subtarget->hasBF16())) {
+    SDLoc dl(Op);
+    SDValue Narrow = SrcVal;
+    SDValue NaN;
+    EVT I32 = SrcVT.changeElementType(MVT::i32);
+    EVT F32 = SrcVT.changeElementType(MVT::f32);
+    if (SrcVT.getScalarType() == MVT::f32) {
+      bool NeverSNaN = DAG.isKnownNeverSNaN(Narrow);
+      Narrow = DAG.getNode(ISD::BITCAST, dl, I32, Narrow);
+      if (!NeverSNaN) {
+        // Set the quiet bit.
+        NaN = DAG.getNode(ISD::OR, dl, I32, Narrow,
+                          DAG.getConstant(0x400000, dl, I32));
+      }
+    } else if (SrcVT.getScalarType() == MVT::f64) {
+      Narrow = DAG.getNode(AArch64ISD::FCVTXN, dl, F32, Narrow);
+      Narrow = DAG.getNode(ISD::BITCAST, dl, I32, Narrow);
+    } else {
+      return SDValue();
+    }
+    if (!Trunc) {
+      SDValue One = DAG.getConstant(1, dl, I32);
+      SDValue Lsb = DAG.getNode(ISD::SRL, dl, I32, Narrow,
+                                DAG.getShiftAmountConstant(16, I32, dl));
+      Lsb = DAG.getNode(ISD::AND, dl, I32, Lsb, One);
+      SDValue RoundingBias =
+          DAG.getNode(ISD::ADD, dl, I32, DAG.getConstant(0x7fff, dl, I32), Lsb);
+      Narrow = DAG.getNode(ISD::ADD, dl, I32, Narrow, RoundingBias);
+    }
+
+    // Don't round if we had a NaN, we don't want to turn 0x7fffffff into
+    // 0x80000000.
+    if (NaN) {
+      SDValue IsNaN = DAG.getSetCC(
+          dl, getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT),
+          SrcVal, SrcVal, ISD::SETUO);
+      Narrow = DAG.getSelect(dl, I32, IsNaN, NaN, Narrow);
+    }
+
+    // Now that we have rounded, shift the bits into position.
+    Narrow = DAG.getNode(ISD::SRL, dl, I32, Narrow,
+                     DAG.getShiftAmountConstant(16, I32, dl));
+    if (VT.isVector()) {
+      EVT I16 = I32.changeVectorElementType(MVT::i16);
+      Narrow = DAG.getNode(ISD::TRUNCATE, dl, I16, Narrow);
+      return DAG.getNode(ISD::BITCAST, dl, VT, Narrow);
+    }
+    Narrow = DAG.getNode(ISD::BITCAST, dl, F32, Narrow);
+    return DAG.getTargetExtractSubreg(AArch64::hsub, dl, VT, Narrow);
+  }
+
   if (SrcVT != MVT::f128) {
     // Expand cases where the input is a vector bigger than NEON.
     if (useSVEForFixedLengthVectorVT(SrcVT))
@@ -4054,8 +4132,8 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
   unsigned NumElts = InVT.getVectorNumElements();
 
   // f16 conversions are promoted to f32 when full fp16 is not supported.
-  if (InVT.getVectorElementType() == MVT::f16 &&
-      !Subtarget->hasFullFP16()) {
+  if ((InVT.getVectorElementType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
+      InVT.getVectorElementType() == MVT::bf16) {
     MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
     SDLoc dl(Op);
     if (IsStrict) {
@@ -4128,7 +4206,8 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
     return LowerVectorFP_TO_INT(Op, DAG);
 
   // f16 conversions are promoted to f32 when full fp16 is not supported.
-  if (SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
+  if ((SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
+      SrcVal.getValueType() == MVT::bf16) {
     SDLoc dl(Op);
     if (IsStrict) {
       SDValue Ext =
@@ -4175,15 +4254,16 @@ AArch64TargetLowering::LowerVectorFP_TO_INT_SAT(SDValue Op,
   EVT SrcElementVT = SrcVT.getVectorElementType();
 
   // In the absence of FP16 support, promote f16 to f32 and saturate the result.
-  if (SrcElementVT == MVT::f16 &&
-      (!Subtarget->hasFullFP16() || DstElementWidth > 16)) {
+  if ((SrcElementVT == MVT::f16 &&
+       (!Subtarget->hasFullFP16() || DstElementWidth > 16)) ||
+      SrcElementVT == MVT::bf16) {
     MVT F32VT = MVT::getVectorVT(MVT::f32, SrcVT.getVectorNumElements());
     SrcVal = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), F32VT, SrcVal);
     SrcVT = F32VT;
     SrcElementVT = MVT::f32;
     SrcElementWidth = 32;
   } else if (SrcElementVT != MVT::f64 && SrcElementVT != MVT::f32 &&
-             SrcElementVT != MVT::f16)
+             SrcElementVT != MVT::f16 && SrcElementVT != MVT::bf16)
     return SDValue();
 
   SDLoc DL(Op);
@@ -4236,10 +4316,11 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
   assert(SatWidth <= DstWidth && "Saturation width cannot exceed result width");
 
   // In the absence of FP16 support, promote f16 to f32 and saturate the result.
-  if (SrcVT == MVT::f16 && !Subtarget->hasFullFP16()) {
+  if ((SrcVT == MVT::f16 && !Subtarget->hasFullFP16()) || SrcVT == MVT::bf16) {
     SrcVal = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, SrcVal);
     SrcVT = MVT::f32;
-  } else if (SrcVT != MVT::f64 && SrcVT != MVT::f32 && SrcVT != MVT::f16)
+  } else if (SrcVT != MVT::f64 && SrcVT != MVT::f32 && SrcVT != MVT::f16 &&
+             SrcVT != MVT::bf16)
     return SDValue();
 
   SDLoc DL(Op);
@@ -4358,17 +4439,17 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
   SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
 
   // f16 conversions are promoted to f32 when full fp16 is not supported.
-  if (Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
+  if ((Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) || Op.getValueType() == MVT::bf16) {
     SDLoc dl(Op);
     if (IsStrict) {
       SDValue Val = DAG.getNode(Op.getOpcode(), dl, {MVT::f32, MVT::Other},
                                 {Op.getOperand(0), SrcVal});
       return DAG.getNode(
-          ISD::STRICT_FP_ROUND, dl, {MVT::f16, MVT::Other},
+          ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
           {Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
     }
     return DAG.getNode(
-        ISD::FP_ROUND, dl, MVT::f16,
+        ISD::FP_ROUND, dl, Op.getValueType(),
         DAG.getNode(Op.getOpcode(), dl, MVT::f32, SrcVal),
         DAG.getIntPtrConstant(0, dl));
   }
@@ -6100,6 +6181,7 @@ static SDValue LowerFLDEXP(SDValue Op, SelectionDAG &DAG) {
   switch (Op.getSimpleValueType().SimpleTy) {
   default:
     return SDValue();
+  case MVT::bf16:
   case MVT::f16:
     X = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, X);
     [[fallthrough]];
@@ -6416,7 +6498,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::LLROUND:
   case ISD::LRINT:
   case ISD::LLRINT: {
-    assert(Op.getOperand(0).getValueType() == MVT::f16 &&
+    assert((Op.getOperand(0).getValueType() == MVT::f16 ||
+            Op.getOperand(0).getValueType() == MVT::bf16) &&
            "Expected custom lowering of rounding operations only for f16");
     SDLoc DL(Op);
     SDValue Ext = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0));
@@ -6426,7 +6509,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::STRICT_LLROUND:
   case ISD::STRICT_LRINT:
   case ISD::STRICT_LLRINT: {
-    assert(Op.getOperand(1).getValueType() == MVT::f16 &&
+    assert((Op.getOperand(1).getValueType() == MVT::f16 ||
+            Op.getOperand(1).getValueType() == MVT::bf16) &&
            "Expected custom lowering of rounding operations only for f16");
     SDLoc DL(Op);
     SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {MVT::f32, MVT::Other},
@@ -9459,8 +9543,8 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
   }
 
   // Now we know we're dealing with FP values.
-  assert(LHS.getValueType() == MVT::f16 || LHS.getValueType() == MVT::f32 ||
-         LHS.getValueType() == MVT::f64);
+  assert(LHS.getValueType() == MVT::bf16 || LHS.getValueType() == MVT::f16 ||
+         LHS.getValueType() == MVT::f32 || LHS.getValueType() == MVT::f64);
 
   // If that fails, we'll need to perform an FCMP + CSEL sequence.  Go ahead
   // and do the comparison.
@@ -9547,7 +9631,8 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
   }
 
   // Also handle f16, for which we need to do a f32 comparison.
-  if (LHS.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
+  if ((LHS.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
+      LHS.getValueType() == MVT::bf16) {
     LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
     RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
   }
@@ -10306,6 +10391,7 @@ bool AArch64TargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
     IsLegal = AArch64_AM::getFP64Imm(ImmInt) != -1 || Imm.isPosZero();
   else if (VT == MVT::f32)
     IsLegal = AArch64_AM::getFP32Imm(ImmInt) != -1 || Imm.isPosZero();
+  // TODO(majnemer): double check this...
   else if (VT == MVT::f16 || VT == MVT::bf16)
     IsLegal =
         (Subtarget->hasFullFP16() && AArch64_AM::getFP16Imm(ImmInt) != -1) ||
@@ -14161,11 +14247,30 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
     return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
   }
 
+  // Lower isnan(x) | isnan(never-nan) to x != x.
+  // Lower !isnan(x) & !isnan(never-nan) to x == x.
+  if (CC == ISD::SETUO || CC == ISD::SETO) {
+    bool OneNaN = false;
+    if (LHS == RHS) {
+      OneNaN = true;
+    } else if (DAG.isKnownNeverNaN(RHS)) {
+      OneNaN = true;
+      RHS = LHS;
+    } else if (DAG.isKnownNeverNaN(LHS)) {
+      OneNaN = true;
+      LHS = RHS;
+    }
+    if (OneNaN) {
+      CC = CC == ISD::SETUO ? ISD::SETUNE : ISD::SETOEQ;
+    }
+  }
+
   const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
 
   // Make v4f16 (only) fcmp operations utilise vector instructions
   // v8f16 support will be a litle more complicated
-  if (!FullFP16 && LHS.getValueType().getVectorElementType() == MVT::f16) {
+  if ((!FullFP16 && LHS.getValueType().getVectorElementType() == MVT::f16) ||
+      LHS.getValueType().getVectorElementType() == MVT::bf16) {
     if (LHS.getValueType().getVectorNumElements() == 4) {
       LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, LHS);
       RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, RHS);
@@ -14177,7 +14282,8 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
   }
 
   assert((!FullFP16 && LHS.getValueType().getVectorElementType() != MVT::f16) ||
-          LHS.getValueType().getVectorElementType() != MVT::f128);
+         LHS.getValueType().getVectorElementType() != MVT::bf16 ||
+         LHS.getValueType().getVectorElementType() != MVT::f128);
 
   // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
   // clean.  Some of them require two branches to implement.
@@ -24684,7 +24790,8 @@ static void ReplaceAddWithADDP(SDNode *N, SmallVectorImpl<SDValue> &Results,
   if (!VT.is256BitVector() ||
       (VT.getScalarType().isFloatingPoint() &&
        !N->getFlags().hasAllowReassociation()) ||
-      (VT.getScalarType() == MVT::f16 && !Subtarget->hasFullFP16()))
+      (VT.getScalarType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
+      VT.getScalarType() == MVT::bf16)
     return;
 
   SDValue X = N->getOperand(0);
@@ -25736,6 +25843,8 @@ bool AArch64TargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT,
   // legalize.
   if (FPVT == MVT::v8f16 && !Subtarget->hasFullFP16())
     return false;
+  if (FPVT == MVT::v8bf16)
+    return false;
   return TargetLowering::shouldConvertFpToSat(Op, FPVT, VT);
 }
 
@@ -25928,6 +26037,8 @@ static EVT getContainerForFixedLengthVector(SelectionDAG &DAG, EVT VT) {
     return EVT(MVT::nxv4i32);
   case MVT::i64:
     return EVT(MVT::nxv2i64);
+  case MVT::bf16:
+    return EVT(MVT::nxv8bf16);
   case MVT::f16:
     return EVT(MVT::nxv8f16);
   case MVT::f32:
@@ -25967,6 +26078,7 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
     break;
   case MVT::i16:
   case MVT::f16:
+  case MVT::bf16:
     MaskVT = MVT::nxv8i1;
     break;
   case MVT::i32:

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index c1fe76c07cba87..68341c199e0a2a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -249,6 +249,9 @@ enum NodeType : unsigned {
   FCMLEz,
   FCMLTz,
 
+  // Round wide FP to narrow FP with inexact results to odd.
+  FCVTXN,
+
   // Vector across-lanes addition
   // Only the lower result lane is defined.
   SADDV,

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 7f8856db6c6e61..091db559a33708 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -7547,7 +7547,7 @@ class BaseSIMDCmpTwoScalar<bit U, bits<2> size, bits<2> size2, bits<5> opcode,
 let mayRaiseFPException = 1, Uses = [FPCR] in
 class SIMDInexactCvtTwoScalar<bits<5> opcode, string asm>
   : I<(outs FPR32:$Rd), (ins FPR64:$Rn), asm, "\t$Rd, $Rn", "",
-     [(set (f32 FPR32:$Rd), (int_aarch64_sisd_fcvtxn (f64 FPR64:$Rn)))]>,
+     [(set (f32 FPR32:$Rd), (AArch64fcvtxn (f64 FPR64:$Rn)))]>,
     Sched<[WriteVd]> {
   bits<5> Rd;
   bits<5> Rn;

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 52137c1f4065bc..c153bb3f0145bc 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -756,6 +756,11 @@ def AArch64fcmgtz: SDNode<"AArch64ISD::FCMGTz", SDT_AArch64fcmpz>;
 def AArch64fcmlez: SDNode<"AArch64ISD::FCMLEz", SDT_AArch64fcmpz>;
 def AArch64fcmltz: SDNode<"AArch64ISD::FCMLTz", SDT_AArch64fcmpz>;
 
+def AArch64fcvtxn_n: SDNode<"AArch64ISD::FCVTXN", SDTFPRoundOp>;
+def AArch64fcvtxn: PatFrags<(ops node:$Rn),
+                            [(f32 (int_aarch64_sisd_fcvtxn (f64 node:$Rn))),
+                             (f32 (AArch64fcvtxn_n (f64 node:$Rn)))]>;
+
 def AArch64bici: SDNode<"AArch64ISD::BICi", SDT_AArch64vecimm>;
 def AArch64orri: SDNode<"AArch64ISD::ORRi", SDT_AArch64vecimm>;
 
@@ -1276,6 +1281,9 @@ def BFMLALTIdx   : SIMDBF16MLALIndex<1, "bfmlalt", int_aarch64_neon_bfmlalt>;
 def BFCVTN       : SIMD_BFCVTN;
 def BFCVTN2      : SIMD_BFCVTN2;
 
+def : Pat<(v4bf16 (any_fpround (v4f32 V128:$Rn))),
+          (EXTRACT_SUBREG (BFCVTN V128:$Rn), dsub)>;
+
 // Vector-scalar BFDOT:
 // The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit
 // register (the instruction uses a single 32-bit lane from it), so the pattern
@@ -1296,6 +1304,8 @@ def : Pat<(v2f32 (int_aarch64_neon_bfdot
 
 let Predicates = [HasNEONorSME, HasBF16] in {
 def BFCVT : BF16ToSinglePrecision<"bfcvt">;
+// Round FP32 to BF16.
+def : Pat<(bf16 (any_fpround (f32 FPR32:$Rn))), (BFCVT $Rn)>;
 }
 
 // ARMv8.6A AArch64 matrix multiplication
@@ -4648,6 +4658,22 @@ let Predicates = [HasFullFP16] in {
 //===----------------------------------------------------------------------===//
 
 defm FCVT : FPConversion<"fcvt">;
+// Helper to get bf16 into fp32.
+def cvt_bf16_to_fp32 :
+  OutPatFrag<(ops node:$Rn),
+             (f32 (COPY_TO_REGCLASS
+	         (i32 (UBFMWri
+		   (i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)),
+		                           node:$Rn, hsub), GPR32)),
+	           (i64 (i32shift_a (i64 16))),
+                   (i64 (i32shift_b (i64 16))))),
+	         FPR32))>;
+// Pattern for bf16 -> fp32.
+def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))),
+          (cvt_bf16_to_fp32 FPR16:$Rn)>;
+// Pattern for bf16 -> fp64.
+def : Pat<(f64 (any_fpextend (bf16 FPR16:$Rn))),
+          (FCVTDSr (f32 (cvt_bf16_to_fp32 FPR16:$Rn)))>;
 
 //===----------------------------------------------------------------------===//
 // Floating point single operand instructions.
@@ -5002,6 +5028,9 @@ defm FCVTNU : SIMDTwoVectorFPToInt<1,0,0b11010, "fcvtnu",int_aarch64_neon_fcvtnu
 defm FCVTN  : SIMDFPNarrowTwoVector<0, 0, 0b10110, "fcvtn">;
 def : Pat<(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn))),
           (FCVTNv4i16 V128:$Rn)>;
+//def : Pat<(concat_vectors V64:$Rd,
+//                          (v4bf16 (any_fpround (v4f32 V128:$Rn)))),
+//          (FCVTNv8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
 def : Pat<(concat_vectors V64:$Rd,
                           (v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn)))),
           (FCVTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
@@ -5686,6 +5715,11 @@ defm USQADD : SIMDTwoScalarBHSDTied< 1, 0b00011, "usqadd",
 def : Pat<(v1i64 (AArch64vashr (v1i64 V64:$Rn), (i32 63))),
           (CMLTv1i64rz V64:$Rn)>;
 
+// Round FP64 to BF16.
+let Predicates = [HasNEONorSME, HasBF16] in
+def : Pat<(bf16 (any_fpround (f64 FPR64:$Rn))),
+          (BFCVT (FCVTXNv1i64 $Rn))>;
+
 def : Pat<(v1i64 (int_aarch64_neon_fcvtas (v1f64 FPR64:$Rn))),
           (FCVTASv1i64 FPR64:$Rn)>;
 def : Pat<(v1i64 (int_aarch64_neon_fcvtau (v1f64 FPR64:$Rn))),
@@ -7698,6 +7732,9 @@ def : Pat<(v4i32 (anyext (v4i16 V64:$Rn))), (USHLLv4i16_shift V64:$Rn, (i32 0))>
 def : Pat<(v2i64 (sext   (v2i32 V64:$Rn))), (SSHLLv2i32_shift V64:$Rn, (i32 0))>;
 def : Pat<(v2i64 (zext   (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
 def : Pat<(v2i64 (anyext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
+// Vector bf16 -> fp32 is implemented morally as a zext + shift.
+def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))),
+          (USHLLv4i16_shift V64:$Rn, (i32 16))>;
 // Also match an extend from the upper half of a 128 bit source register.
 def : Pat<(v8i16 (anyext (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn)) ))),
           (USHLLv16i8_shift V128:$Rn, (i32 0))>;

diff  --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
index 9bc5815ae05371..5a8031641ae099 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
@@ -1022,13 +1022,17 @@ void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
 
   bool Invert = false;
   AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
-  if (Pred == CmpInst::Predicate::FCMP_ORD && IsZero) {
+  if ((Pred == CmpInst::Predicate::FCMP_ORD ||
+       Pred == CmpInst::Predicate::FCMP_UNO) &&
+      IsZero) {
     // The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
     // NaN, so equivalent to a == a and doesn't need the two comparisons an
     // "ord" normally would.
+    // Similarly, "fcmp uno %a, 0" is the canonical check that LHS is NaN and is
+    // thus equivalent to a != a.
     RHS = LHS;
     IsZero = false;
-    CC = AArch64CC::EQ;
+    CC = Pred == CmpInst::Predicate::FCMP_ORD ? AArch64CC::EQ : AArch64CC::NE;
   } else
     changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert);
 

diff  --git a/llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll b/llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll
index 954a836e44f283..a68c21f7943432 100644
--- a/llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll
+++ b/llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll
@@ -13,7 +13,7 @@ define void @strict_fp_reductions() {
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 28 for instruction: %fadd_v8f32 = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> undef)
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %fadd_v2f64 = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> undef)
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f64 = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 18 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 20 for instruction: %fadd_v4f128 = call fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
 ;
@@ -24,7 +24,7 @@ define void @strict_fp_reductions() {
 ; FP16-NEXT:  Cost Model: Found an estimated cost of 28 for instruction: %fadd_v8f32 = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> undef)
 ; FP16-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %fadd_v2f64 = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> undef)
 ; FP16-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f64 = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
-; FP16-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
+; FP16-NEXT:  Cost Model: Found an estimated cost of 18 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
 ; FP16-NEXT:  Cost Model: Found an estimated cost of 20 for instruction: %fadd_v4f128 = call fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
 ; FP16-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
 ;
@@ -72,7 +72,7 @@ define void @fast_fp_reductions() {
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 5 for instruction: %fadd_v4f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 9 for instruction: %fadd_v7f64 = call fast double @llvm.vector.reduce.fadd.v7f64(double 0.000000e+00, <7 x double> undef)
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 15 for instruction: %fadd_v9f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v9f64(double 0.000000e+00, <9 x double> undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 10 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f128 = call reassoc fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
 ;
@@ -95,7 +95,7 @@ define void @fast_fp_reductions() {
 ; FP16-NEXT:  Cost Model: Found an estimated cost of 5 for instruction: %fadd_v4f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
 ; FP16-NEXT:  Cost Model: Found an estimated cost of 9 for instruction: %fadd_v7f64 = call fast double @llvm.vector.reduce.fadd.v7f64(double 0.000000e+00, <7 x double> undef)
 ; FP16-NEXT:  Cost Model: Found an estimated cost of 15 for instruction: %fadd_v9f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v9f64(double 0.000000e+00, <9 x double> undef)
-; FP16-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
+; FP16-NEXT:  Cost Model: Found an estimated cost of 10 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
 ; FP16-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f128 = call reassoc fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
 ; FP16-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
 ;

diff  --git a/llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir b/llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir
index 8f01c009c4967a..1f5fb892df5820 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir
@@ -321,18 +321,15 @@ body:             |
   bb.0:
     liveins: $q0, $q1
 
-    ; Should be inverted. Needs two compares.
 
     ; CHECK-LABEL: name: uno_zero
     ; CHECK: liveins: $q0, $q1
     ; CHECK-NEXT: {{  $}}
     ; CHECK-NEXT: %lhs:_(<2 x s64>) = COPY $q0
-    ; CHECK-NEXT: [[FCMGEZ:%[0-9]+]]:_(<2 x s64>) = G_FCMGEZ %lhs
-    ; CHECK-NEXT: [[FCMLTZ:%[0-9]+]]:_(<2 x s64>) = G_FCMLTZ %lhs
-    ; CHECK-NEXT: [[OR:%[0-9]+]]:_(<2 x s64>) = G_OR [[FCMLTZ]], [[FCMGEZ]]
+    ; CHECK-NEXT: [[FCMEQ:%[0-9]+]]:_(<2 x s64>) = G_FCMEQ %lhs, %lhs(<2 x s64>)
     ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 -1
     ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<2 x s64>) = G_BUILD_VECTOR [[C]](s64), [[C]](s64)
-    ; CHECK-NEXT: [[XOR:%[0-9]+]]:_(<2 x s64>) = G_XOR [[OR]], [[BUILD_VECTOR]]
+    ; CHECK-NEXT: [[XOR:%[0-9]+]]:_(<2 x s64>) = G_XOR [[FCMEQ]], [[BUILD_VECTOR]]
     ; CHECK-NEXT: $q0 = COPY [[XOR]](<2 x s64>)
     ; CHECK-NEXT: RET_ReallyLR implicit $q0
     %lhs:_(<2 x s64>) = COPY $q0

diff  --git a/llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll b/llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll
index a949eaac5cfa29..adde5429a6d93d 100644
--- a/llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll
+++ b/llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll
@@ -187,10 +187,7 @@ entry:
 define <8 x bfloat> @insertzero_v4bf16(<4 x bfloat> %a) {
 ; CHECK-LABEL: insertzero_v4bf16:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi d4, #0000000000000000
-; CHECK-NEXT:    movi d5, #0000000000000000
-; CHECK-NEXT:    movi d6, #0000000000000000
-; CHECK-NEXT:    movi d7, #0000000000000000
+; CHECK-NEXT:    fmov d0, d0
 ; CHECK-NEXT:    ret
 entry:
   %shuffle.i = shufflevector <4 x bfloat> %a, <4 x bfloat> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>

diff  --git a/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll b/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
index 765c81e26e13ca..632b6b32625022 100644
--- a/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
@@ -3057,55 +3057,34 @@ define <2 x i64> @fcmonez2xdouble(<2 x double> %A) {
   ret <2 x i64> %tmp4
 }
 
-; ORD with zero = OLT | OGE
+; ORD A, zero = EQ A, A
 define <2 x i32> @fcmordz2xfloat(<2 x float> %A) {
-; CHECK-SD-LABEL: fcmordz2xfloat:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    fcmge v1.2s, v0.2s, #0.0
-; CHECK-SD-NEXT:    fcmlt v0.2s, v0.2s, #0.0
-; CHECK-SD-NEXT:    orr v0.8b, v0.8b, v1.8b
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: fcmordz2xfloat:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    fcmeq v0.2s, v0.2s, v0.2s
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: fcmordz2xfloat:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    fcmeq v0.2s, v0.2s, v0.2s
+; CHECK-NEXT:    ret
   %tmp3 = fcmp ord <2 x float> %A, zeroinitializer
   %tmp4 = sext <2 x i1> %tmp3 to <2 x i32>
   ret <2 x i32> %tmp4
 }
 
-; ORD with zero = OLT | OGE
+; ORD A, zero = EQ A, A
 define <4 x i32> @fcmordz4xfloat(<4 x float> %A) {
-; CHECK-SD-LABEL: fcmordz4xfloat:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    fcmge v1.4s, v0.4s, #0.0
-; CHECK-SD-NEXT:    fcmlt v0.4s, v0.4s, #0.0
-; CHECK-SD-NEXT:    orr v0.16b, v0.16b, v1.16b
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: fcmordz4xfloat:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: fcmordz4xfloat:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-NEXT:    ret
   %tmp3 = fcmp ord <4 x float> %A, zeroinitializer
   %tmp4 = sext <4 x i1> %tmp3 to <4 x i32>
   ret <4 x i32> %tmp4
 }
 
-; ORD with zero = OLT | OGE
+; ORD A, zero = EQ A, A
 define <2 x i64> @fcmordz2xdouble(<2 x double> %A) {
-; CHECK-SD-LABEL: fcmordz2xdouble:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    fcmge v1.2d, v0.2d, #0.0
-; CHECK-SD-NEXT:    fcmlt v0.2d, v0.2d, #0.0
-; CHECK-SD-NEXT:    orr v0.16b, v0.16b, v1.16b
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: fcmordz2xdouble:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    fcmeq v0.2d, v0.2d, v0.2d
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: fcmordz2xdouble:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    fcmeq v0.2d, v0.2d, v0.2d
+; CHECK-NEXT:    ret
   %tmp3 = fcmp ord <2 x double> %A, zeroinitializer
   %tmp4 = sext <2 x i1> %tmp3 to <2 x i64>
   ret <2 x i64> %tmp4
@@ -3331,13 +3310,11 @@ define <2 x i64> @fcmunez2xdouble(<2 x double> %A) {
   ret <2 x i64> %tmp4
 }
 
-; UNO with zero = !ORD = !(OLT | OGE)
+; UNO A, zero = !(ORD A, zero) = !(EQ A, A)
 define <2 x i32> @fcmunoz2xfloat(<2 x float> %A) {
 ; CHECK-LABEL: fcmunoz2xfloat:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fcmge v1.2s, v0.2s, #0.0
-; CHECK-NEXT:    fcmlt v0.2s, v0.2s, #0.0
-; CHECK-NEXT:    orr v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    fcmeq v0.2s, v0.2s, v0.2s
 ; CHECK-NEXT:    mvn v0.8b, v0.8b
 ; CHECK-NEXT:    ret
   %tmp3 = fcmp uno <2 x float> %A, zeroinitializer
@@ -3345,13 +3322,11 @@ define <2 x i32> @fcmunoz2xfloat(<2 x float> %A) {
   ret <2 x i32> %tmp4
 }
 
-; UNO with zero = !ORD = !(OLT | OGE)
+; UNO A, zero = !(ORD A, zero) = !(EQ A, A)
 define <4 x i32> @fcmunoz4xfloat(<4 x float> %A) {
 ; CHECK-LABEL: fcmunoz4xfloat:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fcmge v1.4s, v0.4s, #0.0
-; CHECK-NEXT:    fcmlt v0.4s, v0.4s, #0.0
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
 ; CHECK-NEXT:    mvn v0.16b, v0.16b
 ; CHECK-NEXT:    ret
   %tmp3 = fcmp uno <4 x float> %A, zeroinitializer
@@ -3359,13 +3334,11 @@ define <4 x i32> @fcmunoz4xfloat(<4 x float> %A) {
   ret <4 x i32> %tmp4
 }
 
-; UNO with zero = !ORD = !(OLT | OGE)
+; UNO A, zero = !(ORD A, zero) = !(EQ A, A)
 define <2 x i64> @fcmunoz2xdouble(<2 x double> %A) {
 ; CHECK-LABEL: fcmunoz2xdouble:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fcmge v1.2d, v0.2d, #0.0
-; CHECK-NEXT:    fcmlt v0.2d, v0.2d, #0.0
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    fcmeq v0.2d, v0.2d, v0.2d
 ; CHECK-NEXT:    mvn v0.16b, v0.16b
 ; CHECK-NEXT:    ret
   %tmp3 = fcmp uno <2 x double> %A, zeroinitializer
@@ -4133,51 +4106,30 @@ define <2 x i64> @fcmonez2xdouble_fast(<2 x double> %A) {
 }
 
 define <2 x i32> @fcmordz2xfloat_fast(<2 x float> %A) {
-; CHECK-SD-LABEL: fcmordz2xfloat_fast:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    fcmge v1.2s, v0.2s, #0.0
-; CHECK-SD-NEXT:    fcmlt v0.2s, v0.2s, #0.0
-; CHECK-SD-NEXT:    orr v0.8b, v0.8b, v1.8b
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: fcmordz2xfloat_fast:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    fcmeq v0.2s, v0.2s, v0.2s
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: fcmordz2xfloat_fast:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    fcmeq v0.2s, v0.2s, v0.2s
+; CHECK-NEXT:    ret
   %tmp3 = fcmp fast ord <2 x float> %A, zeroinitializer
   %tmp4 = sext <2 x i1> %tmp3 to <2 x i32>
   ret <2 x i32> %tmp4
 }
 
 define <4 x i32> @fcmordz4xfloat_fast(<4 x float> %A) {
-; CHECK-SD-LABEL: fcmordz4xfloat_fast:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    fcmge v1.4s, v0.4s, #0.0
-; CHECK-SD-NEXT:    fcmlt v0.4s, v0.4s, #0.0
-; CHECK-SD-NEXT:    orr v0.16b, v0.16b, v1.16b
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: fcmordz4xfloat_fast:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: fcmordz4xfloat_fast:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-NEXT:    ret
   %tmp3 = fcmp fast ord <4 x float> %A, zeroinitializer
   %tmp4 = sext <4 x i1> %tmp3 to <4 x i32>
   ret <4 x i32> %tmp4
 }
 
 define <2 x i64> @fcmordz2xdouble_fast(<2 x double> %A) {
-; CHECK-SD-LABEL: fcmordz2xdouble_fast:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    fcmge v1.2d, v0.2d, #0.0
-; CHECK-SD-NEXT:    fcmlt v0.2d, v0.2d, #0.0
-; CHECK-SD-NEXT:    orr v0.16b, v0.16b, v1.16b
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: fcmordz2xdouble_fast:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    fcmeq v0.2d, v0.2d, v0.2d
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: fcmordz2xdouble_fast:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    fcmeq v0.2d, v0.2d, v0.2d
+; CHECK-NEXT:    ret
   %tmp3 = fcmp fast ord <2 x double> %A, zeroinitializer
   %tmp4 = sext <2 x i1> %tmp3 to <2 x i64>
   ret <2 x i64> %tmp4
@@ -4466,9 +4418,7 @@ define <2 x i64> @fcmunez2xdouble_fast(<2 x double> %A) {
 define <2 x i32> @fcmunoz2xfloat_fast(<2 x float> %A) {
 ; CHECK-LABEL: fcmunoz2xfloat_fast:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fcmge v1.2s, v0.2s, #0.0
-; CHECK-NEXT:    fcmlt v0.2s, v0.2s, #0.0
-; CHECK-NEXT:    orr v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    fcmeq v0.2s, v0.2s, v0.2s
 ; CHECK-NEXT:    mvn v0.8b, v0.8b
 ; CHECK-NEXT:    ret
   %tmp3 = fcmp fast uno <2 x float> %A, zeroinitializer
@@ -4479,9 +4429,7 @@ define <2 x i32> @fcmunoz2xfloat_fast(<2 x float> %A) {
 define <4 x i32> @fcmunoz4xfloat_fast(<4 x float> %A) {
 ; CHECK-LABEL: fcmunoz4xfloat_fast:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fcmge v1.4s, v0.4s, #0.0
-; CHECK-NEXT:    fcmlt v0.4s, v0.4s, #0.0
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
 ; CHECK-NEXT:    mvn v0.16b, v0.16b
 ; CHECK-NEXT:    ret
   %tmp3 = fcmp fast uno <4 x float> %A, zeroinitializer
@@ -4492,9 +4440,7 @@ define <4 x i32> @fcmunoz4xfloat_fast(<4 x float> %A) {
 define <2 x i64> @fcmunoz2xdouble_fast(<2 x double> %A) {
 ; CHECK-LABEL: fcmunoz2xdouble_fast:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fcmge v1.2d, v0.2d, #0.0
-; CHECK-NEXT:    fcmlt v0.2d, v0.2d, #0.0
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    fcmeq v0.2d, v0.2d, v0.2d
 ; CHECK-NEXT:    mvn v0.16b, v0.16b
 ; CHECK-NEXT:    ret
   %tmp3 = fcmp fast uno <2 x double> %A, zeroinitializer

diff  --git a/llvm/test/CodeGen/AArch64/round-fptosi-sat-scalar.ll b/llvm/test/CodeGen/AArch64/round-fptosi-sat-scalar.ll
index 4f23df1d168197..ec7548e1e65410 100644
--- a/llvm/test/CodeGen/AArch64/round-fptosi-sat-scalar.ll
+++ b/llvm/test/CodeGen/AArch64/round-fptosi-sat-scalar.ll
@@ -4,6 +4,54 @@
 
 ; Round towards minus infinity (fcvtms).
 
+define i32 @testmswbf(bfloat %a) {
+; CHECK-LABEL: testmswbf:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $s0
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    mov w8, #32767 // =0x7fff
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    fmov s0, w9
+; CHECK-NEXT:    frintm s0, s0
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    ubfx w10, w9, #16, #1
+; CHECK-NEXT:    add w8, w9, w8
+; CHECK-NEXT:    add w8, w10, w8
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    fcvtzs w0, s0
+; CHECK-NEXT:    ret
+entry:
+  %r = call bfloat @llvm.floor.bf16(bfloat %a) nounwind readnone
+  %i = call i32 @llvm.fptosi.sat.i32.bf16(bfloat %r)
+  ret i32 %i
+}
+
+define i64 @testmsxbf(bfloat %a) {
+; CHECK-LABEL: testmsxbf:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $s0
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    mov w8, #32767 // =0x7fff
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    fmov s0, w9
+; CHECK-NEXT:    frintm s0, s0
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    ubfx w10, w9, #16, #1
+; CHECK-NEXT:    add w8, w9, w8
+; CHECK-NEXT:    add w8, w10, w8
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    fcvtzs x0, s0
+; CHECK-NEXT:    ret
+entry:
+  %r = call bfloat @llvm.floor.bf16(bfloat %a) nounwind readnone
+  %i = call i64 @llvm.fptosi.sat.i64.bf16(bfloat %r)
+  ret i64 %i
+}
+
 define i32 @testmswh(half %a) {
 ; CHECK-CVT-LABEL: testmswh:
 ; CHECK-CVT:       // %bb.0: // %entry
@@ -90,6 +138,54 @@ entry:
 
 ; Round towards plus infinity (fcvtps).
 
+define i32 @testpswbf(bfloat %a) {
+; CHECK-LABEL: testpswbf:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $s0
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    mov w8, #32767 // =0x7fff
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    fmov s0, w9
+; CHECK-NEXT:    frintp s0, s0
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    ubfx w10, w9, #16, #1
+; CHECK-NEXT:    add w8, w9, w8
+; CHECK-NEXT:    add w8, w10, w8
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    fcvtzs w0, s0
+; CHECK-NEXT:    ret
+entry:
+  %r = call bfloat @llvm.ceil.bf16(bfloat %a) nounwind readnone
+  %i = call i32 @llvm.fptosi.sat.i32.bf16(bfloat %r)
+  ret i32 %i
+}
+
+define i64 @testpsxbf(bfloat %a) {
+; CHECK-LABEL: testpsxbf:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $s0
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    mov w8, #32767 // =0x7fff
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    fmov s0, w9
+; CHECK-NEXT:    frintp s0, s0
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    ubfx w10, w9, #16, #1
+; CHECK-NEXT:    add w8, w9, w8
+; CHECK-NEXT:    add w8, w10, w8
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    fcvtzs x0, s0
+; CHECK-NEXT:    ret
+entry:
+  %r = call bfloat @llvm.ceil.bf16(bfloat %a) nounwind readnone
+  %i = call i64 @llvm.fptosi.sat.i64.bf16(bfloat %r)
+  ret i64 %i
+}
+
 define i32 @testpswh(half %a) {
 ; CHECK-CVT-LABEL: testpswh:
 ; CHECK-CVT:       // %bb.0: // %entry
@@ -346,6 +442,8 @@ entry:
   ret i64 %i
 }
 
+declare i32 @llvm.fptosi.sat.i32.bf16 (bfloat)
+declare i64 @llvm.fptosi.sat.i64.bf16 (bfloat)
 declare i32 @llvm.fptosi.sat.i32.f16 (half)
 declare i64 @llvm.fptosi.sat.i64.f16 (half)
 declare i32 @llvm.fptosi.sat.i32.f32 (float)
@@ -353,6 +451,10 @@ declare i64 @llvm.fptosi.sat.i64.f32 (float)
 declare i32 @llvm.fptosi.sat.i32.f64 (double)
 declare i64 @llvm.fptosi.sat.i64.f64 (double)
 
+declare bfloat @llvm.floor.bf16(bfloat) nounwind readnone
+declare bfloat @llvm.ceil.bf16(bfloat) nounwind readnone
+declare bfloat @llvm.trunc.bf16(bfloat) nounwind readnone
+declare bfloat @llvm.round.bf16(bfloat) nounwind readnone
 declare half @llvm.floor.f16(half) nounwind readnone
 declare half @llvm.ceil.f16(half) nounwind readnone
 declare half @llvm.trunc.f16(half) nounwind readnone

diff  --git a/llvm/test/CodeGen/AArch64/vector-fcopysign.ll b/llvm/test/CodeGen/AArch64/vector-fcopysign.ll
index c7134508883b11..d01ca881545c02 100644
--- a/llvm/test/CodeGen/AArch64/vector-fcopysign.ll
+++ b/llvm/test/CodeGen/AArch64/vector-fcopysign.ll
@@ -477,4 +477,443 @@ define <8 x half> @test_copysign_v8f16_v8f32(<8 x half> %a, <8 x float> %b) #0 {
 
 declare <8 x half> @llvm.copysign.v8f16(<8 x half> %a, <8 x half> %b) #0
 
+;============ v4bf16
+
+define <4 x bfloat> @test_copysign_v4bf16_v4bf16(<4 x bfloat> %a, <4 x bfloat> %b) #0 {
+; CHECK-LABEL: test_copysign_v4bf16_v4bf16:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    ; kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    mov h3, v1[1]
+; CHECK-NEXT:    mov h4, v0[1]
+; CHECK-NEXT:    fmov w8, s1
+; CHECK-NEXT:    mov h5, v1[2]
+; CHECK-NEXT:    mov h6, v0[2]
+; CHECK-NEXT:    fmov w11, s0
+; CHECK-NEXT:    mvni.4s v2, #128, lsl #24
+; CHECK-NEXT:    mov h1, v1[3]
+; CHECK-NEXT:    mov h0, v0[3]
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    fmov w9, s3
+; CHECK-NEXT:    lsl w11, w11, #16
+; CHECK-NEXT:    fmov w10, s4
+; CHECK-NEXT:    fmov s7, w8
+; CHECK-NEXT:    fmov w8, s5
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    fmov s3, w9
+; CHECK-NEXT:    fmov s4, w10
+; CHECK-NEXT:    fmov w9, s6
+; CHECK-NEXT:    fmov w10, s1
+; CHECK-NEXT:    bit.16b v3, v4, v2
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    fmov s4, w11
+; CHECK-NEXT:    fmov w11, s0
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    fmov s1, w9
+; CHECK-NEXT:    bif.16b v4, v7, v2
+; CHECK-NEXT:    fmov w8, s3
+; CHECK-NEXT:    lsl w11, w11, #16
+; CHECK-NEXT:    bif.16b v1, v0, v2
+; CHECK-NEXT:    fmov s5, w11
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fmov w9, s4
+; CHECK-NEXT:    fmov s4, w10
+; CHECK-NEXT:    fmov s3, w8
+; CHECK-NEXT:    fmov w8, s1
+; CHECK-NEXT:    mov.16b v1, v2
+; CHECK-NEXT:    lsr w9, w9, #16
+; CHECK-NEXT:    bsl.16b v1, v5, v4
+; CHECK-NEXT:    fmov s0, w9
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fmov s2, w8
+; CHECK-NEXT:    mov.h v0[1], v3[0]
+; CHECK-NEXT:    fmov w8, s1
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    mov.h v0[2], v2[0]
+; CHECK-NEXT:    fmov s1, w8
+; CHECK-NEXT:    mov.h v0[3], v1[0]
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    ret
+  %r = call <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %a, <4 x bfloat> %b)
+  ret <4 x bfloat> %r
+}
+
+define <4 x bfloat> @test_copysign_v4bf16_v4f32(<4 x bfloat> %a, <4 x float> %b) #0 {
+; CHECK-LABEL: test_copysign_v4bf16_v4f32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    movi.4s v2, #127, msl #8
+; CHECK-NEXT:    movi.4s v3, #1
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    ushr.4s v4, v1, #16
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    mov h5, v0[2]
+; CHECK-NEXT:    mov h6, v0[3]
+; CHECK-NEXT:    add.4s v2, v1, v2
+; CHECK-NEXT:    and.16b v3, v4, v3
+; CHECK-NEXT:    fcmeq.4s v4, v1, v1
+; CHECK-NEXT:    orr.4s v1, #64, lsl #16
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    add.4s v2, v3, v2
+; CHECK-NEXT:    mov h3, v0[1]
+; CHECK-NEXT:    bit.16b v1, v2, v4
+; CHECK-NEXT:    fmov w8, s3
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    shrn.4h v2, v1, #16
+; CHECK-NEXT:    mvni.4s v1, #128, lsl #24
+; CHECK-NEXT:    fmov s3, w8
+; CHECK-NEXT:    fmov w8, s5
+; CHECK-NEXT:    fmov s5, w9
+; CHECK-NEXT:    mov h4, v2[1]
+; CHECK-NEXT:    mov h0, v2[2]
+; CHECK-NEXT:    fmov w11, s2
+; CHECK-NEXT:    mov h2, v2[3]
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    lsl w11, w11, #16
+; CHECK-NEXT:    fmov w10, s4
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    fmov w8, s2
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    fmov s4, w10
+; CHECK-NEXT:    fmov w10, s6
+; CHECK-NEXT:    fmov s2, w9
+; CHECK-NEXT:    bif.16b v3, v4, v1
+; CHECK-NEXT:    fmov s4, w11
+; CHECK-NEXT:    bit.16b v2, v0, v1
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    bit.16b v4, v5, v1
+; CHECK-NEXT:    fmov s5, w8
+; CHECK-NEXT:    fmov w9, s3
+; CHECK-NEXT:    fmov w8, s2
+; CHECK-NEXT:    fmov w11, s4
+; CHECK-NEXT:    fmov s4, w10
+; CHECK-NEXT:    lsr w9, w9, #16
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fmov s3, w9
+; CHECK-NEXT:    lsr w11, w11, #16
+; CHECK-NEXT:    bsl.16b v1, v4, v5
+; CHECK-NEXT:    fmov s2, w8
+; CHECK-NEXT:    fmov s0, w11
+; CHECK-NEXT:    fmov w8, s1
+; CHECK-NEXT:    mov.h v0[1], v3[0]
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    mov.h v0[2], v2[0]
+; CHECK-NEXT:    fmov s1, w8
+; CHECK-NEXT:    mov.h v0[3], v1[0]
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    ret
+  %tmp0 = fptrunc <4 x float> %b to <4 x bfloat>
+  %r = call <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %a, <4 x bfloat> %tmp0)
+  ret <4 x bfloat> %r
+}
+
+define <4 x bfloat> @test_copysign_v4bf16_v4f64(<4 x bfloat> %a, <4 x double> %b) #0 {
+; CHECK-LABEL: test_copysign_v4bf16_v4f64:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    mov h4, v0[1]
+; CHECK-NEXT:    mov h5, v0[2]
+; CHECK-NEXT:    mov d3, v1[1]
+; CHECK-NEXT:    fcvt s1, d1
+; CHECK-NEXT:    mov h0, v0[3]
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    fmov w9, s4
+; CHECK-NEXT:    mvni.4s v4, #128, lsl #24
+; CHECK-NEXT:    fmov s6, w8
+; CHECK-NEXT:    fmov w8, s5
+; CHECK-NEXT:    fcvt s3, d3
+; CHECK-NEXT:    fmov w10, s0
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    bit.16b v1, v6, v4
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    mov d6, v2[1]
+; CHECK-NEXT:    fmov s7, w9
+; CHECK-NEXT:    fcvt s2, d2
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    fmov s5, w8
+; CHECK-NEXT:    fmov w8, s1
+; CHECK-NEXT:    mov.16b v1, v4
+; CHECK-NEXT:    bit.16b v3, v7, v4
+; CHECK-NEXT:    bsl.16b v1, v5, v2
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fcvt s2, d6
+; CHECK-NEXT:    fmov w9, s3
+; CHECK-NEXT:    fmov s5, w10
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    fmov w8, s1
+; CHECK-NEXT:    mov.16b v1, v4
+; CHECK-NEXT:    lsr w9, w9, #16
+; CHECK-NEXT:    fmov s3, w9
+; CHECK-NEXT:    bsl.16b v1, v5, v2
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    mov.h v0[1], v3[0]
+; CHECK-NEXT:    fmov s2, w8
+; CHECK-NEXT:    fmov w8, s1
+; CHECK-NEXT:    mov.h v0[2], v2[0]
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fmov s1, w8
+; CHECK-NEXT:    mov.h v0[3], v1[0]
+; CHECK-NEXT:    ; kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    ret
+  %tmp0 = fptrunc <4 x double> %b to <4 x bfloat>
+  %r = call <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %a, <4 x bfloat> %tmp0)
+  ret <4 x bfloat> %r
+}
+
+declare <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %a, <4 x bfloat> %b) #0
+
+;============ v8bf16
+
+define <8 x bfloat> @test_copysign_v8bf16_v8bf16(<8 x bfloat> %a, <8 x bfloat> %b) #0 {
+; CHECK-LABEL: test_copysign_v8bf16_v8bf16:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    fmov w8, s1
+; CHECK-NEXT:    mov h2, v1[1]
+; CHECK-NEXT:    mov h4, v0[1]
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    mov h6, v1[2]
+; CHECK-NEXT:    mov h7, v0[2]
+; CHECK-NEXT:    mvni.4s v3, #128, lsl #24
+; CHECK-NEXT:    mov h5, v1[3]
+; CHECK-NEXT:    mov h16, v0[3]
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    mov h17, v1[4]
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    fmov w10, s4
+; CHECK-NEXT:    mov h4, v0[4]
+; CHECK-NEXT:    fmov s18, w8
+; CHECK-NEXT:    fmov w8, s2
+; CHECK-NEXT:    fmov w11, s7
+; CHECK-NEXT:    fmov s2, w9
+; CHECK-NEXT:    lsl w9, w10, #16
+; CHECK-NEXT:    fmov w10, s6
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    fmov s7, w9
+; CHECK-NEXT:    bif.16b v2, v18, v3
+; CHECK-NEXT:    lsl w9, w11, #16
+; CHECK-NEXT:    fmov s6, w8
+; CHECK-NEXT:    lsl w8, w10, #16
+; CHECK-NEXT:    fmov w10, s5
+; CHECK-NEXT:    fmov w11, s16
+; CHECK-NEXT:    fmov s16, w9
+; CHECK-NEXT:    mov h18, v0[5]
+; CHECK-NEXT:    fmov s5, w8
+; CHECK-NEXT:    bit.16b v6, v7, v3
+; CHECK-NEXT:    fmov w8, s2
+; CHECK-NEXT:    lsl w9, w10, #16
+; CHECK-NEXT:    lsl w10, w11, #16
+; CHECK-NEXT:    mov h7, v1[5]
+; CHECK-NEXT:    bit.16b v5, v16, v3
+; CHECK-NEXT:    fmov s16, w10
+; CHECK-NEXT:    fmov w10, s4
+; CHECK-NEXT:    mov.16b v4, v3
+; CHECK-NEXT:    fmov w11, s6
+; CHECK-NEXT:    fmov s6, w9
+; CHECK-NEXT:    fmov w9, s17
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    lsr w11, w11, #16
+; CHECK-NEXT:    fmov s2, w8
+; CHECK-NEXT:    lsl w8, w9, #16
+; CHECK-NEXT:    bsl.16b v4, v16, v6
+; CHECK-NEXT:    lsl w9, w10, #16
+; CHECK-NEXT:    fmov w10, s5
+; CHECK-NEXT:    fmov s6, w11
+; CHECK-NEXT:    fmov s5, w8
+; CHECK-NEXT:    lsr w8, w10, #16
+; CHECK-NEXT:    fmov w10, s7
+; CHECK-NEXT:    mov.h v2[1], v6[0]
+; CHECK-NEXT:    fmov s6, w9
+; CHECK-NEXT:    fmov w9, s18
+; CHECK-NEXT:    fmov s7, w8
+; CHECK-NEXT:    fmov w8, s4
+; CHECK-NEXT:    mov h4, v1[6]
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    mov h1, v1[7]
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    bit.16b v5, v6, v3
+; CHECK-NEXT:    mov h6, v0[6]
+; CHECK-NEXT:    mov.h v2[2], v7[0]
+; CHECK-NEXT:    fmov s7, w10
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fmov s16, w9
+; CHECK-NEXT:    fmov w9, s4
+; CHECK-NEXT:    mov h0, v0[7]
+; CHECK-NEXT:    fmov w10, s6
+; CHECK-NEXT:    bit.16b v7, v16, v3
+; CHECK-NEXT:    fmov s16, w8
+; CHECK-NEXT:    fmov w8, s5
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    mov.h v2[3], v16[0]
+; CHECK-NEXT:    fmov s5, w9
+; CHECK-NEXT:    fmov w9, s1
+; CHECK-NEXT:    fmov s4, w8
+; CHECK-NEXT:    fmov w8, s7
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    mov.h v2[4], v4[0]
+; CHECK-NEXT:    fmov s4, w10
+; CHECK-NEXT:    fmov w10, s0
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fmov s1, w9
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    bif.16b v4, v5, v3
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    fmov s5, w10
+; CHECK-NEXT:    mov.h v2[5], v0[0]
+; CHECK-NEXT:    mov.16b v0, v3
+; CHECK-NEXT:    fmov w8, s4
+; CHECK-NEXT:    bsl.16b v0, v5, v1
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fmov s1, w8
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    mov.h v2[6], v1[0]
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    mov.h v2[7], v0[0]
+; CHECK-NEXT:    mov.16b v0, v2
+; CHECK-NEXT:    ret
+  %r = call <8 x bfloat> @llvm.copysign.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b)
+  ret <8 x bfloat> %r
+}
+
+define <8 x bfloat> @test_copysign_v8bf16_v8f32(<8 x bfloat> %a, <8 x float> %b) #0 {
+; CHECK-LABEL: test_copysign_v8bf16_v8f32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    movi.4s v3, #127, msl #8
+; CHECK-NEXT:    movi.4s v4, #1
+; CHECK-NEXT:    ushr.4s v5, v1, #16
+; CHECK-NEXT:    fcmeq.4s v7, v1, v1
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    add.4s v6, v1, v3
+; CHECK-NEXT:    and.16b v5, v5, v4
+; CHECK-NEXT:    orr.4s v1, #64, lsl #16
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    add.4s v5, v5, v6
+; CHECK-NEXT:    ushr.4s v6, v2, #16
+; CHECK-NEXT:    and.16b v4, v6, v4
+; CHECK-NEXT:    mov h6, v0[2]
+; CHECK-NEXT:    bit.16b v1, v5, v7
+; CHECK-NEXT:    add.4s v7, v2, v3
+; CHECK-NEXT:    mov h5, v0[1]
+; CHECK-NEXT:    fcmeq.4s v3, v2, v2
+; CHECK-NEXT:    orr.4s v2, #64, lsl #16
+; CHECK-NEXT:    shrn.4h v1, v1, #16
+; CHECK-NEXT:    add.4s v4, v4, v7
+; CHECK-NEXT:    fmov w8, s5
+; CHECK-NEXT:    mov h7, v0[3]
+; CHECK-NEXT:    mov h5, v0[4]
+; CHECK-NEXT:    mov h16, v1[1]
+; CHECK-NEXT:    fmov w10, s1
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    bsl.16b v3, v4, v2
+; CHECK-NEXT:    mov h4, v1[2]
+; CHECK-NEXT:    mov h17, v1[3]
+; CHECK-NEXT:    mvni.4s v2, #128, lsl #24
+; CHECK-NEXT:    fmov s1, w9
+; CHECK-NEXT:    fmov w9, s6
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    fmov s6, w8
+; CHECK-NEXT:    fmov w8, s7
+; CHECK-NEXT:    fmov w11, s16
+; CHECK-NEXT:    fmov s7, w10
+; CHECK-NEXT:    fmov w10, s4
+; CHECK-NEXT:    mov.16b v4, v2
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    shrn.4h v3, v3, #16
+; CHECK-NEXT:    lsl w11, w11, #16
+; CHECK-NEXT:    bif.16b v1, v7, v2
+; CHECK-NEXT:    fmov s16, w8
+; CHECK-NEXT:    fmov s7, w11
+; CHECK-NEXT:    bsl.16b v4, v6, v7
+; CHECK-NEXT:    fmov s7, w9
+; CHECK-NEXT:    lsl w9, w10, #16
+; CHECK-NEXT:    fmov w10, s17
+; CHECK-NEXT:    mov h6, v0[5]
+; CHECK-NEXT:    lsl w8, w10, #16
+; CHECK-NEXT:    fmov w10, s1
+; CHECK-NEXT:    fmov s1, w9
+; CHECK-NEXT:    lsr w9, w10, #16
+; CHECK-NEXT:    fmov w10, s4
+; CHECK-NEXT:    fmov s4, w8
+; CHECK-NEXT:    bif.16b v7, v1, v2
+; CHECK-NEXT:    fmov w8, s5
+; CHECK-NEXT:    mov h5, v3[1]
+; CHECK-NEXT:    fmov s1, w9
+; CHECK-NEXT:    fmov w9, s3
+; CHECK-NEXT:    lsr w10, w10, #16
+; CHECK-NEXT:    bit.16b v4, v16, v2
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    fmov s16, w10
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    fmov w10, s7
+; CHECK-NEXT:    mov h7, v0[6]
+; CHECK-NEXT:    mov h0, v0[7]
+; CHECK-NEXT:    mov.h v1[1], v16[0]
+; CHECK-NEXT:    fmov s16, w8
+; CHECK-NEXT:    fmov w8, s6
+; CHECK-NEXT:    fmov s6, w9
+; CHECK-NEXT:    fmov w9, s5
+; CHECK-NEXT:    lsr w10, w10, #16
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    bit.16b v6, v16, v2
+; CHECK-NEXT:    fmov s16, w10
+; CHECK-NEXT:    fmov w10, s4
+; CHECK-NEXT:    fmov s5, w8
+; CHECK-NEXT:    fmov w8, s7
+; CHECK-NEXT:    fmov s7, w9
+; CHECK-NEXT:    mov h4, v3[2]
+; CHECK-NEXT:    mov h3, v3[3]
+; CHECK-NEXT:    mov.h v1[2], v16[0]
+; CHECK-NEXT:    lsr w10, w10, #16
+; CHECK-NEXT:    fmov w9, s6
+; CHECK-NEXT:    lsl w8, w8, #16
+; CHECK-NEXT:    bif.16b v5, v7, v2
+; CHECK-NEXT:    fmov s16, w10
+; CHECK-NEXT:    fmov w10, s4
+; CHECK-NEXT:    fmov s4, w8
+; CHECK-NEXT:    lsr w9, w9, #16
+; CHECK-NEXT:    mov.h v1[3], v16[0]
+; CHECK-NEXT:    fmov w8, s5
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    fmov s6, w9
+; CHECK-NEXT:    fmov w9, s0
+; CHECK-NEXT:    fmov s5, w10
+; CHECK-NEXT:    fmov w10, s3
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    mov.h v1[4], v6[0]
+; CHECK-NEXT:    lsl w9, w9, #16
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    bif.16b v4, v5, v2
+; CHECK-NEXT:    lsl w10, w10, #16
+; CHECK-NEXT:    fmov s3, w9
+; CHECK-NEXT:    fmov s5, w10
+; CHECK-NEXT:    mov.h v1[5], v0[0]
+; CHECK-NEXT:    mov.16b v0, v2
+; CHECK-NEXT:    fmov w8, s4
+; CHECK-NEXT:    bsl.16b v0, v3, v5
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fmov s2, w8
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    mov.h v1[6], v2[0]
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    fmov s0, w8
+; CHECK-NEXT:    mov.h v1[7], v0[0]
+; CHECK-NEXT:    mov.16b v0, v1
+; CHECK-NEXT:    ret
+  %tmp0 = fptrunc <8 x float> %b to <8 x bfloat>
+  %r = call <8 x bfloat> @llvm.copysign.v8bf16(<8 x bfloat> %a, <8 x bfloat> %tmp0)
+  ret <8 x bfloat> %r
+}
+
+declare <8 x bfloat> @llvm.copysign.v8bf16(<8 x bfloat> %a, <8 x bfloat> %b) #0
+
 attributes #0 = { nounwind }


        


More information about the llvm-commits mailing list