[llvm] [X86][BF16] Do not lower to VCVTNEPS2BF16 without AVX512VL (PR #86395)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Sat Mar 23 07:13:36 PDT 2024


https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/86395

Fixes: #86305

>From 21ac12037fee7848ed73286b010ad05870fce940 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Sat, 23 Mar 2024 22:09:10 +0800
Subject: [PATCH] [X86][BF16] Do not lower to VCVTNEPS2BF16 without AVX512VL

Fixes: #86305
---
 llvm/lib/Target/X86/X86ISelLowering.cpp |  7 ++-
 llvm/test/CodeGen/X86/pr86305.ll        | 74 +++++++++++++++++++++++++
 2 files changed, 79 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/X86/pr86305.ll

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 35f756ea5e1d86..9acbe17d0bcad2 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -21512,7 +21512,9 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
   }
 
   if (VT.getScalarType() == MVT::bf16) {
-    if (SVT.getScalarType() == MVT::f32 && isTypeLegal(VT))
+    if (SVT.getScalarType() == MVT::f32 &&
+        ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
+         Subtarget.hasAVXNECONVERT()))
       return Op;
     return SDValue();
   }
@@ -21619,7 +21621,8 @@ SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
   SDLoc DL(Op);
 
   MVT SVT = Op.getOperand(0).getSimpleValueType();
-  if (SVT == MVT::f32 && (Subtarget.hasBF16() || Subtarget.hasAVXNECONVERT())) {
+  if (SVT == MVT::f32 && ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
+                          Subtarget.hasAVXNECONVERT())) {
     SDValue Res;
     Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, Op.getOperand(0));
     Res = DAG.getNode(X86ISD::CVTNEPS2BF16, DL, MVT::v8bf16, Res);
diff --git a/llvm/test/CodeGen/X86/pr86305.ll b/llvm/test/CodeGen/X86/pr86305.ll
new file mode 100644
index 00000000000000..79b42bb2532ca9
--- /dev/null
+++ b/llvm/test/CodeGen/X86/pr86305.ll
@@ -0,0 +1,74 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc < %s -mtriple=x86_64-linux-gnu -mattr=avx512bf16 | FileCheck %s
+
+define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
+; CHECK-LABEL: add:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    pushq %rbx
+; CHECK-NEXT:    movq %rdx, %rbx
+; CHECK-NEXT:    movzwl (%rsi), %eax
+; CHECK-NEXT:    shll $16, %eax
+; CHECK-NEXT:    vmovd %eax, %xmm0
+; CHECK-NEXT:    movzwl (%rdi), %eax
+; CHECK-NEXT:    shll $16, %eax
+; CHECK-NEXT:    vmovd %eax, %xmm1
+; CHECK-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vpextrw $0, %xmm0, (%rbx)
+; CHECK-NEXT:    popq %rbx
+; CHECK-NEXT:    retq
+  %a = load bfloat, ptr %pa
+  %b = load bfloat, ptr %pb
+  %add = fadd bfloat %a, %b
+  store bfloat %add, ptr %pc
+  ret void
+}
+
+define <4 x bfloat> @fptrunc_v4f32(<4 x float> %a) nounwind {
+; CHECK-LABEL: fptrunc_v4f32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    pushq %rbp
+; CHECK-NEXT:    pushq %r15
+; CHECK-NEXT:    pushq %r14
+; CHECK-NEXT:    pushq %rbx
+; CHECK-NEXT:    subq $72, %rsp
+; CHECK-NEXT:    vmovaps %xmm0, (%rsp) # 16-byte Spill
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vmovaps %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT:    vpermilpd $1, (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT:    # xmm0 = mem[1,0]
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vmovapd %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT:    vpshufd $255, (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT:    # xmm0 = mem[3,3,3,3]
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vpextrw $0, %xmm0, %ebx
+; CHECK-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT:    vpextrw $0, %xmm0, %ebp
+; CHECK-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT:    vpextrw $0, %xmm0, %r14d
+; CHECK-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT:    vpextrw $0, %xmm0, %r15d
+; CHECK-NEXT:    vmovshdup (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT:    # xmm0 = mem[1,1,3,3]
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vpextrw $0, %xmm0, %eax
+; CHECK-NEXT:    vmovd %r15d, %xmm0
+; CHECK-NEXT:    vpinsrw $1, %eax, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $2, %r14d, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $3, %ebp, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $4, %ebx, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $5, %ebx, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $6, %ebx, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $7, %ebx, %xmm0, %xmm0
+; CHECK-NEXT:    addq $72, %rsp
+; CHECK-NEXT:    popq %rbx
+; CHECK-NEXT:    popq %r14
+; CHECK-NEXT:    popq %r15
+; CHECK-NEXT:    popq %rbp
+; CHECK-NEXT:    retq
+  %b = fptrunc <4 x float> %a to <4 x bfloat>
+  ret <4 x bfloat> %b
+}



More information about the llvm-commits mailing list