[llvm] [SelectionDAG] Add `STRICT_BF16_TO_FP` and `STRICT_FP_TO_BF16` (PR #80056)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 30 12:35:36 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Shilei Tian (shiltian)

<details>
<summary>Changes</summary>

This patch adds the support for `STRICT_BF16_TO_FP` and `STRICT_FP_TO_BF16`.

Fix #<!-- -->78540.


---
Full diff: https://github.com/llvm/llvm-project/pull/80056.diff


8 Files Affected:

- (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+2) 
- (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+26-8) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+15-10) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+2) 
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+4-2) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll (+46-5) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 349d1286c8dc4..29fa3bd842c14 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -921,6 +921,8 @@ enum NodeType {
   /// has native conversions.
   BF16_TO_FP,
   FP_TO_BF16,
+  STRICT_BF16_TO_FP,
+  STRICT_FP_TO_BF16,
 
   /// Perform various unary floating-point operations inspired by libm. For
   /// FPOWI, the result is undefined if the integer operand doesn't fit into
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 3130f6c4dce59..d1015630b05d1 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -698,6 +698,8 @@ END_TWO_BYTE_PACK()
         return false;
       case ISD::STRICT_FP16_TO_FP:
       case ISD::STRICT_FP_TO_FP16:
+      case ISD::STRICT_BF16_TO_FP:
+      case ISD::STRICT_FP_TO_BF16:
 #define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN)               \
       case ISD::STRICT_##DAGN:
 #include "llvm/IR/ConstrainedOps.def"
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index d29e44f95798c..beac23a070163 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1033,6 +1033,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
                                     Node->getOperand(0).getValueType());
     break;
   case ISD::STRICT_FP_TO_FP16:
+  case ISD::STRICT_FP_TO_BF16:
   case ISD::STRICT_SINT_TO_FP:
   case ISD::STRICT_UINT_TO_FP:
   case ISD::STRICT_LRINT:
@@ -3248,12 +3249,18 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       Results.push_back(Tmp1);
     break;
   }
+  case ISD::STRICT_BF16_TO_FP:
+    // When strict mode is enforced we can't do expansion because it
+    // does not honor the "strict" properties.
+    if (TLI.isStrictFPEnabled())
+      break;
+    LLVM_FALLTHROUGH;
   case ISD::BF16_TO_FP: {
     // Always expand bf16 to f32 casts, they lower to ext + shift.
     //
     // Note that the operand of this code can be bf16 or an integer type in case
     // bf16 is not supported on the target and was softened.
-    SDValue Op = Node->getOperand(0);
+    SDValue Op = Node->getOperand(Node->getOpcode() == ISD::BF16_TO_FP ? 0 : 1);
     if (Op.getValueType() == MVT::bf16) {
       Op = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32,
                        DAG.getNode(ISD::BITCAST, dl, MVT::i16, Op));
@@ -3271,8 +3278,14 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     Results.push_back(Op);
     break;
   }
+  case ISD::STRICT_FP_TO_BF16:
+    // When strict mode is enforced we can't do expansion because it
+    // does not honor the "strict" properties.
+    if (TLI.isStrictFPEnabled())
+      break;
+    LLVM_FALLTHROUGH;
   case ISD::FP_TO_BF16: {
-    SDValue Op = Node->getOperand(0);
+    SDValue Op = Node->getOperand(Node->getOpcode() == ISD::FP_TO_BF16 ? 0 : 1);
     if (Op.getValueType() != MVT::f32)
       Op = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, Op,
                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
@@ -4773,12 +4786,17 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
     break;
   }
   case ISD::STRICT_FP_EXTEND:
-  case ISD::STRICT_FP_TO_FP16: {
-    RTLIB::Libcall LC =
-        Node->getOpcode() == ISD::STRICT_FP_TO_FP16
-            ? RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16)
-            : RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
-                              Node->getValueType(0));
+  case ISD::STRICT_FP_TO_FP16:
+  case ISD::STRICT_FP_TO_BF16: {
+    RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+    if (Node->getOpcode() == ISD::STRICT_FP_TO_FP16)
+      LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16);
+    else if (Node->getOpcode() == ISD::STRICT_FP_TO_BF16)
+      LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::bf16);
+    else
+      LC = RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
+                           Node->getValueType(0));
+
     assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unable to legalize as libcall");
 
     TargetLowering::MakeLibCallOptions CallOptions;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index f0a04589fbfdc..ea0696be8edc4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -918,6 +918,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
   case ISD::STRICT_FP_TO_FP16:
   case ISD::FP_TO_FP16:  // Same as FP_ROUND for softening purposes
   case ISD::FP_TO_BF16:
+  case ISD::STRICT_FP_TO_BF16:
   case ISD::STRICT_FP_ROUND:
   case ISD::FP_ROUND:    Res = SoftenFloatOp_FP_ROUND(N); break;
   case ISD::STRICT_FP_TO_SINT:
@@ -2193,13 +2194,11 @@ static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
   if (RetVT == MVT::f16)
     return ISD::STRICT_FP_TO_FP16;
 
-  if (OpVT == MVT::bf16) {
-    // TODO: return ISD::STRICT_BF16_TO_FP;
-  }
+  if (OpVT == MVT::bf16)
+    return ISD::STRICT_BF16_TO_FP;
 
-  if (RetVT == MVT::bf16) {
-    // TODO: return ISD::STRICT_FP_TO_BF16;
-  }
+  if (RetVT == MVT::bf16)
+    return ISD::STRICT_FP_TO_BF16;
 
   report_fatal_error("Attempt at an invalid promotion-related conversion");
 }
@@ -2999,10 +2998,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
   EVT SVT = N->getOperand(0).getValueType();
 
   if (N->isStrictFPOpcode()) {
-    assert(RVT == MVT::f16);
-    SDValue Res =
-        DAG.getNode(ISD::STRICT_FP_TO_FP16, SDLoc(N), {MVT::i16, MVT::Other},
-                    {N->getOperand(0), N->getOperand(1)});
+    // FIXME: assume we only have two f16 variants for now.
+    unsigned Opcode;
+    if (RVT == MVT::f16)
+      Opcode = ISD::STRICT_FP_TO_FP16;
+    else if (RVT == MVT::bf16)
+      Opcode = ISD::STRICT_FP_TO_BF16;
+    else
+      llvm_unreachable("unknown half type");
+    SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
+                              {N->getOperand(0), N->getOperand(1)});
     ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
     return Res;
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 814f746f5a4d9..62a21ad71b622 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -165,6 +165,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
     break;
+  case ISD::STRICT_FP_TO_BF16:
   case ISD::STRICT_FP_TO_FP16:
     Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
     break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index a28d834f0522f..c0981d8362a3b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -379,7 +379,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::FP_TO_FP16:                 return "fp_to_fp16";
   case ISD::STRICT_FP_TO_FP16:          return "strict_fp_to_fp16";
   case ISD::BF16_TO_FP:                 return "bf16_to_fp";
+  case ISD::STRICT_BF16_TO_FP:          return "strict_bf16_to_fp";
   case ISD::FP_TO_BF16:                 return "fp_to_bf16";
+  case ISD::STRICT_FP_TO_BF16:          return "strict_fp_to_bf16";
   case ISD::LROUND:                     return "lround";
   case ISD::STRICT_LROUND:              return "strict_lround";
   case ISD::LLROUND:                    return "llround";
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 7ab062bcc4da7..d5ec49fb3114f 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -539,8 +539,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   setOperationAction({ISD::FSIN, ISD::FCOS, ISD::FDIV}, MVT::f32, Custom);
   setOperationAction(ISD::FDIV, MVT::f64, Custom);
 
-  setOperationAction(ISD::BF16_TO_FP, {MVT::i16, MVT::f32, MVT::f64}, Expand);
-  setOperationAction(ISD::FP_TO_BF16, {MVT::i16, MVT::f32, MVT::f64}, Expand);
+  setOperationAction({ISD::BF16_TO_FP, ISD::STRICT_BF16_TO_FP},
+                     {MVT::i16, MVT::f32, MVT::f64}, Expand);
+  setOperationAction({ISD::FP_TO_BF16, ISD::STRICT_FP_TO_BF16},
+                     {MVT::i16, MVT::f32, MVT::f64}, Expand);
 
   // Custom lower these because we can't specify a rule based on an illegal
   // source bf16.
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll b/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll
index 04bf2120b78cf..549b4d0fbd01f 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll
@@ -1094,11 +1094,52 @@ define <4 x i1> @isnan_v4bf16(<4 x bfloat> %x) nounwind {
   ret <4 x i1> %1
 }
 
-; FIXME: Broken for gfx6/7
-; define i1 @isnan_bf16_strictfp(bfloat %x) strictfp nounwind {
-;   %1 = call i1 @llvm.is.fpclass.bf16(bfloat %x, i32 3) strictfp ; nan
-;   ret i1 %1
-; }
+define i1 @isnan_bf16_strictfp(bfloat %x) strictfp nounwind {
+ ; GFX7CHECK-LABEL: isnan_bf16_strictfp:
+ ; GFX7CHECK:       ; %bb.0:
+ ; GFX7CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+ ; GFX7CHECK-NEXT:    v_bfe_u32 v0, v0, 16, 15
+ ; GFX7CHECK-NEXT:    s_movk_i32 s4, 0x7f80
+ ; GFX7CHECK-NEXT:    v_cmp_lt_i32_e32 vcc, s4, v0
+ ; GFX7CHECK-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc
+ ; GFX7CHECK-NEXT:    s_setpc_b64 s[30:31]
+ ;
+ ; GFX8CHECK-LABEL: isnan_bf16_strictfp:
+ ; GFX8CHECK:       ; %bb.0:
+ ; GFX8CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+ ; GFX8CHECK-NEXT:    v_and_b32_e32 v0, 0x7fff, v0
+ ; GFX8CHECK-NEXT:    s_movk_i32 s4, 0x7f80
+ ; GFX8CHECK-NEXT:    v_cmp_lt_i16_e32 vcc, s4, v0
+ ; GFX8CHECK-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc
+ ; GFX8CHECK-NEXT:    s_setpc_b64 s[30:31]
+ ;
+ ; GFX9CHECK-LABEL: isnan_bf16_strictfp:
+ ; GFX9CHECK:       ; %bb.0:
+ ; GFX9CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+ ; GFX9CHECK-NEXT:    v_and_b32_e32 v0, 0x7fff, v0
+ ; GFX9CHECK-NEXT:    s_movk_i32 s4, 0x7f80
+ ; GFX9CHECK-NEXT:    v_cmp_lt_i16_e32 vcc, s4, v0
+ ; GFX9CHECK-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc
+ ; GFX9CHECK-NEXT:    s_setpc_b64 s[30:31]
+ ;
+ ; GFX10CHECK-LABEL: isnan_bf16_strictfp:
+ ; GFX10CHECK:       ; %bb.0:
+ ; GFX10CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+ ; GFX10CHECK-NEXT:    v_and_b32_e32 v0, 0x7fff, v0
+ ; GFX10CHECK-NEXT:    v_cmp_lt_i16_e32 vcc_lo, 0x7f80, v0
+ ; GFX10CHECK-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc_lo
+ ; GFX10CHECK-NEXT:    s_setpc_b64 s[30:31]
+ ;
+ ; GFX11CHECK-LABEL: isnan_bf16_strictfp:
+ ; GFX11CHECK:       ; %bb.0:
+ ; GFX11CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+ ; GFX11CHECK-NEXT:    v_and_b32_e32 v0, 0x7fff, v0
+ ; GFX11CHECK-NEXT:    v_cmp_lt_i16_e32 vcc_lo, 0x7f80, v0
+ ; GFX11CHECK-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc_lo
+ ; GFX11CHECK-NEXT:    s_setpc_b64 s[30:31]
+   %1 = call i1 @llvm.is.fpclass.bf16(bfloat %x, i32 3) strictfp ; nan
+   ret i1 %1
+ }
 
 define i1 @isinf_bf16(bfloat %x) nounwind {
 ; GFX7CHECK-LABEL: isinf_bf16:

``````````

</details>


https://github.com/llvm/llvm-project/pull/80056


More information about the llvm-commits mailing list