[compiler-rt] fb34d53 - Promote bf16 to f32 when the target doesn't support it

Benjamin Kramer via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 15 04:01:16 PDT 2022


Author: Benjamin Kramer
Date: 2022-06-15T12:56:31+02:00
New Revision: fb34d531af953119593be74753b89baf99fbc194

URL: https://github.com/llvm/llvm-project/commit/fb34d531af953119593be74753b89baf99fbc194
DIFF: https://github.com/llvm/llvm-project/commit/fb34d531af953119593be74753b89baf99fbc194.diff

LOG: Promote bf16 to f32 when the target doesn't support it

This is modeled after the half-precision fp support. Two new nodes are
introduced for casting from and to bf16. Since casting from bf16 is a
simple operation I opted to always directly lower it to integer
arithmetic. The other way round is more complicated if you want to
preserve IEEE semantics, so it's handled by a new __truncsfbf2
compiler-rt builtin.

This is of course very bare bones, but sufficient to get a semi-softened
fadd on x86.

Possible future improvements:
 - Targets with bf16 conversion instructions can now make fp_to_bf16 legal
 - The software conversion to bf16 can be replaced by a trivial
   implementation under fast math.

Differential Revision: https://reviews.llvm.org/D126953

Added: 
    compiler-rt/lib/builtins/truncsfbf2.c
    llvm/test/CodeGen/X86/bfloat.ll

Modified: 
    compiler-rt/lib/builtins/CMakeLists.txt
    compiler-rt/lib/builtins/fp_trunc.h
    llvm/include/llvm/CodeGen/ISDOpcodes.h
    llvm/include/llvm/IR/RuntimeLibcalls.def
    llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
    llvm/lib/CodeGen/TargetLoweringBase.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/compiler-rt/lib/builtins/CMakeLists.txt b/compiler-rt/lib/builtins/CMakeLists.txt
index 1b2c7b201c229..007c2725d729b 100644
--- a/compiler-rt/lib/builtins/CMakeLists.txt
+++ b/compiler-rt/lib/builtins/CMakeLists.txt
@@ -167,6 +167,7 @@ set(GENERIC_SOURCES
   trampoline_setup.c
   truncdfhf2.c
   truncdfsf2.c
+  truncsfbf2.c
   truncsfhf2.c
   ucmpdi2.c
   ucmpti2.c

diff  --git a/compiler-rt/lib/builtins/fp_trunc.h b/compiler-rt/lib/builtins/fp_trunc.h
index 00595edd5e018..7a54564a3520a 100644
--- a/compiler-rt/lib/builtins/fp_trunc.h
+++ b/compiler-rt/lib/builtins/fp_trunc.h
@@ -59,6 +59,12 @@ typedef uint16_t dst_rep_t;
 #define DST_REP_C UINT16_C
 static const int dstSigBits = 10;
 
+#elif defined DST_BFLOAT
+typedef uint16_t dst_t;
+typedef uint16_t dst_rep_t;
+#define DST_REP_C UINT16_C
+static const int dstSigBits = 7;
+
 #else
 #error Destination should be single precision or double precision!
 #endif // end destination precision

diff  --git a/compiler-rt/lib/builtins/truncsfbf2.c b/compiler-rt/lib/builtins/truncsfbf2.c
new file mode 100644
index 0000000000000..6bed116af9868
--- /dev/null
+++ b/compiler-rt/lib/builtins/truncsfbf2.c
@@ -0,0 +1,13 @@
+//===-- lib/truncsfbf2.c - single -> bfloat conversion ------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#define SRC_SINGLE
+#define DST_BFLOAT
+#include "fp_trunc_impl.inc"
+
+COMPILER_RT_ABI dst_t __truncsfbf2(float a) { return __truncXfYf2__(a); }

diff  --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index eae6bb926511e..120f89952a95a 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -898,6 +898,13 @@ enum NodeType {
   STRICT_FP16_TO_FP,
   STRICT_FP_TO_FP16,
 
+  /// BF16_TO_FP, FP_TO_BF16 - These operators are used to perform promotions
+  /// and truncation for bfloat16. These nodes form a semi-softened interface
+  /// for dealing with bf16 (as an i16), which is often a storage-only type but
+  /// has native conversions.
+  BF16_TO_FP,
+  FP_TO_BF16,
+
   /// Perform various unary floating-point operations inspired by libm. For
   /// FPOWI, the result is undefined if if the integer operand doesn't fit into
   /// sizeof(int).

diff  --git a/llvm/include/llvm/IR/RuntimeLibcalls.def b/llvm/include/llvm/IR/RuntimeLibcalls.def
index ccd4f2be38a64..b5b9fb4987999 100644
--- a/llvm/include/llvm/IR/RuntimeLibcalls.def
+++ b/llvm/include/llvm/IR/RuntimeLibcalls.def
@@ -310,6 +310,7 @@ HANDLE_LIBCALL(FPROUND_F64_F16, "__truncdfhf2")
 HANDLE_LIBCALL(FPROUND_F80_F16, "__truncxfhf2")
 HANDLE_LIBCALL(FPROUND_F128_F16, "__trunctfhf2")
 HANDLE_LIBCALL(FPROUND_PPCF128_F16, "__trunctfhf2")
+HANDLE_LIBCALL(FPROUND_F32_BF16, "__truncsfbf2")
 HANDLE_LIBCALL(FPROUND_F64_F32, "__truncdfsf2")
 HANDLE_LIBCALL(FPROUND_F80_F32, "__truncxfsf2")
 HANDLE_LIBCALL(FPROUND_F128_F32, "__trunctfsf2")

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 74a5250a44801..8bdc9410d1310 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -998,6 +998,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
       Action = TLI.getOperationAction(Node->getOpcode(), MVT::Other);
     break;
   case ISD::FP_TO_FP16:
+  case ISD::FP_TO_BF16:
   case ISD::SINT_TO_FP:
   case ISD::UINT_TO_FP:
   case ISD::EXTRACT_VECTOR_ELT:
@@ -2904,6 +2905,18 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
                                  Node->getValueType(0), dl)))
       Results.push_back(Tmp1);
     break;
+  case ISD::BF16_TO_FP: {
+    // Always expand bf16 to f32 casts, they lower to ext + shift.
+    SDValue Op = DAG.getNode(ISD::BITCAST, dl, MVT::i16, Node->getOperand(0));
+    Op = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Op);
+    Op = DAG.getNode(
+        ISD::SHL, dl, MVT::i32, Op,
+        DAG.getConstant(16, dl,
+                        TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
+    Op = DAG.getNode(ISD::BITCAST, dl, MVT::f32, Op);
+    Results.push_back(Op);
+    break;
+  }
   case ISD::SIGN_EXTEND_INREG: {
     EVT ExtraVT = cast<VTSDNode>(Node->getOperand(1))->getVT();
     EVT VT = Node->getValueType(0);
@@ -4216,6 +4229,13 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
     Results.push_back(ExpandLibCall(LC, Node, false));
     break;
   }
+  case ISD::FP_TO_BF16: {
+    RTLIB::Libcall LC =
+        RTLIB::getFPROUND(Node->getOperand(0).getValueType(), MVT::bf16);
+    assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unable to expand fp_to_bf16");
+    Results.push_back(ExpandLibCall(LC, Node, false));
+    break;
+  }
   case ISD::STRICT_SINT_TO_FP:
   case ISD::STRICT_UINT_TO_FP:
   case ISD::SINT_TO_FP:

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 12f4118ff9bc9..f464208cd9dcb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -834,6 +834,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
   case ISD::BR_CC:       Res = SoftenFloatOp_BR_CC(N); break;
   case ISD::STRICT_FP_TO_FP16:
   case ISD::FP_TO_FP16:  // Same as FP_ROUND for softening purposes
+  case ISD::FP_TO_BF16:
   case ISD::STRICT_FP_ROUND:
   case ISD::FP_ROUND:    Res = SoftenFloatOp_FP_ROUND(N); break;
   case ISD::STRICT_FP_TO_SINT:
@@ -885,16 +886,19 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
   // returns an i16 so doesn't meet the constraints necessary for FP_ROUND.
   assert(N->getOpcode() == ISD::FP_ROUND || N->getOpcode() == ISD::FP_TO_FP16 ||
          N->getOpcode() == ISD::STRICT_FP_TO_FP16 ||
+         N->getOpcode() == ISD::FP_TO_BF16 ||
          N->getOpcode() == ISD::STRICT_FP_ROUND);
 
   bool IsStrict = N->isStrictFPOpcode();
   SDValue Op = N->getOperand(IsStrict ? 1 : 0);
   EVT SVT = Op.getValueType();
   EVT RVT = N->getValueType(0);
-  EVT FloatRVT = (N->getOpcode() == ISD::FP_TO_FP16 ||
-                  N->getOpcode() == ISD::STRICT_FP_TO_FP16)
-                     ? MVT::f16
-                     : RVT;
+  EVT FloatRVT = RVT;
+  if (N->getOpcode() == ISD::FP_TO_FP16 ||
+      N->getOpcode() == ISD::STRICT_FP_TO_FP16)
+    FloatRVT = MVT::f16;
+  else if (N->getOpcode() == ISD::FP_TO_BF16)
+    FloatRVT = MVT::bf16;
 
   RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, FloatRVT);
   assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported FP_ROUND libcall");
@@ -2068,9 +2072,13 @@ SDValue DAGTypeLegalizer::ExpandFloatOp_LLRINT(SDNode *N) {
 
 static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
   if (OpVT == MVT::f16) {
-      return ISD::FP16_TO_FP;
+    return ISD::FP16_TO_FP;
   } else if (RetVT == MVT::f16) {
-      return ISD::FP_TO_FP16;
+    return ISD::FP_TO_FP16;
+  } else if (OpVT == MVT::bf16) {
+    return ISD::BF16_TO_FP;
+  } else if (RetVT == MVT::bf16) {
+    return ISD::FP_TO_BF16;
   }
 
   report_fatal_error("Attempt at an invalid promotion-related conversion");

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 60a3c02e38cf1..bbfc6e5ef64f5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -365,6 +365,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::STRICT_FP16_TO_FP:          return "strict_fp16_to_fp";
   case ISD::FP_TO_FP16:                 return "fp_to_fp16";
   case ISD::STRICT_FP_TO_FP16:          return "strict_fp_to_fp16";
+  case ISD::BF16_TO_FP:                 return "bf16_to_fp";
+  case ISD::FP_TO_BF16:                 return "fp_to_bf16";
   case ISD::LROUND:                     return "lround";
   case ISD::STRICT_LROUND:              return "strict_lround";
   case ISD::LLROUND:                    return "llround";

diff  --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index c21ffaf205331..ad0f95b06e24a 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -274,6 +274,9 @@ RTLIB::Libcall RTLIB::getFPROUND(EVT OpVT, EVT RetVT) {
       return FPROUND_F128_F16;
     if (OpVT == MVT::ppcf128)
       return FPROUND_PPCF128_F16;
+  } else if (RetVT == MVT::bf16) {
+    if (OpVT == MVT::f32)
+      return FPROUND_F32_BF16;
   } else if (RetVT == MVT::f32) {
     if (OpVT == MVT::f64)
       return FPROUND_F64_F32;
@@ -1373,6 +1376,16 @@ void TargetLoweringBase::computeRegisterProperties(
     }
   }
 
+  // Decide how to handle bf16. If the target does not have native bf16 support,
+  // promote it to f32, because there are no bf16 library calls (except for
+  // converting from f32 to bf16).
+  if (!isTypeLegal(MVT::bf16)) {
+    NumRegistersForVT[MVT::bf16] = NumRegistersForVT[MVT::f32];
+    RegisterTypeForVT[MVT::bf16] = RegisterTypeForVT[MVT::f32];
+    TransformToType[MVT::bf16] = MVT::f32;
+    ValueTypeActions.setTypeAction(MVT::bf16, TypePromoteFloat);
+  }
+
   // Loop over all of the vector value types to see which need transformations.
   for (unsigned i = MVT::FIRST_VECTOR_VALUETYPE;
        i <= (unsigned)MVT::LAST_VECTOR_VALUETYPE; ++i) {

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 8474f90019511..ca2fb6d92975d 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -412,14 +412,15 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     setOperationAction(Op, MVT::f128, Expand);
   }
 
-  setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
-  setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
-  setLoadExtAction(ISD::EXTLOAD, MVT::f80, MVT::f16, Expand);
-  setLoadExtAction(ISD::EXTLOAD, MVT::f128, MVT::f16, Expand);
-  setTruncStoreAction(MVT::f32, MVT::f16, Expand);
-  setTruncStoreAction(MVT::f64, MVT::f16, Expand);
-  setTruncStoreAction(MVT::f80, MVT::f16, Expand);
-  setTruncStoreAction(MVT::f128, MVT::f16, Expand);
+  for (MVT VT : {MVT::f32, MVT::f64, MVT::f80, MVT::f128}) {
+    setLoadExtAction(ISD::EXTLOAD, VT, MVT::f16, Expand);
+    setLoadExtAction(ISD::EXTLOAD, VT, MVT::bf16, Expand);
+    setTruncStoreAction(VT, MVT::f16, Expand);
+    setTruncStoreAction(VT, MVT::bf16, Expand);
+
+    setOperationAction(ISD::BF16_TO_FP, VT, Expand);
+    setOperationAction(ISD::FP_TO_BF16, VT, Expand);
+  }
 
   setOperationAction(ISD::PARITY, MVT::i8, Custom);
   setOperationAction(ISD::PARITY, MVT::i16, Custom);
@@ -916,7 +917,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
 
       // EXTLOAD for MVT::f16 vectors is not legal because f16 vectors are
       // split/scalarized right now.
-      if (VT.getVectorElementType() == MVT::f16)
+      if (VT.getVectorElementType() == MVT::f16 ||
+          VT.getVectorElementType() == MVT::bf16)
         setLoadExtAction(ISD::EXTLOAD, InnerVT, VT, Expand);
     }
   }

diff  --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
new file mode 100644
index 0000000000000..089506c39dae0
--- /dev/null
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -0,0 +1,28 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s
+
+define void @add(ptr %pa, ptr %pb, ptr %pc) {
+; CHECK-LABEL: add:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    pushq %rbx
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    .cfi_offset %rbx, -16
+; CHECK-NEXT:    movq %rdx, %rbx
+; CHECK-NEXT:    movzwl (%rdi), %eax
+; CHECK-NEXT:    shll $16, %eax
+; CHECK-NEXT:    movd %eax, %xmm1
+; CHECK-NEXT:    movzwl (%rsi), %eax
+; CHECK-NEXT:    shll $16, %eax
+; CHECK-NEXT:    movd %eax, %xmm0
+; CHECK-NEXT:    addss %xmm1, %xmm0
+; CHECK-NEXT:    callq __truncsfbf2 at PLT
+; CHECK-NEXT:    movw %ax, (%rbx)
+; CHECK-NEXT:    popq %rbx
+; CHECK-NEXT:    .cfi_def_cfa_offset 8
+; CHECK-NEXT:    retq
+  %a = load bfloat, ptr %pa
+  %b = load bfloat, ptr %pb
+  %add = fadd bfloat %a, %b
+  store bfloat %add, ptr %pc
+  ret void
+}


        


More information about the llvm-commits mailing list