[llvm] [SelectionDAG] Add `STRICT_BF16_TO_FP` and `STRICT_FP_TO_BF16` (PR #80056)
Shilei Tian via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 14 10:27:16 PST 2024
https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/80056
>From 1ab61d1a4b35bf01f91d859440caad7b81adf02a Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Wed, 14 Feb 2024 13:26:38 -0500
Subject: [PATCH] [SelectionDAG] Add `STRICT_BF16_TO_FP` and
`STRICT_FP_TO_BF16`
This patch adds the support for `STRICT_BF16_TO_FP` and `STRICT_FP_TO_BF16`.
Fix #78540.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 2 ++
llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 2 ++
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 24 +++++++++++----
.../SelectionDAG/LegalizeFloatTypes.cpp | 29 ++++++++++++-------
.../SelectionDAG/LegalizeIntegerTypes.cpp | 1 +
.../SelectionDAG/SelectionDAGDumper.cpp | 2 ++
6 files changed, 43 insertions(+), 17 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 8cb0bc9fd98133..e6bfa0d3f39684 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -921,6 +921,8 @@ enum NodeType {
/// has native conversions.
BF16_TO_FP,
FP_TO_BF16,
+ STRICT_BF16_TO_FP,
+ STRICT_FP_TO_BF16,
/// Perform various unary floating-point operations inspired by libm. For
/// FPOWI, the result is undefined if the integer operand doesn't fit into
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 3130f6c4dce598..d1015630b05d12 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -698,6 +698,8 @@ END_TWO_BYTE_PACK()
return false;
case ISD::STRICT_FP16_TO_FP:
case ISD::STRICT_FP_TO_FP16:
+ case ISD::STRICT_BF16_TO_FP:
+ case ISD::STRICT_FP_TO_BF16:
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
case ISD::STRICT_##DAGN:
#include "llvm/IR/ConstrainedOps.def"
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 252b6e9997a710..b31eb074b49726 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1034,6 +1034,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
Node->getOperand(0).getValueType());
break;
case ISD::STRICT_FP_TO_FP16:
+ case ISD::STRICT_FP_TO_BF16:
case ISD::STRICT_SINT_TO_FP:
case ISD::STRICT_UINT_TO_FP:
case ISD::STRICT_LRINT:
@@ -3265,6 +3266,9 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Results.push_back(Tmp1);
break;
}
+ case ISD::STRICT_BF16_TO_FP:
+ // We don't support this expansion for now.
+ break;
case ISD::BF16_TO_FP: {
// Always expand bf16 to f32 casts, they lower to ext + shift.
//
@@ -3288,6 +3292,9 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Results.push_back(Op);
break;
}
+ case ISD::STRICT_FP_TO_BF16:
+ // We don't support this expansion for now.
+ break;
case ISD::FP_TO_BF16: {
SDValue Op = Node->getOperand(0);
if (Op.getValueType() != MVT::f32)
@@ -4790,12 +4797,17 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
break;
}
case ISD::STRICT_FP_EXTEND:
- case ISD::STRICT_FP_TO_FP16: {
- RTLIB::Libcall LC =
- Node->getOpcode() == ISD::STRICT_FP_TO_FP16
- ? RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16)
- : RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
- Node->getValueType(0));
+ case ISD::STRICT_FP_TO_FP16:
+ case ISD::STRICT_FP_TO_BF16: {
+ RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+ if (Node->getOpcode() == ISD::STRICT_FP_TO_FP16)
+ LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16);
+ else if (Node->getOpcode() == ISD::STRICT_FP_TO_BF16)
+ LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::bf16);
+ else
+ LC = RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
+ Node->getValueType(0));
+
assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unable to legalize as libcall");
TargetLowering::MakeLibCallOptions CallOptions;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index f0a04589fbfdc2..fa0c22c5632789 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -918,6 +918,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
case ISD::STRICT_FP_TO_FP16:
case ISD::FP_TO_FP16: // Same as FP_ROUND for softening purposes
case ISD::FP_TO_BF16:
+ case ISD::STRICT_FP_TO_BF16:
case ISD::STRICT_FP_ROUND:
case ISD::FP_ROUND: Res = SoftenFloatOp_FP_ROUND(N); break;
case ISD::STRICT_FP_TO_SINT:
@@ -970,6 +971,7 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
assert(N->getOpcode() == ISD::FP_ROUND || N->getOpcode() == ISD::FP_TO_FP16 ||
N->getOpcode() == ISD::STRICT_FP_TO_FP16 ||
N->getOpcode() == ISD::FP_TO_BF16 ||
+ N->getOpcode() == ISD::STRICT_FP_TO_BF16 ||
N->getOpcode() == ISD::STRICT_FP_ROUND);
bool IsStrict = N->isStrictFPOpcode();
@@ -980,7 +982,8 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
if (N->getOpcode() == ISD::FP_TO_FP16 ||
N->getOpcode() == ISD::STRICT_FP_TO_FP16)
FloatRVT = MVT::f16;
- else if (N->getOpcode() == ISD::FP_TO_BF16)
+ else if (N->getOpcode() == ISD::FP_TO_BF16 ||
+ N->getOpcode() == ISD::STRICT_FP_TO_BF16)
FloatRVT = MVT::bf16;
RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, FloatRVT);
@@ -2193,13 +2196,11 @@ static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
if (RetVT == MVT::f16)
return ISD::STRICT_FP_TO_FP16;
- if (OpVT == MVT::bf16) {
- // TODO: return ISD::STRICT_BF16_TO_FP;
- }
+ if (OpVT == MVT::bf16)
+ return ISD::STRICT_BF16_TO_FP;
- if (RetVT == MVT::bf16) {
- // TODO: return ISD::STRICT_FP_TO_BF16;
- }
+ if (RetVT == MVT::bf16)
+ return ISD::STRICT_FP_TO_BF16;
report_fatal_error("Attempt at an invalid promotion-related conversion");
}
@@ -2999,10 +3000,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
EVT SVT = N->getOperand(0).getValueType();
if (N->isStrictFPOpcode()) {
- assert(RVT == MVT::f16);
- SDValue Res =
- DAG.getNode(ISD::STRICT_FP_TO_FP16, SDLoc(N), {MVT::i16, MVT::Other},
- {N->getOperand(0), N->getOperand(1)});
+ // FIXME: assume we only have two f16 variants for now.
+ unsigned Opcode;
+ if (RVT == MVT::f16)
+ Opcode = ISD::STRICT_FP_TO_FP16;
+ else if (RVT == MVT::bf16)
+ Opcode = ISD::STRICT_FP_TO_BF16;
+ else
+ llvm_unreachable("unknown half type");
+ SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
+ {N->getOperand(0), N->getOperand(1)});
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
return Res;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index a4ba261686c688..82c1dbb97223d3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -165,6 +165,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::FP_TO_FP16:
Res = PromoteIntRes_FP_TO_FP16_BF16(N);
break;
+ case ISD::STRICT_FP_TO_BF16:
case ISD::STRICT_FP_TO_FP16:
Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 0fbd999694f104..18ca17e53dac38 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -380,7 +380,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::FP_TO_FP16: return "fp_to_fp16";
case ISD::STRICT_FP_TO_FP16: return "strict_fp_to_fp16";
case ISD::BF16_TO_FP: return "bf16_to_fp";
+ case ISD::STRICT_BF16_TO_FP: return "strict_bf16_to_fp";
case ISD::FP_TO_BF16: return "fp_to_bf16";
+ case ISD::STRICT_FP_TO_BF16: return "strict_fp_to_bf16";
case ISD::LROUND: return "lround";
case ISD::STRICT_LROUND: return "strict_lround";
case ISD::LLROUND: return "llround";
More information about the llvm-commits
mailing list