[llvm] [SelectionDAG] Add `STRICT_BF16_TO_FP` and `STRICT_FP_TO_BF16` (PR #80056)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 29 10:40:45 PST 2024


https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/80056

>From 03de8c2d1ac9c1ba06583251aa42c6ccb9933ab6 Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Thu, 29 Feb 2024 12:47:29 -0500
Subject: [PATCH] [SelectionDAG] Add `STRICT_BF16_TO_FP` and
 `STRICT_FP_TO_BF16`

This patch adds the support for `STRICT_BF16_TO_FP` and `STRICT_FP_TO_BF16`.

Fix #78540.
---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  2 +
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h |  2 +
 .../include/llvm/Target/TargetSelectionDAG.td | 13 +++
 llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 24 +++--
 .../SelectionDAG/LegalizeFloatTypes.cpp       | 41 +++++---
 .../SelectionDAG/LegalizeIntegerTypes.cpp     |  1 +
 .../SelectionDAG/SelectionDAGDumper.cpp       |  2 +
 llvm/lib/Target/X86/X86ISelLowering.cpp       |  2 +-
 llvm/test/CodeGen/X86/bfloat-constrained.ll   | 98 +++++++++++++++++++
 9 files changed, 164 insertions(+), 21 deletions(-)
 create mode 100644 llvm/test/CodeGen/X86/bfloat-constrained.ll

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index ad876c5db4509a..ef0fec270a43c4 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -921,6 +921,8 @@ enum NodeType {
   /// has native conversions.
   BF16_TO_FP,
   FP_TO_BF16,
+  STRICT_BF16_TO_FP,
+  STRICT_FP_TO_BF16,
 
   /// Perform various unary floating-point operations inspired by libm. For
   /// FPOWI, the result is undefined if the integer operand doesn't fit into
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 3130f6c4dce598..d1015630b05d12 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -698,6 +698,8 @@ END_TWO_BYTE_PACK()
         return false;
       case ISD::STRICT_FP16_TO_FP:
       case ISD::STRICT_FP_TO_FP16:
+      case ISD::STRICT_BF16_TO_FP:
+      case ISD::STRICT_FP_TO_BF16:
 #define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN)               \
       case ISD::STRICT_##DAGN:
 #include "llvm/IR/ConstrainedOps.def"
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 5f8bf0d448105d..d84c2d30e44726 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -541,6 +541,8 @@ def fp_to_sint_sat : SDNode<"ISD::FP_TO_SINT_SAT" , SDTFPToIntSatOp>;
 def fp_to_uint_sat : SDNode<"ISD::FP_TO_UINT_SAT" , SDTFPToIntSatOp>;
 def f16_to_fp  : SDNode<"ISD::FP16_TO_FP" , SDTIntToFPOp>;
 def fp_to_f16  : SDNode<"ISD::FP_TO_FP16" , SDTFPToIntOp>;
+def bf16_to_fp  : SDNode<"ISD::BF16_TO_FP" , SDTIntToFPOp>;
+def fp_to_bf16  : SDNode<"ISD::FP_TO_BF16" , SDTFPToIntOp>;
 
 def strict_fadd       : SDNode<"ISD::STRICT_FADD",
                                SDTFPBinOp, [SDNPHasChain, SDNPCommutative]>;
@@ -620,6 +622,11 @@ def strict_f16_to_fp  : SDNode<"ISD::STRICT_FP16_TO_FP",
 def strict_fp_to_f16  : SDNode<"ISD::STRICT_FP_TO_FP16",
                                SDTFPToIntOp, [SDNPHasChain]>;
 
+def strict_bf16_to_fp  : SDNode<"ISD::STRICT_BF16_TO_FP",
+                               SDTIntToFPOp, [SDNPHasChain]>;
+def strict_fp_to_bf16  : SDNode<"ISD::STRICT_FP_TO_BF16",
+                               SDTFPToIntOp, [SDNPHasChain]>;
+
 def strict_fsetcc  : SDNode<"ISD::STRICT_FSETCC",  SDTSetCC, [SDNPHasChain]>;
 def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
 
@@ -1591,6 +1598,12 @@ def any_f16_to_fp : PatFrags<(ops node:$src),
 def any_fp_to_f16 : PatFrags<(ops node:$src),
                               [(fp_to_f16 node:$src),
                                (strict_fp_to_f16 node:$src)]>;
+def any_bf16_to_fp : PatFrags<(ops node:$src),
+                               [(bf16_to_fp node:$src),
+                                (strict_bf16_to_fp node:$src)]>;
+def any_fp_to_bf16 : PatFrags<(ops node:$src),
+                               [(fp_to_bf16 node:$src),
+                                (strict_fp_to_bf16 node:$src)]>;
 
 multiclass binary_atomic_op_ord {
   def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index f5b7752f7ecc8e..dd71d379aa30ee 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1047,6 +1047,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
                                     Node->getOperand(0).getValueType());
     break;
   case ISD::STRICT_FP_TO_FP16:
+  case ISD::STRICT_FP_TO_BF16:
   case ISD::STRICT_SINT_TO_FP:
   case ISD::STRICT_UINT_TO_FP:
   case ISD::STRICT_LRINT:
@@ -3263,6 +3264,9 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       Results.push_back(Tmp1);
     break;
   }
+  case ISD::STRICT_BF16_TO_FP:
+    // We don't support this expansion for now.
+    break;
   case ISD::BF16_TO_FP: {
     // Always expand bf16 to f32 casts, they lower to ext + shift.
     //
@@ -3286,6 +3290,9 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     Results.push_back(Op);
     break;
   }
+  case ISD::STRICT_FP_TO_BF16:
+    // We don't support this expansion for now.
+    break;
   case ISD::FP_TO_BF16: {
     SDValue Op = Node->getOperand(0);
     if (Op.getValueType() != MVT::f32)
@@ -4792,12 +4799,17 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
     break;
   }
   case ISD::STRICT_FP_EXTEND:
-  case ISD::STRICT_FP_TO_FP16: {
-    RTLIB::Libcall LC =
-        Node->getOpcode() == ISD::STRICT_FP_TO_FP16
-            ? RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16)
-            : RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
-                              Node->getValueType(0));
+  case ISD::STRICT_FP_TO_FP16:
+  case ISD::STRICT_FP_TO_BF16: {
+    RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+    if (Node->getOpcode() == ISD::STRICT_FP_TO_FP16)
+      LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16);
+    else if (Node->getOpcode() == ISD::STRICT_FP_TO_BF16)
+      LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::bf16);
+    else
+      LC = RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
+                           Node->getValueType(0));
+
     assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unable to legalize as libcall");
 
     TargetLowering::MakeLibCallOptions CallOptions;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index f0a04589fbfdc2..3332c02ec72358 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -918,6 +918,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
   case ISD::STRICT_FP_TO_FP16:
   case ISD::FP_TO_FP16:  // Same as FP_ROUND for softening purposes
   case ISD::FP_TO_BF16:
+  case ISD::STRICT_FP_TO_BF16:
   case ISD::STRICT_FP_ROUND:
   case ISD::FP_ROUND:    Res = SoftenFloatOp_FP_ROUND(N); break;
   case ISD::STRICT_FP_TO_SINT:
@@ -970,6 +971,7 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
   assert(N->getOpcode() == ISD::FP_ROUND || N->getOpcode() == ISD::FP_TO_FP16 ||
          N->getOpcode() == ISD::STRICT_FP_TO_FP16 ||
          N->getOpcode() == ISD::FP_TO_BF16 ||
+         N->getOpcode() == ISD::STRICT_FP_TO_BF16 ||
          N->getOpcode() == ISD::STRICT_FP_ROUND);
 
   bool IsStrict = N->isStrictFPOpcode();
@@ -980,7 +982,8 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
   if (N->getOpcode() == ISD::FP_TO_FP16 ||
       N->getOpcode() == ISD::STRICT_FP_TO_FP16)
     FloatRVT = MVT::f16;
-  else if (N->getOpcode() == ISD::FP_TO_BF16)
+  else if (N->getOpcode() == ISD::FP_TO_BF16 ||
+           N->getOpcode() == ISD::STRICT_FP_TO_BF16)
     FloatRVT = MVT::bf16;
 
   RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, FloatRVT);
@@ -2193,13 +2196,11 @@ static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
   if (RetVT == MVT::f16)
     return ISD::STRICT_FP_TO_FP16;
 
-  if (OpVT == MVT::bf16) {
-    // TODO: return ISD::STRICT_BF16_TO_FP;
-  }
+  if (OpVT == MVT::bf16)
+    return ISD::STRICT_BF16_TO_FP;
 
-  if (RetVT == MVT::bf16) {
-    // TODO: return ISD::STRICT_FP_TO_BF16;
-  }
+  if (RetVT == MVT::bf16)
+    return ISD::STRICT_FP_TO_BF16;
 
   report_fatal_error("Attempt at an invalid promotion-related conversion");
 }
@@ -2999,10 +3000,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
   EVT SVT = N->getOperand(0).getValueType();
 
   if (N->isStrictFPOpcode()) {
-    assert(RVT == MVT::f16);
-    SDValue Res =
-        DAG.getNode(ISD::STRICT_FP_TO_FP16, SDLoc(N), {MVT::i16, MVT::Other},
-                    {N->getOperand(0), N->getOperand(1)});
+    // FIXME: assume we only have two f16 variants for now.
+    unsigned Opcode;
+    if (RVT == MVT::f16)
+      Opcode = ISD::STRICT_FP_TO_FP16;
+    else if (RVT == MVT::bf16)
+      Opcode = ISD::STRICT_FP_TO_BF16;
+    else
+      llvm_unreachable("unknown half type");
+    SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
+                              {N->getOperand(0), N->getOperand(1)});
     ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
     return Res;
   }
@@ -3192,10 +3199,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
   Op = GetSoftPromotedHalf(N->getOperand(IsStrict ? 1 : 0));
 
   if (IsStrict) {
-    assert(SVT == MVT::f16);
+    unsigned Opcode;
+    if (SVT == MVT::f16)
+      Opcode = ISD::STRICT_FP16_TO_FP;
+    else if (SVT == MVT::bf16)
+      Opcode = ISD::STRICT_BF16_TO_FP;
+    else
+      llvm_unreachable("unknown half type");
     SDValue Res =
-        DAG.getNode(ISD::STRICT_FP16_TO_FP, SDLoc(N),
-                    {N->getValueType(0), MVT::Other}, {N->getOperand(0), Op});
+        DAG.getNode(Opcode, SDLoc(N), {N->getValueType(0), MVT::Other},
+                    {N->getOperand(0), Op});
     ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
     ReplaceValueWith(SDValue(N, 0), Res);
     return SDValue();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 6e55acd22bb37e..909c669abd120b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -165,6 +165,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
     break;
+  case ISD::STRICT_FP_TO_BF16:
   case ISD::STRICT_FP_TO_FP16:
     Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
     break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 0fbd999694f104..18ca17e53dac38 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -380,7 +380,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::FP_TO_FP16:                 return "fp_to_fp16";
   case ISD::STRICT_FP_TO_FP16:          return "strict_fp_to_fp16";
   case ISD::BF16_TO_FP:                 return "bf16_to_fp";
+  case ISD::STRICT_BF16_TO_FP:          return "strict_bf16_to_fp";
   case ISD::FP_TO_BF16:                 return "fp_to_bf16";
+  case ISD::STRICT_FP_TO_BF16:          return "strict_fp_to_bf16";
   case ISD::LROUND:                     return "lround";
   case ISD::STRICT_LROUND:              return "strict_lround";
   case ISD::LLROUND:                    return "llround";
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b807a97d6e4851..41f83c8075e51b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -393,7 +393,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
   }
 
   for (auto Op : {ISD::FP16_TO_FP, ISD::STRICT_FP16_TO_FP, ISD::FP_TO_FP16,
-                  ISD::STRICT_FP_TO_FP16}) {
+                  ISD::STRICT_FP_TO_FP16, ISD::STRICT_FP_TO_BF16}) {
     // Special handling for half-precision floating point conversions.
     // If we don't have F16C support, then lower half float conversions
     // into library calls.
diff --git a/llvm/test/CodeGen/X86/bfloat-constrained.ll b/llvm/test/CodeGen/X86/bfloat-constrained.ll
new file mode 100644
index 00000000000000..cfc9c143ccbf5d
--- /dev/null
+++ b/llvm/test/CodeGen/X86/bfloat-constrained.ll
@@ -0,0 +1,98 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=i686-linux-gnu | FileCheck %s --check-prefix=X32
+; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s --check-prefix=X64
+
+ at a = global bfloat 0xR0000, align 2
+ at b = global bfloat 0xR0000, align 2
+ at c = global bfloat 0xR0000, align 2
+
+; FIXME: We don't have strict extend yet.
+; define float @bfloat_to_float() strictfp {
+;   %1 = load bfloat, ptr @a, align 2
+;   %2 = tail call float @llvm.experimental.constrained.fpext.f32.bfloat(bfloat %1, metadata !"fpexcept.strict") #0
+;   ret float %2
+; }
+
+; define double @bfloat_to_double() strictfp {
+;   %1 = load bfloat, ptr @a, align 2
+;   %2 = tail call double @llvm.experimental.constrained.fpext.f64.bfloat(bfloat %1, metadata !"fpexcept.strict") #0
+;   ret double %2
+; }
+
+define void @float_to_bfloat(float %0) strictfp {
+; X32-LABEL: float_to_bfloat:
+; X32:       # %bb.0:
+; X32-NEXT:    subl $12, %esp
+; X32-NEXT:    .cfi_def_cfa_offset 16
+; X32-NEXT:    flds {{[0-9]+}}(%esp)
+; X32-NEXT:    fstps (%esp)
+; X32-NEXT:    wait
+; X32-NEXT:    calll __truncsfbf2
+; X32-NEXT:    movw %ax, a
+; X32-NEXT:    addl $12, %esp
+; X32-NEXT:    .cfi_def_cfa_offset 4
+; X32-NEXT:    retl
+;
+; X64-LABEL: float_to_bfloat:
+; X64:       # %bb.0:
+; X64-NEXT:    pushq %rax
+; X64-NEXT:    .cfi_def_cfa_offset 16
+; X64-NEXT:    callq __truncsfbf2 at PLT
+; X64-NEXT:    movq a at GOTPCREL(%rip), %rcx
+; X64-NEXT:    movw %ax, (%rcx)
+; X64-NEXT:    popq %rax
+; X64-NEXT:    .cfi_def_cfa_offset 8
+; X64-NEXT:    retq
+  %2 = tail call bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float %0, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+  store bfloat %2, ptr @a, align 2
+  ret void
+}
+
+define void @double_to_bfloat(double %0) strictfp {
+; X32-LABEL: double_to_bfloat:
+; X32:       # %bb.0:
+; X32-NEXT:    subl $12, %esp
+; X32-NEXT:    .cfi_def_cfa_offset 16
+; X32-NEXT:    fldl {{[0-9]+}}(%esp)
+; X32-NEXT:    fstpl (%esp)
+; X32-NEXT:    wait
+; X32-NEXT:    calll __truncdfbf2
+; X32-NEXT:    movw %ax, a
+; X32-NEXT:    addl $12, %esp
+; X32-NEXT:    .cfi_def_cfa_offset 4
+; X32-NEXT:    retl
+;
+; X64-LABEL: double_to_bfloat:
+; X64:       # %bb.0:
+; X64-NEXT:    pushq %rax
+; X64-NEXT:    .cfi_def_cfa_offset 16
+; X64-NEXT:    callq __truncdfbf2 at PLT
+; X64-NEXT:    movq a at GOTPCREL(%rip), %rcx
+; X64-NEXT:    movw %ax, (%rcx)
+; X64-NEXT:    popq %rax
+; X64-NEXT:    .cfi_def_cfa_offset 8
+; X64-NEXT:    retq
+  %2 = tail call bfloat @llvm.experimental.constrained.fptrunc.bfloat.f64(double %0, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+  store bfloat %2, ptr @a, align 2
+  ret void
+}
+
+; define void @add() strictfp {
+;   %1 = load bfloat, ptr @a, align 2
+;   %2 = tail call float @llvm.experimental.constrained.fpext.f32.bfloat(bfloat %1, metadata !"fpexcept.strict") #0
+;   %3 = load bfloat, ptr @b, align 2
+;   %4 = tail call float @llvm.experimental.constrained.fpext.f32.bfloat(bfloat %3, metadata !"fpexcept.strict") #0
+;   %5 = tail call float @llvm.experimental.constrained.fadd.f32(float %2, float %4, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+;   %6 = tail call bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float %5, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+;   store bfloat %6, ptr @c, align 2
+;   ret void
+; }
+
+; declare float @llvm.experimental.constrained.fpext.f32.bfloat(bfloat, metadata)
+; declare double @llvm.experimental.constrained.fpext.f64.bfloat(bfloat, metadata)
+; declare float @llvm.experimental.constrained.fadd.f32(float, float, metadata, metadata)
+declare bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float, metadata, metadata)
+declare bfloat @llvm.experimental.constrained.fptrunc.bfloat.f64(double, metadata, metadata)
+
+attributes #0 = { strictfp }
+



More information about the llvm-commits mailing list