[llvm] 061e368 - [SelectionDAG] Implement soft FP legalisation for bf16 FP_EXTEND and BF16_TO_FP
Alex Bradbury via llvm-commits
llvm-commits at lists.llvm.org
Mon May 29 02:33:51 PDT 2023
Author: Alex Bradbury
Date: 2023-05-29T10:32:28+01:00
New Revision: 061e368fe213bd0701261a3e59f796c7439484fc
URL: https://github.com/llvm/llvm-project/commit/061e368fe213bd0701261a3e59f796c7439484fc
DIFF: https://github.com/llvm/llvm-project/commit/061e368fe213bd0701261a3e59f796c7439484fc.diff
LOG: [SelectionDAG] Implement soft FP legalisation for bf16 FP_EXTEND and BF16_TO_FP
As discussed in D151436, it's safe to do this as a simple shift (as is
done in LegalizeDAG.cpp) rather than needing a libcall. The added test
cases for RISC-V previously just triggered an assertion.
Codegen for bfloat_to_double will be slightly improved by D151434.
Differential Revision: https://reviews.llvm.org/D151563
Added:
llvm/test/CodeGen/RISCV/bfloat.ll
Modified:
llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index f1e80ce7e037d..29a1951bf9a3a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -107,6 +107,7 @@ void DAGTypeLegalizer::SoftenFloatResult(SDNode *N, unsigned ResNo) {
case ISD::STRICT_FP_ROUND:
case ISD::FP_ROUND: R = SoftenFloatRes_FP_ROUND(N); break;
case ISD::FP16_TO_FP: R = SoftenFloatRes_FP16_TO_FP(N); break;
+ case ISD::BF16_TO_FP: R = SoftenFloatRes_BF16_TO_FP(N); break;
case ISD::STRICT_FPOW:
case ISD::FPOW: R = SoftenFloatRes_FPOW(N); break;
case ISD::STRICT_FPOWI:
@@ -510,10 +511,12 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
return BitConvertToInteger(Op);
}
- // There's only a libcall for f16 -> f32, so proceed in two stages. Also, it's
- // entirely possible for both f16 and f32 to be legal, so use the fully
- // hard-float FP_EXTEND rather than FP16_TO_FP.
- if (Op.getValueType() == MVT::f16 && N->getValueType(0) != MVT::f32) {
+ // There's only a libcall for f16 -> f32 and shifting is only valid for bf16
+ // -> f32, so proceed in two stages. Also, it's entirely possible for both
+ // f16 and f32 to be legal, so use the fully hard-float FP_EXTEND rather
+ // than FP16_TO_FP.
+ if ((Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16) &&
+ N->getValueType(0) != MVT::f32) {
if (IsStrict) {
Op = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N),
{ MVT::f32, MVT::Other }, { Chain, Op });
@@ -523,6 +526,9 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
}
}
+ if (Op.getValueType() == MVT::bf16)
+ return SoftenFloatRes_BF16_TO_FP(N);
+
RTLIB::Libcall LC = RTLIB::getFPEXT(Op.getValueType(), N->getValueType(0));
assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported FP_EXTEND!");
TargetLowering::MakeLibCallOptions CallOptions;
@@ -555,6 +561,21 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP16_TO_FP(SDNode *N) {
return TLI.makeLibCall(DAG, LC, NVT, Res32, CallOptions, SDLoc(N)).first;
}
+// FIXME: Should we just use 'normal' FP_EXTEND / FP_TRUNC instead of special
+// nodes?
+SDValue DAGTypeLegalizer::SoftenFloatRes_BF16_TO_FP(SDNode *N) {
+ assert(N->getValueType(0) == MVT::f32 &&
+ "Can only soften BF16_TO_FP with f32 result");
+ EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), MVT::f32);
+ SDValue Op = N->getOperand(0);
+ SDLoc DL(N);
+ Op = DAG.getNode(ISD::ANY_EXTEND, DL, NVT,
+ DAG.getNode(ISD::BITCAST, DL, MVT::i16, Op));
+ SDValue Res = DAG.getNode(ISD::SHL, DL, NVT, Op,
+ DAG.getShiftAmountConstant(16, NVT, DL));
+ return Res;
+}
+
SDValue DAGTypeLegalizer::SoftenFloatRes_FP_ROUND(SDNode *N) {
bool IsStrict = N->isStrictFPOpcode();
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 09d47caeef471..e73b6b1a826cf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -560,6 +560,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SoftenFloatRes_FNEG(SDNode *N);
SDValue SoftenFloatRes_FP_EXTEND(SDNode *N);
SDValue SoftenFloatRes_FP16_TO_FP(SDNode *N);
+ SDValue SoftenFloatRes_BF16_TO_FP(SDNode *N);
SDValue SoftenFloatRes_FP_ROUND(SDNode *N);
SDValue SoftenFloatRes_FPOW(SDNode *N);
SDValue SoftenFloatRes_FPOWI(SDNode *N);
diff --git a/llvm/test/CodeGen/RISCV/bfloat.ll b/llvm/test/CodeGen/RISCV/bfloat.ll
new file mode 100644
index 0000000000000..e7583a595ff06
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/bfloat.ll
@@ -0,0 +1,116 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s | FileCheck %s -check-prefix=RV32I-ILP32
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s | FileCheck %s -check-prefix=RV64I-LP64
+
+; TODO: Enable codegen for hard float.
+
+define bfloat @float_to_bfloat(float %a) nounwind {
+; RV32I-ILP32-LABEL: float_to_bfloat:
+; RV32I-ILP32: # %bb.0:
+; RV32I-ILP32-NEXT: addi sp, sp, -16
+; RV32I-ILP32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT: call __truncsfbf2 at plt
+; RV32I-ILP32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT: addi sp, sp, 16
+; RV32I-ILP32-NEXT: ret
+;
+; RV64I-LP64-LABEL: float_to_bfloat:
+; RV64I-LP64: # %bb.0:
+; RV64I-LP64-NEXT: addi sp, sp, -16
+; RV64I-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT: call __truncsfbf2 at plt
+; RV64I-LP64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT: addi sp, sp, 16
+; RV64I-LP64-NEXT: ret
+ %1 = fptrunc float %a to bfloat
+ ret bfloat %1
+}
+
+define bfloat @double_to_bfloat(double %a) nounwind {
+; RV32I-ILP32-LABEL: double_to_bfloat:
+; RV32I-ILP32: # %bb.0:
+; RV32I-ILP32-NEXT: addi sp, sp, -16
+; RV32I-ILP32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT: call __truncdfbf2 at plt
+; RV32I-ILP32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT: addi sp, sp, 16
+; RV32I-ILP32-NEXT: ret
+;
+; RV64I-LP64-LABEL: double_to_bfloat:
+; RV64I-LP64: # %bb.0:
+; RV64I-LP64-NEXT: addi sp, sp, -16
+; RV64I-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT: call __truncdfbf2 at plt
+; RV64I-LP64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT: addi sp, sp, 16
+; RV64I-LP64-NEXT: ret
+ %1 = fptrunc double %a to bfloat
+ ret bfloat %1
+}
+
+define float @bfloat_to_float(bfloat %a) nounwind {
+; RV32I-ILP32-LABEL: bfloat_to_float:
+; RV32I-ILP32: # %bb.0:
+; RV32I-ILP32-NEXT: slli a0, a0, 16
+; RV32I-ILP32-NEXT: ret
+;
+; RV64I-LP64-LABEL: bfloat_to_float:
+; RV64I-LP64: # %bb.0:
+; RV64I-LP64-NEXT: slliw a0, a0, 16
+; RV64I-LP64-NEXT: ret
+ %1 = fpext bfloat %a to float
+ ret float %1
+}
+
+define double @bfloat_to_double(bfloat %a) nounwind {
+; RV32I-ILP32-LABEL: bfloat_to_double:
+; RV32I-ILP32: # %bb.0:
+; RV32I-ILP32-NEXT: addi sp, sp, -16
+; RV32I-ILP32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT: slli a0, a0, 16
+; RV32I-ILP32-NEXT: call __extendsfdf2 at plt
+; RV32I-ILP32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT: addi sp, sp, 16
+; RV32I-ILP32-NEXT: ret
+;
+; RV64I-LP64-LABEL: bfloat_to_double:
+; RV64I-LP64: # %bb.0:
+; RV64I-LP64-NEXT: addi sp, sp, -16
+; RV64I-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT: slli a0, a0, 48
+; RV64I-LP64-NEXT: srli a0, a0, 32
+; RV64I-LP64-NEXT: call __extendsfdf2 at plt
+; RV64I-LP64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT: addi sp, sp, 16
+; RV64I-LP64-NEXT: ret
+ %1 = fpext bfloat %a to double
+ ret double %1
+}
+
+define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
+; RV32I-ILP32-LABEL: bfloat_add:
+; RV32I-ILP32: # %bb.0:
+; RV32I-ILP32-NEXT: addi sp, sp, -16
+; RV32I-ILP32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT: slli a0, a0, 16
+; RV32I-ILP32-NEXT: slli a1, a1, 16
+; RV32I-ILP32-NEXT: call __addsf3 at plt
+; RV32I-ILP32-NEXT: call __truncsfbf2 at plt
+; RV32I-ILP32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT: addi sp, sp, 16
+; RV32I-ILP32-NEXT: ret
+;
+; RV64I-LP64-LABEL: bfloat_add:
+; RV64I-LP64: # %bb.0:
+; RV64I-LP64-NEXT: addi sp, sp, -16
+; RV64I-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT: slliw a0, a0, 16
+; RV64I-LP64-NEXT: slliw a1, a1, 16
+; RV64I-LP64-NEXT: call __addsf3 at plt
+; RV64I-LP64-NEXT: call __truncsfbf2 at plt
+; RV64I-LP64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT: addi sp, sp, 16
+; RV64I-LP64-NEXT: ret
+ %1 = fadd bfloat %a, %b
+ ret bfloat %1
+}
More information about the llvm-commits
mailing list