[llvm] [X86][BF16] Improve vectorization of BF16 (PR #88486)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 12 02:23:23 PDT 2024


https://github.com/phoebewang updated https://github.com/llvm/llvm-project/pull/88486

>From 587de34d61b6a576dfa6b20b2810c2ce1a35a666 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Thu, 11 Apr 2024 21:20:48 +0800
Subject: [PATCH] [X86][BF16] Improve vectorization of BF16

1. Move expansion to combineFP_EXTEND to help with small vectors;
2. Combine FP_ROUND to reduce fptrunc then fpextend after promotion;
---
 llvm/lib/Target/X86/X86ISelLowering.cpp      |  53 +++--
 llvm/test/CodeGen/X86/bfloat.ll              | 234 +++----------------
 llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll |   8 +-
 3 files changed, 66 insertions(+), 229 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b7cb4b7dafeb69..a8eb445f1e75e5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -21433,25 +21433,9 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
     return Res;
   }
 
-  if (!SVT.isVector())
+  if (!SVT.isVector() || SVT.getVectorElementType() == MVT::bf16)
     return Op;
 
-  if (SVT.getVectorElementType() == MVT::bf16) {
-    // FIXME: Do we need to support strict FP?
-    assert(!IsStrict && "Strict FP doesn't support BF16");
-    if (VT.getVectorElementType() == MVT::f64) {
-      MVT TmpVT = VT.changeVectorElementType(MVT::f32);
-      return DAG.getNode(ISD::FP_EXTEND, DL, VT,
-                         DAG.getNode(ISD::FP_EXTEND, DL, TmpVT, In));
-    }
-    assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
-    MVT NVT = SVT.changeVectorElementType(MVT::i32);
-    In = DAG.getBitcast(SVT.changeTypeToInteger(), In);
-    In = DAG.getNode(ISD::ZERO_EXTEND, DL, NVT, In);
-    In = DAG.getNode(ISD::SHL, DL, NVT, In, DAG.getConstant(16, DL, NVT));
-    return DAG.getBitcast(VT, In);
-  }
-
   if (SVT.getVectorElementType() == MVT::f16) {
     if (Subtarget.hasFP16() && isTypeLegal(SVT))
       return Op;
@@ -56517,17 +56501,40 @@ static SDValue combineFP16_TO_FP(SDNode *N, SelectionDAG &DAG,
 
 static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
                                 const X86Subtarget &Subtarget) {
+  EVT VT = N->getValueType(0);
+  bool IsStrict = N->isStrictFPOpcode();
+  SDValue Src = N->getOperand(IsStrict ? 1 : 0);
+  EVT SrcVT = Src.getValueType();
+
+  SDLoc dl(N);
+  if (SrcVT.getScalarType() == MVT::bf16) {
+    assert(!IsStrict && "Strict FP doesn't support BF16");
+    if (Src.getOpcode() == ISD::FP_ROUND &&
+        Src.getOperand(0).getValueType() == VT)
+      return Src.getOperand(0);
+
+    if (!SrcVT.isVector())
+      return SDValue();
+
+    if (VT.getVectorElementType() == MVT::f64) {
+      MVT TmpVT = VT.getSimpleVT().changeVectorElementType(MVT::f32);
+      return DAG.getNode(ISD::FP_EXTEND, dl, VT,
+                         DAG.getNode(ISD::FP_EXTEND, dl, TmpVT, Src));
+    }
+    assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
+    MVT NVT = SrcVT.getSimpleVT().changeVectorElementType(MVT::i32);
+    Src = DAG.getBitcast(SrcVT.changeTypeToInteger(), Src);
+    Src = DAG.getNode(ISD::ZERO_EXTEND, dl, NVT, Src);
+    Src = DAG.getNode(ISD::SHL, dl, NVT, Src, DAG.getConstant(16, dl, NVT));
+    return DAG.getBitcast(VT, Src);
+  }
+
   if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
     return SDValue();
 
   if (Subtarget.hasFP16())
     return SDValue();
 
-  bool IsStrict = N->isStrictFPOpcode();
-  EVT VT = N->getValueType(0);
-  SDValue Src = N->getOperand(IsStrict ? 1 : 0);
-  EVT SrcVT = Src.getValueType();
-
   if (!SrcVT.isVector() || SrcVT.getVectorElementType() != MVT::f16)
     return SDValue();
 
@@ -56539,8 +56546,6 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
   if (NumElts == 1 || !isPowerOf2_32(NumElts))
     return SDValue();
 
-  SDLoc dl(N);
-
   // Convert the input to vXi16.
   EVT IntVT = SrcVT.changeVectorElementTypeToInteger();
   Src = DAG.getBitcast(IntVT, Src);
diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
index 8a2109a1c78df9..39d8e2d50c91ea 100644
--- a/llvm/test/CodeGen/X86/bfloat.ll
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -1629,22 +1629,8 @@ define <4 x float> @pr64460_1(<4 x bfloat> %a) {
 ;
 ; SSE2-LABEL: pr64460_1:
 ; SSE2:       # %bb.0:
-; SSE2-NEXT:    pextrw $1, %xmm0, %eax
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm2
-; SSE2-NEXT:    movd %xmm0, %eax
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm1
-; SSE2-NEXT:    pextrw $3, %xmm0, %eax
-; SSE2-NEXT:    shufps {{.*#+}} xmm0 = xmm0[1,1,1,1]
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1]
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm2
-; SSE2-NEXT:    movd %xmm0, %eax
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm0
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE2-NEXT:    pxor %xmm1, %xmm1
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
 ; SSE2-NEXT:    movdqa %xmm1, %xmm0
 ; SSE2-NEXT:    retq
 ;
@@ -1666,41 +1652,11 @@ define <8 x float> @pr64460_2(<8 x bfloat> %a) {
 ;
 ; SSE2-LABEL: pr64460_2:
 ; SSE2:       # %bb.0:
-; SSE2-NEXT:    movq %xmm0, %rdx
-; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
-; SSE2-NEXT:    movq %xmm0, %rcx
-; SSE2-NEXT:    movq %rcx, %rax
-; SSE2-NEXT:    shrq $32, %rax
-; SSE2-NEXT:    movq %rdx, %rsi
-; SSE2-NEXT:    shrq $32, %rsi
-; SSE2-NEXT:    movl %edx, %edi
-; SSE2-NEXT:    andl $-65536, %edi # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %edi, %xmm1
-; SSE2-NEXT:    movl %edx, %edi
-; SSE2-NEXT:    shll $16, %edi
-; SSE2-NEXT:    movd %edi, %xmm0
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
-; SSE2-NEXT:    shrq $48, %rdx
-; SSE2-NEXT:    shll $16, %edx
-; SSE2-NEXT:    movd %edx, %xmm1
-; SSE2-NEXT:    shll $16, %esi
-; SSE2-NEXT:    movd %esi, %xmm2
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
-; SSE2-NEXT:    movl %ecx, %edx
-; SSE2-NEXT:    andl $-65536, %edx # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %edx, %xmm2
-; SSE2-NEXT:    movl %ecx, %edx
-; SSE2-NEXT:    shll $16, %edx
-; SSE2-NEXT:    movd %edx, %xmm1
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1]
-; SSE2-NEXT:    shrq $48, %rcx
-; SSE2-NEXT:    shll $16, %ecx
-; SSE2-NEXT:    movd %ecx, %xmm2
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm3
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm3 = xmm3[0],xmm2[0],xmm3[1],xmm2[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm3[0]
+; SSE2-NEXT:    pxor %xmm1, %xmm1
+; SSE2-NEXT:    pxor %xmm2, %xmm2
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3]
+; SSE2-NEXT:    punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7]
+; SSE2-NEXT:    movdqa %xmm2, %xmm0
 ; SSE2-NEXT:    retq
 ;
 ; AVX-LABEL: pr64460_2:
@@ -1721,76 +1677,16 @@ define <16 x float> @pr64460_3(<16 x bfloat> %a) {
 ;
 ; SSE2-LABEL: pr64460_3:
 ; SSE2:       # %bb.0:
-; SSE2-NEXT:    movq %xmm1, %rdi
-; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm1 = xmm1[1,1]
-; SSE2-NEXT:    movq %xmm1, %rcx
-; SSE2-NEXT:    movq %rcx, %rax
-; SSE2-NEXT:    shrq $32, %rax
-; SSE2-NEXT:    movq %xmm0, %r9
-; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
-; SSE2-NEXT:    movq %xmm0, %rsi
-; SSE2-NEXT:    movq %rsi, %rdx
-; SSE2-NEXT:    shrq $32, %rdx
-; SSE2-NEXT:    movq %rdi, %r8
-; SSE2-NEXT:    shrq $32, %r8
-; SSE2-NEXT:    movq %r9, %r10
-; SSE2-NEXT:    shrq $32, %r10
-; SSE2-NEXT:    movl %r9d, %r11d
-; SSE2-NEXT:    andl $-65536, %r11d # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %r11d, %xmm1
-; SSE2-NEXT:    movl %r9d, %r11d
-; SSE2-NEXT:    shll $16, %r11d
-; SSE2-NEXT:    movd %r11d, %xmm0
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
-; SSE2-NEXT:    shrq $48, %r9
-; SSE2-NEXT:    shll $16, %r9d
-; SSE2-NEXT:    movd %r9d, %xmm1
-; SSE2-NEXT:    shll $16, %r10d
-; SSE2-NEXT:    movd %r10d, %xmm2
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
-; SSE2-NEXT:    movl %edi, %r9d
-; SSE2-NEXT:    andl $-65536, %r9d # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %r9d, %xmm1
-; SSE2-NEXT:    movl %edi, %r9d
-; SSE2-NEXT:    shll $16, %r9d
-; SSE2-NEXT:    movd %r9d, %xmm2
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
-; SSE2-NEXT:    shrq $48, %rdi
-; SSE2-NEXT:    shll $16, %edi
-; SSE2-NEXT:    movd %edi, %xmm1
-; SSE2-NEXT:    shll $16, %r8d
-; SSE2-NEXT:    movd %r8d, %xmm3
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm3 = xmm3[0],xmm1[0],xmm3[1],xmm1[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0]
-; SSE2-NEXT:    movl %esi, %edi
-; SSE2-NEXT:    andl $-65536, %edi # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %edi, %xmm3
-; SSE2-NEXT:    movl %esi, %edi
-; SSE2-NEXT:    shll $16, %edi
-; SSE2-NEXT:    movd %edi, %xmm1
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm1 = xmm1[0],xmm3[0],xmm1[1],xmm3[1]
-; SSE2-NEXT:    shrq $48, %rsi
-; SSE2-NEXT:    shll $16, %esi
-; SSE2-NEXT:    movd %esi, %xmm3
-; SSE2-NEXT:    shll $16, %edx
-; SSE2-NEXT:    movd %edx, %xmm4
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm4 = xmm4[0],xmm3[0],xmm4[1],xmm3[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm4[0]
-; SSE2-NEXT:    movl %ecx, %edx
-; SSE2-NEXT:    andl $-65536, %edx # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %edx, %xmm4
-; SSE2-NEXT:    movl %ecx, %edx
-; SSE2-NEXT:    shll $16, %edx
-; SSE2-NEXT:    movd %edx, %xmm3
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm3 = xmm3[0],xmm4[0],xmm3[1],xmm4[1]
-; SSE2-NEXT:    shrq $48, %rcx
-; SSE2-NEXT:    shll $16, %ecx
-; SSE2-NEXT:    movd %ecx, %xmm4
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm5
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm5 = xmm5[0],xmm4[0],xmm5[1],xmm4[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm5[0]
+; SSE2-NEXT:    pxor %xmm3, %xmm3
+; SSE2-NEXT:    pxor %xmm5, %xmm5
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm5 = xmm5[0],xmm0[0],xmm5[1],xmm0[1],xmm5[2],xmm0[2],xmm5[3],xmm0[3]
+; SSE2-NEXT:    pxor %xmm4, %xmm4
+; SSE2-NEXT:    punpckhwd {{.*#+}} xmm4 = xmm4[4],xmm0[4],xmm4[5],xmm0[5],xmm4[6],xmm0[6],xmm4[7],xmm0[7]
+; SSE2-NEXT:    pxor %xmm2, %xmm2
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3]
+; SSE2-NEXT:    punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm1[4],xmm3[5],xmm1[5],xmm3[6],xmm1[6],xmm3[7],xmm1[7]
+; SSE2-NEXT:    movdqa %xmm5, %xmm0
+; SSE2-NEXT:    movdqa %xmm4, %xmm1
 ; SSE2-NEXT:    retq
 ;
 ; F16-LABEL: pr64460_3:
@@ -1822,47 +1718,17 @@ define <8 x double> @pr64460_4(<8 x bfloat> %a) {
 ;
 ; SSE2-LABEL: pr64460_4:
 ; SSE2:       # %bb.0:
-; SSE2-NEXT:    movq %xmm0, %rsi
-; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
-; SSE2-NEXT:    movq %xmm0, %rdx
-; SSE2-NEXT:    movq %rdx, %rax
-; SSE2-NEXT:    shrq $32, %rax
-; SSE2-NEXT:    movq %rdx, %rcx
-; SSE2-NEXT:    shrq $48, %rcx
-; SSE2-NEXT:    movq %rsi, %rdi
-; SSE2-NEXT:    shrq $32, %rdi
-; SSE2-NEXT:    movq %rsi, %r8
-; SSE2-NEXT:    shrq $48, %r8
-; SSE2-NEXT:    movl %esi, %r9d
-; SSE2-NEXT:    andl $-65536, %r9d # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %r9d, %xmm0
-; SSE2-NEXT:    cvtss2sd %xmm0, %xmm1
-; SSE2-NEXT:    shll $16, %esi
-; SSE2-NEXT:    movd %esi, %xmm0
-; SSE2-NEXT:    cvtss2sd %xmm0, %xmm0
-; SSE2-NEXT:    movlhps {{.*#+}} xmm0 = xmm0[0],xmm1[0]
-; SSE2-NEXT:    shll $16, %r8d
-; SSE2-NEXT:    movd %r8d, %xmm1
-; SSE2-NEXT:    cvtss2sd %xmm1, %xmm2
-; SSE2-NEXT:    shll $16, %edi
-; SSE2-NEXT:    movd %edi, %xmm1
-; SSE2-NEXT:    cvtss2sd %xmm1, %xmm1
-; SSE2-NEXT:    movlhps {{.*#+}} xmm1 = xmm1[0],xmm2[0]
-; SSE2-NEXT:    movl %edx, %esi
-; SSE2-NEXT:    andl $-65536, %esi # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %esi, %xmm2
-; SSE2-NEXT:    cvtss2sd %xmm2, %xmm3
-; SSE2-NEXT:    shll $16, %edx
-; SSE2-NEXT:    movd %edx, %xmm2
-; SSE2-NEXT:    cvtss2sd %xmm2, %xmm2
-; SSE2-NEXT:    movlhps {{.*#+}} xmm2 = xmm2[0],xmm3[0]
-; SSE2-NEXT:    shll $16, %ecx
-; SSE2-NEXT:    movd %ecx, %xmm3
-; SSE2-NEXT:    cvtss2sd %xmm3, %xmm4
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm3
-; SSE2-NEXT:    cvtss2sd %xmm3, %xmm3
-; SSE2-NEXT:    movlhps {{.*#+}} xmm3 = xmm3[0],xmm4[0]
+; SSE2-NEXT:    pxor %xmm3, %xmm3
+; SSE2-NEXT:    pxor %xmm1, %xmm1
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
+; SSE2-NEXT:    cvtps2pd %xmm1, %xmm4
+; SSE2-NEXT:    punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm0[4],xmm3[5],xmm0[5],xmm3[6],xmm0[6],xmm3[7],xmm0[7]
+; SSE2-NEXT:    cvtps2pd %xmm3, %xmm2
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
+; SSE2-NEXT:    cvtps2pd %xmm0, %xmm1
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm3[2,3,2,3]
+; SSE2-NEXT:    cvtps2pd %xmm0, %xmm3
+; SSE2-NEXT:    movaps %xmm4, %xmm0
 ; SSE2-NEXT:    retq
 ;
 ; F16-LABEL: pr64460_4:
@@ -1874,45 +1740,11 @@ define <8 x double> @pr64460_4(<8 x bfloat> %a) {
 ;
 ; AVXNC-LABEL: pr64460_4:
 ; AVXNC:       # %bb.0:
-; AVXNC-NEXT:    vpextrw $3, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm1
-; AVXNC-NEXT:    vcvtss2sd %xmm1, %xmm1, %xmm1
-; AVXNC-NEXT:    vpextrw $2, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm2
-; AVXNC-NEXT:    vcvtss2sd %xmm2, %xmm2, %xmm2
-; AVXNC-NEXT:    vmovlhps {{.*#+}} xmm1 = xmm2[0],xmm1[0]
-; AVXNC-NEXT:    vpextrw $1, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm2
-; AVXNC-NEXT:    vcvtss2sd %xmm2, %xmm2, %xmm2
-; AVXNC-NEXT:    vmovd %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm3
-; AVXNC-NEXT:    vcvtss2sd %xmm3, %xmm3, %xmm3
-; AVXNC-NEXT:    vmovlhps {{.*#+}} xmm2 = xmm3[0],xmm2[0]
-; AVXNC-NEXT:    vinsertf128 $1, %xmm1, %ymm2, %ymm2
-; AVXNC-NEXT:    vpextrw $7, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm1
-; AVXNC-NEXT:    vcvtss2sd %xmm1, %xmm1, %xmm1
-; AVXNC-NEXT:    vpextrw $6, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm3
-; AVXNC-NEXT:    vcvtss2sd %xmm3, %xmm3, %xmm3
-; AVXNC-NEXT:    vmovlhps {{.*#+}} xmm1 = xmm3[0],xmm1[0]
-; AVXNC-NEXT:    vpextrw $5, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm3
-; AVXNC-NEXT:    vcvtss2sd %xmm3, %xmm3, %xmm3
-; AVXNC-NEXT:    vpextrw $4, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm0
-; AVXNC-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
-; AVXNC-NEXT:    vmovlhps {{.*#+}} xmm0 = xmm0[0],xmm3[0]
-; AVXNC-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm1
-; AVXNC-NEXT:    vmovaps %ymm2, %ymm0
+; AVXNC-NEXT:    vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
+; AVXNC-NEXT:    vpslld $16, %ymm0, %ymm1
+; AVXNC-NEXT:    vcvtps2pd %xmm1, %ymm0
+; AVXNC-NEXT:    vextracti128 $1, %ymm1, %xmm1
+; AVXNC-NEXT:    vcvtps2pd %xmm1, %ymm1
 ; AVXNC-NEXT:    retq
   %b = fpext <8 x bfloat> %a to <8 x double>
   ret <8 x double> %b
diff --git a/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll b/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
index eff1937b593436..c079a44bc5efd5 100644
--- a/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
+++ b/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
@@ -10,11 +10,11 @@ define void @test(<2 x ptr> %ptr) {
 ; CHECK-NEXT:  # %bb.2: # %loop.127.preheader
 ; CHECK-NEXT:    retq
 ; CHECK-NEXT:  .LBB0_1: # %ifmerge.89
-; CHECK-NEXT:    movzwl (%rax), %eax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    vmovd %eax, %xmm0
-; CHECK-NEXT:    vmulss %xmm0, %xmm0, %xmm0
 ; CHECK-NEXT:    vbroadcastss %xmm0, %xmm0
+; CHECK-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; CHECK-NEXT:    vpbroadcastw (%rax), %xmm2
+; CHECK-NEXT:    vpunpcklwd {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1],xmm1[2],xmm2[2],xmm1[3],xmm2[3]
+; CHECK-NEXT:    vmulps %xmm1, %xmm0, %xmm0
 ; CHECK-NEXT:    vmovlps %xmm0, (%rax)
 entry:
   br label %then.13



More information about the llvm-commits mailing list