[llvm] [X86] Add support for `__bf16` to `f16` conversion (PR #134859)
Antonio Frighetto via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 8 07:16:35 PDT 2025
https://github.com/antoniofrighetto created https://github.com/llvm/llvm-project/pull/134859
`bf16` is a typedef short type introduced in AVX-512_BF16 and should be able to leverage SSE/AVX registers used for `f16`.
Fixes: https://github.com/llvm/llvm-project/issues/134222.
>From 7f2a1120c9fc31912e7d98de99dd2a04f368b5c6 Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Tue, 8 Apr 2025 16:06:44 +0200
Subject: [PATCH] [X86] Add support for `__bf16` to `f16` conversion
`bf16` is a typedef short type introduced in AVX-512_BF16 and
should be able to leverage SSE/AVX registers used for `f16`.
Fixes: https://github.com/llvm/llvm-project/issues/134222.
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 12 +++++++++++-
llvm/lib/Target/X86/X86InstrAVX512.td | 5 +++++
2 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a4381b99dbae0..ad170f49aebb5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -661,10 +661,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
};
if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) {
- // f16, f32 and f64 use SSE.
+ // f16, bf16, f32 and f64 use SSE.
// Set up the FP register classes.
addRegisterClass(MVT::f16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
: &X86::FR16RegClass);
+ addRegisterClass(MVT::bf16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
+ : &X86::FR16RegClass);
addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass
: &X86::FR32RegClass);
addRegisterClass(MVT::f64, Subtarget.hasAVX512() ? &X86::FR64XRegClass
@@ -676,6 +678,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
// non-optsize case.
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
+ // Set the operation action Custom for bitcast to do the customization
+ // later.
+ setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
+
for (auto VT : { MVT::f32, MVT::f64 }) {
// Use ANDPD to simulate FABS.
setOperationAction(ISD::FABS, VT, Custom);
@@ -32151,6 +32157,10 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getZExtOrTrunc(V, DL, DstVT);
}
+ // Bitcasts between f16 and bf16 should be legal.
+ if (DstVT == MVT::f16 || DstVT == MVT::bf16)
+ return Op;
+
assert((SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 ||
SrcVT == MVT::i64) && "Unexpected VT!");
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 0ab94cca41425..5705094dca55b 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -2456,6 +2456,11 @@ let Predicates = [HasFP16] in {
(VCMPSHZrmi FR16X:$src1, addr:$src2, (X86cmpm_imm_commute timm:$cc))>;
}
+let Predicates = [HasAVX512, HasBF16] in {
+ def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
+ def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
+}
+
// ----------------------------------------------------------------
// FPClass
More information about the llvm-commits
mailing list