[llvm] [X86][BF16] Do not lower to VCVTNEPS2BF16 without AVX512VL (PR #86395)

via llvm-commits llvm-commits at lists.llvm.org
Sat Mar 23 07:14:11 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Phoebe Wang (phoebewang)

<details>
<summary>Changes</summary>

Fixes: #<!-- -->86305

---
Full diff: https://github.com/llvm/llvm-project/pull/86395.diff


2 Files Affected:

- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+5-2) 
- (added) llvm/test/CodeGen/X86/pr86305.ll (+74) 


``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 35f756ea5e1d86..9acbe17d0bcad2 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -21512,7 +21512,9 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
   }
 
   if (VT.getScalarType() == MVT::bf16) {
-    if (SVT.getScalarType() == MVT::f32 && isTypeLegal(VT))
+    if (SVT.getScalarType() == MVT::f32 &&
+        ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
+         Subtarget.hasAVXNECONVERT()))
       return Op;
     return SDValue();
   }
@@ -21619,7 +21621,8 @@ SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
   SDLoc DL(Op);
 
   MVT SVT = Op.getOperand(0).getSimpleValueType();
-  if (SVT == MVT::f32 && (Subtarget.hasBF16() || Subtarget.hasAVXNECONVERT())) {
+  if (SVT == MVT::f32 && ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
+                          Subtarget.hasAVXNECONVERT())) {
     SDValue Res;
     Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, Op.getOperand(0));
     Res = DAG.getNode(X86ISD::CVTNEPS2BF16, DL, MVT::v8bf16, Res);
diff --git a/llvm/test/CodeGen/X86/pr86305.ll b/llvm/test/CodeGen/X86/pr86305.ll
new file mode 100644
index 00000000000000..79b42bb2532ca9
--- /dev/null
+++ b/llvm/test/CodeGen/X86/pr86305.ll
@@ -0,0 +1,74 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc < %s -mtriple=x86_64-linux-gnu -mattr=avx512bf16 | FileCheck %s
+
+define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
+; CHECK-LABEL: add:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    pushq %rbx
+; CHECK-NEXT:    movq %rdx, %rbx
+; CHECK-NEXT:    movzwl (%rsi), %eax
+; CHECK-NEXT:    shll $16, %eax
+; CHECK-NEXT:    vmovd %eax, %xmm0
+; CHECK-NEXT:    movzwl (%rdi), %eax
+; CHECK-NEXT:    shll $16, %eax
+; CHECK-NEXT:    vmovd %eax, %xmm1
+; CHECK-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vpextrw $0, %xmm0, (%rbx)
+; CHECK-NEXT:    popq %rbx
+; CHECK-NEXT:    retq
+  %a = load bfloat, ptr %pa
+  %b = load bfloat, ptr %pb
+  %add = fadd bfloat %a, %b
+  store bfloat %add, ptr %pc
+  ret void
+}
+
+define <4 x bfloat> @fptrunc_v4f32(<4 x float> %a) nounwind {
+; CHECK-LABEL: fptrunc_v4f32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    pushq %rbp
+; CHECK-NEXT:    pushq %r15
+; CHECK-NEXT:    pushq %r14
+; CHECK-NEXT:    pushq %rbx
+; CHECK-NEXT:    subq $72, %rsp
+; CHECK-NEXT:    vmovaps %xmm0, (%rsp) # 16-byte Spill
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vmovaps %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT:    vpermilpd $1, (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT:    # xmm0 = mem[1,0]
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vmovapd %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT:    vpshufd $255, (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT:    # xmm0 = mem[3,3,3,3]
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vpextrw $0, %xmm0, %ebx
+; CHECK-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT:    vpextrw $0, %xmm0, %ebp
+; CHECK-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT:    vpextrw $0, %xmm0, %r14d
+; CHECK-NEXT:    vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT:    vpextrw $0, %xmm0, %r15d
+; CHECK-NEXT:    vmovshdup (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT:    # xmm0 = mem[1,1,3,3]
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    vpextrw $0, %xmm0, %eax
+; CHECK-NEXT:    vmovd %r15d, %xmm0
+; CHECK-NEXT:    vpinsrw $1, %eax, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $2, %r14d, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $3, %ebp, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $4, %ebx, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $5, %ebx, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $6, %ebx, %xmm0, %xmm0
+; CHECK-NEXT:    vpinsrw $7, %ebx, %xmm0, %xmm0
+; CHECK-NEXT:    addq $72, %rsp
+; CHECK-NEXT:    popq %rbx
+; CHECK-NEXT:    popq %r14
+; CHECK-NEXT:    popq %r15
+; CHECK-NEXT:    popq %rbp
+; CHECK-NEXT:    retq
+  %b = fptrunc <4 x float> %a to <4 x bfloat>
+  ret <4 x bfloat> %b
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/86395


More information about the llvm-commits mailing list