[llvm] [X86][BF16] Do not lower to VCVTNEPS2BF16 without AVX512VL (PR #86395)
Phoebe Wang via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 23 07:13:36 PDT 2024
https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/86395
Fixes: #86305
>From 21ac12037fee7848ed73286b010ad05870fce940 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Sat, 23 Mar 2024 22:09:10 +0800
Subject: [PATCH] [X86][BF16] Do not lower to VCVTNEPS2BF16 without AVX512VL
Fixes: #86305
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 7 ++-
llvm/test/CodeGen/X86/pr86305.ll | 74 +++++++++++++++++++++++++
2 files changed, 79 insertions(+), 2 deletions(-)
create mode 100644 llvm/test/CodeGen/X86/pr86305.ll
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 35f756ea5e1d86..9acbe17d0bcad2 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -21512,7 +21512,9 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
}
if (VT.getScalarType() == MVT::bf16) {
- if (SVT.getScalarType() == MVT::f32 && isTypeLegal(VT))
+ if (SVT.getScalarType() == MVT::f32 &&
+ ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
+ Subtarget.hasAVXNECONVERT()))
return Op;
return SDValue();
}
@@ -21619,7 +21621,8 @@ SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
SDLoc DL(Op);
MVT SVT = Op.getOperand(0).getSimpleValueType();
- if (SVT == MVT::f32 && (Subtarget.hasBF16() || Subtarget.hasAVXNECONVERT())) {
+ if (SVT == MVT::f32 && ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
+ Subtarget.hasAVXNECONVERT())) {
SDValue Res;
Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, Op.getOperand(0));
Res = DAG.getNode(X86ISD::CVTNEPS2BF16, DL, MVT::v8bf16, Res);
diff --git a/llvm/test/CodeGen/X86/pr86305.ll b/llvm/test/CodeGen/X86/pr86305.ll
new file mode 100644
index 00000000000000..79b42bb2532ca9
--- /dev/null
+++ b/llvm/test/CodeGen/X86/pr86305.ll
@@ -0,0 +1,74 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc < %s -mtriple=x86_64-linux-gnu -mattr=avx512bf16 | FileCheck %s
+
+define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
+; CHECK-LABEL: add:
+; CHECK: # %bb.0:
+; CHECK-NEXT: pushq %rbx
+; CHECK-NEXT: movq %rdx, %rbx
+; CHECK-NEXT: movzwl (%rsi), %eax
+; CHECK-NEXT: shll $16, %eax
+; CHECK-NEXT: vmovd %eax, %xmm0
+; CHECK-NEXT: movzwl (%rdi), %eax
+; CHECK-NEXT: shll $16, %eax
+; CHECK-NEXT: vmovd %eax, %xmm1
+; CHECK-NEXT: vaddss %xmm0, %xmm1, %xmm0
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vpextrw $0, %xmm0, (%rbx)
+; CHECK-NEXT: popq %rbx
+; CHECK-NEXT: retq
+ %a = load bfloat, ptr %pa
+ %b = load bfloat, ptr %pb
+ %add = fadd bfloat %a, %b
+ store bfloat %add, ptr %pc
+ ret void
+}
+
+define <4 x bfloat> @fptrunc_v4f32(<4 x float> %a) nounwind {
+; CHECK-LABEL: fptrunc_v4f32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: pushq %rbp
+; CHECK-NEXT: pushq %r15
+; CHECK-NEXT: pushq %r14
+; CHECK-NEXT: pushq %rbx
+; CHECK-NEXT: subq $72, %rsp
+; CHECK-NEXT: vmovaps %xmm0, (%rsp) # 16-byte Spill
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vmovaps %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT: vpermilpd $1, (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT: # xmm0 = mem[1,0]
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vmovapd %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT: vpshufd $255, (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT: # xmm0 = mem[3,3,3,3]
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vpextrw $0, %xmm0, %ebx
+; CHECK-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT: vpextrw $0, %xmm0, %ebp
+; CHECK-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT: vpextrw $0, %xmm0, %r14d
+; CHECK-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload
+; CHECK-NEXT: vpextrw $0, %xmm0, %r15d
+; CHECK-NEXT: vmovshdup (%rsp), %xmm0 # 16-byte Folded Reload
+; CHECK-NEXT: # xmm0 = mem[1,1,3,3]
+; CHECK-NEXT: callq __truncsfbf2 at PLT
+; CHECK-NEXT: vpextrw $0, %xmm0, %eax
+; CHECK-NEXT: vmovd %r15d, %xmm0
+; CHECK-NEXT: vpinsrw $1, %eax, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $2, %r14d, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $3, %ebp, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $4, %ebx, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $5, %ebx, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $6, %ebx, %xmm0, %xmm0
+; CHECK-NEXT: vpinsrw $7, %ebx, %xmm0, %xmm0
+; CHECK-NEXT: addq $72, %rsp
+; CHECK-NEXT: popq %rbx
+; CHECK-NEXT: popq %r14
+; CHECK-NEXT: popq %r15
+; CHECK-NEXT: popq %rbp
+; CHECK-NEXT: retq
+ %b = fptrunc <4 x float> %a to <4 x bfloat>
+ ret <4 x bfloat> %b
+}
More information about the llvm-commits
mailing list