[llvm] [X86][BF16] Do not lower to VCVTNEPS2BF16 without AVX512VL (PR #86395)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 23 07:14:11 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86
Author: Phoebe Wang (phoebewang)
<details>
<summary>Changes</summary>
Fixes: #<!-- -->86305
---
Full diff: https://github.com/llvm/llvm-project/pull/86395.diff
2 Files Affected:
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+5-2)
- (added) llvm/test/CodeGen/X86/pr86305.ll (+74)
``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 35f756ea5e1d86..9acbe17d0bcad2 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -21512,7 +21512,9 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
}
if (VT.getScalarType() == MVT::bf16) {
- if (SVT.getScalarType() == MVT::f32 && isTypeLegal(VT))
+ if (SVT.getScalarType() == MVT::f32 &&
+ ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
+ Subtarget.hasAVXNECONVERT()))
return Op;
return SDValue();
}
@@ -21619,7 +21621,8 @@ SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
SDLoc DL(Op);
MVT SVT = Op.getOperand(0).getSimpleValueType();
- if (SVT == MVT::f32 && (Subtarget.hasBF16() || Subtarget.hasAVXNECONVERT())) {
+ if (SVT == MVT::f32 && ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
+ Subtarget.hasAVXNECONVERT())) {
SDValue Res;
Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, Op.getOperand(0));
Res = DAG.getNode(X86ISD::CVTNEPS2BF16, DL, MVT::v8bf16, Res);
diff --git a/llvm/test/CodeGen/X86/pr86305.ll b/llvm/test/CodeGen/X86/pr86305.ll
new file mode 100644
index 00000000000000..79b42bb2532ca9
--- /dev/null
+++ b/llvm/test/CodeGen/X86/pr86305.ll
@@ -0,0 +1,74 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc < %s -mtriple=x86_64-linux-gnu -mattr=avx512bf16 | FileCheck %s
+
+define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
+; CHECK-LABEL: add:
+; CHECK: # %bb.0:
+; CHECK-NEXT: pushq %rbx
+; CHECK-NEXT: movq %rdx, %rbx
+; CHECK-NEXT: movzwl (%rsi), %eax
+; CHECK-NEXT: shll $16, %eax
+; CHECK-NEXT: vmovd %eax, %xmm0
+; CHECK-NEXT: movzwl (%rdi), %eax
+; CHECK-NEXT: shll $16, %eax
+; CHECK-NEXT: vmovd %eax, %xmm1
+; CHECK-NEXT: vaddss %xmm0, %xmm1, %xmm0
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vpextrw $0, %xmm0, (%rbx)
+; CHECK-NEXT: popq %rbx
+; CHECK-NEXT: retq
+ %a = load bfloat, ptr %pa
+ %b = load bfloat, ptr %pb
+ %add = fadd bfloat %a, %b
+ store bfloat %add, ptr %pc
+ ret void
+}
+
+define <4 x bfloat> @fptrunc_v4f32(<4 x float> %a) nounwind {
+; CHECK-LABEL: fptrunc_v4f32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: pushq %rbp
+; CHECK-NEXT: pushq %r15
+; CHECK-NEXT: pushq %r14
+; CHECK-NEXT: pushq %rbx
+; CHECK-NEXT: subq $72, %rsp
+; CHECK-NEXT: vmovaps %xmm0, (%rsp) # 16-byte Spill
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vmovaps %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT: vpermilpd $1, (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT: # xmm0 = mem[1,0]
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vmovapd %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT: vpshufd $255, (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT: # xmm0 = mem[3,3,3,3]
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vpextrw $0, %xmm0, %ebx
+; CHECK-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT: vpextrw $0, %xmm0, %ebp
+; CHECK-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT: vpextrw $0, %xmm0, %r14d
+; CHECK-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT: vpextrw $0, %xmm0, %r15d
+; CHECK-NEXT: vmovshdup (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT: # xmm0 = mem[1,1,3,3]
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vpextrw $0, %xmm0, %eax
+; CHECK-NEXT: vmovd %r15d, %xmm0
+; CHECK-NEXT: vpinsrw $1, %eax, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $2, %r14d, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $3, %ebp, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $4, %ebx, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $5, %ebx, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $6, %ebx, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $7, %ebx, %xmm0, %xmm0
+; CHECK-NEXT: addq $72, %rsp
+; CHECK-NEXT: popq %rbx
+; CHECK-NEXT: popq %r14
+; CHECK-NEXT: popq %r15
+; CHECK-NEXT: popq %rbp
+; CHECK-NEXT: retq
+ %b = fptrunc <4 x float> %a to <4 x bfloat>
+ ret <4 x bfloat> %b
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/86395
More information about the llvm-commits
mailing list