[llvm] [X86] Add support for `__bf16` to `f16` conversion (PR #134859)

Antonio Frighetto via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 8 07:16:35 PDT 2025


https://github.com/antoniofrighetto created https://github.com/llvm/llvm-project/pull/134859

`bf16` is a typedef short type introduced in AVX-512_BF16 and should be able to leverage SSE/AVX registers used for `f16`.

Fixes: https://github.com/llvm/llvm-project/issues/134222.

>From 7f2a1120c9fc31912e7d98de99dd2a04f368b5c6 Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Tue, 8 Apr 2025 16:06:44 +0200
Subject: [PATCH] [X86] Add support for `__bf16` to `f16` conversion

`bf16` is a typedef short type introduced in AVX-512_BF16 and
should be able to leverage SSE/AVX registers used for `f16`.

Fixes: https://github.com/llvm/llvm-project/issues/134222.
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 12 +++++++++++-
 llvm/lib/Target/X86/X86InstrAVX512.td   |  5 +++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a4381b99dbae0..ad170f49aebb5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -661,10 +661,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
   };
 
   if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) {
-    // f16, f32 and f64 use SSE.
+    // f16, bf16, f32 and f64 use SSE.
     // Set up the FP register classes.
     addRegisterClass(MVT::f16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
                                                      : &X86::FR16RegClass);
+    addRegisterClass(MVT::bf16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
+                                                      : &X86::FR16RegClass);
     addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass
                                                      : &X86::FR32RegClass);
     addRegisterClass(MVT::f64, Subtarget.hasAVX512() ? &X86::FR64XRegClass
@@ -676,6 +678,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     // non-optsize case.
     setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
 
+    // Set the operation action Custom for bitcast to do the customization
+    // later.
+    setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
+
     for (auto VT : { MVT::f32, MVT::f64 }) {
       // Use ANDPD to simulate FABS.
       setOperationAction(ISD::FABS, VT, Custom);
@@ -32151,6 +32157,10 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
     return DAG.getZExtOrTrunc(V, DL, DstVT);
   }
 
+  // Bitcasts between f16 and bf16 should be legal.
+  if (DstVT == MVT::f16 || DstVT == MVT::bf16)
+    return Op;
+
   assert((SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 ||
           SrcVT == MVT::i64) && "Unexpected VT!");
 
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 0ab94cca41425..5705094dca55b 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -2456,6 +2456,11 @@ let Predicates = [HasFP16] in {
             (VCMPSHZrmi FR16X:$src1, addr:$src2, (X86cmpm_imm_commute timm:$cc))>;
 }
 
+let Predicates = [HasAVX512, HasBF16] in {
+  def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
+  def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
+}
+
 // ----------------------------------------------------------------
 // FPClass
 



More information about the llvm-commits mailing list