[llvm] bc18193 - [X86][RFC] Using `__bf16` for AVX512_BF16 intrinsics

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 19 08:47:43 PDT 2022


Author: Phoebe Wang
Date: 2022-10-19T23:47:04+08:00
New Revision: bc1819389fb4701cdeba5e093278e32dd668d6d5

URL: https://github.com/llvm/llvm-project/commit/bc1819389fb4701cdeba5e093278e32dd668d6d5
DIFF: https://github.com/llvm/llvm-project/commit/bc1819389fb4701cdeba5e093278e32dd668d6d5.diff

LOG: [X86][RFC] Using `__bf16` for AVX512_BF16 intrinsics

This is an alternative of D120395 and D120411.

Previously we use `__bfloat16` as a typedef of `unsigned short`. The
name may give user an impression it is a brand new type to represent
BF16. So that they may use it in arithmetic operations and we don't have
a good way to block it.

To solve the problem, we introduced `__bf16` to X86 psABI and landed the
support in Clang by D130964. Now we can solve the problem by switching
intrinsics to the new type.

Reviewed By: LuoYuanke, RKSimon

Differential Revision: https://reviews.llvm.org/D132329

Added: 
    clang/test/CodeGen/X86/avx512bf16-error.c
    llvm/test/CodeGen/X86/avx512bf16-intrinsics-upgrade.ll
    llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics-upgrade.ll

Modified: 
    clang/docs/ReleaseNotes.rst
    clang/include/clang/Basic/BuiltinsX86.def
    clang/lib/CodeGen/CGBuiltin.cpp
    clang/lib/Headers/avx512bf16intrin.h
    clang/lib/Headers/avx512vlbf16intrin.h
    clang/test/CodeGen/X86/avx512bf16-builtins.c
    clang/test/CodeGen/X86/avx512vlbf16-builtins.c
    llvm/include/llvm/IR/Intrinsics.td
    llvm/include/llvm/IR/IntrinsicsX86.td
    llvm/lib/IR/AutoUpgrade.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/lib/Target/X86/X86InstrAVX512.td
    llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
    llvm/lib/Target/X86/X86RegisterInfo.td
    llvm/test/CodeGen/X86/avx512bf16-intrinsics.ll
    llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll
    llvm/test/CodeGen/X86/bfloat.ll
    llvm/test/CodeGen/X86/stack-folding-avx512bf16.ll

Removed: 
    


################################################################################
diff  --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index cb704266df407..60f6562cee67c 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -583,6 +583,7 @@ X86 Support in Clang
 --------------------
 - Support ``-mindirect-branch-cs-prefix`` for call and jmp to indirect thunk.
 - Fix 32-bit ``__fastcall`` and ``__vectorcall`` ABI mismatch with MSVC.
+- Switch ``AVX512-BF16`` intrinsics types from ``short`` to ``__bf16``.
 
 DWARF Support in Clang
 ----------------------

diff  --git a/clang/include/clang/Basic/BuiltinsX86.def b/clang/include/clang/Basic/BuiltinsX86.def
index ad8509e6124d4..8ff0eaec35736 100644
--- a/clang/include/clang/Basic/BuiltinsX86.def
+++ b/clang/include/clang/Basic/BuiltinsX86.def
@@ -1749,16 +1749,16 @@ TARGET_BUILTIN(__builtin_ia32_vpmultishiftqb128, "V16cV16cV16c", "ncV:128:", "av
 TARGET_BUILTIN(__builtin_ia32_vpmultishiftqb256, "V32cV32cV32c", "ncV:256:", "avx512vbmi,avx512vl")
 
 // bf16 intrinsics
-TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_128, "V8sV4fV4f", "ncV:128:", "avx512bf16,avx512vl")
-TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_256, "V16sV8fV8f", "ncV:256:", "avx512bf16,avx512vl")
-TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_512, "V32sV16fV16f", "ncV:512:", "avx512bf16")
-TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_128_mask, "V8sV4fV8sUc", "ncV:128:", "avx512bf16,avx512vl")
-TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_256_mask, "V8sV8fV8sUc", "ncV:256:", "avx512bf16,avx512vl")
-TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_512_mask, "V16sV16fV16sUs", "ncV:512:", "avx512bf16")
-TARGET_BUILTIN(__builtin_ia32_dpbf16ps_128, "V4fV4fV4iV4i", "ncV:128:", "avx512bf16,avx512vl")
-TARGET_BUILTIN(__builtin_ia32_dpbf16ps_256, "V8fV8fV8iV8i", "ncV:256:", "avx512bf16,avx512vl")
-TARGET_BUILTIN(__builtin_ia32_dpbf16ps_512, "V16fV16fV16iV16i", "ncV:512:", "avx512bf16")
-TARGET_BUILTIN(__builtin_ia32_cvtsbf162ss_32, "fUs", "nc", "avx512bf16")
+TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_128, "V8yV4fV4f", "ncV:128:", "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_256, "V16yV8fV8f", "ncV:256:", "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_512, "V32yV16fV16f", "ncV:512:", "avx512bf16")
+TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_128_mask, "V8yV4fV8yUc", "ncV:128:", "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_256_mask, "V8yV8fV8yUc", "ncV:256:", "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_512_mask, "V16yV16fV16yUs", "ncV:512:", "avx512bf16")
+TARGET_BUILTIN(__builtin_ia32_dpbf16ps_128, "V4fV4fV8yV8y", "ncV:128:", "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_dpbf16ps_256, "V8fV8fV16yV16y", "ncV:256:", "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_dpbf16ps_512, "V16fV16fV32yV32y", "ncV:512:", "avx512bf16")
+TARGET_BUILTIN(__builtin_ia32_cvtsbf162ss_32, "fy", "nc", "avx512bf16")
 
 TARGET_BUILTIN(__builtin_ia32_vp2intersect_q_512, "vV8OiV8OiUc*Uc*", "nV:512:", "avx512vp2intersect")
 TARGET_BUILTIN(__builtin_ia32_vp2intersect_q_256, "vV4OiV4OiUc*Uc*", "nV:256:", "avx512vp2intersect,avx512vl")
@@ -1977,6 +1977,9 @@ TARGET_BUILTIN(__builtin_ia32_selectd_512, "V16iUsV16iV16i", "ncV:512:", "avx512
 TARGET_BUILTIN(__builtin_ia32_selectph_128, "V8xUcV8xV8x", "ncV:128:", "avx512fp16,avx512vl")
 TARGET_BUILTIN(__builtin_ia32_selectph_256, "V16xUsV16xV16x", "ncV:256:", "avx512fp16,avx512vl")
 TARGET_BUILTIN(__builtin_ia32_selectph_512, "V32xUiV32xV32x", "ncV:512:", "avx512fp16")
+TARGET_BUILTIN(__builtin_ia32_selectpbf_128, "V8yUcV8yV8y", "ncV:128:", "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_selectpbf_256, "V16yUsV16yV16y", "ncV:256:", "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_selectpbf_512, "V32yUiV32yV32y", "ncV:512:", "avx512bf16")
 TARGET_BUILTIN(__builtin_ia32_selectq_128, "V2OiUcV2OiV2Oi", "ncV:128:", "avx512vl")
 TARGET_BUILTIN(__builtin_ia32_selectq_256, "V4OiUcV4OiV4Oi", "ncV:256:", "avx512vl")
 TARGET_BUILTIN(__builtin_ia32_selectq_512, "V8OiUcV8OiV8Oi", "ncV:512:", "avx512f")
@@ -1987,6 +1990,7 @@ TARGET_BUILTIN(__builtin_ia32_selectpd_128, "V2dUcV2dV2d", "ncV:128:", "avx512vl
 TARGET_BUILTIN(__builtin_ia32_selectpd_256, "V4dUcV4dV4d", "ncV:256:", "avx512vl")
 TARGET_BUILTIN(__builtin_ia32_selectpd_512, "V8dUcV8dV8d", "ncV:512:", "avx512f")
 TARGET_BUILTIN(__builtin_ia32_selectsh_128, "V8xUcV8xV8x", "ncV:128:", "avx512fp16")
+TARGET_BUILTIN(__builtin_ia32_selectsbf_128, "V8yUcV8yV8y", "ncV:128:", "avx512bf16")
 TARGET_BUILTIN(__builtin_ia32_selectss_128, "V4fUcV4fV4f", "ncV:128:", "avx512f")
 TARGET_BUILTIN(__builtin_ia32_selectsd_128, "V2dUcV2dV2d", "ncV:128:", "avx512f")
 

diff  --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index cfae2f05b0d46..6bf021888c88b 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -12873,18 +12873,6 @@ static Value *EmitX86CvtF16ToFloatExpr(CodeGenFunction &CGF,
   return Res;
 }
 
-// Convert a BF16 to a float.
-static Value *EmitX86CvtBF16ToFloatExpr(CodeGenFunction &CGF,
-                                        const CallExpr *E,
-                                        ArrayRef<Value *> Ops) {
-  llvm::Type *Int32Ty = CGF.Builder.getInt32Ty();
-  Value *ZeroExt = CGF.Builder.CreateZExt(Ops[0], Int32Ty);
-  Value *Shl = CGF.Builder.CreateShl(ZeroExt, 16);
-  llvm::Type *ResultType = CGF.ConvertType(E->getType());
-  Value *BitCast = CGF.Builder.CreateBitCast(Shl, ResultType);
-  return BitCast;
-}
-
 Value *CodeGenFunction::EmitX86CpuIs(StringRef CPUStr) {
 
   llvm::Type *Int32Ty = Builder.getInt32Ty();
@@ -14291,6 +14279,9 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
   case X86::BI__builtin_ia32_selectph_128:
   case X86::BI__builtin_ia32_selectph_256:
   case X86::BI__builtin_ia32_selectph_512:
+  case X86::BI__builtin_ia32_selectpbf_128:
+  case X86::BI__builtin_ia32_selectpbf_256:
+  case X86::BI__builtin_ia32_selectpbf_512:
   case X86::BI__builtin_ia32_selectps_128:
   case X86::BI__builtin_ia32_selectps_256:
   case X86::BI__builtin_ia32_selectps_512:
@@ -14299,6 +14290,7 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
   case X86::BI__builtin_ia32_selectpd_512:
     return EmitX86Select(*this, Ops[0], Ops[1], Ops[2]);
   case X86::BI__builtin_ia32_selectsh_128:
+  case X86::BI__builtin_ia32_selectsbf_128:
   case X86::BI__builtin_ia32_selectss_128:
   case X86::BI__builtin_ia32_selectsd_128: {
     Value *A = Builder.CreateExtractElement(Ops[1], (uint64_t)0);
@@ -15135,7 +15127,7 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
     return EmitX86CvtF16ToFloatExpr(*this, Ops, ConvertType(E->getType()));
   }
 
-// AVX512 bf16 intrinsics
+  // AVX512 bf16 intrinsics
   case X86::BI__builtin_ia32_cvtneps2bf16_128_mask: {
     Ops[2] = getMaskVecValue(
         *this, Ops[2],
@@ -15144,7 +15136,7 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
     return Builder.CreateCall(CGM.getIntrinsic(IID), Ops);
   }
   case X86::BI__builtin_ia32_cvtsbf162ss_32:
-    return EmitX86CvtBF16ToFloatExpr(*this, E, Ops);
+    return Builder.CreateFPExt(Ops[0], Builder.getFloatTy());
 
   case X86::BI__builtin_ia32_cvtneps2bf16_256_mask:
   case X86::BI__builtin_ia32_cvtneps2bf16_512_mask: {

diff  --git a/clang/lib/Headers/avx512bf16intrin.h b/clang/lib/Headers/avx512bf16intrin.h
index 09653738d40ab..4fc99951ed3ba 100644
--- a/clang/lib/Headers/avx512bf16intrin.h
+++ b/clang/lib/Headers/avx512bf16intrin.h
@@ -10,12 +10,16 @@
 #error "Never use <avx512bf16intrin.h> directly; include <immintrin.h> instead."
 #endif
 
+#ifdef __SSE2__
+
 #ifndef __AVX512BF16INTRIN_H
 #define __AVX512BF16INTRIN_H
 
-typedef short __m512bh __attribute__((__vector_size__(64), __aligned__(64)));
-typedef short __m256bh __attribute__((__vector_size__(32), __aligned__(32)));
-typedef unsigned short __bfloat16;
+typedef __bf16 __v32bf __attribute__((__vector_size__(64), __aligned__(64)));
+typedef __bf16 __m512bh __attribute__((__vector_size__(64), __aligned__(64)));
+typedef __bf16 __v16bf __attribute__((__vector_size__(32), __aligned__(32)));
+typedef __bf16 __m256bh __attribute__((__vector_size__(32), __aligned__(32)));
+typedef __bf16 __bfloat16 __attribute__((deprecated("use __bf16 instead")));
 
 #define __DEFAULT_FN_ATTRS512 \
   __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"), \
@@ -33,7 +37,7 @@ typedef unsigned short __bfloat16;
 ///    A bfloat data.
 /// \returns A float data whose sign field and exponent field keep unchanged,
 ///    and fraction field is extended to 23 bits.
-static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bfloat16 __A) {
+static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bf16 __A) {
   return __builtin_ia32_cvtsbf162ss_32(__A);
 }
 
@@ -74,9 +78,9 @@ _mm512_cvtne2ps_pbh(__m512 __A, __m512 __B) {
 ///    conversion of __B, and higher 256 bits come from conversion of __A.
 static __inline__ __m512bh __DEFAULT_FN_ATTRS512
 _mm512_mask_cvtne2ps_pbh(__m512bh __W, __mmask32 __U, __m512 __A, __m512 __B) {
-  return (__m512bh)__builtin_ia32_selectw_512((__mmask32)__U,
-                                        (__v32hi)_mm512_cvtne2ps_pbh(__A, __B),
-                                        (__v32hi)__W);
+  return (__m512bh)__builtin_ia32_selectpbf_512((__mmask32)__U,
+                                        (__v32bf)_mm512_cvtne2ps_pbh(__A, __B),
+                                        (__v32bf)__W);
 }
 
 /// Convert Two Packed Single Data to One Packed BF16 Data.
@@ -96,9 +100,9 @@ _mm512_mask_cvtne2ps_pbh(__m512bh __W, __mmask32 __U, __m512 __A, __m512 __B) {
 ///    conversion of __B, and higher 256 bits come from conversion of __A.
 static __inline__ __m512bh __DEFAULT_FN_ATTRS512
 _mm512_maskz_cvtne2ps_pbh(__mmask32 __U, __m512 __A, __m512 __B) {
-  return (__m512bh)__builtin_ia32_selectw_512((__mmask32)__U,
-                                        (__v32hi)_mm512_cvtne2ps_pbh(__A, __B),
-                                        (__v32hi)_mm512_setzero_si512());
+  return (__m512bh)__builtin_ia32_selectpbf_512((__mmask32)__U,
+                                        (__v32bf)_mm512_cvtne2ps_pbh(__A, __B),
+                                        (__v32bf)_mm512_setzero_si512());
 }
 
 /// Convert Packed Single Data to Packed BF16 Data.
@@ -113,7 +117,7 @@ _mm512_maskz_cvtne2ps_pbh(__mmask32 __U, __m512 __A, __m512 __B) {
 static __inline__ __m256bh __DEFAULT_FN_ATTRS512
 _mm512_cvtneps_pbh(__m512 __A) {
   return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A,
-                                              (__v16hi)_mm256_undefined_si256(),
+                                              (__v16bf)_mm256_undefined_si256(),
                                               (__mmask16)-1);
 }
 
@@ -134,7 +138,7 @@ _mm512_cvtneps_pbh(__m512 __A) {
 static __inline__ __m256bh __DEFAULT_FN_ATTRS512
 _mm512_mask_cvtneps_pbh(__m256bh __W, __mmask16 __U, __m512 __A) {
   return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A,
-                                                        (__v16hi)__W,
+                                                        (__v16bf)__W,
                                                         (__mmask16)__U);
 }
 
@@ -153,7 +157,7 @@ _mm512_mask_cvtneps_pbh(__m256bh __W, __mmask16 __U, __m512 __A) {
 static __inline__ __m256bh __DEFAULT_FN_ATTRS512
 _mm512_maskz_cvtneps_pbh(__mmask16 __U, __m512 __A) {
   return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A,
-                                                (__v16hi)_mm256_setzero_si256(),
+                                                (__v16bf)_mm256_setzero_si256(),
                                                 (__mmask16)__U);
 }
 
@@ -174,8 +178,8 @@ _mm512_maskz_cvtneps_pbh(__mmask16 __U, __m512 __A) {
 static __inline__ __m512 __DEFAULT_FN_ATTRS512
 _mm512_dpbf16_ps(__m512 __D, __m512bh __A, __m512bh __B) {
   return (__m512)__builtin_ia32_dpbf16ps_512((__v16sf) __D,
-                                             (__v16si) __A,
-                                             (__v16si) __B);
+                                             (__v32bf) __A,
+                                             (__v32bf) __B);
 }
 
 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
@@ -277,3 +281,4 @@ _mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __U, __m256bh __A) {
 #undef __DEFAULT_FN_ATTRS512
 
 #endif
+#endif

diff  --git a/clang/lib/Headers/avx512vlbf16intrin.h b/clang/lib/Headers/avx512vlbf16intrin.h
index 1cdbb28484acf..b76927d6737a3 100644
--- a/clang/lib/Headers/avx512vlbf16intrin.h
+++ b/clang/lib/Headers/avx512vlbf16intrin.h
@@ -10,10 +10,13 @@
 #error "Never use <avx512vlbf16intrin.h> directly; include <immintrin.h> instead."
 #endif
 
+#ifdef __SSE2__
+
 #ifndef __AVX512VLBF16INTRIN_H
 #define __AVX512VLBF16INTRIN_H
 
-typedef short __m128bh __attribute__((__vector_size__(16), __aligned__(16)));
+typedef __bf16 __v8bf __attribute__((__vector_size__(16), __aligned__(16)));
+typedef __bf16 __m128bh __attribute__((__vector_size__(16), __aligned__(16)));
 
 #define __DEFAULT_FN_ATTRS128 \
   __attribute__((__always_inline__, __nodebug__, \
@@ -59,9 +62,9 @@ _mm_cvtne2ps_pbh(__m128 __A, __m128 __B) {
 ///    conversion of __B, and higher 64 bits come from conversion of __A.
 static __inline__ __m128bh __DEFAULT_FN_ATTRS128
 _mm_mask_cvtne2ps_pbh(__m128bh __W, __mmask8 __U, __m128 __A, __m128 __B) {
-  return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
-                                             (__v8hi)_mm_cvtne2ps_pbh(__A, __B),
-                                             (__v8hi)__W);
+  return (__m128bh)__builtin_ia32_selectpbf_128((__mmask8)__U,
+                                             (__v8bf)_mm_cvtne2ps_pbh(__A, __B),
+                                             (__v8bf)__W);
 }
 
 /// Convert Two Packed Single Data to One Packed BF16 Data.
@@ -81,9 +84,9 @@ _mm_mask_cvtne2ps_pbh(__m128bh __W, __mmask8 __U, __m128 __A, __m128 __B) {
 ///    conversion of __B, and higher 64 bits come from conversion of __A.
 static __inline__ __m128bh __DEFAULT_FN_ATTRS128
 _mm_maskz_cvtne2ps_pbh(__mmask8 __U, __m128 __A, __m128 __B) {
-  return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
-                                             (__v8hi)_mm_cvtne2ps_pbh(__A, __B),
-                                             (__v8hi)_mm_setzero_si128());
+  return (__m128bh)__builtin_ia32_selectpbf_128((__mmask8)__U,
+                                             (__v8bf)_mm_cvtne2ps_pbh(__A, __B),
+                                             (__v8bf)_mm_setzero_si128());
 }
 
 /// Convert Two Packed Single Data to One Packed BF16 Data.
@@ -123,9 +126,9 @@ _mm256_cvtne2ps_pbh(__m256 __A, __m256 __B) {
 ///    conversion of __B, and higher 128 bits come from conversion of __A.
 static __inline__ __m256bh __DEFAULT_FN_ATTRS256
 _mm256_mask_cvtne2ps_pbh(__m256bh __W, __mmask16 __U, __m256 __A, __m256 __B) {
-  return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
-                                         (__v16hi)_mm256_cvtne2ps_pbh(__A, __B),
-                                         (__v16hi)__W);
+  return (__m256bh)__builtin_ia32_selectpbf_256((__mmask16)__U,
+                                         (__v16bf)_mm256_cvtne2ps_pbh(__A, __B),
+                                         (__v16bf)__W);
 }
 
 /// Convert Two Packed Single Data to One Packed BF16 Data.
@@ -145,9 +148,9 @@ _mm256_mask_cvtne2ps_pbh(__m256bh __W, __mmask16 __U, __m256 __A, __m256 __B) {
 ///    conversion of __B, and higher 128 bits come from conversion of __A.
 static __inline__ __m256bh __DEFAULT_FN_ATTRS256
 _mm256_maskz_cvtne2ps_pbh(__mmask16 __U, __m256 __A, __m256 __B) {
-  return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
-                                         (__v16hi)_mm256_cvtne2ps_pbh(__A, __B),
-                                         (__v16hi)_mm256_setzero_si256());
+  return (__m256bh)__builtin_ia32_selectpbf_256((__mmask16)__U,
+                                         (__v16bf)_mm256_cvtne2ps_pbh(__A, __B),
+                                         (__v16bf)_mm256_setzero_si256());
 }
 
 /// Convert Packed Single Data to Packed BF16 Data.
@@ -163,7 +166,7 @@ _mm256_maskz_cvtne2ps_pbh(__mmask16 __U, __m256 __A, __m256 __B) {
 static __inline__ __m128bh __DEFAULT_FN_ATTRS128
 _mm_cvtneps_pbh(__m128 __A) {
   return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
-                                                  (__v8hi)_mm_undefined_si128(),
+                                                  (__v8bf)_mm_undefined_si128(),
                                                   (__mmask8)-1);
 }
 
@@ -185,7 +188,7 @@ _mm_cvtneps_pbh(__m128 __A) {
 static __inline__ __m128bh __DEFAULT_FN_ATTRS128
 _mm_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m128 __A) {
   return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
-                                                        (__v8hi)__W,
+                                                        (__v8bf)__W,
                                                         (__mmask8)__U);
 }
 
@@ -205,7 +208,7 @@ _mm_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m128 __A) {
 static __inline__ __m128bh __DEFAULT_FN_ATTRS128
 _mm_maskz_cvtneps_pbh(__mmask8 __U, __m128 __A) {
   return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
-                                                    (__v8hi)_mm_setzero_si128(),
+                                                    (__v8bf)_mm_setzero_si128(),
                                                     (__mmask8)__U);
 }
 
@@ -221,7 +224,7 @@ _mm_maskz_cvtneps_pbh(__mmask8 __U, __m128 __A) {
 static __inline__ __m128bh __DEFAULT_FN_ATTRS256
 _mm256_cvtneps_pbh(__m256 __A) {
   return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
-                                                  (__v8hi)_mm_undefined_si128(),
+                                                  (__v8bf)_mm_undefined_si128(),
                                                   (__mmask8)-1);
 }
 
@@ -242,7 +245,7 @@ _mm256_cvtneps_pbh(__m256 __A) {
 static __inline__ __m128bh __DEFAULT_FN_ATTRS256
 _mm256_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m256 __A) {
   return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
-                                                        (__v8hi)__W,
+                                                        (__v8bf)__W,
                                                         (__mmask8)__U);
 }
 
@@ -261,7 +264,7 @@ _mm256_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m256 __A) {
 static __inline__ __m128bh __DEFAULT_FN_ATTRS256
 _mm256_maskz_cvtneps_pbh(__mmask8 __U, __m256 __A) {
   return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
-                                                    (__v8hi)_mm_setzero_si128(),
+                                                    (__v8bf)_mm_setzero_si128(),
                                                     (__mmask8)__U);
 }
 
@@ -282,8 +285,8 @@ _mm256_maskz_cvtneps_pbh(__mmask8 __U, __m256 __A) {
 static __inline__ __m128 __DEFAULT_FN_ATTRS128
 _mm_dpbf16_ps(__m128 __D, __m128bh __A, __m128bh __B) {
   return (__m128)__builtin_ia32_dpbf16ps_128((__v4sf)__D,
-                                             (__v4si)__A,
-                                             (__v4si)__B);
+                                             (__v8bf)__A,
+                                             (__v8bf)__B);
 }
 
 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
@@ -351,8 +354,8 @@ _mm_maskz_dpbf16_ps(__mmask8 __U, __m128 __D, __m128bh __A, __m128bh __B) {
 static __inline__ __m256 __DEFAULT_FN_ATTRS256
 _mm256_dpbf16_ps(__m256 __D, __m256bh __A, __m256bh __B) {
   return (__m256)__builtin_ia32_dpbf16ps_256((__v8sf)__D,
-                                             (__v8si)__A,
-                                             (__v8si)__B);
+                                             (__v16bf)__A,
+                                             (__v16bf)__B);
 }
 
 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
@@ -413,11 +416,11 @@ _mm256_maskz_dpbf16_ps(__mmask8 __U, __m256 __D, __m256bh __A, __m256bh __B) {
 ///    A float data.
 /// \returns A bf16 data whose sign field and exponent field keep unchanged,
 ///    and fraction field is truncated to 7 bits.
-static __inline__ __bfloat16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
+static __inline__ __bf16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
   __v4sf __V = {__A, 0, 0, 0};
-  __v8hi __R = __builtin_ia32_cvtneps2bf16_128_mask(
-      (__v4sf)__V, (__v8hi)_mm_undefined_si128(), (__mmask8)-1);
-  return (__bfloat16)__R[0];
+  __v8bf __R = __builtin_ia32_cvtneps2bf16_128_mask(
+      (__v4sf)__V, (__v8bf)_mm_undefined_si128(), (__mmask8)-1);
+  return (__bf16)__R[0];
 }
 
 /// Convert Packed BF16 Data to Packed float Data.
@@ -520,3 +523,4 @@ _mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
 #undef __DEFAULT_FN_ATTRS256
 
 #endif
+#endif

diff  --git a/clang/test/CodeGen/X86/avx512bf16-builtins.c b/clang/test/CodeGen/X86/avx512bf16-builtins.c
index 6e0afc709ec29..8eb93e6889bea 100644
--- a/clang/test/CodeGen/X86/avx512bf16-builtins.c
+++ b/clang/test/CodeGen/X86/avx512bf16-builtins.c
@@ -4,10 +4,9 @@
 
 #include <immintrin.h>
 
-float test_mm_cvtsbh_ss(__bfloat16 A) {
+float test_mm_cvtsbh_ss(__bf16 A) {
   // CHECK-LABEL: @test_mm_cvtsbh_ss
-  // CHECK: zext i16 %{{.*}} to i32
-  // CHECK: shl i32 %{{.*}}, 16
+  // CHECK: fpext bfloat %{{.*}} to float
   // CHECK: ret float %{{.*}}
   return _mm_cvtsbh_ss(A);
 }
@@ -15,46 +14,46 @@ float test_mm_cvtsbh_ss(__bfloat16 A) {
 __m512bh test_mm512_cvtne2ps_pbh(__m512 A, __m512 B) {
   // CHECK-LABEL: @test_mm512_cvtne2ps_pbh
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
-  // CHECK: ret <32 x i16> %{{.*}}
+  // CHECK: ret <32 x bfloat> %{{.*}}
   return _mm512_cvtne2ps_pbh(A, B);
 }
 
 __m512bh test_mm512_maskz_cvtne2ps_pbh(__m512 A, __m512 B, __mmask32 U) {
   // CHECK-LABEL: @test_mm512_maskz_cvtne2ps_pbh
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
-  // CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}}
-  // CHECK: ret <32 x i16> %{{.*}}
+  // CHECK: select <32 x i1> %{{.*}}, <32 x bfloat> %{{.*}}, <32 x bfloat> %{{.*}}
+  // CHECK: ret <32 x bfloat> %{{.*}}
   return _mm512_maskz_cvtne2ps_pbh(U, A, B);
 }
 
 __m512bh test_mm512_mask_cvtne2ps_pbh(__m512bh C, __mmask32 U, __m512 A, __m512 B) {
   // CHECK-LABEL: @test_mm512_mask_cvtne2ps_pbh
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
-  // CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}}
-  // CHECK: ret <32 x i16> %{{.*}}
+  // CHECK: select <32 x i1> %{{.*}}, <32 x bfloat> %{{.*}}, <32 x bfloat> %{{.*}}
+  // CHECK: ret <32 x bfloat> %{{.*}}
   return _mm512_mask_cvtne2ps_pbh(C, U, A, B);
 }
 
 __m256bh test_mm512_cvtneps_pbh(__m512 A) {
   // CHECK-LABEL: @test_mm512_cvtneps_pbh
   // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.512
-  // CHECK: ret <16 x i16> %{{.*}}
+  // CHECK: ret <16 x bfloat> %{{.*}}
   return _mm512_cvtneps_pbh(A);
 }
 
 __m256bh test_mm512_mask_cvtneps_pbh(__m256bh C, __mmask16 U, __m512 A) {
   // CHECK-LABEL: @test_mm512_mask_cvtneps_pbh
   // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.512
-  // CHECK: select <16 x i1> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}
-  // CHECK: ret <16 x i16> %{{.*}}
+  // CHECK: select <16 x i1> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}
+  // CHECK: ret <16 x bfloat> %{{.*}}
   return _mm512_mask_cvtneps_pbh(C, U, A);
 }
 
 __m256bh test_mm512_maskz_cvtneps_pbh(__m512 A, __mmask16 U) {
   // CHECK-LABEL: @test_mm512_maskz_cvtneps_pbh
   // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.512
-  // CHECK: select <16 x i1> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}
-  // CHECK: ret <16 x i16> %{{.*}}
+  // CHECK: select <16 x i1> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}
+  // CHECK: ret <16 x bfloat> %{{.*}}
   return _mm512_maskz_cvtneps_pbh(U, A);
 }
 

diff  --git a/clang/test/CodeGen/X86/avx512bf16-error.c b/clang/test/CodeGen/X86/avx512bf16-error.c
new file mode 100644
index 0000000000000..8e0916539cab6
--- /dev/null
+++ b/clang/test/CodeGen/X86/avx512bf16-error.c
@@ -0,0 +1,15 @@
+// RUN: %clang_cc1 -fsyntax-only -verify -ffreestanding -triple x86_64-linux-pc %s
+
+// expected-error at +1 3 {{unknown type name '__bfloat16'}}
+__bfloat16 foo(__bfloat16 a, __bfloat16 b) {
+  return a + b;
+}
+
+#include <immintrin.h>
+
+// expected-error at +4 {{invalid operands to binary expression ('__bfloat16' (aka '__bf16') and '__bfloat16')}}
+// expected-warning at +2 3 {{'__bfloat16' is deprecated: use __bf16 instead}}
+// expected-note@* 3 {{'__bfloat16' has been explicitly marked deprecated here}}
+__bfloat16 bar(__bfloat16 a, __bfloat16 b) {
+  return a + b;
+}

diff  --git a/clang/test/CodeGen/X86/avx512vlbf16-builtins.c b/clang/test/CodeGen/X86/avx512vlbf16-builtins.c
index a5a0ab1a8a4b1..539c6f8c43b2b 100644
--- a/clang/test/CodeGen/X86/avx512vlbf16-builtins.c
+++ b/clang/test/CodeGen/X86/avx512vlbf16-builtins.c
@@ -7,113 +7,113 @@
 __m128bh test_mm_cvtne2ps2bf16(__m128 A, __m128 B) {
   // CHECK-LABEL: @test_mm_cvtne2ps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.128
-  // CHECK: ret <8 x i16> %{{.*}}
+  // CHECK: ret <8 x bfloat> %{{.*}}
   return _mm_cvtne2ps_pbh(A, B);
 }
 
 __m128bh test_mm_maskz_cvtne2ps2bf16(__m128 A, __m128 B, __mmask8 U) {
   // CHECK-LABEL: @test_mm_maskz_cvtne2ps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.128
-  // CHECK: select <8 x i1> %{{.*}}, <8 x i16> %{{.*}}, <8 x i16> %{{.*}}
-  // CHECK: ret <8 x i16> %{{.*}}
+  // CHECK: select <8 x i1> %{{.*}}, <8 x bfloat> %{{.*}}, <8 x bfloat> %{{.*}}
+  // CHECK: ret <8 x bfloat> %{{.*}}
   return _mm_maskz_cvtne2ps_pbh(U, A, B);
 }
 
 __m128bh test_mm_mask_cvtne2ps2bf16(__m128bh C, __mmask8 U, __m128 A, __m128 B) {
   // CHECK-LABEL: @test_mm_mask_cvtne2ps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.128
-  // CHECK: select <8 x i1> %{{.*}}, <8 x i16> %{{.*}}, <8 x i16> %{{.*}}
-  // CHECK: ret <8 x i16> %{{.*}}
+  // CHECK: select <8 x i1> %{{.*}}, <8 x bfloat> %{{.*}}, <8 x bfloat> %{{.*}}
+  // CHECK: ret <8 x bfloat> %{{.*}}
   return _mm_mask_cvtne2ps_pbh(C, U, A, B);
 }
 
 __m256bh test_mm256_cvtne2ps2bf16(__m256 A, __m256 B) {
   // CHECK-LABEL: @test_mm256_cvtne2ps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.256
-  // CHECK: ret <16 x i16> %{{.*}}
+  // CHECK: ret <16 x bfloat> %{{.*}}
   return _mm256_cvtne2ps_pbh(A, B);
 }
 
 __m256bh test_mm256_maskz_cvtne2ps2bf16(__m256 A, __m256 B, __mmask16 U) {
   // CHECK-LABEL: @test_mm256_maskz_cvtne2ps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.256
-  // CHECK: select <16 x i1> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}
-  // CHECK: ret <16 x i16> %{{.*}}
+  // CHECK: select <16 x i1> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}
+  // CHECK: ret <16 x bfloat> %{{.*}}
   return _mm256_maskz_cvtne2ps_pbh(U, A, B);
 }
 
 __m256bh test_mm256_mask_cvtne2ps2bf16(__m256bh C, __mmask16 U, __m256 A, __m256 B) {
   // CHECK-LABEL: @test_mm256_mask_cvtne2ps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.256
-  // CHECK: select <16 x i1> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}
-  // CHECK: ret <16 x i16> %{{.*}}
+  // CHECK: select <16 x i1> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}
+  // CHECK: ret <16 x bfloat> %{{.*}}
   return _mm256_mask_cvtne2ps_pbh(C, U, A, B);
 }
 
 __m512bh test_mm512_cvtne2ps2bf16(__m512 A, __m512 B) {
   // CHECK-LABEL: @test_mm512_cvtne2ps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
-  // CHECK: ret <32 x i16> %{{.*}}
+  // CHECK: ret <32 x bfloat> %{{.*}}
   return _mm512_cvtne2ps_pbh(A, B);
 }
 
 __m512bh test_mm512_maskz_cvtne2ps2bf16(__m512 A, __m512 B, __mmask32 U) {
   // CHECK-LABEL: @test_mm512_maskz_cvtne2ps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
-  // CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}}
-  // CHECK: ret <32 x i16> %{{.*}}
+  // CHECK: select <32 x i1> %{{.*}}, <32 x bfloat> %{{.*}}, <32 x bfloat> %{{.*}}
+  // CHECK: ret <32 x bfloat> %{{.*}}
   return _mm512_maskz_cvtne2ps_pbh(U, A, B);
 }
 
 __m512bh test_mm512_mask_cvtne2ps2bf16(__m512bh C, __mmask32 U, __m512 A, __m512 B) {
   // CHECK-LABEL: @test_mm512_mask_cvtne2ps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
-  // CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}}
-  // CHECK: ret <32 x i16> %{{.*}}
+  // CHECK: select <32 x i1> %{{.*}}, <32 x bfloat> %{{.*}}, <32 x bfloat> %{{.*}}
+  // CHECK: ret <32 x bfloat> %{{.*}}
   return _mm512_mask_cvtne2ps_pbh(C, U, A, B);
 }
 
 __m128bh test_mm_cvtneps2bf16(__m128 A) {
   // CHECK-LABEL: @test_mm_cvtneps2bf16
   // CHECK: @llvm.x86.avx512bf16.mask.cvtneps2bf16.128
-  // CHECK: ret <8 x i16> %{{.*}}
+  // CHECK: ret <8 x bfloat> %{{.*}}
   return _mm_cvtneps_pbh(A);
 }
 
 __m128bh test_mm_mask_cvtneps2bf16(__m128bh C, __mmask8 U, __m128 A) {
   // CHECK-LABEL: @test_mm_mask_cvtneps2bf16
   // CHECK: @llvm.x86.avx512bf16.mask.cvtneps2bf16.
-  // CHECK: ret <8 x i16> %{{.*}}
+  // CHECK: ret <8 x bfloat> %{{.*}}
   return _mm_mask_cvtneps_pbh(C, U, A);
 }
 
 __m128bh test_mm_maskz_cvtneps2bf16(__m128 A, __mmask8 U) {
   // CHECK-LABEL: @test_mm_maskz_cvtneps2bf16
   // CHECK: @llvm.x86.avx512bf16.mask.cvtneps2bf16.128
-  // CHECK: ret <8 x i16> %{{.*}}
+  // CHECK: ret <8 x bfloat> %{{.*}}
   return _mm_maskz_cvtneps_pbh(U, A);
 }
 
 __m128bh test_mm256_cvtneps2bf16(__m256 A) {
   // CHECK-LABEL: @test_mm256_cvtneps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.256
-  // CHECK: ret <8 x i16> %{{.*}}
+  // CHECK: ret <8 x bfloat> %{{.*}}
   return _mm256_cvtneps_pbh(A);
 }
 
 __m128bh test_mm256_mask_cvtneps2bf16(__m128bh C, __mmask8 U, __m256 A) {
   // CHECK-LABEL: @test_mm256_mask_cvtneps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.256
-  // CHECK: select <8 x i1> %{{.*}}, <8 x i16> %{{.*}}, <8 x i16> %{{.*}}
-  // CHECK: ret <8 x i16> %{{.*}}
+  // CHECK: select <8 x i1> %{{.*}}, <8 x bfloat> %{{.*}}, <8 x bfloat> %{{.*}}
+  // CHECK: ret <8 x bfloat> %{{.*}}
   return _mm256_mask_cvtneps_pbh(C, U, A);
 }
 
 __m128bh test_mm256_maskz_cvtneps2bf16(__m256 A, __mmask8 U) {
   // CHECK-LABEL: @test_mm256_maskz_cvtneps2bf16
   // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.256
-  // CHECK: select <8 x i1> %{{.*}}, <8 x i16> %{{.*}}, <8 x i16> %{{.*}}
-  // CHECK: ret <8 x i16> %{{.*}}
+  // CHECK: select <8 x i1> %{{.*}}, <8 x bfloat> %{{.*}}, <8 x bfloat> %{{.*}}
+  // CHECK: ret <8 x bfloat> %{{.*}}
   return _mm256_maskz_cvtneps_pbh(U, A);
 }
 
@@ -162,10 +162,10 @@ __m256 test_mm256_mask_dpbf16_ps(__m256 D, __m256bh A, __m256bh B, __mmask8 U) {
   return _mm256_mask_dpbf16_ps(D, U, A, B);
 }
 
-__bfloat16 test_mm_cvtness_sbh(float A) {
+__bf16 test_mm_cvtness_sbh(float A) {
   // CHECK-LABEL: @test_mm_cvtness_sbh
   // CHECK: @llvm.x86.avx512bf16.mask.cvtneps2bf16.128
-  // CHECK: ret i16 %{{.*}}
+  // CHECK: ret bfloat %{{.*}}
   return _mm_cvtness_sbh(A);
 }
 

diff  --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 8817b909c4591..2d0b8bc876757 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -327,6 +327,8 @@ def llvm_v32f16_ty     : LLVMType<v32f16>;   // 32 x half (__fp16)
 def llvm_v2bf16_ty     : LLVMType<v2bf16>;   //  2 x bfloat (__bf16)
 def llvm_v4bf16_ty     : LLVMType<v4bf16>;   //  4 x bfloat (__bf16)
 def llvm_v8bf16_ty     : LLVMType<v8bf16>;   //  8 x bfloat (__bf16)
+def llvm_v16bf16_ty    : LLVMType<v16bf16>;  // 16 x bfloat (__bf16)
+def llvm_v32bf16_ty    : LLVMType<v32bf16>;  // 32 x bfloat (__bf16)
 def llvm_v1f32_ty      : LLVMType<v1f32>;    //  1 x float
 def llvm_v2f32_ty      : LLVMType<v2f32>;    //  2 x float
 def llvm_v3f32_ty      : LLVMType<v3f32>;    //  3 x float

diff  --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td
index c274e35042502..a3ec128b75022 100644
--- a/llvm/include/llvm/IR/IntrinsicsX86.td
+++ b/llvm/include/llvm/IR/IntrinsicsX86.td
@@ -4901,39 +4901,39 @@ let TargetPrefix = "x86" in {
 let TargetPrefix = "x86" in {
   def int_x86_avx512bf16_cvtne2ps2bf16_128:
               ClangBuiltin<"__builtin_ia32_cvtne2ps2bf16_128">,
-              Intrinsic<[llvm_v8i16_ty], [llvm_v4f32_ty, llvm_v4f32_ty],
+              Intrinsic<[llvm_v8bf16_ty], [llvm_v4f32_ty, llvm_v4f32_ty],
               [IntrNoMem]>;
   def int_x86_avx512bf16_cvtne2ps2bf16_256:
               ClangBuiltin<"__builtin_ia32_cvtne2ps2bf16_256">,
-              Intrinsic<[llvm_v16i16_ty], [llvm_v8f32_ty, llvm_v8f32_ty],
+              Intrinsic<[llvm_v16bf16_ty], [llvm_v8f32_ty, llvm_v8f32_ty],
               [IntrNoMem]>;
   def int_x86_avx512bf16_cvtne2ps2bf16_512:
               ClangBuiltin<"__builtin_ia32_cvtne2ps2bf16_512">,
-              Intrinsic<[llvm_v32i16_ty], [llvm_v16f32_ty, llvm_v16f32_ty],
+              Intrinsic<[llvm_v32bf16_ty], [llvm_v16f32_ty, llvm_v16f32_ty],
               [IntrNoMem]>;
   // Intrinsic must be masked due to it producing less than 128 bits of results.
   def int_x86_avx512bf16_mask_cvtneps2bf16_128:
-              Intrinsic<[llvm_v8i16_ty],
-                        [llvm_v4f32_ty, llvm_v8i16_ty, llvm_v4i1_ty],
+              Intrinsic<[llvm_v8bf16_ty],
+                        [llvm_v4f32_ty, llvm_v8bf16_ty, llvm_v4i1_ty],
                         [IntrNoMem]>;
   def int_x86_avx512bf16_cvtneps2bf16_256:
               ClangBuiltin<"__builtin_ia32_cvtneps2bf16_256">,
-              Intrinsic<[llvm_v8i16_ty], [llvm_v8f32_ty], [IntrNoMem]>;
+              Intrinsic<[llvm_v8bf16_ty], [llvm_v8f32_ty], [IntrNoMem]>;
   def int_x86_avx512bf16_cvtneps2bf16_512:
               ClangBuiltin<"__builtin_ia32_cvtneps2bf16_512">,
-              Intrinsic<[llvm_v16i16_ty], [llvm_v16f32_ty], [IntrNoMem]>;
+              Intrinsic<[llvm_v16bf16_ty], [llvm_v16f32_ty], [IntrNoMem]>;
   def int_x86_avx512bf16_dpbf16ps_128:
               ClangBuiltin<"__builtin_ia32_dpbf16ps_128">,
               Intrinsic<[llvm_v4f32_ty],
-              [llvm_v4f32_ty, llvm_v4i32_ty, llvm_v4i32_ty], [IntrNoMem]>;
+              [llvm_v4f32_ty, llvm_v8bf16_ty, llvm_v8bf16_ty], [IntrNoMem]>;
   def int_x86_avx512bf16_dpbf16ps_256:
               ClangBuiltin<"__builtin_ia32_dpbf16ps_256">,
               Intrinsic<[llvm_v8f32_ty],
-              [llvm_v8f32_ty, llvm_v8i32_ty, llvm_v8i32_ty], [IntrNoMem]>;
+              [llvm_v8f32_ty, llvm_v16bf16_ty, llvm_v16bf16_ty], [IntrNoMem]>;
   def int_x86_avx512bf16_dpbf16ps_512:
               ClangBuiltin<"__builtin_ia32_dpbf16ps_512">,
               Intrinsic<[llvm_v16f32_ty],
-              [llvm_v16f32_ty, llvm_v16i32_ty, llvm_v16i32_ty], [IntrNoMem]>;
+              [llvm_v16f32_ty, llvm_v32bf16_ty, llvm_v32bf16_ty], [IntrNoMem]>;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index a4a33efb055a4..ebc362a4f8a09 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -82,6 +82,26 @@ static bool UpgradeX86MaskedFPCompare(Function *F, Intrinsic::ID IID,
   return true;
 }
 
+static bool UpgradeX86BF16Intrinsic(Function *F, Intrinsic::ID IID,
+                                    Function *&NewFn) {
+  if (F->getReturnType()->getScalarType()->isBFloatTy())
+    return false;
+
+  rename(F);
+  NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
+  return true;
+}
+
+static bool UpgradeX86BF16DPIntrinsic(Function *F, Intrinsic::ID IID,
+                                      Function *&NewFn) {
+  if (F->getFunctionType()->getParamType(1)->getScalarType()->isBFloatTy())
+    return false;
+
+  rename(F);
+  NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
+  return true;
+}
+
 static bool ShouldUpgradeX86Intrinsic(Function *F, StringRef Name) {
   // All of the intrinsics matches below should be marked with which llvm
   // version started autoupgrading them. At some point in the future we would
@@ -488,6 +508,33 @@ static bool UpgradeX86IntrinsicFunction(Function *F, StringRef Name,
   if (Name == "avx512.mask.cmp.ps.512") // Added in 7.0
     return UpgradeX86MaskedFPCompare(F, Intrinsic::x86_avx512_mask_cmp_ps_512,
                                      NewFn);
+  if (Name == "avx512bf16.cvtne2ps2bf16.128") // Added in 9.0
+    return UpgradeX86BF16Intrinsic(
+        F, Intrinsic::x86_avx512bf16_cvtne2ps2bf16_128, NewFn);
+  if (Name == "avx512bf16.cvtne2ps2bf16.256") // Added in 9.0
+    return UpgradeX86BF16Intrinsic(
+        F, Intrinsic::x86_avx512bf16_cvtne2ps2bf16_256, NewFn);
+  if (Name == "avx512bf16.cvtne2ps2bf16.512") // Added in 9.0
+    return UpgradeX86BF16Intrinsic(
+        F, Intrinsic::x86_avx512bf16_cvtne2ps2bf16_512, NewFn);
+  if (Name == "avx512bf16.mask.cvtneps2bf16.128") // Added in 9.0
+    return UpgradeX86BF16Intrinsic(
+        F, Intrinsic::x86_avx512bf16_mask_cvtneps2bf16_128, NewFn);
+  if (Name == "avx512bf16.cvtneps2bf16.256") // Added in 9.0
+    return UpgradeX86BF16Intrinsic(
+        F, Intrinsic::x86_avx512bf16_cvtneps2bf16_256, NewFn);
+  if (Name == "avx512bf16.cvtneps2bf16.512") // Added in 9.0
+    return UpgradeX86BF16Intrinsic(
+        F, Intrinsic::x86_avx512bf16_cvtneps2bf16_512, NewFn);
+  if (Name == "avx512bf16.dpbf16ps.128") // Added in 9.0
+    return UpgradeX86BF16DPIntrinsic(
+        F, Intrinsic::x86_avx512bf16_dpbf16ps_128, NewFn);
+  if (Name == "avx512bf16.dpbf16ps.256") // Added in 9.0
+    return UpgradeX86BF16DPIntrinsic(
+        F, Intrinsic::x86_avx512bf16_dpbf16ps_256, NewFn);
+  if (Name == "avx512bf16.dpbf16ps.512") // Added in 9.0
+    return UpgradeX86BF16DPIntrinsic(
+        F, Intrinsic::x86_avx512bf16_dpbf16ps_512, NewFn);
 
   // frcz.ss/sd may need to have an argument dropped. Added in 3.2
   if (Name.startswith("xop.vfrcz.ss") && F->arg_size() == 2) {
@@ -4170,6 +4217,43 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
     return;
   }
 
+  case Intrinsic::x86_avx512bf16_cvtne2ps2bf16_128:
+  case Intrinsic::x86_avx512bf16_cvtne2ps2bf16_256:
+  case Intrinsic::x86_avx512bf16_cvtne2ps2bf16_512:
+  case Intrinsic::x86_avx512bf16_mask_cvtneps2bf16_128:
+  case Intrinsic::x86_avx512bf16_cvtneps2bf16_256:
+  case Intrinsic::x86_avx512bf16_cvtneps2bf16_512: {
+    SmallVector<Value *, 4> Args(CI->args());
+    unsigned NumElts = cast<FixedVectorType>(CI->getType())->getNumElements();
+    if (NewFn->getIntrinsicID() ==
+        Intrinsic::x86_avx512bf16_mask_cvtneps2bf16_128)
+      Args[1] = Builder.CreateBitCast(
+          Args[1], FixedVectorType::get(Builder.getBFloatTy(), NumElts));
+
+    NewCall = Builder.CreateCall(NewFn, Args);
+    Value *Res = Builder.CreateBitCast(
+        NewCall, FixedVectorType::get(Builder.getInt16Ty(), NumElts));
+
+    NewCall->takeName(CI);
+    CI->replaceAllUsesWith(Res);
+    CI->eraseFromParent();
+    return;
+  }
+  case Intrinsic::x86_avx512bf16_dpbf16ps_128:
+  case Intrinsic::x86_avx512bf16_dpbf16ps_256:
+  case Intrinsic::x86_avx512bf16_dpbf16ps_512:{
+    SmallVector<Value *, 4> Args(CI->args());
+    unsigned NumElts =
+        cast<FixedVectorType>(CI->getType())->getNumElements() * 2;
+    Args[1] = Builder.CreateBitCast(
+        Args[1], FixedVectorType::get(Builder.getBFloatTy(), NumElts));
+    Args[2] = Builder.CreateBitCast(
+        Args[2], FixedVectorType::get(Builder.getBFloatTy(), NumElts));
+
+    NewCall = Builder.CreateCall(NewFn, Args);
+    break;
+  }
+
   case Intrinsic::thread_pointer: {
     NewCall = Builder.CreateCall(NewFn, {});
     break;

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index e7cc2ea50545f..d219f82a7a97a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -2178,6 +2178,25 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     }
   }
 
+  if (!Subtarget.useSoftFloat() && Subtarget.hasBF16()) {
+    addRegisterClass(MVT::v8bf16, &X86::VR128XRegClass);
+    addRegisterClass(MVT::v16bf16, &X86::VR256XRegClass);
+    addRegisterClass(MVT::v32bf16, &X86::VR512RegClass);
+    // We set the type action of bf16 to TypeSoftPromoteHalf, but we don't
+    // provide the method to promote BUILD_VECTOR. Set the operation action
+    // Custom to do the customization later.
+    setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom);
+    for (auto VT : { MVT::v8bf16, MVT::v16bf16, MVT::v32bf16 }) {
+      setF16Action(VT, Expand);
+      setOperationAction(ISD::FADD, VT, Expand);
+      setOperationAction(ISD::FSUB, VT, Expand);
+      setOperationAction(ISD::FMUL, VT, Expand);
+      setOperationAction(ISD::FDIV, VT, Expand);
+      setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
+    }
+    addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
+  }
+
   if (!Subtarget.useSoftFloat() && Subtarget.hasVLX()) {
     setTruncStoreAction(MVT::v4i64, MVT::v4i8,  Legal);
     setTruncStoreAction(MVT::v4i64, MVT::v4i16, Legal);
@@ -9921,6 +9940,18 @@ static SDValue buildFromShuffleMostly(SDValue Op, SelectionDAG &DAG) {
   return NV;
 }
 
+// Lower BUILD_VECTOR operation for v8bf16, v16bf16 and v32bf16 types.
+static SDValue LowerBUILD_VECTORvXbf16(SDValue Op, SelectionDAG &DAG,
+                                       const X86Subtarget &Subtarget) {
+  MVT VT = Op.getSimpleValueType();
+  MVT IVT = VT.changeVectorElementTypeToInteger();
+  SmallVector<SDValue, 16> NewOps;
+  for (unsigned I = 0, E = Op.getNumOperands(); I != E; ++I)
+    NewOps.push_back(DAG.getBitcast(MVT::i16, Op.getOperand(I)));
+  SDValue Res = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(), IVT, NewOps);
+  return DAG.getBitcast(VT, Res);
+}
+
 // Lower BUILD_VECTOR operation for v8i1 and v16i1 types.
 static SDValue LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG,
                                      const X86Subtarget &Subtarget) {
@@ -11075,6 +11106,9 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
   if (VT.getVectorElementType() == MVT::i1 && Subtarget.hasAVX512())
     return LowerBUILD_VECTORvXi1(Op, DAG, Subtarget);
 
+  if (VT.getVectorElementType() == MVT::bf16 && Subtarget.hasBF16())
+    return LowerBUILD_VECTORvXbf16(Op, DAG, Subtarget);
+
   if (SDValue VectorConstant = materializeVectorConstant(Op, DAG, Subtarget))
     return VectorConstant;
 

diff  --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 788bed3eb6896..b0996c75222e7 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -53,34 +53,36 @@ class X86VectorVTInfo<int numelts, ValueType eltvt, RegisterClass rc,
 
   string EltTypeName = !cast<string>(EltVT);
   // Size of the element type in bits, e.g. 32 for v16i32.
-  string EltSizeName = !subst("i", "", !subst("f", "", EltTypeName));
+  string EltSizeName = !subst("i", "", !subst("f", "", !subst("b", "", EltTypeName)));
   int EltSize = EltVT.Size;
 
   // "i" for integer types and "f" for floating-point types
-  string TypeVariantName = !subst(EltSizeName, "", EltTypeName);
+  string TypeVariantName = !subst("b", "", !subst(EltSizeName, "", EltTypeName));
 
   // Size of RC in bits, e.g. 512 for VR512.
   int Size = VT.Size;
 
   // The corresponding memory operand, e.g. i512mem for VR512.
   X86MemOperand MemOp = !cast<X86MemOperand>(TypeVariantName # Size # "mem");
-  X86MemOperand ScalarMemOp = !cast<X86MemOperand>(EltVT # "mem");
+  X86MemOperand ScalarMemOp = !cast<X86MemOperand>(!subst("b", "", EltTypeName) # "mem");
   // FP scalar memory operand for intrinsics - ssmem/sdmem.
   Operand IntScalarMemOp = !if (!eq (EltTypeName, "f16"), !cast<Operand>("shmem"),
+                           !if (!eq (EltTypeName, "bf16"), !cast<Operand>("shmem"),
                            !if (!eq (EltTypeName, "f32"), !cast<Operand>("ssmem"),
-                           !if (!eq (EltTypeName, "f64"), !cast<Operand>("sdmem"), ?)));
+                           !if (!eq (EltTypeName, "f64"), !cast<Operand>("sdmem"), ?))));
 
   // Load patterns
   PatFrag LdFrag = !cast<PatFrag>("load" # VTName);
 
   PatFrag AlignedLdFrag = !cast<PatFrag>("alignedload" # VTName);
 
-  PatFrag ScalarLdFrag = !cast<PatFrag>("load" # EltVT);
+  PatFrag ScalarLdFrag = !cast<PatFrag>("load" # !subst("b", "", EltTypeName));
   PatFrag BroadcastLdFrag = !cast<PatFrag>("X86VBroadcastld" # EltSizeName);
 
   PatFrags ScalarIntMemFrags = !if (!eq (EltTypeName, "f16"), !cast<PatFrags>("sse_load_f16"),
+                               !if (!eq (EltTypeName, "bf16"), !cast<PatFrags>("sse_load_f16"),
                                !if (!eq (EltTypeName, "f32"), !cast<PatFrags>("sse_load_f32"),
-                               !if (!eq (EltTypeName, "f64"), !cast<PatFrags>("sse_load_f64"), ?)));
+                               !if (!eq (EltTypeName, "f64"), !cast<PatFrags>("sse_load_f64"), ?))));
 
   // The string to specify embedded broadcast in assembly.
   string BroadcastStr = "{1to" # NumElts # "}";
@@ -96,11 +98,13 @@ class X86VectorVTInfo<int numelts, ValueType eltvt, RegisterClass rc,
   Domain ExeDomain = !if (!eq (EltTypeName, "f32"), SSEPackedSingle,
                      !if (!eq (EltTypeName, "f64"), SSEPackedDouble,
                      !if (!eq (EltTypeName, "f16"), SSEPackedSingle, // FIXME?
-                     SSEPackedInt)));
+                     !if (!eq (EltTypeName, "bf16"), SSEPackedSingle, // FIXME?
+                     SSEPackedInt))));
 
   RegisterClass FRC = !if (!eq (EltTypeName, "f32"), FR32X,
                       !if (!eq (EltTypeName, "f16"), FR16X,
-                      FR64X));
+                      !if (!eq (EltTypeName, "bf16"), FR16X,
+                      FR64X)));
 
   dag ImmAllZerosV = (VT immAllZerosV);
 
@@ -113,6 +117,7 @@ def v32i16_info : X86VectorVTInfo<32, i16, VR512, "w">;
 def v16i32_info : X86VectorVTInfo<16, i32, VR512, "d">;
 def v8i64_info  : X86VectorVTInfo<8,  i64, VR512, "q">;
 def v32f16_info : X86VectorVTInfo<32, f16, VR512, "ph">;
+def v32bf16_info: X86VectorVTInfo<32, bf16, VR512, "pbf">;
 def v16f32_info : X86VectorVTInfo<16, f32, VR512, "ps">;
 def v8f64_info  : X86VectorVTInfo<8,  f64, VR512, "pd">;
 
@@ -122,6 +127,7 @@ def v16i16x_info : X86VectorVTInfo<16, i16, VR256X, "w">;
 def v8i32x_info  : X86VectorVTInfo<8,  i32, VR256X, "d">;
 def v4i64x_info  : X86VectorVTInfo<4,  i64, VR256X, "q">;
 def v16f16x_info : X86VectorVTInfo<16, f16, VR256X, "ph">;
+def v16bf16x_info: X86VectorVTInfo<16, bf16, VR256X, "pbf">;
 def v8f32x_info  : X86VectorVTInfo<8,  f32, VR256X, "ps">;
 def v4f64x_info  : X86VectorVTInfo<4,  f64, VR256X, "pd">;
 
@@ -130,6 +136,7 @@ def v8i16x_info  : X86VectorVTInfo<8,  i16, VR128X, "w">;
 def v4i32x_info  : X86VectorVTInfo<4,  i32, VR128X, "d">;
 def v2i64x_info  : X86VectorVTInfo<2,  i64, VR128X, "q">;
 def v8f16x_info  : X86VectorVTInfo<8,  f16, VR128X, "ph">;
+def v8bf16x_info : X86VectorVTInfo<8,  bf16, VR128X, "pbf">;
 def v4f32x_info  : X86VectorVTInfo<4,  f32, VR128X, "ps">;
 def v2f64x_info  : X86VectorVTInfo<2,  f64, VR128X, "pd">;
 
@@ -138,6 +145,7 @@ def v2f64x_info  : X86VectorVTInfo<2,  f64, VR128X, "pd">;
 def i32x_info    : X86VectorVTInfo<1,  i32, GR32, "si">;
 def i64x_info    : X86VectorVTInfo<1,  i64, GR64, "sq">;
 def f16x_info    : X86VectorVTInfo<1,  f16, VR128X, "sh">;
+def bf16x_info   : X86VectorVTInfo<1,  bf16, VR128X, "sbf">;
 def f32x_info    : X86VectorVTInfo<1,  f32, VR128X, "ss">;
 def f64x_info    : X86VectorVTInfo<1,  f64, VR128X, "sd">;
 
@@ -158,6 +166,8 @@ def avx512vl_i64_info : AVX512VLVectorVTInfo<v8i64_info, v4i64x_info,
                                              v2i64x_info>;
 def avx512vl_f16_info : AVX512VLVectorVTInfo<v32f16_info, v16f16x_info,
                                              v8f16x_info>;
+def avx512vl_bf16_info : AVX512VLVectorVTInfo<v32bf16_info, v16bf16x_info,
+                                             v8bf16x_info>;
 def avx512vl_f32_info : AVX512VLVectorVTInfo<v16f32_info, v8f32x_info,
                                              v4f32x_info>;
 def avx512vl_f64_info : AVX512VLVectorVTInfo<v8f64_info, v4f64x_info,
@@ -3761,6 +3771,9 @@ let Predicates = [HasBWI, NoVLX] in {
 
   defm : mask_move_lowering<"VMOVDQU16Z", v8f16x_info, v32f16_info>;
   defm : mask_move_lowering<"VMOVDQU16Z", v16f16x_info, v32f16_info>;
+
+  defm : mask_move_lowering<"VMOVDQU16Z", v8bf16x_info, v32bf16_info>;
+  defm : mask_move_lowering<"VMOVDQU16Z", v16bf16x_info, v32bf16_info>;
 }
 
 let Predicates = [HasAVX512] in {
@@ -3771,6 +3784,8 @@ let Predicates = [HasAVX512] in {
             (VMOVDQA64Zrm addr:$src)>;
   def : Pat<(alignedloadv32f16 addr:$src),
             (VMOVAPSZrm addr:$src)>;
+  def : Pat<(alignedloadv32bf16 addr:$src),
+            (VMOVAPSZrm addr:$src)>;
   def : Pat<(alignedloadv64i8 addr:$src),
             (VMOVDQA64Zrm addr:$src)>;
   def : Pat<(loadv16i32 addr:$src),
@@ -3779,6 +3794,8 @@ let Predicates = [HasAVX512] in {
             (VMOVDQU64Zrm addr:$src)>;
   def : Pat<(loadv32f16 addr:$src),
             (VMOVUPSZrm addr:$src)>;
+  def : Pat<(loadv32bf16 addr:$src),
+            (VMOVUPSZrm addr:$src)>;
   def : Pat<(loadv64i8 addr:$src),
             (VMOVDQU64Zrm addr:$src)>;
 
@@ -3789,6 +3806,8 @@ let Predicates = [HasAVX512] in {
             (VMOVDQA64Zmr addr:$dst, VR512:$src)>;
   def : Pat<(alignedstore (v32f16 VR512:$src), addr:$dst),
             (VMOVAPSZmr addr:$dst, VR512:$src)>;
+  def : Pat<(alignedstore (v32bf16 VR512:$src), addr:$dst),
+            (VMOVAPSZmr addr:$dst, VR512:$src)>;
   def : Pat<(alignedstore (v64i8 VR512:$src), addr:$dst),
             (VMOVDQA64Zmr addr:$dst, VR512:$src)>;
   def : Pat<(store (v16i32 VR512:$src), addr:$dst),
@@ -3797,6 +3816,8 @@ let Predicates = [HasAVX512] in {
             (VMOVDQU64Zmr addr:$dst, VR512:$src)>;
   def : Pat<(store (v32f16 VR512:$src), addr:$dst),
             (VMOVUPSZmr addr:$dst, VR512:$src)>;
+  def : Pat<(store (v32bf16 VR512:$src), addr:$dst),
+            (VMOVUPSZmr addr:$dst, VR512:$src)>;
   def : Pat<(store (v64i8 VR512:$src), addr:$dst),
             (VMOVDQU64Zmr addr:$dst, VR512:$src)>;
 }
@@ -3809,6 +3830,8 @@ let Predicates = [HasVLX] in {
             (VMOVDQA64Z128rm addr:$src)>;
   def : Pat<(alignedloadv8f16 addr:$src),
             (VMOVAPSZ128rm addr:$src)>;
+  def : Pat<(alignedloadv8bf16 addr:$src),
+            (VMOVAPSZ128rm addr:$src)>;
   def : Pat<(alignedloadv16i8 addr:$src),
             (VMOVDQA64Z128rm addr:$src)>;
   def : Pat<(loadv4i32 addr:$src),
@@ -3817,6 +3840,8 @@ let Predicates = [HasVLX] in {
             (VMOVDQU64Z128rm addr:$src)>;
   def : Pat<(loadv8f16 addr:$src),
             (VMOVUPSZ128rm addr:$src)>;
+  def : Pat<(loadv8bf16 addr:$src),
+            (VMOVUPSZ128rm addr:$src)>;
   def : Pat<(loadv16i8 addr:$src),
             (VMOVDQU64Z128rm addr:$src)>;
 
@@ -3827,6 +3852,8 @@ let Predicates = [HasVLX] in {
             (VMOVDQA64Z128mr addr:$dst, VR128X:$src)>;
   def : Pat<(alignedstore (v8f16 VR128X:$src), addr:$dst),
             (VMOVAPSZ128mr addr:$dst, VR128X:$src)>;
+  def : Pat<(alignedstore (v8bf16 VR128X:$src), addr:$dst),
+            (VMOVAPSZ128mr addr:$dst, VR128X:$src)>;
   def : Pat<(alignedstore (v16i8 VR128X:$src), addr:$dst),
             (VMOVDQA64Z128mr addr:$dst, VR128X:$src)>;
   def : Pat<(store (v4i32 VR128X:$src), addr:$dst),
@@ -3835,6 +3862,8 @@ let Predicates = [HasVLX] in {
             (VMOVDQU64Z128mr addr:$dst, VR128X:$src)>;
   def : Pat<(store (v8f16 VR128X:$src), addr:$dst),
             (VMOVUPSZ128mr addr:$dst, VR128X:$src)>;
+  def : Pat<(store (v8bf16 VR128X:$src), addr:$dst),
+            (VMOVUPSZ128mr addr:$dst, VR128X:$src)>;
   def : Pat<(store (v16i8 VR128X:$src), addr:$dst),
             (VMOVDQU64Z128mr addr:$dst, VR128X:$src)>;
 
@@ -3845,6 +3874,8 @@ let Predicates = [HasVLX] in {
             (VMOVDQA64Z256rm addr:$src)>;
   def : Pat<(alignedloadv16f16 addr:$src),
             (VMOVAPSZ256rm addr:$src)>;
+  def : Pat<(alignedloadv16bf16 addr:$src),
+            (VMOVAPSZ256rm addr:$src)>;
   def : Pat<(alignedloadv32i8 addr:$src),
             (VMOVDQA64Z256rm addr:$src)>;
   def : Pat<(loadv8i32 addr:$src),
@@ -3853,6 +3884,8 @@ let Predicates = [HasVLX] in {
             (VMOVDQU64Z256rm addr:$src)>;
   def : Pat<(loadv16f16 addr:$src),
             (VMOVUPSZ256rm addr:$src)>;
+  def : Pat<(loadv16bf16 addr:$src),
+            (VMOVUPSZ256rm addr:$src)>;
   def : Pat<(loadv32i8 addr:$src),
             (VMOVDQU64Z256rm addr:$src)>;
 
@@ -3863,6 +3896,8 @@ let Predicates = [HasVLX] in {
             (VMOVDQA64Z256mr addr:$dst, VR256X:$src)>;
   def : Pat<(alignedstore (v16f16 VR256X:$src), addr:$dst),
             (VMOVAPSZ256mr addr:$dst, VR256X:$src)>;
+  def : Pat<(alignedstore (v16bf16 VR256X:$src), addr:$dst),
+            (VMOVAPSZ256mr addr:$dst, VR256X:$src)>;
   def : Pat<(alignedstore (v32i8 VR256X:$src), addr:$dst),
             (VMOVDQA64Z256mr addr:$dst, VR256X:$src)>;
   def : Pat<(store (v8i32 VR256X:$src), addr:$dst),
@@ -3871,89 +3906,97 @@ let Predicates = [HasVLX] in {
             (VMOVDQU64Z256mr addr:$dst, VR256X:$src)>;
   def : Pat<(store (v16f16 VR256X:$src), addr:$dst),
             (VMOVUPSZ256mr addr:$dst, VR256X:$src)>;
+  def : Pat<(store (v16bf16 VR256X:$src), addr:$dst),
+            (VMOVUPSZ256mr addr:$dst, VR256X:$src)>;
   def : Pat<(store (v32i8 VR256X:$src), addr:$dst),
             (VMOVDQU64Z256mr addr:$dst, VR256X:$src)>;
 }
+
+multiclass mask_move_lowering_f16_bf16<AVX512VLVectorVTInfo _> {
 let Predicates = [HasBWI] in {
-  def : Pat<(v32f16 (vselect VK32WM:$mask, (v32f16 VR512:$src1), (v32f16 VR512:$src0))),
+  def : Pat<(_.info512.VT (vselect VK32WM:$mask, (_.info512.VT VR512:$src1), (_.info512.VT VR512:$src0))),
             (VMOVDQU16Zrrk VR512:$src0, VK32WM:$mask, VR512:$src1)>;
-  def : Pat<(v32f16 (vselect VK32WM:$mask, (v32f16 VR512:$src1), v32f16_info.ImmAllZerosV)),
+  def : Pat<(_.info512.VT (vselect VK32WM:$mask, (_.info512.VT VR512:$src1), _.info512.ImmAllZerosV)),
             (VMOVDQU16Zrrkz VK32WM:$mask, VR512:$src1)>;
-  def : Pat<(v32f16 (vselect VK32WM:$mask,
-                     (v32f16 (alignedloadv32f16 addr:$src)), (v32f16 VR512:$src0))),
+  def : Pat<(_.info512.VT (vselect VK32WM:$mask,
+                     (_.info512.VT (_.info512.AlignedLdFrag addr:$src)), (_.info512.VT VR512:$src0))),
             (VMOVDQU16Zrmk VR512:$src0, VK32WM:$mask, addr:$src)>;
-  def : Pat<(v32f16 (vselect VK32WM:$mask,
-                     (v32f16 (alignedloadv32f16 addr:$src)), v32f16_info.ImmAllZerosV)),
+  def : Pat<(_.info512.VT (vselect VK32WM:$mask,
+                     (_.info512.VT (_.info512.AlignedLdFrag addr:$src)), _.info512.ImmAllZerosV)),
             (VMOVDQU16Zrmkz VK32WM:$mask, addr:$src)>;
-  def : Pat<(v32f16 (vselect VK32WM:$mask,
-                     (v32f16 (loadv32f16 addr:$src)), (v32f16 VR512:$src0))),
+  def : Pat<(_.info512.VT (vselect VK32WM:$mask,
+                     (_.info512.VT (_.info512.LdFrag addr:$src)), (_.info512.VT VR512:$src0))),
             (VMOVDQU16Zrmk VR512:$src0, VK32WM:$mask, addr:$src)>;
-  def : Pat<(v32f16 (vselect VK32WM:$mask,
-                     (v32f16 (loadv32f16 addr:$src)), v32f16_info.ImmAllZerosV)),
+  def : Pat<(_.info512.VT (vselect VK32WM:$mask,
+                     (_.info512.VT (_.info512.LdFrag addr:$src)), _.info512.ImmAllZerosV)),
             (VMOVDQU16Zrmkz VK32WM:$mask, addr:$src)>;
-  def : Pat<(v32f16 (masked_load addr:$src, VK32WM:$mask, (v32f16 VR512:$src0))),
+  def : Pat<(_.info512.VT (masked_load addr:$src, VK32WM:$mask, (_.info512.VT VR512:$src0))),
             (VMOVDQU16Zrmk VR512:$src0, VK32WM:$mask, addr:$src)>;
-  def : Pat<(v32f16 (masked_load addr:$src, VK32WM:$mask, undef)),
+  def : Pat<(_.info512.VT (masked_load addr:$src, VK32WM:$mask, undef)),
             (VMOVDQU16Zrmkz VK32WM:$mask, addr:$src)>;
-  def : Pat<(v32f16 (masked_load addr:$src, VK32WM:$mask, v32f16_info.ImmAllZerosV)),
+  def : Pat<(_.info512.VT (masked_load addr:$src, VK32WM:$mask, _.info512.ImmAllZerosV)),
             (VMOVDQU16Zrmkz VK32WM:$mask, addr:$src)>;
 
-  def : Pat<(masked_store (v32f16 VR512:$src), addr:$dst, VK32WM:$mask),
+  def : Pat<(masked_store (_.info512.VT VR512:$src), addr:$dst, VK32WM:$mask),
             (VMOVDQU16Zmrk addr:$dst, VK32WM:$mask, VR512:$src)>;
 }
 let Predicates = [HasBWI, HasVLX] in {
-  def : Pat<(v16f16 (vselect VK16WM:$mask, (v16f16 VR256X:$src1), (v16f16 VR256X:$src0))),
+  def : Pat<(_.info256.VT (vselect VK16WM:$mask, (_.info256.VT VR256X:$src1), (_.info256.VT VR256X:$src0))),
             (VMOVDQU16Z256rrk VR256X:$src0, VK16WM:$mask, VR256X:$src1)>;
-  def : Pat<(v16f16 (vselect VK16WM:$mask, (v16f16 VR256X:$src1), v16f16x_info.ImmAllZerosV)),
+  def : Pat<(_.info256.VT (vselect VK16WM:$mask, (_.info256.VT VR256X:$src1), _.info256.ImmAllZerosV)),
             (VMOVDQU16Z256rrkz VK16WM:$mask, VR256X:$src1)>;
-  def : Pat<(v16f16 (vselect VK16WM:$mask,
-                     (v16f16 (alignedloadv16f16 addr:$src)), (v16f16 VR256X:$src0))),
+  def : Pat<(_.info256.VT (vselect VK16WM:$mask,
+                     (_.info256.VT (_.info256.AlignedLdFrag addr:$src)), (_.info256.VT VR256X:$src0))),
             (VMOVDQU16Z256rmk VR256X:$src0, VK16WM:$mask, addr:$src)>;
-  def : Pat<(v16f16 (vselect VK16WM:$mask,
-                     (v16f16 (alignedloadv16f16 addr:$src)), v16f16x_info.ImmAllZerosV)),
+  def : Pat<(_.info256.VT (vselect VK16WM:$mask,
+                     (_.info256.VT (_.info256.AlignedLdFrag addr:$src)), _.info256.ImmAllZerosV)),
             (VMOVDQU16Z256rmkz VK16WM:$mask, addr:$src)>;
-  def : Pat<(v16f16 (vselect VK16WM:$mask,
-                     (v16f16 (loadv16f16 addr:$src)), (v16f16 VR256X:$src0))),
+  def : Pat<(_.info256.VT (vselect VK16WM:$mask,
+                     (_.info256.VT (_.info256.LdFrag addr:$src)), (_.info256.VT VR256X:$src0))),
             (VMOVDQU16Z256rmk VR256X:$src0, VK16WM:$mask, addr:$src)>;
-  def : Pat<(v16f16 (vselect VK16WM:$mask,
-                     (v16f16 (loadv16f16 addr:$src)), v16f16x_info.ImmAllZerosV)),
+  def : Pat<(_.info256.VT (vselect VK16WM:$mask,
+                     (_.info256.VT (_.info256.LdFrag addr:$src)), _.info256.ImmAllZerosV)),
             (VMOVDQU16Z256rmkz VK16WM:$mask, addr:$src)>;
-  def : Pat<(v16f16 (masked_load addr:$src, VK16WM:$mask, (v16f16 VR256X:$src0))),
+  def : Pat<(_.info256.VT (masked_load addr:$src, VK16WM:$mask, (_.info256.VT VR256X:$src0))),
             (VMOVDQU16Z256rmk VR256X:$src0, VK16WM:$mask, addr:$src)>;
-  def : Pat<(v16f16 (masked_load addr:$src, VK16WM:$mask, undef)),
+  def : Pat<(_.info256.VT (masked_load addr:$src, VK16WM:$mask, undef)),
             (VMOVDQU16Z256rmkz VK16WM:$mask, addr:$src)>;
-  def : Pat<(v16f16 (masked_load addr:$src, VK16WM:$mask, v16f16x_info.ImmAllZerosV)),
+  def : Pat<(_.info256.VT (masked_load addr:$src, VK16WM:$mask, _.info256.ImmAllZerosV)),
             (VMOVDQU16Z256rmkz VK16WM:$mask, addr:$src)>;
 
-  def : Pat<(masked_store (v16f16 VR256X:$src), addr:$dst, VK16WM:$mask),
+  def : Pat<(masked_store (_.info256.VT VR256X:$src), addr:$dst, VK16WM:$mask),
             (VMOVDQU16Z256mrk addr:$dst, VK16WM:$mask, VR256X:$src)>;
 
-  def : Pat<(v8f16 (vselect VK8WM:$mask, (v8f16 VR128X:$src1), (v8f16 VR128X:$src0))),
+  def : Pat<(_.info128.VT (vselect VK8WM:$mask, (_.info128.VT VR128X:$src1), (_.info128.VT VR128X:$src0))),
             (VMOVDQU16Z128rrk VR128X:$src0, VK8WM:$mask, VR128X:$src1)>;
-  def : Pat<(v8f16 (vselect VK8WM:$mask, (v8f16 VR128X:$src1), v8f16x_info.ImmAllZerosV)),
+  def : Pat<(_.info128.VT (vselect VK8WM:$mask, (_.info128.VT VR128X:$src1), _.info128.ImmAllZerosV)),
             (VMOVDQU16Z128rrkz VK8WM:$mask, VR128X:$src1)>;
-  def : Pat<(v8f16 (vselect VK8WM:$mask,
-                     (v8f16 (alignedloadv8f16 addr:$src)), (v8f16 VR128X:$src0))),
+  def : Pat<(_.info128.VT (vselect VK8WM:$mask,
+                     (_.info128.VT (_.info128.AlignedLdFrag addr:$src)), (_.info128.VT VR128X:$src0))),
             (VMOVDQU16Z128rmk VR128X:$src0, VK8WM:$mask, addr:$src)>;
-  def : Pat<(v8f16 (vselect VK8WM:$mask,
-                     (v8f16 (alignedloadv8f16 addr:$src)), v8f16x_info.ImmAllZerosV)),
+  def : Pat<(_.info128.VT (vselect VK8WM:$mask,
+                     (_.info128.VT (_.info128.AlignedLdFrag addr:$src)), _.info128.ImmAllZerosV)),
             (VMOVDQU16Z128rmkz VK8WM:$mask, addr:$src)>;
-  def : Pat<(v8f16 (vselect VK8WM:$mask,
-                     (v8f16 (loadv8f16 addr:$src)), (v8f16 VR128X:$src0))),
+  def : Pat<(_.info128.VT (vselect VK8WM:$mask,
+                     (_.info128.VT (_.info128.LdFrag addr:$src)), (_.info128.VT VR128X:$src0))),
             (VMOVDQU16Z128rmk VR128X:$src0, VK8WM:$mask, addr:$src)>;
-  def : Pat<(v8f16 (vselect VK8WM:$mask,
-                     (v8f16 (loadv8f16 addr:$src)), v8f16x_info.ImmAllZerosV)),
+  def : Pat<(_.info128.VT (vselect VK8WM:$mask,
+                     (_.info128.VT (_.info128.LdFrag addr:$src)), _.info128.ImmAllZerosV)),
             (VMOVDQU16Z128rmkz VK8WM:$mask, addr:$src)>;
-  def : Pat<(v8f16 (masked_load addr:$src, VK8WM:$mask, (v8f16 VR128X:$src0))),
+  def : Pat<(_.info128.VT (masked_load addr:$src, VK8WM:$mask, (_.info128.VT VR128X:$src0))),
             (VMOVDQU16Z128rmk VR128X:$src0, VK8WM:$mask, addr:$src)>;
-  def : Pat<(v8f16 (masked_load addr:$src, VK8WM:$mask, undef)),
+  def : Pat<(_.info128.VT (masked_load addr:$src, VK8WM:$mask, undef)),
             (VMOVDQU16Z128rmkz VK8WM:$mask, addr:$src)>;
-  def : Pat<(v8f16 (masked_load addr:$src, VK8WM:$mask, v8f16x_info.ImmAllZerosV)),
+  def : Pat<(_.info128.VT (masked_load addr:$src, VK8WM:$mask, _.info128.ImmAllZerosV)),
             (VMOVDQU16Z128rmkz VK8WM:$mask, addr:$src)>;
 
-  def : Pat<(masked_store (v8f16 VR128X:$src), addr:$dst, VK8WM:$mask),
+  def : Pat<(masked_store (_.info128.VT VR128X:$src), addr:$dst, VK8WM:$mask),
             (VMOVDQU16Z128mrk addr:$dst, VK8WM:$mask, VR128X:$src)>;
 }
+}
+
+defm : mask_move_lowering_f16_bf16<avx512vl_f16_info>;
+defm : mask_move_lowering_f16_bf16<avx512vl_bf16_info>;
 
 // Move Int Doubleword to Packed Double Int
 //
@@ -12811,7 +12854,7 @@ multiclass avx512_binop_all2<bits<8> opc, string OpcodeStr,
 let ExeDomain = SSEPackedSingle in
 defm VCVTNE2PS2BF16 : avx512_binop_all2<0x72, "vcvtne2ps2bf16",
                                         SchedWriteCvtPD2PS, //FIXME: Should be SchedWriteCvtPS2BF
-                                        avx512vl_f32_info, avx512vl_i16_info,
+                                        avx512vl_f32_info, avx512vl_bf16_info,
                                         X86cvtne2ps2bf16, HasBF16, 0>, T8XD;
 
 // Truncate Float to BFloat16
@@ -12819,15 +12862,15 @@ multiclass avx512_cvtps2bf16<bits<8> opc, string OpcodeStr,
                              X86SchedWriteWidths sched> {
   let ExeDomain = SSEPackedSingle in {
   let Predicates = [HasBF16], Uses = []<Register>, mayRaiseFPException = 0 in {
-    defm Z : avx512_vcvt_fp<opc, OpcodeStr, v16i16x_info, v16f32_info,
+    defm Z : avx512_vcvt_fp<opc, OpcodeStr, v16bf16x_info, v16f32_info,
                             X86cvtneps2bf16, X86cvtneps2bf16, sched.ZMM>, EVEX_V512;
   }
   let Predicates = [HasBF16, HasVLX] in {
     let Uses = []<Register>, mayRaiseFPException = 0 in {
-    defm Z128 : avx512_vcvt_fp<opc, OpcodeStr, v8i16x_info, v4f32x_info,
+    defm Z128 : avx512_vcvt_fp<opc, OpcodeStr, v8bf16x_info, v4f32x_info,
                                null_frag, null_frag, sched.XMM, "{1to4}", "{x}", f128mem,
                                VK4WM>, EVEX_V128;
-    defm Z256 : avx512_vcvt_fp<opc, OpcodeStr, v8i16x_info, v8f32x_info,
+    defm Z256 : avx512_vcvt_fp<opc, OpcodeStr, v8bf16x_info, v8f32x_info,
                                X86cvtneps2bf16, X86cvtneps2bf16,
                                sched.YMM, "{1to8}", "{y}">, EVEX_V256;
     }
@@ -12855,32 +12898,32 @@ defm VCVTNEPS2BF16 : avx512_cvtps2bf16<0x72, "vcvtneps2bf16",
 let Predicates = [HasBF16, HasVLX] in {
   // Special patterns to allow use of X86mcvtneps2bf16 for masking. Instruction
   // patterns have been disabled with null_frag.
-  def : Pat<(v8i16 (X86cvtneps2bf16 (v4f32 VR128X:$src))),
+  def : Pat<(v8bf16 (X86cvtneps2bf16 (v4f32 VR128X:$src))),
             (VCVTNEPS2BF16Z128rr VR128X:$src)>;
-  def : Pat<(X86mcvtneps2bf16 (v4f32 VR128X:$src), (v8i16 VR128X:$src0),
+  def : Pat<(X86mcvtneps2bf16 (v4f32 VR128X:$src), (v8bf16 VR128X:$src0),
                               VK4WM:$mask),
             (VCVTNEPS2BF16Z128rrk VR128X:$src0, VK4WM:$mask, VR128X:$src)>;
-  def : Pat<(X86mcvtneps2bf16 (v4f32 VR128X:$src), v8i16x_info.ImmAllZerosV,
+  def : Pat<(X86mcvtneps2bf16 (v4f32 VR128X:$src), v8bf16x_info.ImmAllZerosV,
                               VK4WM:$mask),
             (VCVTNEPS2BF16Z128rrkz VK4WM:$mask, VR128X:$src)>;
 
-  def : Pat<(v8i16 (X86cvtneps2bf16 (loadv4f32 addr:$src))),
+  def : Pat<(v8bf16 (X86cvtneps2bf16 (loadv4f32 addr:$src))),
             (VCVTNEPS2BF16Z128rm addr:$src)>;
-  def : Pat<(X86mcvtneps2bf16 (loadv4f32 addr:$src), (v8i16 VR128X:$src0),
+  def : Pat<(X86mcvtneps2bf16 (loadv4f32 addr:$src), (v8bf16 VR128X:$src0),
                               VK4WM:$mask),
             (VCVTNEPS2BF16Z128rmk VR128X:$src0, VK4WM:$mask, addr:$src)>;
-  def : Pat<(X86mcvtneps2bf16 (loadv4f32 addr:$src), v8i16x_info.ImmAllZerosV,
+  def : Pat<(X86mcvtneps2bf16 (loadv4f32 addr:$src), v8bf16x_info.ImmAllZerosV,
                               VK4WM:$mask),
             (VCVTNEPS2BF16Z128rmkz VK4WM:$mask, addr:$src)>;
 
-  def : Pat<(v8i16 (X86cvtneps2bf16 (v4f32
+  def : Pat<(v8bf16 (X86cvtneps2bf16 (v4f32
                                      (X86VBroadcastld32 addr:$src)))),
             (VCVTNEPS2BF16Z128rmb addr:$src)>;
   def : Pat<(X86mcvtneps2bf16 (v4f32 (X86VBroadcastld32 addr:$src)),
-                              (v8i16 VR128X:$src0), VK4WM:$mask),
+                              (v8bf16 VR128X:$src0), VK4WM:$mask),
             (VCVTNEPS2BF16Z128rmbk VR128X:$src0, VK4WM:$mask, addr:$src)>;
   def : Pat<(X86mcvtneps2bf16 (v4f32 (X86VBroadcastld32 addr:$src)),
-                              v8i16x_info.ImmAllZerosV, VK4WM:$mask),
+                              v8bf16x_info.ImmAllZerosV, VK4WM:$mask),
             (VCVTNEPS2BF16Z128rmbkz VK4WM:$mask, addr:$src)>;
 }
 
@@ -12902,7 +12945,7 @@ multiclass avx512_dpbf16ps_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
                                Sched<[sched.Folded, sched.ReadAfterFold]>;
 
   defm mb: AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
-                  (ins src_v.RC:$src2, src_v.ScalarMemOp:$src3),
+                  (ins src_v.RC:$src2, f32mem:$src3),
                   OpcodeStr,
                   !strconcat("${src3}", _.BroadcastStr,", $src2"),
                   !strconcat("$src2, ${src3}", _.BroadcastStr),
@@ -12930,7 +12973,7 @@ multiclass avx512_dpbf16ps_sizes<bits<8> opc, string OpcodeStr, SDNode OpNode,
 
 let ExeDomain = SSEPackedSingle in
 defm VDPBF16PS : avx512_dpbf16ps_sizes<0x52, "vdpbf16ps", X86dpbf16ps, SchedWriteFMA,
-                                       avx512vl_f32_info, avx512vl_i32_info,
+                                       avx512vl_f32_info, avx512vl_bf16_info,
                                        HasBF16>, T8XS, EVEX_CD8<32, CD8VF>;
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
index f35294da45f05..65c1adc38d7dc 100644
--- a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
+++ b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
@@ -785,23 +785,23 @@ def X86vfproundRnd: SDNode<"X86ISD::VFPROUND_RND",
 
 // cvt fp to bfloat16
 def X86cvtne2ps2bf16 : SDNode<"X86ISD::CVTNE2PS2BF16",
-                       SDTypeProfile<1, 2, [SDTCVecEltisVT<0, i16>,
+                       SDTypeProfile<1, 2, [SDTCVecEltisVT<0, bf16>,
                                             SDTCVecEltisVT<1, f32>,
                                             SDTCisSameSizeAs<0,1>,
                                             SDTCisSameAs<1,2>]>>;
 def X86mcvtneps2bf16 : SDNode<"X86ISD::MCVTNEPS2BF16",
-                       SDTypeProfile<1, 3, [SDTCVecEltisVT<0, i16>,
+                       SDTypeProfile<1, 3, [SDTCVecEltisVT<0, bf16>,
                                             SDTCVecEltisVT<1, f32>,
                                             SDTCisSameAs<0, 2>,
                                             SDTCVecEltisVT<3, i1>,
                                             SDTCisSameNumEltsAs<1, 3>]>>;
 def X86cvtneps2bf16 :  SDNode<"X86ISD::CVTNEPS2BF16",
-                       SDTypeProfile<1, 1, [SDTCVecEltisVT<0, i16>,
+                       SDTypeProfile<1, 1, [SDTCVecEltisVT<0, bf16>,
                                             SDTCVecEltisVT<1, f32>]>>;
 def X86dpbf16ps :      SDNode<"X86ISD::DPBF16PS",
                        SDTypeProfile<1, 3, [SDTCVecEltisVT<0, f32>,
                                             SDTCisSameAs<0,1>,
-                                            SDTCVecEltisVT<2, i32>,
+                                            SDTCVecEltisVT<2, bf16>,
                                             SDTCisSameAs<2,3>]>>;
 
 // galois field arithmetic
@@ -819,6 +819,7 @@ def SDTX86MaskedStore: SDTypeProfile<0, 3, [       // masked store
 
 // 128-bit load pattern fragments
 def loadv8f16    : PatFrag<(ops node:$ptr), (v8f16 (load node:$ptr))>;
+def loadv8bf16   : PatFrag<(ops node:$ptr), (v8bf16 (load node:$ptr))>;
 def loadv4f32    : PatFrag<(ops node:$ptr), (v4f32 (load node:$ptr))>;
 def loadv2f64    : PatFrag<(ops node:$ptr), (v2f64 (load node:$ptr))>;
 def loadv2i64    : PatFrag<(ops node:$ptr), (v2i64 (load node:$ptr))>;
@@ -828,6 +829,7 @@ def loadv16i8    : PatFrag<(ops node:$ptr), (v16i8 (load node:$ptr))>;
 
 // 256-bit load pattern fragments
 def loadv16f16   : PatFrag<(ops node:$ptr), (v16f16 (load node:$ptr))>;
+def loadv16bf16  : PatFrag<(ops node:$ptr), (v16bf16 (load node:$ptr))>;
 def loadv8f32    : PatFrag<(ops node:$ptr), (v8f32  (load node:$ptr))>;
 def loadv4f64    : PatFrag<(ops node:$ptr), (v4f64  (load node:$ptr))>;
 def loadv4i64    : PatFrag<(ops node:$ptr), (v4i64  (load node:$ptr))>;
@@ -837,6 +839,7 @@ def loadv32i8    : PatFrag<(ops node:$ptr), (v32i8  (load node:$ptr))>;
 
 // 512-bit load pattern fragments
 def loadv32f16   : PatFrag<(ops node:$ptr), (v32f16 (load node:$ptr))>;
+def loadv32bf16  : PatFrag<(ops node:$ptr), (v32bf16 (load node:$ptr))>;
 def loadv16f32   : PatFrag<(ops node:$ptr), (v16f32 (load node:$ptr))>;
 def loadv8f64    : PatFrag<(ops node:$ptr), (v8f64  (load node:$ptr))>;
 def loadv8i64    : PatFrag<(ops node:$ptr), (v8i64  (load node:$ptr))>;
@@ -870,6 +873,8 @@ def alignedload : PatFrag<(ops node:$ptr), (load node:$ptr), [{
 // NOTE: all 128-bit integer vector loads are promoted to v2i64
 def alignedloadv8f16 : PatFrag<(ops node:$ptr),
                                (v8f16 (alignedload node:$ptr))>;
+def alignedloadv8bf16 : PatFrag<(ops node:$ptr),
+                                (v8bf16 (alignedload node:$ptr))>;
 def alignedloadv4f32 : PatFrag<(ops node:$ptr),
                                (v4f32 (alignedload node:$ptr))>;
 def alignedloadv2f64 : PatFrag<(ops node:$ptr),
@@ -887,6 +892,8 @@ def alignedloadv16i8 : PatFrag<(ops node:$ptr),
 // NOTE: all 256-bit integer vector loads are promoted to v4i64
 def alignedloadv16f16 : PatFrag<(ops node:$ptr),
                                 (v16f16 (alignedload node:$ptr))>;
+def alignedloadv16bf16 : PatFrag<(ops node:$ptr),
+                                 (v16bf16 (alignedload node:$ptr))>;
 def alignedloadv8f32  : PatFrag<(ops node:$ptr),
                                 (v8f32  (alignedload node:$ptr))>;
 def alignedloadv4f64  : PatFrag<(ops node:$ptr),
@@ -903,6 +910,8 @@ def alignedloadv32i8  : PatFrag<(ops node:$ptr),
 // 512-bit aligned load pattern fragments
 def alignedloadv32f16 : PatFrag<(ops node:$ptr),
                                 (v32f16 (alignedload node:$ptr))>;
+def alignedloadv32bf16 : PatFrag<(ops node:$ptr),
+                                 (v32bf16 (alignedload node:$ptr))>;
 def alignedloadv16f32 : PatFrag<(ops node:$ptr),
                                 (v16f32 (alignedload node:$ptr))>;
 def alignedloadv8f64  : PatFrag<(ops node:$ptr),

diff  --git a/llvm/lib/Target/X86/X86RegisterInfo.td b/llvm/lib/Target/X86/X86RegisterInfo.td
index b5b151d3090e1..d1445384abdf6 100644
--- a/llvm/lib/Target/X86/X86RegisterInfo.td
+++ b/llvm/lib/Target/X86/X86RegisterInfo.td
@@ -561,9 +561,9 @@ def RSTi : RegisterOperand<RST, "printSTiRegOperand">;
 // Generic vector registers: VR64 and VR128.
 // Ensure that float types are declared first - only float is legal on SSE1.
 def VR64: RegisterClass<"X86", [x86mmx], 64, (sequence "MM%u", 0, 7)>;
-def VR128 : RegisterClass<"X86", [v4f32, v2f64, v8f16, v16i8, v8i16, v4i32, v2i64, f128],
+def VR128 : RegisterClass<"X86", [v4f32, v2f64, v8f16, v8bf16, v16i8, v8i16, v4i32, v2i64, f128],
                           128, (add FR32)>;
-def VR256 : RegisterClass<"X86", [v8f32, v4f64, v16f16, v32i8, v16i16, v8i32, v4i64],
+def VR256 : RegisterClass<"X86", [v8f32, v4f64, v16f16, v16bf16, v32i8, v16i16, v8i32, v4i64],
                           256, (sequence "YMM%u", 0, 15)>;
 
 // Status flags registers.
@@ -581,7 +581,7 @@ def DFCCR : RegisterClass<"X86", [i32], 32, (add DF)> {
 }
 
 // AVX-512 vector/mask registers.
-def VR512 : RegisterClass<"X86", [v16f32, v8f64, v32f16, v64i8, v32i16, v16i32, v8i64],
+def VR512 : RegisterClass<"X86", [v16f32, v8f64, v32f16, v32bf16, v64i8, v32i16, v16i32, v8i64],
                           512, (sequence "ZMM%u", 0, 31)>;
 
 // Represents the lower 16 registers that have VEX/legacy encodable subregs.
@@ -596,9 +596,9 @@ def FR64X : RegisterClass<"X86", [f64], 64, (add FR32X)>;
 def FR16X : RegisterClass<"X86", [f16], 16, (add FR32X)> {let Size = 32;}
 
 // Extended VR128 and VR256 for AVX-512 instructions
-def VR128X : RegisterClass<"X86", [v4f32, v2f64, v8f16, v16i8, v8i16, v4i32, v2i64, f128],
+def VR128X : RegisterClass<"X86", [v4f32, v2f64, v8f16, v8bf16, v16i8, v8i16, v4i32, v2i64, f128],
                            128, (add FR32X)>;
-def VR256X : RegisterClass<"X86", [v8f32, v4f64, v16f16, v32i8, v16i16, v8i32, v4i64],
+def VR256X : RegisterClass<"X86", [v8f32, v4f64, v16f16, v16bf16, v32i8, v16i16, v8i32, v4i64],
                            256, (sequence "YMM%u", 0, 31)>;
 
 // Mask registers

diff  --git a/llvm/test/CodeGen/X86/avx512bf16-intrinsics-upgrade.ll b/llvm/test/CodeGen/X86/avx512bf16-intrinsics-upgrade.ll
new file mode 100644
index 0000000000000..32f8c0f0be9f2
--- /dev/null
+++ b/llvm/test/CodeGen/X86/avx512bf16-intrinsics-upgrade.ll
@@ -0,0 +1,174 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=i686-unknown-unknown -mattr=+avx512bf16 --show-mc-encoding | FileCheck %s --check-prefixes=CHECK,X86
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512bf16 --show-mc-encoding | FileCheck %s --check-prefixes=CHECK,X64
+
+declare <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float>, <16 x float>) #3
+
+define <8 x i64> @test_mm512_cvtne2ps2bf16_512(<16 x float> %A, <16 x float> %B) local_unnamed_addr #2 {
+; CHECK-LABEL: test_mm512_cvtne2ps2bf16_512:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vcvtne2ps2bf16 %zmm1, %zmm0, %zmm0 # encoding: [0x62,0xf2,0x7f,0x48,0x72,0xc1]
+; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
+entry:
+  %0 = tail call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B) #4
+  %1 = bitcast <32 x i16> %0 to <8 x i64>
+  ret <8 x i64> %1
+}
+
+define <8 x i64> @test_mm512_maskz_cvtne2ps2bf16_512(<16 x float> %A, <16 x float> %B, i32 %U) local_unnamed_addr #2 {
+; X86-LABEL: test_mm512_maskz_cvtne2ps2bf16_512:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    vcvtne2ps2bf16 %zmm1, %zmm0, %zmm0 # encoding: [0x62,0xf2,0x7f,0x48,0x72,0xc1]
+; X86-NEXT:    kmovd {{[0-9]+}}(%esp), %k1 # encoding: [0xc4,0xe1,0xf9,0x90,0x4c,0x24,0x04]
+; X86-NEXT:    vmovdqu16 %zmm0, %zmm0 {%k1} {z} # encoding: [0x62,0xf1,0xff,0xc9,0x6f,0xc0]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm512_maskz_cvtne2ps2bf16_512:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    vcvtne2ps2bf16 %zmm1, %zmm0, %zmm0 # encoding: [0x62,0xf2,0x7f,0x48,0x72,0xc1]
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vmovdqu16 %zmm0, %zmm0 {%k1} {z} # encoding: [0x62,0xf1,0xff,0xc9,0x6f,0xc0]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B) #4
+  %1 = bitcast i32 %U to <32 x i1>
+  %2 = select <32 x i1> %1, <32 x i16> %0, <32 x i16> zeroinitializer
+  %3 = bitcast <32 x i16> %2 to <8 x i64>
+  ret <8 x i64> %3
+}
+
+define <8 x i64> @test_mm512_mask_cvtne2ps2bf16_512(<8 x i64> %C, i32 %U, <16 x float> %A, <16 x float> %B) local_unnamed_addr #2 {
+; X86-LABEL: test_mm512_mask_cvtne2ps2bf16_512:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    vcvtne2ps2bf16 %zmm2, %zmm1, %zmm1 # encoding: [0x62,0xf2,0x77,0x48,0x72,0xca]
+; X86-NEXT:    kmovd {{[0-9]+}}(%esp), %k1 # encoding: [0xc4,0xe1,0xf9,0x90,0x4c,0x24,0x04]
+; X86-NEXT:    vmovdqu16 %zmm1, %zmm0 {%k1} # encoding: [0x62,0xf1,0xff,0x49,0x6f,0xc1]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm512_mask_cvtne2ps2bf16_512:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    vcvtne2ps2bf16 %zmm2, %zmm1, %zmm1 # encoding: [0x62,0xf2,0x77,0x48,0x72,0xca]
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vmovdqu16 %zmm1, %zmm0 {%k1} # encoding: [0x62,0xf1,0xff,0x49,0x6f,0xc1]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B) #4
+  %1 = bitcast <8 x i64> %C to <32 x i16>
+  %2 = bitcast i32 %U to <32 x i1>
+  %3 = select <32 x i1> %2, <32 x i16> %0, <32 x i16> %1
+  %4 = bitcast <32 x i16> %3 to <8 x i64>
+  ret <8 x i64> %4
+}
+
+declare <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float>) #3
+
+define <4 x i64> @test_mm512_cvtneps2bf16_512(<16 x float> %A) local_unnamed_addr #2 {
+; CHECK-LABEL: test_mm512_cvtneps2bf16_512:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vcvtneps2bf16 %zmm0, %ymm0 # encoding: [0x62,0xf2,0x7e,0x48,0x72,0xc0]
+; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
+entry:
+  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A) #4
+  %1 = bitcast <16 x i16> %0 to <4 x i64>
+  ret <4 x i64> %1
+}
+
+define <4 x i64> @test_mm512_maskz_cvtneps2bf16_512(<16 x float> %A, i16 %U) local_unnamed_addr #2 {
+; X86-LABEL: test_mm512_maskz_cvtneps2bf16_512:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    vcvtneps2bf16 %zmm0, %ymm0 # encoding: [0x62,0xf2,0x7e,0x48,0x72,0xc0]
+; X86-NEXT:    kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
+; X86-NEXT:    vmovdqu16 %zmm0, %zmm0 {%k1} {z} # encoding: [0x62,0xf1,0xff,0xc9,0x6f,0xc0]
+; X86-NEXT:    # kill: def $ymm0 killed $ymm0 killed $zmm0
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm512_maskz_cvtneps2bf16_512:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    vcvtneps2bf16 %zmm0, %ymm0 # encoding: [0x62,0xf2,0x7e,0x48,0x72,0xc0]
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vmovdqu16 %zmm0, %zmm0 {%k1} {z} # encoding: [0x62,0xf1,0xff,0xc9,0x6f,0xc0]
+; X64-NEXT:    # kill: def $ymm0 killed $ymm0 killed $zmm0
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A) #4
+  %1 = bitcast i16 %U to <16 x i1>
+  %2 = select <16 x i1> %1, <16 x i16> %0, <16 x i16> zeroinitializer
+  %3 = bitcast <16 x i16> %2 to <4 x i64>
+  ret <4 x i64> %3
+}
+
+define <4 x i64> @test_mm512_mask_cvtneps2bf16_512(<4 x i64> %C, i16 %U, <16 x float> %A) local_unnamed_addr #2 {
+; X86-LABEL: test_mm512_mask_cvtneps2bf16_512:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    # kill: def $ymm0 killed $ymm0 def $zmm0
+; X86-NEXT:    vcvtneps2bf16 %zmm1, %ymm1 # encoding: [0x62,0xf2,0x7e,0x48,0x72,0xc9]
+; X86-NEXT:    kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
+; X86-NEXT:    vmovdqu16 %zmm1, %zmm0 {%k1} # encoding: [0x62,0xf1,0xff,0x49,0x6f,0xc1]
+; X86-NEXT:    # kill: def $ymm0 killed $ymm0 killed $zmm0
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm512_mask_cvtneps2bf16_512:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    # kill: def $ymm0 killed $ymm0 def $zmm0
+; X64-NEXT:    vcvtneps2bf16 %zmm1, %ymm1 # encoding: [0x62,0xf2,0x7e,0x48,0x72,0xc9]
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vmovdqu16 %zmm1, %zmm0 {%k1} # encoding: [0x62,0xf1,0xff,0x49,0x6f,0xc1]
+; X64-NEXT:    # kill: def $ymm0 killed $ymm0 killed $zmm0
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A) #4
+  %1 = bitcast <4 x i64> %C to <16 x i16>
+  %2 = bitcast i16 %U to <16 x i1>
+  %3 = select <16 x i1> %2, <16 x i16> %0, <16 x i16> %1
+  %4 = bitcast <16 x i16> %3 to <4 x i64>
+  ret <4 x i64> %4
+}
+
+declare <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float>, <16 x i32>, <16 x i32>) #3
+
+define <16 x float> @test_mm512_dpbf16ps_512(<16 x float> %E, <16 x i32> %A, <16 x i32> %B) local_unnamed_addr #2 {
+; CHECK-LABEL: test_mm512_dpbf16ps_512:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vdpbf16ps %zmm2, %zmm1, %zmm0 # encoding: [0x62,0xf2,0x76,0x48,0x52,0xc2]
+; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
+entry:
+  %0 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <16 x i32> %A, <16 x i32> %B) #4
+  ret <16 x float> %0
+}
+
+define <16 x float> @test_mm512_maskz_dpbf16ps_512(<16 x float> %E, <16 x i32> %A, <16 x i32> %B, i16 zeroext %U) local_unnamed_addr #2 {
+; X86-LABEL: test_mm512_maskz_dpbf16ps_512:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
+; X86-NEXT:    vdpbf16ps %zmm2, %zmm1, %zmm0 {%k1} {z} # encoding: [0x62,0xf2,0x76,0xc9,0x52,0xc2]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm512_maskz_dpbf16ps_512:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vdpbf16ps %zmm2, %zmm1, %zmm0 {%k1} {z} # encoding: [0x62,0xf2,0x76,0xc9,0x52,0xc2]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <16 x i32> %A, <16 x i32> %B) #4
+  %1 = bitcast i16 %U to <16 x i1>
+  %2 = select <16 x i1> %1, <16 x float> %0, <16 x float> zeroinitializer
+  ret <16 x float> %2
+}
+define <16 x float> @test_mm512_mask_dpbf16ps_512(i16 zeroext %U, <16 x float> %E, <16 x i32> %A, <16 x i32> %B) local_unnamed_addr #2 {
+; X86-LABEL: test_mm512_mask_dpbf16ps_512:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
+; X86-NEXT:    vdpbf16ps %zmm2, %zmm1, %zmm0 {%k1} # encoding: [0x62,0xf2,0x76,0x49,0x52,0xc2]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm512_mask_dpbf16ps_512:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vdpbf16ps %zmm2, %zmm1, %zmm0 {%k1} # encoding: [0x62,0xf2,0x76,0x49,0x52,0xc2]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <16 x i32> %A, <16 x i32> %B) #4
+  %1 = bitcast i16 %U to <16 x i1>
+  %2 = select <16 x i1> %1, <16 x float> %0, <16 x float> %E
+  ret <16 x float> %2
+}

diff  --git a/llvm/test/CodeGen/X86/avx512bf16-intrinsics.ll b/llvm/test/CodeGen/X86/avx512bf16-intrinsics.ll
index 759e6ca5683ec..a5767ecf14a53 100644
--- a/llvm/test/CodeGen/X86/avx512bf16-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512bf16-intrinsics.ll
@@ -2,7 +2,7 @@
 ; RUN: llc < %s -mtriple=i686-unknown-unknown -mattr=+avx512bf16 --show-mc-encoding | FileCheck %s --check-prefixes=CHECK,X86
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512bf16 --show-mc-encoding | FileCheck %s --check-prefixes=CHECK,X64
 
-declare <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float>, <16 x float>) #3
+declare <32 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float>, <16 x float>) #3
 
 define <8 x i64> @test_mm512_cvtne2ps2bf16_512(<16 x float> %A, <16 x float> %B) local_unnamed_addr #2 {
 ; CHECK-LABEL: test_mm512_cvtne2ps2bf16_512:
@@ -10,8 +10,8 @@ define <8 x i64> @test_mm512_cvtne2ps2bf16_512(<16 x float> %A, <16 x float> %B)
 ; CHECK-NEXT:    vcvtne2ps2bf16 %zmm1, %zmm0, %zmm0 # encoding: [0x62,0xf2,0x7f,0x48,0x72,0xc1]
 ; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
 entry:
-  %0 = tail call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B) #4
-  %1 = bitcast <32 x i16> %0 to <8 x i64>
+  %0 = tail call <32 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B) #4
+  %1 = bitcast <32 x bfloat> %0 to <8 x i64>
   ret <8 x i64> %1
 }
 
@@ -28,10 +28,10 @@ define <8 x i64> @test_mm512_maskz_cvtne2ps2bf16_512(<16 x float> %A, <16 x floa
 ; X64-NEXT:    vcvtne2ps2bf16 %zmm1, %zmm0, %zmm0 {%k1} {z} # encoding: [0x62,0xf2,0x7f,0xc9,0x72,0xc1]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B) #4
+  %0 = tail call <32 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B) #4
   %1 = bitcast i32 %U to <32 x i1>
-  %2 = select <32 x i1> %1, <32 x i16> %0, <32 x i16> zeroinitializer
-  %3 = bitcast <32 x i16> %2 to <8 x i64>
+  %2 = select <32 x i1> %1, <32 x bfloat> %0, <32 x bfloat> zeroinitializer
+  %3 = bitcast <32 x bfloat> %2 to <8 x i64>
   ret <8 x i64> %3
 }
 
@@ -48,15 +48,15 @@ define <8 x i64> @test_mm512_mask_cvtne2ps2bf16_512(<8 x i64> %C, i32 %U, <16 x
 ; X64-NEXT:    vcvtne2ps2bf16 %zmm2, %zmm1, %zmm0 {%k1} # encoding: [0x62,0xf2,0x77,0x49,0x72,0xc2]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B) #4
-  %1 = bitcast <8 x i64> %C to <32 x i16>
+  %0 = tail call <32 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B) #4
+  %1 = bitcast <8 x i64> %C to <32 x bfloat>
   %2 = bitcast i32 %U to <32 x i1>
-  %3 = select <32 x i1> %2, <32 x i16> %0, <32 x i16> %1
-  %4 = bitcast <32 x i16> %3 to <8 x i64>
+  %3 = select <32 x i1> %2, <32 x bfloat> %0, <32 x bfloat> %1
+  %4 = bitcast <32 x bfloat> %3 to <8 x i64>
   ret <8 x i64> %4
 }
 
-declare <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float>) #3
+declare <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float>) #3
 
 define <4 x i64> @test_mm512_cvtneps2bf16_512(<16 x float> %A) local_unnamed_addr #2 {
 ; CHECK-LABEL: test_mm512_cvtneps2bf16_512:
@@ -64,8 +64,8 @@ define <4 x i64> @test_mm512_cvtneps2bf16_512(<16 x float> %A) local_unnamed_add
 ; CHECK-NEXT:    vcvtneps2bf16 %zmm0, %ymm0 # encoding: [0x62,0xf2,0x7e,0x48,0x72,0xc0]
 ; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
 entry:
-  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A) #4
-  %1 = bitcast <16 x i16> %0 to <4 x i64>
+  %0 = tail call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A) #4
+  %1 = bitcast <16 x bfloat> %0 to <4 x i64>
   ret <4 x i64> %1
 }
 
@@ -82,10 +82,10 @@ define <4 x i64> @test_mm512_maskz_cvtneps2bf16_512(<16 x float> %A, i16 %U) loc
 ; X64-NEXT:    vcvtneps2bf16 %zmm0, %ymm0 {%k1} {z} # encoding: [0x62,0xf2,0x7e,0xc9,0x72,0xc0]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A) #4
+  %0 = tail call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A) #4
   %1 = bitcast i16 %U to <16 x i1>
-  %2 = select <16 x i1> %1, <16 x i16> %0, <16 x i16> zeroinitializer
-  %3 = bitcast <16 x i16> %2 to <4 x i64>
+  %2 = select <16 x i1> %1, <16 x bfloat> %0, <16 x bfloat> zeroinitializer
+  %3 = bitcast <16 x bfloat> %2 to <4 x i64>
   ret <4 x i64> %3
 }
 
@@ -102,27 +102,27 @@ define <4 x i64> @test_mm512_mask_cvtneps2bf16_512(<4 x i64> %C, i16 %U, <16 x f
 ; X64-NEXT:    vcvtneps2bf16 %zmm1, %ymm0 {%k1} # encoding: [0x62,0xf2,0x7e,0x49,0x72,0xc1]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A) #4
-  %1 = bitcast <4 x i64> %C to <16 x i16>
+  %0 = tail call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A) #4
+  %1 = bitcast <4 x i64> %C to <16 x bfloat>
   %2 = bitcast i16 %U to <16 x i1>
-  %3 = select <16 x i1> %2, <16 x i16> %0, <16 x i16> %1
-  %4 = bitcast <16 x i16> %3 to <4 x i64>
+  %3 = select <16 x i1> %2, <16 x bfloat> %0, <16 x bfloat> %1
+  %4 = bitcast <16 x bfloat> %3 to <4 x i64>
   ret <4 x i64> %4
 }
 
-declare <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float>, <16 x i32>, <16 x i32>) #3
+declare <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float>, <32 x bfloat>, <32 x bfloat>) #3
 
-define <16 x float> @test_mm512_dpbf16ps_512(<16 x float> %E, <16 x i32> %A, <16 x i32> %B) local_unnamed_addr #2 {
+define <16 x float> @test_mm512_dpbf16ps_512(<16 x float> %E, <32 x bfloat> %A, <32 x bfloat> %B) local_unnamed_addr #2 {
 ; CHECK-LABEL: test_mm512_dpbf16ps_512:
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vdpbf16ps %zmm2, %zmm1, %zmm0 # encoding: [0x62,0xf2,0x76,0x48,0x52,0xc2]
 ; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
 entry:
-  %0 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <16 x i32> %A, <16 x i32> %B) #4
+  %0 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <32 x bfloat> %A, <32 x bfloat> %B) #4
   ret <16 x float> %0
 }
 
-define <16 x float> @test_mm512_maskz_dpbf16ps_512(<16 x float> %E, <16 x i32> %A, <16 x i32> %B, i16 zeroext %U) local_unnamed_addr #2 {
+define <16 x float> @test_mm512_maskz_dpbf16ps_512(<16 x float> %E, <32 x bfloat> %A, <32 x bfloat> %B, i16 zeroext %U) local_unnamed_addr #2 {
 ; X86-LABEL: test_mm512_maskz_dpbf16ps_512:
 ; X86:       # %bb.0: # %entry
 ; X86-NEXT:    kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
@@ -135,12 +135,12 @@ define <16 x float> @test_mm512_maskz_dpbf16ps_512(<16 x float> %E, <16 x i32> %
 ; X64-NEXT:    vdpbf16ps %zmm2, %zmm1, %zmm0 {%k1} {z} # encoding: [0x62,0xf2,0x76,0xc9,0x52,0xc2]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <16 x i32> %A, <16 x i32> %B) #4
+  %0 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <32 x bfloat> %A, <32 x bfloat> %B) #4
   %1 = bitcast i16 %U to <16 x i1>
   %2 = select <16 x i1> %1, <16 x float> %0, <16 x float> zeroinitializer
   ret <16 x float> %2
 }
-define <16 x float> @test_mm512_mask_dpbf16ps_512(i16 zeroext %U, <16 x float> %E, <16 x i32> %A, <16 x i32> %B) local_unnamed_addr #2 {
+define <16 x float> @test_mm512_mask_dpbf16ps_512(i16 zeroext %U, <16 x float> %E, <32 x bfloat> %A, <32 x bfloat> %B) local_unnamed_addr #2 {
 ; X86-LABEL: test_mm512_mask_dpbf16ps_512:
 ; X86:       # %bb.0: # %entry
 ; X86-NEXT:    kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
@@ -153,7 +153,7 @@ define <16 x float> @test_mm512_mask_dpbf16ps_512(i16 zeroext %U, <16 x float> %
 ; X64-NEXT:    vdpbf16ps %zmm2, %zmm1, %zmm0 {%k1} # encoding: [0x62,0xf2,0x76,0x49,0x52,0xc2]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <16 x i32> %A, <16 x i32> %B) #4
+  %0 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <32 x bfloat> %A, <32 x bfloat> %B) #4
   %1 = bitcast i16 %U to <16 x i1>
   %2 = select <16 x i1> %1, <16 x float> %0, <16 x float> %E
   ret <16 x float> %2

diff  --git a/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics-upgrade.ll b/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics-upgrade.ll
new file mode 100644
index 0000000000000..f51fe141fed21
--- /dev/null
+++ b/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics-upgrade.ll
@@ -0,0 +1,370 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=i686-unknown-unknown -mattr=+avx512bf16 -mattr=+avx512vl --show-mc-encoding | FileCheck %s --check-prefixes=CHECK,X86
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512bf16 -mattr=+avx512vl --show-mc-encoding | FileCheck %s --check-prefixes=CHECK,X64
+
+declare <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float>, <4 x float>) #1
+
+define <2 x i64> @test_mm_cvtne2ps2bf16_128(<4 x float> %A, <4 x float> %B) local_unnamed_addr #0 {
+; CHECK-LABEL: test_mm_cvtne2ps2bf16_128:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vcvtne2ps2bf16 %xmm1, %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7f,0x08,0x72,0xc1]
+; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
+entry:
+  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %A, <4 x float> %B) #2
+  %1 = bitcast <8 x i16> %0 to <2 x i64>
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_mm_maskz_cvtne2ps2bf16_128(<4 x float> %A, <4 x float> %B, i8 zeroext %U) local_unnamed_addr #0 {
+; X86-LABEL: test_mm_maskz_cvtne2ps2bf16_128:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    vcvtne2ps2bf16 %xmm1, %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7f,0x08,0x72,0xc1]
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vmovdqu16 %xmm0, %xmm0 {%k1} {z} # encoding: [0x62,0xf1,0xff,0x89,0x6f,0xc0]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm_maskz_cvtne2ps2bf16_128:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    vcvtne2ps2bf16 %xmm1, %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7f,0x08,0x72,0xc1]
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vmovdqu16 %xmm0, %xmm0 {%k1} {z} # encoding: [0x62,0xf1,0xff,0x89,0x6f,0xc0]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %A, <4 x float> %B) #2
+  %1 = bitcast i8 %U to <8 x i1>
+  %2 = select <8 x i1> %1, <8 x i16> %0, <8 x i16> zeroinitializer
+  %3 = bitcast <8 x i16> %2 to <2 x i64>
+  ret <2 x i64> %3
+}
+
+define <2 x i64> @test_mm_mask_cvtne2ps2bf16_128(<2 x i64> %C, i8 zeroext %U, <4 x float> %A, <4 x float> %B) local_unnamed_addr #0 {
+; X86-LABEL: test_mm_mask_cvtne2ps2bf16_128:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    vcvtne2ps2bf16 %xmm2, %xmm1, %xmm1 # encoding: [0x62,0xf2,0x77,0x08,0x72,0xca]
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vmovdqu16 %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xc1]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm_mask_cvtne2ps2bf16_128:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    vcvtne2ps2bf16 %xmm2, %xmm1, %xmm1 # encoding: [0x62,0xf2,0x77,0x08,0x72,0xca]
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vmovdqu16 %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xc1]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %A, <4 x float> %B) #2
+  %1 = bitcast <2 x i64> %C to <8 x i16>
+  %2 = bitcast i8 %U to <8 x i1>
+  %3 = select <8 x i1> %2, <8 x i16> %0, <8 x i16> %1
+  %4 = bitcast <8 x i16> %3 to <2 x i64>
+  ret <2 x i64> %4
+}
+
+declare <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float>, <8 x float>) #3
+
+define <4 x i64> @test_mm256_cvtne2ps2bf16_256(<8 x float> %A, <8 x float> %B) local_unnamed_addr #1 {
+; CHECK-LABEL: test_mm256_cvtne2ps2bf16_256:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vcvtne2ps2bf16 %ymm1, %ymm0, %ymm0 # encoding: [0x62,0xf2,0x7f,0x28,0x72,0xc1]
+; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
+entry:
+  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %A, <8 x float> %B) #4
+  %1 = bitcast <16 x i16> %0 to <4 x i64>
+  ret <4 x i64> %1
+}
+
+define <4 x i64> @test_mm256_maskz_cvtne2ps2bf16_256(<8 x float> %A, <8 x float> %B, i16 zeroext %U) local_unnamed_addr #1 {
+; X86-LABEL: test_mm256_maskz_cvtne2ps2bf16_256:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    vcvtne2ps2bf16 %ymm1, %ymm0, %ymm0 # encoding: [0x62,0xf2,0x7f,0x28,0x72,0xc1]
+; X86-NEXT:    kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
+; X86-NEXT:    vmovdqu16 %ymm0, %ymm0 {%k1} {z} # encoding: [0x62,0xf1,0xff,0xa9,0x6f,0xc0]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm256_maskz_cvtne2ps2bf16_256:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    vcvtne2ps2bf16 %ymm1, %ymm0, %ymm0 # encoding: [0x62,0xf2,0x7f,0x28,0x72,0xc1]
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vmovdqu16 %ymm0, %ymm0 {%k1} {z} # encoding: [0x62,0xf1,0xff,0xa9,0x6f,0xc0]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %A, <8 x float> %B) #4
+  %1 = bitcast i16 %U to <16 x i1>
+  %2 = select <16 x i1> %1, <16 x i16> %0, <16 x i16> zeroinitializer
+  %3 = bitcast <16 x i16> %2 to <4 x i64>
+  ret <4 x i64> %3
+}
+
+define <4 x i64> @test_mm256_mask_cvtne2ps2bf16_256(<4 x i64> %C, i16 zeroext %U, <8 x float> %A, <8 x float> %B) local_unnamed_addr #1 {
+; X86-LABEL: test_mm256_mask_cvtne2ps2bf16_256:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    vcvtne2ps2bf16 %ymm2, %ymm1, %ymm1 # encoding: [0x62,0xf2,0x77,0x28,0x72,0xca]
+; X86-NEXT:    kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
+; X86-NEXT:    vmovdqu16 %ymm1, %ymm0 {%k1} # encoding: [0x62,0xf1,0xff,0x29,0x6f,0xc1]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm256_mask_cvtne2ps2bf16_256:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    vcvtne2ps2bf16 %ymm2, %ymm1, %ymm1 # encoding: [0x62,0xf2,0x77,0x28,0x72,0xca]
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vmovdqu16 %ymm1, %ymm0 {%k1} # encoding: [0x62,0xf1,0xff,0x29,0x6f,0xc1]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %A, <8 x float> %B) #4
+  %1 = bitcast <4 x i64> %C to <16 x i16>
+  %2 = bitcast i16 %U to <16 x i1>
+  %3 = select <16 x i1> %2, <16 x i16> %0, <16 x i16> %1
+  %4 = bitcast <16 x i16> %3 to <4 x i64>
+  ret <4 x i64> %4
+}
+
+declare <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float>) #3
+
+define <2 x i64> @test_mm256_cvtneps2bf16_256(<8 x float> %A) local_unnamed_addr #2 {
+; CHECK-LABEL: test_mm256_cvtneps2bf16_256:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vcvtneps2bf16 %ymm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x28,0x72,0xc0]
+; CHECK-NEXT:    vzeroupper # encoding: [0xc5,0xf8,0x77]
+; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
+entry:
+  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %A) #4
+  %1 = bitcast <8 x i16> %0 to <2 x i64>
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_mm256_maskz_cvtneps2bf16_256(<8 x float> %A, i8 zeroext %U) local_unnamed_addr #2 {
+; X86-LABEL: test_mm256_maskz_cvtneps2bf16_256:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    vcvtneps2bf16 %ymm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x28,0x72,0xc0]
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vmovdqu16 %xmm0, %xmm0 {%k1} {z} # encoding: [0x62,0xf1,0xff,0x89,0x6f,0xc0]
+; X86-NEXT:    vzeroupper # encoding: [0xc5,0xf8,0x77]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm256_maskz_cvtneps2bf16_256:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    vcvtneps2bf16 %ymm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x28,0x72,0xc0]
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vmovdqu16 %xmm0, %xmm0 {%k1} {z} # encoding: [0x62,0xf1,0xff,0x89,0x6f,0xc0]
+; X64-NEXT:    vzeroupper # encoding: [0xc5,0xf8,0x77]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %A) #4
+  %1 = bitcast i8 %U to <8 x i1>
+  %2 = select <8 x i1> %1, <8 x i16> %0, <8 x i16> zeroinitializer
+  %3 = bitcast <8 x i16> %2 to <2 x i64>
+  ret <2 x i64> %3
+}
+
+define <2 x i64> @test_mm256_mask_cvtneps2bf16_256(<2 x i64> %C, i8 zeroext %U, <8 x float> %A) local_unnamed_addr #2 {
+; X86-LABEL: test_mm256_mask_cvtneps2bf16_256:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    vcvtneps2bf16 %ymm1, %xmm1 # encoding: [0x62,0xf2,0x7e,0x28,0x72,0xc9]
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vmovdqu16 %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xc1]
+; X86-NEXT:    vzeroupper # encoding: [0xc5,0xf8,0x77]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm256_mask_cvtneps2bf16_256:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    vcvtneps2bf16 %ymm1, %xmm1 # encoding: [0x62,0xf2,0x7e,0x28,0x72,0xc9]
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vmovdqu16 %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xc1]
+; X64-NEXT:    vzeroupper # encoding: [0xc5,0xf8,0x77]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %A) #4
+  %1 = bitcast <2 x i64> %C to <8 x i16>
+  %2 = bitcast i8 %U to <8 x i1>
+  %3 = select <8 x i1> %2, <8 x i16> %0, <8 x i16> %1
+  %4 = bitcast <8 x i16> %3 to <2 x i64>
+  ret <2 x i64> %4
+}
+
+declare <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float>, <8 x i16>, <4 x i1>) #3
+
+define <2 x i64> @test_mm128_cvtneps2bf16_128(<4 x float> %A) local_unnamed_addr #2 {
+; CHECK-LABEL: test_mm128_cvtneps2bf16_128:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
+; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
+entry:
+  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x i16> undef, <4 x i1> <i1 true, i1 true, i1 true, i1 true>) #4
+  %1 = bitcast <8 x i16> %0 to <2 x i64>
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_mm128_maskz_cvtneps2bf16_128(<4 x float> %A, i8 zeroext %U) local_unnamed_addr #2 {
+; X86-LABEL: test_mm128_maskz_cvtneps2bf16_128:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vcvtneps2bf16 %xmm0, %xmm0 {%k1} {z} # encoding: [0x62,0xf2,0x7e,0x89,0x72,0xc0]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm128_maskz_cvtneps2bf16_128:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vcvtneps2bf16 %xmm0, %xmm0 {%k1} {z} # encoding: [0x62,0xf2,0x7e,0x89,0x72,0xc0]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = bitcast i8 %U to <8 x i1>
+  %1 = shufflevector <8 x i1> %0, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %2 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x i16> zeroinitializer, <4 x i1> %1) #4
+  %3 = bitcast <8 x i16> %2 to <2 x i64>
+  ret <2 x i64> %3
+}
+
+define <2 x i64> @test_mm128_mask_cvtneps2bf16_128(<2 x i64> %C, i8 zeroext %U, <4 x float> %A) local_unnamed_addr #2 {
+; X86-LABEL: test_mm128_mask_cvtneps2bf16_128:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vcvtneps2bf16 %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf2,0x7e,0x09,0x72,0xc1]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm128_mask_cvtneps2bf16_128:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vcvtneps2bf16 %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf2,0x7e,0x09,0x72,0xc1]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = bitcast i8 %U to <8 x i1>
+  %1 = shufflevector <8 x i1> %0, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %2 = bitcast <2 x i64> %C to <8 x i16>
+  %3 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x i16> %2, <4 x i1> %1) #4
+  %4 = bitcast <8 x i16> %3 to <2 x i64>
+  ret <2 x i64> %4
+}
+
+; Make sure we don't fold a select into the 128 bit form of cvtneps2bf16. It
+; always writes zeros to bits 127:64 regardless of mask.
+define <2 x i64> @test_mm128_cvtneps2bf16_128_select(<2 x i64> %C, i8 zeroext %U, <4 x float> %A) local_unnamed_addr #2 {
+; X86-LABEL: test_mm128_cvtneps2bf16_128_select:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vcvtneps2bf16 %xmm1, %xmm1 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc9]
+; X86-NEXT:    vmovdqu16 %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xc1]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm128_cvtneps2bf16_128_select:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vcvtneps2bf16 %xmm1, %xmm1 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc9]
+; X64-NEXT:    vmovdqu16 %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xc1]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = bitcast i8 %U to <8 x i1>
+  %1 = bitcast <2 x i64> %C to <8 x i16>
+  %2 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x i16> undef, <4 x i1> <i1 true, i1 true, i1 true, i1 true>) #4
+  %3 = select <8 x i1> %0, <8 x i16> %2, <8 x i16> %1
+  %4 = bitcast <8 x i16> %3 to <2 x i64>
+  ret <2 x i64> %4
+}
+
+declare <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float>, <8 x i32>, <8 x i32>) #3
+
+define <8 x float> @test_mm256_dpbf16ps_256(<8 x float> %E, <8 x i32> %A, <8 x i32> %B) local_unnamed_addr #2 {
+; CHECK-LABEL: test_mm256_dpbf16ps_256:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vdpbf16ps %ymm2, %ymm1, %ymm0 # encoding: [0x62,0xf2,0x76,0x28,0x52,0xc2]
+; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
+entry:
+  %0 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %E, <8 x i32> %A, <8 x i32> %B) #4
+  ret <8 x float> %0
+}
+
+define <8 x float> @test_mm256_maskz_dpbf16ps_256(<8 x float> %E, <8 x i32> %A, <8 x i32> %B, i8 zeroext %U) local_unnamed_addr #2 {
+; X86-LABEL: test_mm256_maskz_dpbf16ps_256:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vdpbf16ps %ymm2, %ymm1, %ymm0 {%k1} {z} # encoding: [0x62,0xf2,0x76,0xa9,0x52,0xc2]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm256_maskz_dpbf16ps_256:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vdpbf16ps %ymm2, %ymm1, %ymm0 {%k1} {z} # encoding: [0x62,0xf2,0x76,0xa9,0x52,0xc2]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %E, <8 x i32> %A, <8 x i32> %B) #4
+  %1 = bitcast i8 %U to <8 x i1>
+  %2 = select <8 x i1> %1, <8 x float> %0, <8 x float> zeroinitializer
+  ret <8 x float> %2
+}
+define <8 x float> @test_mm256_mask_dpbf16ps_256(i8 zeroext %U, <8 x float> %E, <8 x i32> %A, <8 x i32> %B) local_unnamed_addr #2 {
+; X86-LABEL: test_mm256_mask_dpbf16ps_256:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vdpbf16ps %ymm2, %ymm1, %ymm0 {%k1} # encoding: [0x62,0xf2,0x76,0x29,0x52,0xc2]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm256_mask_dpbf16ps_256:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vdpbf16ps %ymm2, %ymm1, %ymm0 {%k1} # encoding: [0x62,0xf2,0x76,0x29,0x52,0xc2]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %E, <8 x i32> %A, <8 x i32> %B) #4
+  %1 = bitcast i8 %U to <8 x i1>
+  %2 = select <8 x i1> %1, <8 x float> %0, <8 x float> %E
+  ret <8 x float> %2
+}
+
+declare <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float>, <4 x i32>, <4 x i32>) #3
+
+define <4 x float> @test_mm128_dpbf16ps_128(<4 x float> %E, <4 x i32> %A, <4 x i32> %B) local_unnamed_addr #2 {
+; CHECK-LABEL: test_mm128_dpbf16ps_128:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vdpbf16ps %xmm2, %xmm1, %xmm0 # encoding: [0x62,0xf2,0x76,0x08,0x52,0xc2]
+; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
+entry:
+  %0 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %E, <4 x i32> %A, <4x i32> %B) #4
+  ret <4 x float> %0
+}
+
+define <4 x float> @test_mm128_maskz_dpbf16ps_128(<4 x float> %E, <4 x i32> %A, <4 x i32> %B, i4 zeroext %U) local_unnamed_addr #2 {
+; X86-LABEL: test_mm128_maskz_dpbf16ps_128:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vdpbf16ps %xmm2, %xmm1, %xmm0 {%k1} {z} # encoding: [0x62,0xf2,0x76,0x89,0x52,0xc2]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm128_maskz_dpbf16ps_128:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vdpbf16ps %xmm2, %xmm1, %xmm0 {%k1} {z} # encoding: [0x62,0xf2,0x76,0x89,0x52,0xc2]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %E, <4 x i32> %A, <4 x i32> %B) #4
+  %1 = bitcast i4 %U to <4 x i1>
+  %2 = select <4 x i1> %1, <4 x float> %0, <4 x float> zeroinitializer
+  ret <4 x float> %2
+}
+define <4 x float> @test_mm128_mask_dpbf16ps_128(i4 zeroext %U, <4 x float> %E, <4 x i32> %A, <4 x i32> %B) local_unnamed_addr #2 {
+; X86-LABEL: test_mm128_mask_dpbf16ps_128:
+; X86:       # %bb.0: # %entry
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
+; X86-NEXT:    kmovd %eax, %k1 # encoding: [0xc5,0xfb,0x92,0xc8]
+; X86-NEXT:    vdpbf16ps %xmm2, %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf2,0x76,0x09,0x52,0xc2]
+; X86-NEXT:    retl # encoding: [0xc3]
+;
+; X64-LABEL: test_mm128_mask_dpbf16ps_128:
+; X64:       # %bb.0: # %entry
+; X64-NEXT:    kmovd %edi, %k1 # encoding: [0xc5,0xfb,0x92,0xcf]
+; X64-NEXT:    vdpbf16ps %xmm2, %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf2,0x76,0x09,0x52,0xc2]
+; X64-NEXT:    retq # encoding: [0xc3]
+entry:
+  %0 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %E, <4 x i32> %A, <4 x i32> %B) #4
+  %1 = bitcast i4 %U to <4 x i1>
+  %2 = select <4 x i1> %1, <4 x float> %0, <4 x float> %E
+  ret <4 x float> %2
+}

diff  --git a/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll
index 170197816ae19..3cdc5de871e21 100644
--- a/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll
@@ -2,7 +2,7 @@
 ; RUN: llc < %s -mtriple=i686-unknown-unknown -mattr=+avx512bf16 -mattr=+avx512vl --show-mc-encoding | FileCheck %s --check-prefixes=CHECK,X86
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512bf16 -mattr=+avx512vl --show-mc-encoding | FileCheck %s --check-prefixes=CHECK,X64
 
-declare <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float>, <4 x float>) #1
+declare <8 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float>, <4 x float>) #1
 
 define <2 x i64> @test_mm_cvtne2ps2bf16_128(<4 x float> %A, <4 x float> %B) local_unnamed_addr #0 {
 ; CHECK-LABEL: test_mm_cvtne2ps2bf16_128:
@@ -10,8 +10,8 @@ define <2 x i64> @test_mm_cvtne2ps2bf16_128(<4 x float> %A, <4 x float> %B) loca
 ; CHECK-NEXT:    vcvtne2ps2bf16 %xmm1, %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7f,0x08,0x72,0xc1]
 ; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
 entry:
-  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %A, <4 x float> %B) #2
-  %1 = bitcast <8 x i16> %0 to <2 x i64>
+  %0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %A, <4 x float> %B) #2
+  %1 = bitcast <8 x bfloat> %0 to <2 x i64>
   ret <2 x i64> %1
 }
 
@@ -29,10 +29,10 @@ define <2 x i64> @test_mm_maskz_cvtne2ps2bf16_128(<4 x float> %A, <4 x float> %B
 ; X64-NEXT:    vcvtne2ps2bf16 %xmm1, %xmm0, %xmm0 {%k1} {z} # encoding: [0x62,0xf2,0x7f,0x89,0x72,0xc1]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %A, <4 x float> %B) #2
+  %0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %A, <4 x float> %B) #2
   %1 = bitcast i8 %U to <8 x i1>
-  %2 = select <8 x i1> %1, <8 x i16> %0, <8 x i16> zeroinitializer
-  %3 = bitcast <8 x i16> %2 to <2 x i64>
+  %2 = select <8 x i1> %1, <8 x bfloat> %0, <8 x bfloat> zeroinitializer
+  %3 = bitcast <8 x bfloat> %2 to <2 x i64>
   ret <2 x i64> %3
 }
 
@@ -50,15 +50,15 @@ define <2 x i64> @test_mm_mask_cvtne2ps2bf16_128(<2 x i64> %C, i8 zeroext %U, <4
 ; X64-NEXT:    vcvtne2ps2bf16 %xmm2, %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf2,0x77,0x09,0x72,0xc2]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %A, <4 x float> %B) #2
-  %1 = bitcast <2 x i64> %C to <8 x i16>
+  %0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %A, <4 x float> %B) #2
+  %1 = bitcast <2 x i64> %C to <8 x bfloat>
   %2 = bitcast i8 %U to <8 x i1>
-  %3 = select <8 x i1> %2, <8 x i16> %0, <8 x i16> %1
-  %4 = bitcast <8 x i16> %3 to <2 x i64>
+  %3 = select <8 x i1> %2, <8 x bfloat> %0, <8 x bfloat> %1
+  %4 = bitcast <8 x bfloat> %3 to <2 x i64>
   ret <2 x i64> %4
 }
 
-declare <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float>, <8 x float>) #3
+declare <16 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float>, <8 x float>) #3
 
 define <4 x i64> @test_mm256_cvtne2ps2bf16_256(<8 x float> %A, <8 x float> %B) local_unnamed_addr #1 {
 ; CHECK-LABEL: test_mm256_cvtne2ps2bf16_256:
@@ -66,8 +66,8 @@ define <4 x i64> @test_mm256_cvtne2ps2bf16_256(<8 x float> %A, <8 x float> %B) l
 ; CHECK-NEXT:    vcvtne2ps2bf16 %ymm1, %ymm0, %ymm0 # encoding: [0x62,0xf2,0x7f,0x28,0x72,0xc1]
 ; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
 entry:
-  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %A, <8 x float> %B) #4
-  %1 = bitcast <16 x i16> %0 to <4 x i64>
+  %0 = tail call <16 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %A, <8 x float> %B) #4
+  %1 = bitcast <16 x bfloat> %0 to <4 x i64>
   ret <4 x i64> %1
 }
 
@@ -84,10 +84,10 @@ define <4 x i64> @test_mm256_maskz_cvtne2ps2bf16_256(<8 x float> %A, <8 x float>
 ; X64-NEXT:    vcvtne2ps2bf16 %ymm1, %ymm0, %ymm0 {%k1} {z} # encoding: [0x62,0xf2,0x7f,0xa9,0x72,0xc1]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %A, <8 x float> %B) #4
+  %0 = tail call <16 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %A, <8 x float> %B) #4
   %1 = bitcast i16 %U to <16 x i1>
-  %2 = select <16 x i1> %1, <16 x i16> %0, <16 x i16> zeroinitializer
-  %3 = bitcast <16 x i16> %2 to <4 x i64>
+  %2 = select <16 x i1> %1, <16 x bfloat> %0, <16 x bfloat> zeroinitializer
+  %3 = bitcast <16 x bfloat> %2 to <4 x i64>
   ret <4 x i64> %3
 }
 
@@ -104,15 +104,15 @@ define <4 x i64> @test_mm256_mask_cvtne2ps2bf16_256(<4 x i64> %C, i16 zeroext %U
 ; X64-NEXT:    vcvtne2ps2bf16 %ymm2, %ymm1, %ymm0 {%k1} # encoding: [0x62,0xf2,0x77,0x29,0x72,0xc2]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %A, <8 x float> %B) #4
-  %1 = bitcast <4 x i64> %C to <16 x i16>
+  %0 = tail call <16 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %A, <8 x float> %B) #4
+  %1 = bitcast <4 x i64> %C to <16 x bfloat>
   %2 = bitcast i16 %U to <16 x i1>
-  %3 = select <16 x i1> %2, <16 x i16> %0, <16 x i16> %1
-  %4 = bitcast <16 x i16> %3 to <4 x i64>
+  %3 = select <16 x i1> %2, <16 x bfloat> %0, <16 x bfloat> %1
+  %4 = bitcast <16 x bfloat> %3 to <4 x i64>
   ret <4 x i64> %4
 }
 
-declare <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float>) #3
+declare <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float>) #3
 
 define <2 x i64> @test_mm256_cvtneps2bf16_256(<8 x float> %A) local_unnamed_addr #2 {
 ; CHECK-LABEL: test_mm256_cvtneps2bf16_256:
@@ -121,8 +121,8 @@ define <2 x i64> @test_mm256_cvtneps2bf16_256(<8 x float> %A) local_unnamed_addr
 ; CHECK-NEXT:    vzeroupper # encoding: [0xc5,0xf8,0x77]
 ; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
 entry:
-  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %A) #4
-  %1 = bitcast <8 x i16> %0 to <2 x i64>
+  %0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %A) #4
+  %1 = bitcast <8 x bfloat> %0 to <2 x i64>
   ret <2 x i64> %1
 }
 
@@ -142,10 +142,10 @@ define <2 x i64> @test_mm256_maskz_cvtneps2bf16_256(<8 x float> %A, i8 zeroext %
 ; X64-NEXT:    vzeroupper # encoding: [0xc5,0xf8,0x77]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %A) #4
+  %0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %A) #4
   %1 = bitcast i8 %U to <8 x i1>
-  %2 = select <8 x i1> %1, <8 x i16> %0, <8 x i16> zeroinitializer
-  %3 = bitcast <8 x i16> %2 to <2 x i64>
+  %2 = select <8 x i1> %1, <8 x bfloat> %0, <8 x bfloat> zeroinitializer
+  %3 = bitcast <8 x bfloat> %2 to <2 x i64>
   ret <2 x i64> %3
 }
 
@@ -165,15 +165,15 @@ define <2 x i64> @test_mm256_mask_cvtneps2bf16_256(<2 x i64> %C, i8 zeroext %U,
 ; X64-NEXT:    vzeroupper # encoding: [0xc5,0xf8,0x77]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %A) #4
-  %1 = bitcast <2 x i64> %C to <8 x i16>
+  %0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %A) #4
+  %1 = bitcast <2 x i64> %C to <8 x bfloat>
   %2 = bitcast i8 %U to <8 x i1>
-  %3 = select <8 x i1> %2, <8 x i16> %0, <8 x i16> %1
-  %4 = bitcast <8 x i16> %3 to <2 x i64>
+  %3 = select <8 x i1> %2, <8 x bfloat> %0, <8 x bfloat> %1
+  %4 = bitcast <8 x bfloat> %3 to <2 x i64>
   ret <2 x i64> %4
 }
 
-declare <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float>, <8 x i16>, <4 x i1>) #3
+declare <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float>, <8 x bfloat>, <4 x i1>) #3
 
 define <2 x i64> @test_mm128_cvtneps2bf16_128(<4 x float> %A) local_unnamed_addr #2 {
 ; CHECK-LABEL: test_mm128_cvtneps2bf16_128:
@@ -181,8 +181,8 @@ define <2 x i64> @test_mm128_cvtneps2bf16_128(<4 x float> %A) local_unnamed_addr
 ; CHECK-NEXT:    vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
 ; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
 entry:
-  %0 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x i16> undef, <4 x i1> <i1 true, i1 true, i1 true, i1 true>) #4
-  %1 = bitcast <8 x i16> %0 to <2 x i64>
+  %0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x bfloat> undef, <4 x i1> <i1 true, i1 true, i1 true, i1 true>) #4
+  %1 = bitcast <8 x bfloat> %0 to <2 x i64>
   ret <2 x i64> %1
 }
 
@@ -202,8 +202,8 @@ define <2 x i64> @test_mm128_maskz_cvtneps2bf16_128(<4 x float> %A, i8 zeroext %
 entry:
   %0 = bitcast i8 %U to <8 x i1>
   %1 = shufflevector <8 x i1> %0, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-  %2 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x i16> zeroinitializer, <4 x i1> %1) #4
-  %3 = bitcast <8 x i16> %2 to <2 x i64>
+  %2 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x bfloat> zeroinitializer, <4 x i1> %1) #4
+  %3 = bitcast <8 x bfloat> %2 to <2 x i64>
   ret <2 x i64> %3
 }
 
@@ -223,9 +223,9 @@ define <2 x i64> @test_mm128_mask_cvtneps2bf16_128(<2 x i64> %C, i8 zeroext %U,
 entry:
   %0 = bitcast i8 %U to <8 x i1>
   %1 = shufflevector <8 x i1> %0, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-  %2 = bitcast <2 x i64> %C to <8 x i16>
-  %3 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x i16> %2, <4 x i1> %1) #4
-  %4 = bitcast <8 x i16> %3 to <2 x i64>
+  %2 = bitcast <2 x i64> %C to <8 x bfloat>
+  %3 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x bfloat> %2, <4 x i1> %1) #4
+  %4 = bitcast <8 x bfloat> %3 to <2 x i64>
   ret <2 x i64> %4
 }
 
@@ -248,26 +248,26 @@ define <2 x i64> @test_mm128_cvtneps2bf16_128_select(<2 x i64> %C, i8 zeroext %U
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
   %0 = bitcast i8 %U to <8 x i1>
-  %1 = bitcast <2 x i64> %C to <8 x i16>
-  %2 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x i16> undef, <4 x i1> <i1 true, i1 true, i1 true, i1 true>) #4
-  %3 = select <8 x i1> %0, <8 x i16> %2, <8 x i16> %1
-  %4 = bitcast <8 x i16> %3 to <2 x i64>
+  %1 = bitcast <2 x i64> %C to <8 x bfloat>
+  %2 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x bfloat> undef, <4 x i1> <i1 true, i1 true, i1 true, i1 true>) #4
+  %3 = select <8 x i1> %0, <8 x bfloat> %2, <8 x bfloat> %1
+  %4 = bitcast <8 x bfloat> %3 to <2 x i64>
   ret <2 x i64> %4
 }
 
-declare <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float>, <8 x i32>, <8 x i32>) #3
+declare <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float>, <16 x bfloat>, <16 x bfloat>) #3
 
-define <8 x float> @test_mm256_dpbf16ps_256(<8 x float> %E, <8 x i32> %A, <8 x i32> %B) local_unnamed_addr #2 {
+define <8 x float> @test_mm256_dpbf16ps_256(<8 x float> %E, <16 x bfloat> %A, <16 x bfloat> %B) local_unnamed_addr #2 {
 ; CHECK-LABEL: test_mm256_dpbf16ps_256:
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vdpbf16ps %ymm2, %ymm1, %ymm0 # encoding: [0x62,0xf2,0x76,0x28,0x52,0xc2]
 ; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
 entry:
-  %0 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %E, <8 x i32> %A, <8 x i32> %B) #4
+  %0 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %E, <16 x bfloat> %A, <16 x bfloat> %B) #4
   ret <8 x float> %0
 }
 
-define <8 x float> @test_mm256_maskz_dpbf16ps_256(<8 x float> %E, <8 x i32> %A, <8 x i32> %B, i8 zeroext %U) local_unnamed_addr #2 {
+define <8 x float> @test_mm256_maskz_dpbf16ps_256(<8 x float> %E, <16 x bfloat> %A, <16 x bfloat> %B, i8 zeroext %U) local_unnamed_addr #2 {
 ; X86-LABEL: test_mm256_maskz_dpbf16ps_256:
 ; X86:       # %bb.0: # %entry
 ; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
@@ -281,12 +281,12 @@ define <8 x float> @test_mm256_maskz_dpbf16ps_256(<8 x float> %E, <8 x i32> %A,
 ; X64-NEXT:    vdpbf16ps %ymm2, %ymm1, %ymm0 {%k1} {z} # encoding: [0x62,0xf2,0x76,0xa9,0x52,0xc2]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %E, <8 x i32> %A, <8 x i32> %B) #4
+  %0 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %E, <16 x bfloat> %A, <16 x bfloat> %B) #4
   %1 = bitcast i8 %U to <8 x i1>
   %2 = select <8 x i1> %1, <8 x float> %0, <8 x float> zeroinitializer
   ret <8 x float> %2
 }
-define <8 x float> @test_mm256_mask_dpbf16ps_256(i8 zeroext %U, <8 x float> %E, <8 x i32> %A, <8 x i32> %B) local_unnamed_addr #2 {
+define <8 x float> @test_mm256_mask_dpbf16ps_256(i8 zeroext %U, <8 x float> %E, <16 x bfloat> %A, <16 x bfloat> %B) local_unnamed_addr #2 {
 ; X86-LABEL: test_mm256_mask_dpbf16ps_256:
 ; X86:       # %bb.0: # %entry
 ; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
@@ -300,25 +300,25 @@ define <8 x float> @test_mm256_mask_dpbf16ps_256(i8 zeroext %U, <8 x float> %E,
 ; X64-NEXT:    vdpbf16ps %ymm2, %ymm1, %ymm0 {%k1} # encoding: [0x62,0xf2,0x76,0x29,0x52,0xc2]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %E, <8 x i32> %A, <8 x i32> %B) #4
+  %0 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %E, <16 x bfloat> %A, <16 x bfloat> %B) #4
   %1 = bitcast i8 %U to <8 x i1>
   %2 = select <8 x i1> %1, <8 x float> %0, <8 x float> %E
   ret <8 x float> %2
 }
 
-declare <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float>, <4 x i32>, <4 x i32>) #3
+declare <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float>, <8 x bfloat>, <8 x bfloat>) #3
 
-define <4 x float> @test_mm128_dpbf16ps_128(<4 x float> %E, <4 x i32> %A, <4 x i32> %B) local_unnamed_addr #2 {
+define <4 x float> @test_mm128_dpbf16ps_128(<4 x float> %E, <8 x bfloat> %A, <8 x bfloat> %B) local_unnamed_addr #2 {
 ; CHECK-LABEL: test_mm128_dpbf16ps_128:
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vdpbf16ps %xmm2, %xmm1, %xmm0 # encoding: [0x62,0xf2,0x76,0x08,0x52,0xc2]
 ; CHECK-NEXT:    ret{{[l|q]}} # encoding: [0xc3]
 entry:
-  %0 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %E, <4 x i32> %A, <4x i32> %B) #4
+  %0 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %E, <8 x bfloat> %A, <8 x bfloat> %B) #4
   ret <4 x float> %0
 }
 
-define <4 x float> @test_mm128_maskz_dpbf16ps_128(<4 x float> %E, <4 x i32> %A, <4 x i32> %B, i4 zeroext %U) local_unnamed_addr #2 {
+define <4 x float> @test_mm128_maskz_dpbf16ps_128(<4 x float> %E, <8 x bfloat> %A, <8 x bfloat> %B, i4 zeroext %U) local_unnamed_addr #2 {
 ; X86-LABEL: test_mm128_maskz_dpbf16ps_128:
 ; X86:       # %bb.0: # %entry
 ; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
@@ -332,12 +332,12 @@ define <4 x float> @test_mm128_maskz_dpbf16ps_128(<4 x float> %E, <4 x i32> %A,
 ; X64-NEXT:    vdpbf16ps %xmm2, %xmm1, %xmm0 {%k1} {z} # encoding: [0x62,0xf2,0x76,0x89,0x52,0xc2]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %E, <4 x i32> %A, <4 x i32> %B) #4
+  %0 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %E, <8 x bfloat> %A, <8 x bfloat> %B) #4
   %1 = bitcast i4 %U to <4 x i1>
   %2 = select <4 x i1> %1, <4 x float> %0, <4 x float> zeroinitializer
   ret <4 x float> %2
 }
-define <4 x float> @test_mm128_mask_dpbf16ps_128(i4 zeroext %U, <4 x float> %E, <4 x i32> %A, <4 x i32> %B) local_unnamed_addr #2 {
+define <4 x float> @test_mm128_mask_dpbf16ps_128(i4 zeroext %U, <4 x float> %E, <8 x bfloat> %A, <8 x bfloat> %B) local_unnamed_addr #2 {
 ; X86-LABEL: test_mm128_mask_dpbf16ps_128:
 ; X86:       # %bb.0: # %entry
 ; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax # encoding: [0x0f,0xb6,0x44,0x24,0x04]
@@ -351,7 +351,7 @@ define <4 x float> @test_mm128_mask_dpbf16ps_128(i4 zeroext %U, <4 x float> %E,
 ; X64-NEXT:    vdpbf16ps %xmm2, %xmm1, %xmm0 {%k1} # encoding: [0x62,0xf2,0x76,0x09,0x52,0xc2]
 ; X64-NEXT:    retq # encoding: [0xc3]
 entry:
-  %0 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %E, <4 x i32> %A, <4 x i32> %B) #4
+  %0 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %E, <8 x bfloat> %A, <8 x bfloat> %B) #4
   %1 = bitcast i4 %U to <4 x i1>
   %2 = select <4 x i1> %1, <4 x float> %0, <4 x float> %E
   ret <4 x float> %2

diff  --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
index 20b095f3b74f8..ca338235f1fd6 100644
--- a/llvm/test/CodeGen/X86/bfloat.ll
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -1,23 +1,41 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s
+; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s --check-prefixes=CHECK,SSE2
+; RUN: llc < %s -mtriple=x86_64-linux-gnu -mattr=avx512bf16 | FileCheck %s --check-prefixes=CHECK,BF16
 
 define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
-; CHECK-LABEL: add:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    pushq %rbx
-; CHECK-NEXT:    movq %rdx, %rbx
-; CHECK-NEXT:    movzwl (%rsi), %eax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm1
-; CHECK-NEXT:    movzwl (%rdi), %eax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    movw %ax, (%rbx)
-; CHECK-NEXT:    popq %rbx
-; CHECK-NEXT:    retq
+; SSE2-LABEL: add:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    pushq %rbx
+; SSE2-NEXT:    movq %rdx, %rbx
+; SSE2-NEXT:    movzwl (%rsi), %eax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm1
+; SSE2-NEXT:    movzwl (%rdi), %eax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    movw %ax, (%rbx)
+; SSE2-NEXT:    popq %rbx
+; SSE2-NEXT:    retq
+;
+; BF16-LABEL: add:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rbx
+; BF16-NEXT:    movq %rdx, %rbx
+; BF16-NEXT:    movzwl (%rsi), %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    movzwl (%rdi), %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %eax
+; BF16-NEXT:    movw %ax, (%rbx)
+; BF16-NEXT:    popq %rbx
+; BF16-NEXT:    retq
   %a = load bfloat, ptr %pa
   %b = load bfloat, ptr %pb
   %add = fadd bfloat %a, %b
@@ -26,52 +44,95 @@ define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
 }
 
 define bfloat @add2(bfloat %a, bfloat %b) nounwind {
-; CHECK-LABEL: add2:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    pushq %rax
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    movd %xmm1, %ecx
-; CHECK-NEXT:    shll $16, %ecx
-; CHECK-NEXT:    movd %ecx, %xmm1
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    popq %rax
-; CHECK-NEXT:    retq
+; SSE2-LABEL: add2:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    pushq %rax
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    movd %xmm1, %ecx
+; SSE2-NEXT:    shll $16, %ecx
+; SSE2-NEXT:    movd %ecx, %xmm1
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    popq %rax
+; SSE2-NEXT:    retq
+;
+; BF16-LABEL: add2:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rax
+; BF16-NEXT:    vmovd %xmm0, %eax
+; BF16-NEXT:    vmovd %xmm1, %ecx
+; BF16-NEXT:    shll $16, %ecx
+; BF16-NEXT:    vmovd %ecx, %xmm0
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    popq %rax
+; BF16-NEXT:    retq
   %add = fadd bfloat %a, %b
   ret bfloat %add
 }
 
 define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
-; CHECK-LABEL: add_double:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    pushq %rbp
-; CHECK-NEXT:    pushq %r14
-; CHECK-NEXT:    pushq %rbx
-; CHECK-NEXT:    movq %rdx, %rbx
-; CHECK-NEXT:    movq %rsi, %r14
-; CHECK-NEXT:    movq {{.*#+}} xmm0 = mem[0],zero
-; CHECK-NEXT:    callq __truncdfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %ebp
-; CHECK-NEXT:    movq {{.*#+}} xmm0 = mem[0],zero
-; CHECK-NEXT:    callq __truncdfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm1
-; CHECK-NEXT:    shll $16, %ebp
-; CHECK-NEXT:    movd %ebp, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    cvtss2sd %xmm0, %xmm0
-; CHECK-NEXT:    movsd %xmm0, (%rbx)
-; CHECK-NEXT:    popq %rbx
-; CHECK-NEXT:    popq %r14
-; CHECK-NEXT:    popq %rbp
-; CHECK-NEXT:    retq
+; SSE2-LABEL: add_double:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    pushq %rbp
+; SSE2-NEXT:    pushq %r14
+; SSE2-NEXT:    pushq %rbx
+; SSE2-NEXT:    movq %rdx, %rbx
+; SSE2-NEXT:    movq %rsi, %r14
+; SSE2-NEXT:    movq {{.*#+}} xmm0 = mem[0],zero
+; SSE2-NEXT:    callq __truncdfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %ebp
+; SSE2-NEXT:    movq {{.*#+}} xmm0 = mem[0],zero
+; SSE2-NEXT:    callq __truncdfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm1
+; SSE2-NEXT:    shll $16, %ebp
+; SSE2-NEXT:    movd %ebp, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    cvtss2sd %xmm0, %xmm0
+; SSE2-NEXT:    movsd %xmm0, (%rbx)
+; SSE2-NEXT:    popq %rbx
+; SSE2-NEXT:    popq %r14
+; SSE2-NEXT:    popq %rbp
+; SSE2-NEXT:    retq
+;
+; BF16-LABEL: add_double:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rbp
+; BF16-NEXT:    pushq %r14
+; BF16-NEXT:    pushq %rbx
+; BF16-NEXT:    movq %rdx, %rbx
+; BF16-NEXT:    movq %rsi, %r14
+; BF16-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; BF16-NEXT:    callq __truncdfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %ebp
+; BF16-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; BF16-NEXT:    callq __truncdfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    shll $16, %ebp
+; BF16-NEXT:    vmovd %ebp, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
+; BF16-NEXT:    vmovsd %xmm0, (%rbx)
+; BF16-NEXT:    popq %rbx
+; BF16-NEXT:    popq %r14
+; BF16-NEXT:    popq %rbp
+; BF16-NEXT:    retq
   %la = load double, ptr %pa
   %a = fptrunc double %la to bfloat
   %lb = load double, ptr %pb
@@ -83,30 +144,55 @@ define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
 }
 
 define double @add_double2(double %da, double %db) nounwind {
-; CHECK-LABEL: add_double2:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    pushq %rbx
-; CHECK-NEXT:    subq $16, %rsp
-; CHECK-NEXT:    movsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; CHECK-NEXT:    callq __truncdfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %ebx
-; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload
-; CHECK-NEXT:    # xmm0 = mem[0],zero
-; CHECK-NEXT:    callq __truncdfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm1
-; CHECK-NEXT:    shll $16, %ebx
-; CHECK-NEXT:    movd %ebx, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    cvtss2sd %xmm0, %xmm0
-; CHECK-NEXT:    addq $16, %rsp
-; CHECK-NEXT:    popq %rbx
-; CHECK-NEXT:    retq
+; SSE2-LABEL: add_double2:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    pushq %rbx
+; SSE2-NEXT:    subq $16, %rsp
+; SSE2-NEXT:    movsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; SSE2-NEXT:    callq __truncdfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %ebx
+; SSE2-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload
+; SSE2-NEXT:    # xmm0 = mem[0],zero
+; SSE2-NEXT:    callq __truncdfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm1
+; SSE2-NEXT:    shll $16, %ebx
+; SSE2-NEXT:    movd %ebx, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    cvtss2sd %xmm0, %xmm0
+; SSE2-NEXT:    addq $16, %rsp
+; SSE2-NEXT:    popq %rbx
+; SSE2-NEXT:    retq
+;
+; BF16-LABEL: add_double2:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rbx
+; BF16-NEXT:    subq $16, %rsp
+; BF16-NEXT:    vmovsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; BF16-NEXT:    callq __truncdfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %ebx
+; BF16-NEXT:    vmovq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload
+; BF16-NEXT:    # xmm0 = mem[0],zero
+; BF16-NEXT:    callq __truncdfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    shll $16, %ebx
+; BF16-NEXT:    vmovd %ebx, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
+; BF16-NEXT:    addq $16, %rsp
+; BF16-NEXT:    popq %rbx
+; BF16-NEXT:    retq
   %a = fptrunc double %da to bfloat
   %b = fptrunc double %db to bfloat
   %add = fadd bfloat %a, %b
@@ -115,19 +201,33 @@ define double @add_double2(double %da, double %db) nounwind {
 }
 
 define void @add_constant(ptr %pa, ptr %pc) nounwind {
-; CHECK-LABEL: add_constant:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    pushq %rbx
-; CHECK-NEXT:    movq %rsi, %rbx
-; CHECK-NEXT:    movzwl (%rdi), %eax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    addss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    movw %ax, (%rbx)
-; CHECK-NEXT:    popq %rbx
-; CHECK-NEXT:    retq
+; SSE2-LABEL: add_constant:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    pushq %rbx
+; SSE2-NEXT:    movq %rsi, %rbx
+; SSE2-NEXT:    movzwl (%rdi), %eax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    addss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    movw %ax, (%rbx)
+; SSE2-NEXT:    popq %rbx
+; SSE2-NEXT:    retq
+;
+; BF16-LABEL: add_constant:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rbx
+; BF16-NEXT:    movq %rsi, %rbx
+; BF16-NEXT:    movzwl (%rdi), %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %eax
+; BF16-NEXT:    movw %ax, (%rbx)
+; BF16-NEXT:    popq %rbx
+; BF16-NEXT:    retq
   %a = load bfloat, ptr %pa
   %add = fadd bfloat %a, 1.0
   store bfloat %add, ptr %pc
@@ -135,16 +235,27 @@ define void @add_constant(ptr %pa, ptr %pc) nounwind {
 }
 
 define bfloat @add_constant2(bfloat %a) nounwind {
-; CHECK-LABEL: add_constant2:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    pushq %rax
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    addss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    popq %rax
-; CHECK-NEXT:    retq
+; SSE2-LABEL: add_constant2:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    pushq %rax
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    addss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    popq %rax
+; SSE2-NEXT:    retq
+;
+; BF16-LABEL: add_constant2:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rax
+; BF16-NEXT:    vmovd %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    popq %rax
+; BF16-NEXT:    retq
   %add = fadd bfloat %a, 1.0
   ret bfloat %add
 }
@@ -181,140 +292,254 @@ define bfloat @fold_ext_trunc2(bfloat %a) nounwind {
 }
 
 define <8 x bfloat> @addv(<8 x bfloat> %a, <8 x bfloat> %b) nounwind {
-; CHECK-LABEL: addv:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    pushq %rbp
-; CHECK-NEXT:    pushq %r15
-; CHECK-NEXT:    pushq %r14
-; CHECK-NEXT:    pushq %r13
-; CHECK-NEXT:    pushq %r12
-; CHECK-NEXT:    pushq %rbx
-; CHECK-NEXT:    subq $56, %rsp
-; CHECK-NEXT:    movq %xmm0, %rcx
-; CHECK-NEXT:    movq %rcx, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; CHECK-NEXT:    movq %rcx, %rax
-; CHECK-NEXT:    shrq $32, %rax
-; CHECK-NEXT:    movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; CHECK-NEXT:    movq %xmm1, %rdx
-; CHECK-NEXT:    movq %rdx, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; CHECK-NEXT:    movq %rdx, %rax
-; CHECK-NEXT:    shrq $32, %rax
-; CHECK-NEXT:    movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; CHECK-NEXT:    movq %rcx, %rax
-; CHECK-NEXT:    shrq $48, %rax
-; CHECK-NEXT:    movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; CHECK-NEXT:    movq %rdx, %rax
-; CHECK-NEXT:    shrq $48, %rax
-; CHECK-NEXT:    movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; CHECK-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3]
-; CHECK-NEXT:    movq %xmm0, %r12
-; CHECK-NEXT:    movq %r12, %rax
-; CHECK-NEXT:    shrq $32, %rax
-; CHECK-NEXT:    movq %rax, (%rsp) # 8-byte Spill
-; CHECK-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
-; CHECK-NEXT:    movq %xmm0, %r14
-; CHECK-NEXT:    movq %r14, %rbp
-; CHECK-NEXT:    shrq $32, %rbp
-; CHECK-NEXT:    movq %r12, %r15
-; CHECK-NEXT:    shrq $48, %r15
-; CHECK-NEXT:    movq %r14, %r13
-; CHECK-NEXT:    shrq $48, %r13
-; CHECK-NEXT:    movl %r14d, %eax
-; CHECK-NEXT:    andl $-65536, %eax # imm = 0xFFFF0000
-; CHECK-NEXT:    movd %eax, %xmm1
-; CHECK-NEXT:    movl %r12d, %eax
-; CHECK-NEXT:    andl $-65536, %eax # imm = 0xFFFF0000
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %ebx
-; CHECK-NEXT:    shll $16, %ebx
-; CHECK-NEXT:    shll $16, %r14d
-; CHECK-NEXT:    movd %r14d, %xmm1
-; CHECK-NEXT:    shll $16, %r12d
-; CHECK-NEXT:    movd %r12d, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    movzwl %ax, %r12d
-; CHECK-NEXT:    orl %ebx, %r12d
-; CHECK-NEXT:    shll $16, %r13d
-; CHECK-NEXT:    movd %r13d, %xmm1
-; CHECK-NEXT:    shll $16, %r15d
-; CHECK-NEXT:    movd %r15d, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %r14d
-; CHECK-NEXT:    shll $16, %r14d
-; CHECK-NEXT:    shll $16, %ebp
-; CHECK-NEXT:    movd %ebp, %xmm1
-; CHECK-NEXT:    movq (%rsp), %rax # 8-byte Reload
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    movzwl %ax, %ebx
-; CHECK-NEXT:    orl %r14d, %ebx
-; CHECK-NEXT:    shlq $32, %rbx
-; CHECK-NEXT:    orq %r12, %rbx
-; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %r15 # 8-byte Reload
-; CHECK-NEXT:    movl %r15d, %eax
-; CHECK-NEXT:    andl $-65536, %eax # imm = 0xFFFF0000
-; CHECK-NEXT:    movd %eax, %xmm1
-; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %r14 # 8-byte Reload
-; CHECK-NEXT:    movl %r14d, %eax
-; CHECK-NEXT:    andl $-65536, %eax # imm = 0xFFFF0000
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %ebp
-; CHECK-NEXT:    shll $16, %ebp
-; CHECK-NEXT:    movq %r15, %rax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm1
-; CHECK-NEXT:    movq %r14, %rax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    movzwl %ax, %r14d
-; CHECK-NEXT:    orl %ebp, %r14d
-; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm1
-; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %ebp
-; CHECK-NEXT:    shll $16, %ebp
-; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm1
-; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    movd %eax, %xmm0
-; CHECK-NEXT:    addss %xmm1, %xmm0
-; CHECK-NEXT:    callq __truncsfbf2 at PLT
-; CHECK-NEXT:    movd %xmm0, %eax
-; CHECK-NEXT:    movzwl %ax, %eax
-; CHECK-NEXT:    orl %ebp, %eax
-; CHECK-NEXT:    shlq $32, %rax
-; CHECK-NEXT:    orq %r14, %rax
-; CHECK-NEXT:    movq %rax, %xmm0
-; CHECK-NEXT:    movq %rbx, %xmm1
-; CHECK-NEXT:    punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
-; CHECK-NEXT:    addq $56, %rsp
-; CHECK-NEXT:    popq %rbx
-; CHECK-NEXT:    popq %r12
-; CHECK-NEXT:    popq %r13
-; CHECK-NEXT:    popq %r14
-; CHECK-NEXT:    popq %r15
-; CHECK-NEXT:    popq %rbp
-; CHECK-NEXT:    retq
+; SSE2-LABEL: addv:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    pushq %rbp
+; SSE2-NEXT:    pushq %r15
+; SSE2-NEXT:    pushq %r14
+; SSE2-NEXT:    pushq %r13
+; SSE2-NEXT:    pushq %r12
+; SSE2-NEXT:    pushq %rbx
+; SSE2-NEXT:    subq $56, %rsp
+; SSE2-NEXT:    movq %xmm0, %rcx
+; SSE2-NEXT:    movq %rcx, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; SSE2-NEXT:    movq %rcx, %rax
+; SSE2-NEXT:    shrq $32, %rax
+; SSE2-NEXT:    movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; SSE2-NEXT:    movq %xmm1, %rdx
+; SSE2-NEXT:    movq %rdx, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; SSE2-NEXT:    movq %rdx, %rax
+; SSE2-NEXT:    shrq $32, %rax
+; SSE2-NEXT:    movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; SSE2-NEXT:    movq %rcx, %rax
+; SSE2-NEXT:    shrq $48, %rax
+; SSE2-NEXT:    movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; SSE2-NEXT:    movq %rdx, %rax
+; SSE2-NEXT:    shrq $48, %rax
+; SSE2-NEXT:    movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3]
+; SSE2-NEXT:    movq %xmm0, %r12
+; SSE2-NEXT:    movq %r12, %rax
+; SSE2-NEXT:    shrq $32, %rax
+; SSE2-NEXT:    movq %rax, (%rsp) # 8-byte Spill
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
+; SSE2-NEXT:    movq %xmm0, %r14
+; SSE2-NEXT:    movq %r14, %rbp
+; SSE2-NEXT:    shrq $32, %rbp
+; SSE2-NEXT:    movq %r12, %r15
+; SSE2-NEXT:    shrq $48, %r15
+; SSE2-NEXT:    movq %r14, %r13
+; SSE2-NEXT:    shrq $48, %r13
+; SSE2-NEXT:    movl %r14d, %eax
+; SSE2-NEXT:    andl $-65536, %eax # imm = 0xFFFF0000
+; SSE2-NEXT:    movd %eax, %xmm1
+; SSE2-NEXT:    movl %r12d, %eax
+; SSE2-NEXT:    andl $-65536, %eax # imm = 0xFFFF0000
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %ebx
+; SSE2-NEXT:    shll $16, %ebx
+; SSE2-NEXT:    shll $16, %r14d
+; SSE2-NEXT:    movd %r14d, %xmm1
+; SSE2-NEXT:    shll $16, %r12d
+; SSE2-NEXT:    movd %r12d, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    movzwl %ax, %r12d
+; SSE2-NEXT:    orl %ebx, %r12d
+; SSE2-NEXT:    shll $16, %r13d
+; SSE2-NEXT:    movd %r13d, %xmm1
+; SSE2-NEXT:    shll $16, %r15d
+; SSE2-NEXT:    movd %r15d, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %r14d
+; SSE2-NEXT:    shll $16, %r14d
+; SSE2-NEXT:    shll $16, %ebp
+; SSE2-NEXT:    movd %ebp, %xmm1
+; SSE2-NEXT:    movq (%rsp), %rax # 8-byte Reload
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    movzwl %ax, %ebx
+; SSE2-NEXT:    orl %r14d, %ebx
+; SSE2-NEXT:    shlq $32, %rbx
+; SSE2-NEXT:    orq %r12, %rbx
+; SSE2-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %r15 # 8-byte Reload
+; SSE2-NEXT:    movl %r15d, %eax
+; SSE2-NEXT:    andl $-65536, %eax # imm = 0xFFFF0000
+; SSE2-NEXT:    movd %eax, %xmm1
+; SSE2-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %r14 # 8-byte Reload
+; SSE2-NEXT:    movl %r14d, %eax
+; SSE2-NEXT:    andl $-65536, %eax # imm = 0xFFFF0000
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %ebp
+; SSE2-NEXT:    shll $16, %ebp
+; SSE2-NEXT:    movq %r15, %rax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm1
+; SSE2-NEXT:    movq %r14, %rax
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    movzwl %ax, %r14d
+; SSE2-NEXT:    orl %ebp, %r14d
+; SSE2-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm1
+; SSE2-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %ebp
+; SSE2-NEXT:    shll $16, %ebp
+; SSE2-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm1
+; SSE2-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
+; SSE2-NEXT:    shll $16, %eax
+; SSE2-NEXT:    movd %eax, %xmm0
+; SSE2-NEXT:    addss %xmm1, %xmm0
+; SSE2-NEXT:    callq __truncsfbf2 at PLT
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    movzwl %ax, %eax
+; SSE2-NEXT:    orl %ebp, %eax
+; SSE2-NEXT:    shlq $32, %rax
+; SSE2-NEXT:    orq %r14, %rax
+; SSE2-NEXT:    movq %rax, %xmm0
+; SSE2-NEXT:    movq %rbx, %xmm1
+; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; SSE2-NEXT:    addq $56, %rsp
+; SSE2-NEXT:    popq %rbx
+; SSE2-NEXT:    popq %r12
+; SSE2-NEXT:    popq %r13
+; SSE2-NEXT:    popq %r14
+; SSE2-NEXT:    popq %r15
+; SSE2-NEXT:    popq %rbp
+; SSE2-NEXT:    retq
+;
+; BF16-LABEL: addv:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rbp
+; BF16-NEXT:    pushq %r15
+; BF16-NEXT:    pushq %r14
+; BF16-NEXT:    pushq %r13
+; BF16-NEXT:    pushq %r12
+; BF16-NEXT:    pushq %rbx
+; BF16-NEXT:    subq $40, %rsp
+; BF16-NEXT:    vmovdqa %xmm1, (%rsp) # 16-byte Spill
+; BF16-NEXT:    vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; BF16-NEXT:    vpextrw $7, %xmm1, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm2
+; BF16-NEXT:    vpextrw $7, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm2, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovss %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill
+; BF16-NEXT:    vmovdqa (%rsp), %xmm0 # 16-byte Reload
+; BF16-NEXT:    vpextrw $6, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload
+; BF16-NEXT:    vpextrw $6, %xmm1, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %ebp
+; BF16-NEXT:    vmovdqa (%rsp), %xmm0 # 16-byte Reload
+; BF16-NEXT:    vpextrw $5, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload
+; BF16-NEXT:    vpextrw $5, %xmm1, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %r14d
+; BF16-NEXT:    vmovdqa (%rsp), %xmm0 # 16-byte Reload
+; BF16-NEXT:    vpextrw $4, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload
+; BF16-NEXT:    vpextrw $4, %xmm1, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %r15d
+; BF16-NEXT:    vmovdqa (%rsp), %xmm0 # 16-byte Reload
+; BF16-NEXT:    vpextrw $3, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload
+; BF16-NEXT:    vpextrw $3, %xmm1, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %r12d
+; BF16-NEXT:    vmovdqa (%rsp), %xmm0 # 16-byte Reload
+; BF16-NEXT:    vpextrw $2, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload
+; BF16-NEXT:    vpextrw $2, %xmm1, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %r13d
+; BF16-NEXT:    vmovdqa (%rsp), %xmm0 # 16-byte Reload
+; BF16-NEXT:    vpextrw $1, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload
+; BF16-NEXT:    vpextrw $1, %xmm1, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %ebx
+; BF16-NEXT:    vmovdqa (%rsp), %xmm0 # 16-byte Reload
+; BF16-NEXT:    vmovd %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload
+; BF16-NEXT:    vmovd %xmm1, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2 at PLT
+; BF16-NEXT:    vmovd %xmm0, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vpinsrw $1, %ebx, %xmm0, %xmm0
+; BF16-NEXT:    vpinsrw $2, %r13d, %xmm0, %xmm0
+; BF16-NEXT:    vpinsrw $3, %r12d, %xmm0, %xmm0
+; BF16-NEXT:    vpinsrw $4, %r15d, %xmm0, %xmm0
+; BF16-NEXT:    vpinsrw $5, %r14d, %xmm0, %xmm0
+; BF16-NEXT:    vpinsrw $6, %ebp, %xmm0, %xmm0
+; BF16-NEXT:    vpinsrw $7, {{[-0-9]+}}(%r{{[sb]}}p), %xmm0, %xmm0 # 4-byte Folded Reload
+; BF16-NEXT:    addq $40, %rsp
+; BF16-NEXT:    popq %rbx
+; BF16-NEXT:    popq %r12
+; BF16-NEXT:    popq %r13
+; BF16-NEXT:    popq %r14
+; BF16-NEXT:    popq %r15
+; BF16-NEXT:    popq %rbp
+; BF16-NEXT:    retq
   %add = fadd <8 x bfloat> %a, %b
   ret <8 x bfloat> %add
 }

diff  --git a/llvm/test/CodeGen/X86/stack-folding-avx512bf16.ll b/llvm/test/CodeGen/X86/stack-folding-avx512bf16.ll
index 060cab8d720e6..349f70ef43654 100644
--- a/llvm/test/CodeGen/X86/stack-folding-avx512bf16.ll
+++ b/llvm/test/CodeGen/X86/stack-folding-avx512bf16.ll
@@ -9,7 +9,7 @@ target triple = "x86_64-unknown-unknown"
 ; By including a nop call with sideeffects we can force a partial register spill of the
 ; relevant registers and check that the reload is correctly folded into the instruction.
 
-define <32 x i16> @stack_fold_cvtne2ps2bf16(<16 x float> %a0, <16 x float> %a1) {
+define <32 x bfloat> @stack_fold_cvtne2ps2bf16(<16 x float> %a0, <16 x float> %a1) {
 ; CHECK-LABEL: stack_fold_cvtne2ps2bf16:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
@@ -19,12 +19,12 @@ define <32 x i16> @stack_fold_cvtne2ps2bf16(<16 x float> %a0, <16 x float> %a1)
 ; CHECK-NEXT:    vcvtne2ps2bf16 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm0 # 64-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %a0, <16 x float> %a1)
-  ret <32 x i16> %2
+  %2 = call <32 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %a0, <16 x float> %a1)
+  ret <32 x bfloat> %2
 }
-declare <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float>, <16 x float>)
+declare <32 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float>, <16 x float>)
 
-define <32 x i16> @stack_fold_cvtne2ps2bf16_mask(<16 x float> %a0, <16 x float> %a1, ptr %passthru, i32 %U) {
+define <32 x bfloat> @stack_fold_cvtne2ps2bf16_mask(<16 x float> %a0, <16 x float> %a1, ptr %passthru, i32 %U) {
 ; CHECK-LABEL: stack_fold_cvtne2ps2bf16_mask:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
@@ -37,15 +37,15 @@ define <32 x i16> @stack_fold_cvtne2ps2bf16_mask(<16 x float> %a0, <16 x float>
 ; CHECK-NEXT:    vmovaps %zmm2, %zmm0
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %a0, <16 x float> %a1)
+  %2 = call <32 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %a0, <16 x float> %a1)
   %3 = bitcast i32 %U to <32 x i1>
   ; load needed to keep the operation from being scheduled above the asm block
-  %4 = load <32 x i16>, ptr %passthru
-  %5 = select <32 x i1> %3, <32 x i16> %2, <32 x i16> %4
-  ret <32 x i16> %5
+  %4 = load <32 x bfloat>, ptr %passthru
+  %5 = select <32 x i1> %3, <32 x bfloat> %2, <32 x bfloat> %4
+  ret <32 x bfloat> %5
 }
 
-define <32 x i16> @stack_fold_cvtne2ps2bf16_maskz(<16 x float> %a0, <16 x float> %a1, i32 %U) {
+define <32 x bfloat> @stack_fold_cvtne2ps2bf16_maskz(<16 x float> %a0, <16 x float> %a1, i32 %U) {
 ; CHECK-LABEL: stack_fold_cvtne2ps2bf16_maskz:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
@@ -56,13 +56,13 @@ define <32 x i16> @stack_fold_cvtne2ps2bf16_maskz(<16 x float> %a0, <16 x float>
 ; CHECK-NEXT:    vcvtne2ps2bf16 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm0 {%k1} {z} # 64-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %a0, <16 x float> %a1)
+  %2 = call <32 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %a0, <16 x float> %a1)
   %3 = bitcast i32 %U to <32 x i1>
-  %4 = select <32 x i1> %3, <32 x i16> %2, <32 x i16> zeroinitializer
-  ret <32 x i16> %4
+  %4 = select <32 x i1> %3, <32 x bfloat> %2, <32 x bfloat> zeroinitializer
+  ret <32 x bfloat> %4
 }
 
-define <16 x i16> @stack_fold_cvtneps2bf16(<16 x float> %a0) {
+define <16 x bfloat> @stack_fold_cvtneps2bf16(<16 x float> %a0) {
 ; CHECK-LABEL: stack_fold_cvtneps2bf16:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
@@ -72,12 +72,12 @@ define <16 x i16> @stack_fold_cvtneps2bf16(<16 x float> %a0) {
 ; CHECK-NEXT:    vcvtneps2bf16 {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 64-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm1},~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %a0)
-  ret <16 x i16> %2
+  %2 = tail call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %a0)
+  ret <16 x bfloat> %2
 }
-declare <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float>)
+declare <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float>)
 
-define <16 x i16> @stack_fold_cvtneps2bf16_mask(<16 x float> %a0, ptr %passthru, i16 %U) {
+define <16 x bfloat> @stack_fold_cvtneps2bf16_mask(<16 x float> %a0, ptr %passthru, i16 %U) {
 ; CHECK-LABEL: stack_fold_cvtneps2bf16_mask:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
@@ -90,15 +90,15 @@ define <16 x i16> @stack_fold_cvtneps2bf16_mask(<16 x float> %a0, ptr %passthru,
 ; CHECK-NEXT:    vmovaps %ymm1, %ymm0
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm1},~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %a0)
+  %2 = tail call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %a0)
   %3 = bitcast i16 %U to <16 x i1>
   ; load needed to keep the operation from being scheduled above the asm block
-  %4 = load <16 x i16>, ptr %passthru
-  %5 = select <16 x i1> %3, <16 x i16> %2, <16 x i16> %4
-  ret <16 x i16> %5
+  %4 = load <16 x bfloat>, ptr %passthru
+  %5 = select <16 x i1> %3, <16 x bfloat> %2, <16 x bfloat> %4
+  ret <16 x bfloat> %5
 }
 
-define <16 x i16> @stack_fold_cvtneps2bf16_maskz(<16 x float> %a0, i16 %U) {
+define <16 x bfloat> @stack_fold_cvtneps2bf16_maskz(<16 x float> %a0, i16 %U) {
 ; CHECK-LABEL: stack_fold_cvtneps2bf16_maskz:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
@@ -109,13 +109,13 @@ define <16 x i16> @stack_fold_cvtneps2bf16_maskz(<16 x float> %a0, i16 %U) {
 ; CHECK-NEXT:    vcvtneps2bf16 {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 {%k1} {z} # 64-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm1},~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <16 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %a0)
+  %2 = tail call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %a0)
   %3 = bitcast i16 %U to <16 x i1>
-  %4 = select <16 x i1> %3, <16 x i16> %2, <16 x i16> zeroinitializer
-  ret <16 x i16> %4
+  %4 = select <16 x i1> %3, <16 x bfloat> %2, <16 x bfloat> zeroinitializer
+  ret <16 x bfloat> %4
 }
 
-define <16 x float> @stack_fold_vdpbf16ps(<16 x float> %a0, <16 x i32> %a1, <16 x i32> %a2) {
+define <16 x float> @stack_fold_vdpbf16ps(<16 x float> %a0, <32 x bfloat> %a1, <32 x bfloat> %a2) {
 ; CHECK-LABEL: stack_fold_vdpbf16ps:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
@@ -125,12 +125,12 @@ define <16 x float> @stack_fold_vdpbf16ps(<16 x float> %a0, <16 x i32> %a1, <16
 ; CHECK-NEXT:    vdpbf16ps {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm0 # 64-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %a0, <16 x i32> %a1, <16 x i32> %a2)
+  %2 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %a0, <32 x bfloat> %a1, <32 x bfloat> %a2)
   ret <16 x float> %2
 }
-declare <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float>, <16 x i32>, <16 x i32>)
+declare <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float>, <32 x bfloat>, <32 x bfloat>)
 
-define <16 x float> @stack_fold_vdpbf16ps_mask(ptr %a0, <16 x i32> %a1, <16 x i32> %a2, ptr %passthru, i16 %U) {
+define <16 x float> @stack_fold_vdpbf16ps_mask(ptr %a0, <32 x bfloat> %a1, <32 x bfloat> %a2, ptr %passthru, i16 %U) {
 ; CHECK-LABEL: stack_fold_vdpbf16ps_mask:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
@@ -145,13 +145,13 @@ define <16 x float> @stack_fold_vdpbf16ps_mask(ptr %a0, <16 x i32> %a1, <16 x i3
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
   ; load needed to keep the operation from being scheduled above the asm block
   %2 = load <16 x float>, ptr %a0
-  %3 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %2, <16 x i32> %a1, <16 x i32> %a2)
+  %3 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %2, <32 x bfloat> %a1, <32 x bfloat> %a2)
   %4 = bitcast i16 %U to <16 x i1>
   %5 = select <16 x i1> %4, <16 x float> %3, <16 x float> %2
   ret <16 x float> %5
 }
 
-define <16 x float> @stack_fold_vdpbf16ps_maskz(<16 x float> %a0, <16 x i32> %a1, <16 x i32> %a2, ptr %U) {
+define <16 x float> @stack_fold_vdpbf16ps_maskz(<16 x float> %a0, <32 x bfloat> %a1, <32 x bfloat> %a2, ptr %U) {
 ; CHECK-LABEL: stack_fold_vdpbf16ps_maskz:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill
@@ -162,7 +162,7 @@ define <16 x float> @stack_fold_vdpbf16ps_maskz(<16 x float> %a0, <16 x i32> %a1
 ; CHECK-NEXT:    vdpbf16ps {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm0 {%k1} {z} # 64-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %a0, <16 x i32> %a1, <16 x i32> %a2)
+  %2 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %a0, <32 x bfloat> %a1, <32 x bfloat> %a2)
   %3 = load i16, ptr %U
   %4 = bitcast i16 %3 to <16 x i1>
   %5 = select <16 x i1> %4, <16 x float> %2, <16 x float> zeroinitializer
@@ -171,7 +171,7 @@ define <16 x float> @stack_fold_vdpbf16ps_maskz(<16 x float> %a0, <16 x i32> %a1
 
 
 
-define <16 x i16> @stack_fold_cvtne2ps2bf16_ymm(<8 x float> %a0, <8 x float> %a1) {
+define <16 x bfloat> @stack_fold_cvtne2ps2bf16_ymm(<8 x float> %a0, <8 x float> %a1) {
 ; CHECK-LABEL: stack_fold_cvtne2ps2bf16_ymm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill
@@ -181,12 +181,12 @@ define <16 x i16> @stack_fold_cvtne2ps2bf16_ymm(<8 x float> %a0, <8 x float> %a1
 ; CHECK-NEXT:    vcvtne2ps2bf16 {{[-0-9]+}}(%r{{[sb]}}p), %ymm0, %ymm0 # 32-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = call <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %a0, <8 x float> %a1)
-  ret <16 x i16> %2
+  %2 = call <16 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %a0, <8 x float> %a1)
+  ret <16 x bfloat> %2
 }
-declare <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float>, <8 x float>)
+declare <16 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float>, <8 x float>)
 
-define <16 x i16> @stack_fold_cvtne2ps2bf16_mask_ymm(<8 x float> %a0, <8 x float> %a1, ptr %passthru, i16 %U) {
+define <16 x bfloat> @stack_fold_cvtne2ps2bf16_mask_ymm(<8 x float> %a0, <8 x float> %a1, ptr %passthru, i16 %U) {
 ; CHECK-LABEL: stack_fold_cvtne2ps2bf16_mask_ymm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill
@@ -199,15 +199,15 @@ define <16 x i16> @stack_fold_cvtne2ps2bf16_mask_ymm(<8 x float> %a0, <8 x float
 ; CHECK-NEXT:    vmovaps %ymm2, %ymm0
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = call <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %a0, <8 x float> %a1)
+  %2 = call <16 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %a0, <8 x float> %a1)
   %3 = bitcast i16 %U to <16 x i1>
   ; load needed to keep the operation from being scheduled above the asm block
-  %4 = load <16 x i16>, ptr %passthru
-  %5 = select <16 x i1> %3, <16 x i16> %2, <16 x i16> %4
-  ret <16 x i16> %5
+  %4 = load <16 x bfloat>, ptr %passthru
+  %5 = select <16 x i1> %3, <16 x bfloat> %2, <16 x bfloat> %4
+  ret <16 x bfloat> %5
 }
 
-define <16 x i16> @stack_fold_cvtne2ps2bf16_maskz_ymm(<8 x float> %a0, <8 x float> %a1, i16 %U) {
+define <16 x bfloat> @stack_fold_cvtne2ps2bf16_maskz_ymm(<8 x float> %a0, <8 x float> %a1, i16 %U) {
 ; CHECK-LABEL: stack_fold_cvtne2ps2bf16_maskz_ymm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill
@@ -218,13 +218,13 @@ define <16 x i16> @stack_fold_cvtne2ps2bf16_maskz_ymm(<8 x float> %a0, <8 x floa
 ; CHECK-NEXT:    vcvtne2ps2bf16 {{[-0-9]+}}(%r{{[sb]}}p), %ymm0, %ymm0 {%k1} {z} # 32-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = call <16 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %a0, <8 x float> %a1)
+  %2 = call <16 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %a0, <8 x float> %a1)
   %3 = bitcast i16 %U to <16 x i1>
-  %4 = select <16 x i1> %3, <16 x i16> %2, <16 x i16> zeroinitializer
-  ret <16 x i16> %4
+  %4 = select <16 x i1> %3, <16 x bfloat> %2, <16 x bfloat> zeroinitializer
+  ret <16 x bfloat> %4
 }
 
-define <8 x i16> @stack_fold_cvtneps2bf16_ymm(<8 x float> %a0) {
+define <8 x bfloat> @stack_fold_cvtneps2bf16_ymm(<8 x float> %a0) {
 ; CHECK-LABEL: stack_fold_cvtneps2bf16_ymm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill
@@ -235,12 +235,12 @@ define <8 x i16> @stack_fold_cvtneps2bf16_ymm(<8 x float> %a0) {
 ; CHECK-NEXT:    vzeroupper
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm1},~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %a0)
-  ret <8 x i16> %2
+  %2 = tail call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %a0)
+  ret <8 x bfloat> %2
 }
-declare <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float>)
+declare <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float>)
 
-define <8 x i16> @stack_fold_cvtneps2bf16_mask_ymm(<8 x float> %a0, ptr %passthru, i8 %U) {
+define <8 x bfloat> @stack_fold_cvtneps2bf16_mask_ymm(<8 x float> %a0, ptr %passthru, i8 %U) {
 ; CHECK-LABEL: stack_fold_cvtneps2bf16_mask_ymm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill
@@ -254,15 +254,15 @@ define <8 x i16> @stack_fold_cvtneps2bf16_mask_ymm(<8 x float> %a0, ptr %passthr
 ; CHECK-NEXT:    vzeroupper
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm1},~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %a0)
+  %2 = tail call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %a0)
   %3 = bitcast i8 %U to <8 x i1>
   ; load needed to keep the operation from being scheduled above the asm block
-  %4 = load <8 x i16>, ptr %passthru
-  %5 = select <8 x i1> %3, <8 x i16> %2, <8 x i16> %4
-  ret <8 x i16> %5
+  %4 = load <8 x bfloat>, ptr %passthru
+  %5 = select <8 x i1> %3, <8 x bfloat> %2, <8 x bfloat> %4
+  ret <8 x bfloat> %5
 }
 
-define <8 x i16> @stack_fold_cvtneps2bf16_maskz_ymm(<8 x float> %a0, i8 %U) {
+define <8 x bfloat> @stack_fold_cvtneps2bf16_maskz_ymm(<8 x float> %a0, i8 %U) {
 ; CHECK-LABEL: stack_fold_cvtneps2bf16_maskz_ymm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill
@@ -274,13 +274,13 @@ define <8 x i16> @stack_fold_cvtneps2bf16_maskz_ymm(<8 x float> %a0, i8 %U) {
 ; CHECK-NEXT:    vzeroupper
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm1},~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <8 x i16> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %a0)
+  %2 = tail call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %a0)
   %3 = bitcast i8 %U to <8 x i1>
-  %4 = select <8 x i1> %3, <8 x i16> %2, <8 x i16> zeroinitializer
-  ret <8 x i16> %4
+  %4 = select <8 x i1> %3, <8 x bfloat> %2, <8 x bfloat> zeroinitializer
+  ret <8 x bfloat> %4
 }
 
-define <8 x float> @stack_fold_vdpbf16ps_ymm(<8 x float> %a0, <8 x i32> %a1, <8 x i32> %a2) {
+define <8 x float> @stack_fold_vdpbf16ps_ymm(<8 x float> %a0, <16 x bfloat> %a1, <16 x bfloat> %a2) {
 ; CHECK-LABEL: stack_fold_vdpbf16ps_ymm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill
@@ -290,12 +290,12 @@ define <8 x float> @stack_fold_vdpbf16ps_ymm(<8 x float> %a0, <8 x i32> %a1, <8
 ; CHECK-NEXT:    vdpbf16ps {{[-0-9]+}}(%r{{[sb]}}p), %ymm1, %ymm0 # 32-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %a0, <8 x i32> %a1, <8 x i32> %a2)
+  %2 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %a0, <16 x bfloat> %a1, <16 x bfloat> %a2)
   ret <8 x float> %2
 }
-declare <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float>, <8 x i32>, <8 x i32>)
+declare <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float>, <16 x bfloat>, <16 x bfloat>)
 
-define <8 x float> @stack_fold_vdpbf16ps_mask_ymm(ptr %a0, <8 x i32> %a1, <8 x i32> %a2, ptr %passthru, i8 %U) {
+define <8 x float> @stack_fold_vdpbf16ps_mask_ymm(ptr %a0, <16 x bfloat> %a1, <16 x bfloat> %a2, ptr %passthru, i8 %U) {
 ; CHECK-LABEL: stack_fold_vdpbf16ps_mask_ymm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill
@@ -310,13 +310,13 @@ define <8 x float> @stack_fold_vdpbf16ps_mask_ymm(ptr %a0, <8 x i32> %a1, <8 x i
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
   ; load needed to keep the operation from being scheduled above the asm block
   %2 = load <8 x float>, ptr %a0
-  %3 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %2, <8 x i32> %a1, <8 x i32> %a2)
+  %3 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %2, <16 x bfloat> %a1, <16 x bfloat> %a2)
   %4 = bitcast i8 %U to <8 x i1>
   %5 = select <8 x i1> %4, <8 x float> %3, <8 x float> %2
   ret <8 x float> %5
 }
 
-define <8 x float> @stack_fold_vdpbf16ps_maskz_ymm(<8 x float> %a0, <8 x i32> %a1, <8 x i32> %a2, ptr %U) {
+define <8 x float> @stack_fold_vdpbf16ps_maskz_ymm(<8 x float> %a0, <16 x bfloat> %a1, <16 x bfloat> %a2, ptr %U) {
 ; CHECK-LABEL: stack_fold_vdpbf16ps_maskz_ymm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovups %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill
@@ -328,7 +328,7 @@ define <8 x float> @stack_fold_vdpbf16ps_maskz_ymm(<8 x float> %a0, <8 x i32> %a
 ; CHECK-NEXT:    vdpbf16ps {{[-0-9]+}}(%r{{[sb]}}p), %ymm1, %ymm0 {%k1} {z} # 32-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %a0, <8 x i32> %a1, <8 x i32> %a2)
+  %2 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %a0, <16 x bfloat> %a1, <16 x bfloat> %a2)
   %3 = load i8, ptr %U
   %4 = bitcast i8 %3 to <8 x i1>
   %5 = select <8 x i1> %4, <8 x float> %2, <8 x float> zeroinitializer
@@ -338,7 +338,7 @@ define <8 x float> @stack_fold_vdpbf16ps_maskz_ymm(<8 x float> %a0, <8 x i32> %a
 
 
 
-define <8 x i16> @stack_fold_cvtne2ps2bf16_xmm(<4 x float> %a0, <4 x float> %a1) {
+define <8 x bfloat> @stack_fold_cvtne2ps2bf16_xmm(<4 x float> %a0, <4 x float> %a1) {
 ; CHECK-LABEL: stack_fold_cvtne2ps2bf16_xmm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovaps %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
@@ -348,12 +348,12 @@ define <8 x i16> @stack_fold_cvtne2ps2bf16_xmm(<4 x float> %a0, <4 x float> %a1)
 ; CHECK-NEXT:    vcvtne2ps2bf16 {{[-0-9]+}}(%r{{[sb]}}p), %xmm0, %xmm0 # 16-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = call <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %a0, <4 x float> %a1)
-  ret <8 x i16> %2
+  %2 = call <8 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %a0, <4 x float> %a1)
+  ret <8 x bfloat> %2
 }
-declare <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float>, <4 x float>)
+declare <8 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float>, <4 x float>)
 
-define <8 x i16> @stack_fold_cvtne2ps2bf16_mask_xmm(<4 x float> %a0, <4 x float> %a1, ptr %passthru, i8 %U) {
+define <8 x bfloat> @stack_fold_cvtne2ps2bf16_mask_xmm(<4 x float> %a0, <4 x float> %a1, ptr %passthru, i8 %U) {
 ; CHECK-LABEL: stack_fold_cvtne2ps2bf16_mask_xmm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovaps %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
@@ -366,15 +366,15 @@ define <8 x i16> @stack_fold_cvtne2ps2bf16_mask_xmm(<4 x float> %a0, <4 x float>
 ; CHECK-NEXT:    vmovaps %xmm2, %xmm0
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = call <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %a0, <4 x float> %a1)
+  %2 = call <8 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %a0, <4 x float> %a1)
   %3 = bitcast i8 %U to <8 x i1>
   ; load needed to keep the operation from being scheduled above the asm block
-  %4 = load <8 x i16>, ptr %passthru
-  %5 = select <8 x i1> %3, <8 x i16> %2, <8 x i16> %4
-  ret <8 x i16> %5
+  %4 = load <8 x bfloat>, ptr %passthru
+  %5 = select <8 x i1> %3, <8 x bfloat> %2, <8 x bfloat> %4
+  ret <8 x bfloat> %5
 }
 
-define <8 x i16> @stack_fold_cvtne2ps2bf16_maskz_xmm(<4 x float> %a0, <4 x float> %a1, i8 %U) {
+define <8 x bfloat> @stack_fold_cvtne2ps2bf16_maskz_xmm(<4 x float> %a0, <4 x float> %a1, i8 %U) {
 ; CHECK-LABEL: stack_fold_cvtne2ps2bf16_maskz_xmm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovaps %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
@@ -385,13 +385,13 @@ define <8 x i16> @stack_fold_cvtne2ps2bf16_maskz_xmm(<4 x float> %a0, <4 x float
 ; CHECK-NEXT:    vcvtne2ps2bf16 {{[-0-9]+}}(%r{{[sb]}}p), %xmm0, %xmm0 {%k1} {z} # 16-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = call <8 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %a0, <4 x float> %a1)
+  %2 = call <8 x bfloat> @llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %a0, <4 x float> %a1)
   %3 = bitcast i8 %U to <8 x i1>
-  %4 = select <8 x i1> %3, <8 x i16> %2, <8 x i16> zeroinitializer
-  ret <8 x i16> %4
+  %4 = select <8 x i1> %3, <8 x bfloat> %2, <8 x bfloat> zeroinitializer
+  ret <8 x bfloat> %4
 }
 
-define <8 x i16> @stack_fold_cvtneps2bf16_xmm(<4 x float> %a0) {
+define <8 x bfloat> @stack_fold_cvtneps2bf16_xmm(<4 x float> %a0) {
 ; CHECK-LABEL: stack_fold_cvtneps2bf16_xmm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovaps %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
@@ -401,12 +401,12 @@ define <8 x i16> @stack_fold_cvtneps2bf16_xmm(<4 x float> %a0) {
 ; CHECK-NEXT:    vcvtneps2bf16x {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm1},~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %a0, <8 x i16> undef, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
-  ret <8 x i16> %2
+  %2 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %a0, <8 x bfloat> undef, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
+  ret <8 x bfloat> %2
 }
-declare <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float>, <8 x i16>, <4 x i1>)
+declare <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float>, <8 x bfloat>, <4 x i1>)
 
-define <8 x i16> @stack_fold_cvtneps2bf16_mask_xmm(<4 x float> %a0, ptr %passthru, i8 %U) {
+define <8 x bfloat> @stack_fold_cvtneps2bf16_mask_xmm(<4 x float> %a0, ptr %passthru, i8 %U) {
 ; CHECK-LABEL: stack_fold_cvtneps2bf16_mask_xmm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovaps %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
@@ -419,14 +419,14 @@ define <8 x i16> @stack_fold_cvtneps2bf16_mask_xmm(<4 x float> %a0, ptr %passthr
 ; CHECK-NEXT:    vmovaps %xmm1, %xmm0
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm1},~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = load <8 x i16>, ptr %passthru
+  %2 = load <8 x bfloat>, ptr %passthru
   %3 = bitcast i8 %U to <8 x i1>
   %4 = shufflevector <8 x i1> %3, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-  %5 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %a0, <8 x i16> %2, <4 x i1> %4)
-  ret <8 x i16> %5
+  %5 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %a0, <8 x bfloat> %2, <4 x i1> %4)
+  ret <8 x bfloat> %5
 }
 
-define <8 x i16> @stack_fold_cvtneps2bf16_maskz_xmm(<4 x float> %a0, i8 %U) {
+define <8 x bfloat> @stack_fold_cvtneps2bf16_maskz_xmm(<4 x float> %a0, i8 %U) {
 ; CHECK-LABEL: stack_fold_cvtneps2bf16_maskz_xmm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovaps %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
@@ -439,11 +439,11 @@ define <8 x i16> @stack_fold_cvtneps2bf16_maskz_xmm(<4 x float> %a0, i8 %U) {
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm1},~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
   %2 = bitcast i8 %U to <8 x i1>
   %3 = shufflevector <8 x i1> %2, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-  %4 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %a0, <8 x i16> zeroinitializer, <4 x i1> %3)
-  ret <8 x i16> %4
+  %4 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %a0, <8 x bfloat> zeroinitializer, <4 x i1> %3)
+  ret <8 x bfloat> %4
 }
 
-define <4 x float> @stack_fold_vdpbf16ps_xmm(<4 x float> %a0, <4 x i32> %a1, <4 x i32> %a2) {
+define <4 x float> @stack_fold_vdpbf16ps_xmm(<4 x float> %a0, <8 x bfloat> %a1, <8 x bfloat> %a2) {
 ; CHECK-LABEL: stack_fold_vdpbf16ps_xmm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovaps %xmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
@@ -453,12 +453,12 @@ define <4 x float> @stack_fold_vdpbf16ps_xmm(<4 x float> %a0, <4 x i32> %a1, <4
 ; CHECK-NEXT:    vdpbf16ps {{[-0-9]+}}(%r{{[sb]}}p), %xmm1, %xmm0 # 16-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %a0, <4 x i32> %a1, <4 x i32> %a2)
+  %2 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %a0, <8 x bfloat> %a1, <8 x bfloat> %a2)
   ret <4 x float> %2
 }
-declare <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float>, <4 x i32>, <4 x i32>)
+declare <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float>, <8 x bfloat>, <8 x bfloat>)
 
-define <4 x float> @stack_fold_vdpbf16ps_mask_xmm(ptr %a0, <4 x i32> %a1, <4 x i32> %a2, ptr %passthru, i8 %U) {
+define <4 x float> @stack_fold_vdpbf16ps_mask_xmm(ptr %a0, <8 x bfloat> %a1, <8 x bfloat> %a2, ptr %passthru, i8 %U) {
 ; CHECK-LABEL: stack_fold_vdpbf16ps_mask_xmm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovaps %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
@@ -473,14 +473,14 @@ define <4 x float> @stack_fold_vdpbf16ps_mask_xmm(ptr %a0, <4 x i32> %a1, <4 x i
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm2},~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
   ; load needed to keep the operation from being scheduled above the asm block
   %2 = load <4 x float>, ptr %a0
-  %3 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %2, <4 x i32> %a1, <4 x i32> %a2)
+  %3 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %2, <8 x bfloat> %a1, <8 x bfloat> %a2)
   %4 = bitcast i8 %U to <8 x i1>
   %5 = shufflevector <8 x i1> %4, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
   %6 = select <4 x i1> %5, <4 x float> %3, <4 x float> %2
   ret <4 x float> %6
 }
 
-define <4 x float> @stack_fold_vdpbf16ps_maskz_xmm(<4 x float> %a0, <4 x i32> %a1, <4 x i32> %a2, ptr %U) {
+define <4 x float> @stack_fold_vdpbf16ps_maskz_xmm(<4 x float> %a0, <8 x bfloat> %a1, <8 x bfloat> %a2, ptr %U) {
 ; CHECK-LABEL: stack_fold_vdpbf16ps_maskz_xmm:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vmovaps %xmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
@@ -492,7 +492,7 @@ define <4 x float> @stack_fold_vdpbf16ps_maskz_xmm(<4 x float> %a0, <4 x i32> %a
 ; CHECK-NEXT:    vdpbf16ps {{[-0-9]+}}(%r{{[sb]}}p), %xmm1, %xmm0 {%k1} {z} # 16-byte Folded Reload
 ; CHECK-NEXT:    retq
   %1 = tail call <2 x i64> asm sideeffect "nop", "=x,~{xmm3},~{xmm4},~{xmm5},~{xmm6},~{xmm7},~{xmm8},~{xmm9},~{xmm10},~{xmm11},~{xmm12},~{xmm13},~{xmm14},~{xmm15},~{xmm16},~{xmm17},~{xmm18},~{xmm19},~{xmm20},~{xmm21},~{xmm22},~{xmm23},~{xmm24},~{xmm25},~{xmm26},~{xmm27},~{xmm28},~{xmm29},~{xmm30},~{xmm31},~{flags}"()
-  %2 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %a0, <4 x i32> %a1, <4 x i32> %a2)
+  %2 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %a0, <8 x bfloat> %a1, <8 x bfloat> %a2)
   %3 = load i8, ptr %U
   %4 = bitcast i8 %3 to <8 x i1>
   %5 = shufflevector <8 x i1> %4, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>


        


More information about the llvm-commits mailing list