[llvm] 061e368 - [SelectionDAG] Implement soft FP legalisation for bf16 FP_EXTEND and BF16_TO_FP

Alex Bradbury via llvm-commits llvm-commits at lists.llvm.org
Mon May 29 02:33:51 PDT 2023


Author: Alex Bradbury
Date: 2023-05-29T10:32:28+01:00
New Revision: 061e368fe213bd0701261a3e59f796c7439484fc

URL: https://github.com/llvm/llvm-project/commit/061e368fe213bd0701261a3e59f796c7439484fc
DIFF: https://github.com/llvm/llvm-project/commit/061e368fe213bd0701261a3e59f796c7439484fc.diff

LOG: [SelectionDAG] Implement soft FP legalisation for bf16 FP_EXTEND and BF16_TO_FP

As discussed in D151436, it's safe to do this as a simple shift (as is
done in LegalizeDAG.cpp) rather than needing a libcall. The added test
cases for RISC-V previously just triggered an assertion.

Codegen for bfloat_to_double will be slightly improved by D151434.

Differential Revision: https://reviews.llvm.org/D151563

Added: 
    llvm/test/CodeGen/RISCV/bfloat.ll

Modified: 
    llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index f1e80ce7e037d..29a1951bf9a3a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -107,6 +107,7 @@ void DAGTypeLegalizer::SoftenFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::STRICT_FP_ROUND:
     case ISD::FP_ROUND:    R = SoftenFloatRes_FP_ROUND(N); break;
     case ISD::FP16_TO_FP:  R = SoftenFloatRes_FP16_TO_FP(N); break;
+    case ISD::BF16_TO_FP:  R = SoftenFloatRes_BF16_TO_FP(N); break;
     case ISD::STRICT_FPOW:
     case ISD::FPOW:        R = SoftenFloatRes_FPOW(N); break;
     case ISD::STRICT_FPOWI:
@@ -510,10 +511,12 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
       return BitConvertToInteger(Op);
   }
 
-  // There's only a libcall for f16 -> f32, so proceed in two stages. Also, it's
-  // entirely possible for both f16 and f32 to be legal, so use the fully
-  // hard-float FP_EXTEND rather than FP16_TO_FP.
-  if (Op.getValueType() == MVT::f16 && N->getValueType(0) != MVT::f32) {
+  // There's only a libcall for f16 -> f32 and shifting is only valid for bf16
+  // -> f32, so proceed in two stages. Also, it's entirely possible for both
+  // f16 and f32 to be legal, so use the fully hard-float FP_EXTEND rather
+  // than FP16_TO_FP.
+  if ((Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16) &&
+      N->getValueType(0) != MVT::f32) {
     if (IsStrict) {
       Op = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N),
                        { MVT::f32, MVT::Other }, { Chain, Op });
@@ -523,6 +526,9 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
     }
   }
 
+  if (Op.getValueType() == MVT::bf16)
+    return SoftenFloatRes_BF16_TO_FP(N);
+
   RTLIB::Libcall LC = RTLIB::getFPEXT(Op.getValueType(), N->getValueType(0));
   assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported FP_EXTEND!");
   TargetLowering::MakeLibCallOptions CallOptions;
@@ -555,6 +561,21 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP16_TO_FP(SDNode *N) {
   return TLI.makeLibCall(DAG, LC, NVT, Res32, CallOptions, SDLoc(N)).first;
 }
 
+// FIXME: Should we just use 'normal' FP_EXTEND / FP_TRUNC instead of special
+// nodes?
+SDValue DAGTypeLegalizer::SoftenFloatRes_BF16_TO_FP(SDNode *N) {
+  assert(N->getValueType(0) == MVT::f32 &&
+         "Can only soften BF16_TO_FP with f32 result");
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), MVT::f32);
+  SDValue Op = N->getOperand(0);
+  SDLoc DL(N);
+  Op = DAG.getNode(ISD::ANY_EXTEND, DL, NVT,
+                   DAG.getNode(ISD::BITCAST, DL, MVT::i16, Op));
+  SDValue Res = DAG.getNode(ISD::SHL, DL, NVT, Op,
+                            DAG.getShiftAmountConstant(16, NVT, DL));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::SoftenFloatRes_FP_ROUND(SDNode *N) {
   bool IsStrict = N->isStrictFPOpcode();
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 09d47caeef471..e73b6b1a826cf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -560,6 +560,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SoftenFloatRes_FNEG(SDNode *N);
   SDValue SoftenFloatRes_FP_EXTEND(SDNode *N);
   SDValue SoftenFloatRes_FP16_TO_FP(SDNode *N);
+  SDValue SoftenFloatRes_BF16_TO_FP(SDNode *N);
   SDValue SoftenFloatRes_FP_ROUND(SDNode *N);
   SDValue SoftenFloatRes_FPOW(SDNode *N);
   SDValue SoftenFloatRes_FPOWI(SDNode *N);

diff  --git a/llvm/test/CodeGen/RISCV/bfloat.ll b/llvm/test/CodeGen/RISCV/bfloat.ll
new file mode 100644
index 0000000000000..e7583a595ff06
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/bfloat.ll
@@ -0,0 +1,116 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s | FileCheck %s -check-prefix=RV32I-ILP32
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s | FileCheck %s -check-prefix=RV64I-LP64
+
+; TODO: Enable codegen for hard float.
+
+define bfloat @float_to_bfloat(float %a) nounwind {
+; RV32I-ILP32-LABEL: float_to_bfloat:
+; RV32I-ILP32:       # %bb.0:
+; RV32I-ILP32-NEXT:    addi sp, sp, -16
+; RV32I-ILP32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT:    call __truncsfbf2 at plt
+; RV32I-ILP32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT:    addi sp, sp, 16
+; RV32I-ILP32-NEXT:    ret
+;
+; RV64I-LP64-LABEL: float_to_bfloat:
+; RV64I-LP64:       # %bb.0:
+; RV64I-LP64-NEXT:    addi sp, sp, -16
+; RV64I-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT:    call __truncsfbf2 at plt
+; RV64I-LP64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT:    addi sp, sp, 16
+; RV64I-LP64-NEXT:    ret
+  %1 = fptrunc float %a to bfloat
+  ret bfloat %1
+}
+
+define bfloat @double_to_bfloat(double %a) nounwind {
+; RV32I-ILP32-LABEL: double_to_bfloat:
+; RV32I-ILP32:       # %bb.0:
+; RV32I-ILP32-NEXT:    addi sp, sp, -16
+; RV32I-ILP32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT:    call __truncdfbf2 at plt
+; RV32I-ILP32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT:    addi sp, sp, 16
+; RV32I-ILP32-NEXT:    ret
+;
+; RV64I-LP64-LABEL: double_to_bfloat:
+; RV64I-LP64:       # %bb.0:
+; RV64I-LP64-NEXT:    addi sp, sp, -16
+; RV64I-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT:    call __truncdfbf2 at plt
+; RV64I-LP64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT:    addi sp, sp, 16
+; RV64I-LP64-NEXT:    ret
+  %1 = fptrunc double %a to bfloat
+  ret bfloat %1
+}
+
+define float @bfloat_to_float(bfloat %a) nounwind {
+; RV32I-ILP32-LABEL: bfloat_to_float:
+; RV32I-ILP32:       # %bb.0:
+; RV32I-ILP32-NEXT:    slli a0, a0, 16
+; RV32I-ILP32-NEXT:    ret
+;
+; RV64I-LP64-LABEL: bfloat_to_float:
+; RV64I-LP64:       # %bb.0:
+; RV64I-LP64-NEXT:    slliw a0, a0, 16
+; RV64I-LP64-NEXT:    ret
+  %1 = fpext bfloat %a to float
+  ret float %1
+}
+
+define double @bfloat_to_double(bfloat %a) nounwind {
+; RV32I-ILP32-LABEL: bfloat_to_double:
+; RV32I-ILP32:       # %bb.0:
+; RV32I-ILP32-NEXT:    addi sp, sp, -16
+; RV32I-ILP32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT:    slli a0, a0, 16
+; RV32I-ILP32-NEXT:    call __extendsfdf2 at plt
+; RV32I-ILP32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT:    addi sp, sp, 16
+; RV32I-ILP32-NEXT:    ret
+;
+; RV64I-LP64-LABEL: bfloat_to_double:
+; RV64I-LP64:       # %bb.0:
+; RV64I-LP64-NEXT:    addi sp, sp, -16
+; RV64I-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT:    slli a0, a0, 48
+; RV64I-LP64-NEXT:    srli a0, a0, 32
+; RV64I-LP64-NEXT:    call __extendsfdf2 at plt
+; RV64I-LP64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT:    addi sp, sp, 16
+; RV64I-LP64-NEXT:    ret
+  %1 = fpext bfloat %a to double
+  ret double %1
+}
+
+define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
+; RV32I-ILP32-LABEL: bfloat_add:
+; RV32I-ILP32:       # %bb.0:
+; RV32I-ILP32-NEXT:    addi sp, sp, -16
+; RV32I-ILP32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT:    slli a0, a0, 16
+; RV32I-ILP32-NEXT:    slli a1, a1, 16
+; RV32I-ILP32-NEXT:    call __addsf3 at plt
+; RV32I-ILP32-NEXT:    call __truncsfbf2 at plt
+; RV32I-ILP32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT:    addi sp, sp, 16
+; RV32I-ILP32-NEXT:    ret
+;
+; RV64I-LP64-LABEL: bfloat_add:
+; RV64I-LP64:       # %bb.0:
+; RV64I-LP64-NEXT:    addi sp, sp, -16
+; RV64I-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT:    slliw a0, a0, 16
+; RV64I-LP64-NEXT:    slliw a1, a1, 16
+; RV64I-LP64-NEXT:    call __addsf3 at plt
+; RV64I-LP64-NEXT:    call __truncsfbf2 at plt
+; RV64I-LP64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT:    addi sp, sp, 16
+; RV64I-LP64-NEXT:    ret
+  %1 = fadd bfloat %a, %b
+  ret bfloat %1
+}


        


More information about the llvm-commits mailing list