[llvm] DAG: Fix ABI lowering with FP promote in strictfp functions (PR #74405)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 4 19:11:52 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Matt Arsenault (arsenm)

<details>
<summary>Changes</summary>

This was emitting non-strict casts in ABI contexts for illegal
types.

---

Patch is 72.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74405.diff


10 Files Affected:

- (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+13) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+63) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+14-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+37-22) 
- (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+8-1) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll (+3) 
- (modified) llvm/test/CodeGen/AMDGPU/strict_fpext.ll (+254-26) 
- (modified) llvm/test/CodeGen/AMDGPU/strict_fptrunc.ll (+227-21) 
- (added) llvm/test/CodeGen/AMDGPU/strictfp_f16_abi_promote.ll (+558) 


``````````diff
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 798e6a1d9525e..de7bf26868b34 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -614,6 +614,12 @@ def strict_sint_to_fp : SDNode<"ISD::STRICT_SINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
 def strict_uint_to_fp : SDNode<"ISD::STRICT_UINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
+
+def strict_f16_to_fp  : SDNode<"ISD::STRICT_FP16_TO_FP",
+                               SDTIntToFPOp, [SDNPHasChain]>;
+def strict_fp_to_f16  : SDNode<"ISD::STRICT_FP_TO_FP16",
+                               SDTFPToIntOp, [SDNPHasChain]>;
+
 def strict_fsetcc  : SDNode<"ISD::STRICT_FSETCC",  SDTSetCC, [SDNPHasChain]>;
 def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
 
@@ -1558,6 +1564,13 @@ def any_fsetccs : PatFrags<(ops node:$lhs, node:$rhs, node:$pred),
                           [(strict_fsetccs node:$lhs, node:$rhs, node:$pred),
                            (setcc node:$lhs, node:$rhs, node:$pred)]>;
 
+def any_f16_to_fp : PatFrags<(ops node:$src),
+                              [(f16_to_fp node:$src),
+                               (strict_f16_to_fp node:$src)]>;
+def any_fp_to_f16 : PatFrags<(ops node:$src),
+                              [(fp_to_f16 node:$src),
+                               (strict_fp_to_f16 node:$src)]>;
+
 multiclass binary_atomic_op_ord {
   def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),
       (!cast<SDPatternOperator>(NAME) node:$ptr, node:$val)> {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 630aa4a07d7b9..d7a688511b726 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2181,6 +2181,20 @@ static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
   report_fatal_error("Attempt at an invalid promotion-related conversion");
 }
 
+static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
+  if (OpVT == MVT::f16) {
+    return ISD::STRICT_FP16_TO_FP;
+  } else if (RetVT == MVT::f16) {
+    return ISD::STRICT_FP_TO_FP16;
+  } else if (OpVT == MVT::bf16) {
+    // return ISD::STRICT_BF16_TO_FP;
+  } else if (RetVT == MVT::bf16) {
+    // return ISD::STRICT_FP_TO_BF16;
+  }
+
+  report_fatal_error("Attempt at an invalid promotion-related conversion");
+}
+
 bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
   LLVM_DEBUG(dbgs() << "Promote float operand " << OpNo << ": "; N->dump(&DAG));
   SDValue R = SDValue();
@@ -2214,6 +2228,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
     case ISD::FP_TO_UINT_SAT:
                           R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
     case ISD::FP_EXTEND:  R = PromoteFloatOp_FP_EXTEND(N, OpNo); break;
+    case ISD::STRICT_FP_EXTEND:
+      R = PromoteFloatOp_STRICT_FP_EXTEND(N, OpNo);
+      break;
     case ISD::SELECT_CC:  R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
     case ISD::SETCC:      R = PromoteFloatOp_SETCC(N, OpNo); break;
     case ISD::STORE:      R = PromoteFloatOp_STORE(N, OpNo); break;
@@ -2276,6 +2293,26 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo) {
   return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, Op);
 }
 
+SDValue DAGTypeLegalizer::PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N,
+                                                          unsigned OpNo) {
+  assert(OpNo == 1);
+
+  SDValue Op = GetPromotedFloat(N->getOperand(1));
+  EVT VT = N->getValueType(0);
+
+  // Desired VT is same as promoted type.  Use promoted float directly.
+  if (VT == Op->getValueType(0)) {
+    ReplaceValueWith(SDValue(N, 1), N->getOperand(0));
+    return Op;
+  }
+
+  // Else, extend the promoted float value to the desired VT.
+  SDValue Res = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N), N->getVTList(),
+                            N->getOperand(0), Op);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 // Promote the float operands used for comparison.  The true- and false-
 // operands have the same type as the result and are promoted, if needed, by
 // PromoteFloatRes_SELECT_CC
@@ -2393,6 +2430,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::FFREXP:     R = PromoteFloatRes_FFREXP(N); break;
 
     case ISD::FP_ROUND:   R = PromoteFloatRes_FP_ROUND(N); break;
+    case ISD::STRICT_FP_ROUND:
+      R = PromoteFloatRes_STRICT_FP_ROUND(N);
+      break;
     case ISD::LOAD:       R = PromoteFloatRes_LOAD(N); break;
     case ISD::SELECT:     R = PromoteFloatRes_SELECT(N); break;
     case ISD::SELECT_CC:  R = PromoteFloatRes_SELECT_CC(N); break;
@@ -2598,6 +2638,29 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_FP_ROUND(SDNode *N) {
   return DAG.getNode(GetPromotionOpcode(VT, NVT), DL, NVT, Round);
 }
 
+// Explicit operation to reduce precision.  Reduce the value to half precision
+// and promote it back to the legal type.
+SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_FP_ROUND(SDNode *N) {
+  SDLoc DL(N);
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Op = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  EVT OpVT = Op->getValueType(0);
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+
+  // Round promoted float to desired precision
+  SDValue Round = DAG.getNode(GetPromotionOpcodeStrict(OpVT, VT), DL,
+                              DAG.getVTList(IVT, MVT::Other), Chain, Op);
+  // Promote it back to the legal output type
+  SDValue Res =
+      DAG.getNode(GetPromotionOpcodeStrict(VT, NVT), DL,
+                  DAG.getVTList(NVT, MVT::Other), Round.getValue(1), Round);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteFloatRes_LOAD(SDNode *N) {
   LoadSDNode *L = cast<LoadSDNode>(N);
   EVT VT = N->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 54698edce7d6f..887670fb6baff 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -165,7 +165,9 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
     break;
-
+  case ISD::STRICT_FP_TO_FP16:
+    Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
+    break;
   case ISD::GET_ROUNDING: Res = PromoteIntRes_GET_ROUNDING(N); break;
 
   case ISD::AND:
@@ -787,6 +789,16 @@ SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_FP16_BF16(SDNode *N) {
   return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N) {
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  SDLoc dl(N);
+
+  SDValue Res = DAG.getNode(N->getOpcode(), dl, DAG.getVTList(NVT, MVT::Other),
+                            N->getOperand(0), N->getOperand(1));
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_XRINT(SDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   SDLoc dl(N);
@@ -1804,6 +1816,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::FP16_TO_FP:
   case ISD::VP_UINT_TO_FP:
   case ISD::UINT_TO_FP:   Res = PromoteIntOp_UINT_TO_FP(N); break;
+  case ISD::STRICT_FP16_TO_FP:
   case ISD::STRICT_UINT_TO_FP:  Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
   case ISD::ZERO_EXTEND:  Res = PromoteIntOp_ZERO_EXTEND(N); break;
   case ISD::VP_ZERO_EXTEND: Res = PromoteIntOp_VP_ZERO_EXTEND(N); break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index e9bd54089d062..9361e7fff2190 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -326,6 +326,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
+  SDValue PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N);
   SDValue PromoteIntRes_XRINT(SDNode *N);
   SDValue PromoteIntRes_FREEZE(SDNode *N);
   SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
@@ -698,6 +699,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatRes_ExpOp(SDNode *N);
   SDValue PromoteFloatRes_FFREXP(SDNode *N);
   SDValue PromoteFloatRes_FP_ROUND(SDNode *N);
+  SDValue PromoteFloatRes_STRICT_FP_ROUND(SDNode *N);
   SDValue PromoteFloatRes_LOAD(SDNode *N);
   SDValue PromoteFloatRes_SELECT(SDNode *N);
   SDValue PromoteFloatRes_SELECT_CC(SDNode *N);
@@ -712,6 +714,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index da7d9ace4114a..29684d3372bdb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -153,6 +153,7 @@ static const unsigned MaxParallelChains = 64;
 static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
                                       const SDValue *Parts, unsigned NumParts,
                                       MVT PartVT, EVT ValueVT, const Value *V,
+                                      SDValue InChain,
                                       std::optional<CallingConv::ID> CC);
 
 /// getCopyFromParts - Create a value that contains the specified legal parts
@@ -163,6 +164,7 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
 static SDValue
 getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
                  unsigned NumParts, MVT PartVT, EVT ValueVT, const Value *V,
+                 SDValue InChain,
                  std::optional<CallingConv::ID> CC = std::nullopt,
                  std::optional<ISD::NodeType> AssertOp = std::nullopt) {
   // Let the target assemble the parts if it wants to
@@ -173,7 +175,7 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
 
   if (ValueVT.isVector())
     return getCopyFromPartsVector(DAG, DL, Parts, NumParts, PartVT, ValueVT, V,
-                                  CC);
+                                  InChain, CC);
 
   assert(NumParts > 0 && "No parts to assemble!");
   SDValue Val = Parts[0];
@@ -194,10 +196,10 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
       EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), RoundBits/2);
 
       if (RoundParts > 2) {
-        Lo = getCopyFromParts(DAG, DL, Parts, RoundParts / 2,
-                              PartVT, HalfVT, V);
-        Hi = getCopyFromParts(DAG, DL, Parts + RoundParts / 2,
-                              RoundParts / 2, PartVT, HalfVT, V);
+        Lo = getCopyFromParts(DAG, DL, Parts, RoundParts / 2, PartVT, HalfVT, V,
+                              InChain);
+        Hi = getCopyFromParts(DAG, DL, Parts + RoundParts / 2, RoundParts / 2,
+                              PartVT, HalfVT, V, InChain);
       } else {
         Lo = DAG.getNode(ISD::BITCAST, DL, HalfVT, Parts[0]);
         Hi = DAG.getNode(ISD::BITCAST, DL, HalfVT, Parts[1]);
@@ -213,7 +215,7 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
         unsigned OddParts = NumParts - RoundParts;
         EVT OddVT = EVT::getIntegerVT(*DAG.getContext(), OddParts * PartBits);
         Hi = getCopyFromParts(DAG, DL, Parts + RoundParts, OddParts, PartVT,
-                              OddVT, V, CC);
+                              OddVT, V, InChain, CC);
 
         // Combine the round and odd parts.
         Lo = Val;
@@ -243,7 +245,8 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
       assert(ValueVT.isFloatingPoint() && PartVT.isInteger() &&
              !PartVT.isVector() && "Unexpected split");
       EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), ValueVT.getSizeInBits());
-      Val = getCopyFromParts(DAG, DL, Parts, NumParts, PartVT, IntVT, V, CC);
+      Val = getCopyFromParts(DAG, DL, Parts, NumParts, PartVT, IntVT, V,
+                             InChain, CC);
     }
   }
 
@@ -283,10 +286,20 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
 
   if (PartEVT.isFloatingPoint() && ValueVT.isFloatingPoint()) {
     // FP_ROUND's are always exact here.
-    if (ValueVT.bitsLT(Val.getValueType()))
-      return DAG.getNode(
-          ISD::FP_ROUND, DL, ValueVT, Val,
-          DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout())));
+    if (ValueVT.bitsLT(Val.getValueType())) {
+
+      SDValue NoChange =
+          DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout()));
+
+      if (DAG.getMachineFunction().getFunction().getAttributes().hasFnAttr(
+              llvm::Attribute::StrictFP)) {
+        return DAG.getNode(ISD::STRICT_FP_ROUND, DL,
+                           DAG.getVTList(ValueVT, MVT::Other), InChain, Val,
+                           NoChange);
+      }
+
+      return DAG.getNode(ISD::FP_ROUND, DL, ValueVT, Val, NoChange);
+    }
 
     return DAG.getNode(ISD::FP_EXTEND, DL, ValueVT, Val);
   }
@@ -324,6 +337,7 @@ static void diagnosePossiblyInvalidConstraint(LLVMContext &Ctx, const Value *V,
 static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
                                       const SDValue *Parts, unsigned NumParts,
                                       MVT PartVT, EVT ValueVT, const Value *V,
+                                      SDValue InChain,
                                       std::optional<CallingConv::ID> CallConv) {
   assert(ValueVT.isVector() && "Not a vector value");
   assert(NumParts > 0 && "No parts to assemble!");
@@ -362,8 +376,8 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
       // If the register was not expanded, truncate or copy the value,
       // as appropriate.
       for (unsigned i = 0; i != NumParts; ++i)
-        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i], 1,
-                                  PartVT, IntermediateVT, V, CallConv);
+        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i], 1, PartVT, IntermediateVT,
+                                  V, InChain, CallConv);
     } else if (NumParts > 0) {
       // If the intermediate type was expanded, build the intermediate
       // operands from the parts.
@@ -371,8 +385,8 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
              "Must expand into a divisible number of parts!");
       unsigned Factor = NumParts / NumIntermediates;
       for (unsigned i = 0; i != NumIntermediates; ++i)
-        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i * Factor], Factor,
-                                  PartVT, IntermediateVT, V, CallConv);
+        Ops[i] = getCopyFromParts(DAG, DL, &Parts[i * Factor], Factor, PartVT,
+                                  IntermediateVT, V, InChain, CallConv);
     }
 
     // Build a vector with BUILD_VECTOR or CONCAT_VECTORS from the
@@ -926,7 +940,7 @@ SDValue RegsForValue::getCopyFromRegs(SelectionDAG &DAG,
     }
 
     Values[Value] = getCopyFromParts(DAG, dl, Parts.begin(), NumRegs,
-                                     RegisterVT, ValueVT, V, CallConv);
+                                     RegisterVT, ValueVT, V, Chain, CallConv);
     Part += NumRegs;
     Parts.clear();
   }
@@ -10641,9 +10655,9 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
       unsigned NumRegs = getNumRegistersForCallingConv(CLI.RetTy->getContext(),
                                                        CLI.CallConv, VT);
 
-      ReturnValues.push_back(getCopyFromParts(CLI.DAG, CLI.DL, &InVals[CurReg],
-                                              NumRegs, RegisterVT, VT, nullptr,
-                                              CLI.CallConv, AssertOp));
+      ReturnValues.push_back(getCopyFromParts(
+          CLI.DAG, CLI.DL, &InVals[CurReg], NumRegs, RegisterVT, VT, nullptr,
+          CLI.Chain, CLI.CallConv, AssertOp));
       CurReg += NumRegs;
     }
 
@@ -11122,8 +11136,9 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
     MVT VT = ValueVTs[0].getSimpleVT();
     MVT RegVT = TLI->getRegisterType(*CurDAG->getContext(), VT);
     std::optional<ISD::NodeType> AssertOp;
-    SDValue ArgValue = getCopyFromParts(DAG, dl, &InVals[0], 1, RegVT, VT,
-                                        nullptr, F.getCallingConv(), AssertOp);
+    SDValue ArgValue =
+        getCopyFromParts(DAG, dl, &InVals[0], 1, RegVT, VT, nullptr, NewRoot,
+                         F.getCallingConv(), AssertOp);
 
     MachineFunction& MF = SDB->DAG.getMachineFunction();
     MachineRegisterInfo& RegInfo = MF.getRegInfo();
@@ -11195,7 +11210,7 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
           AssertOp = ISD::AssertZext;
 
         ArgValues.push_back(getCopyFromParts(DAG, dl, &InVals[i], NumParts,
-                                             PartVT, VT, nullptr,
+                                             PartVT, VT, nullptr, NewRoot,
                                              F.getCallingConv(), AssertOp));
       }
 
diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td
index 9362fe5d9678b..d23d6b4f93da8 100644
--- a/llvm/lib/Target/AMDGPU/SIInstructions.td
+++ b/llvm/lib/Target/AMDGPU/SIInstructions.td
@@ -1097,7 +1097,7 @@ def : Pat <
 multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16_inst_e64> {
   // f16_to_fp patterns
   def : GCNPat <
-    (f32 (f16_to_fp i32:$src0)),
+    (f32 (any_f16_to_fp i32:$src0)),
     (cvt_f32_f16_inst_e64 SRCMODS.NONE, $src0)
   >;
 
@@ -1151,6 +1151,13 @@ multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16
     (f16 (uint_to_fp i32:$src)),
     (cvt_f16_f32_inst_e64 SRCMODS.NONE, (V_CVT_F32_U32_e32 VSrc_b32:$src))
   >;
+
+  // This is only used on targets without half support
+  // TODO: Introduce strict variant of AMDGPUfp_to_f16 and share custom lowering
+  def : GCNPat <
+    (i32 (strict_fp_to_f16 (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))),
+    (cvt_f16_f32_inst_e64 $src0_modifiers, f32:$src0)
+  >;
 }
 
 let SubtargetPredicate = NotHasTrue16BitInsts in
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll b/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll
index d74948a460c98..20d175ecfd909 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll
@@ -1316,6 +1316,9 @@ define i1 @isnan_f16_strictfp(half %x) strictfp nounwind {
 ; GFX7SELDAG-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX7SELDAG-NEXT:    v_cvt_f16_f32_e32 v0, v0
 ; GFX7SELDAG-NEXT:    s_movk_i32 s4, 0x7c00
+; GFX7SELDAG-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7SELDAG-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7SELDAG-NEXT:    v_cvt_f16_f32_e32 v0, v0
 ; GFX7SELDAG-NEXT:    v_and_b32_e32 v0, 0x7fff, v0
 ; GFX7SELDAG-NEXT:    v_cmp_lt_i32_e32 vcc, s4, v0
 ; GFX7SELDAG-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc
diff --git a/llvm/test/CodeGen/AMDGPU/strict_fpext.ll b/llvm/test/CodeGen/AMDGPU/strict_fpext.ll
index 22bebb7ad26f5..fe59a8491c91a 100644
--- a/llvm/test/CodeGen/AMDGPU/strict_fpext.ll
+++ b/llvm/test/CodeGen/AMDGPU/strict_fpext.ll
@@ -1,22 +1,47 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
 ; FIXME: Missing operand promote for f16
-; XUN: llc -mtriple=amdgcn-mesa-mesa3d -mcpu=tahiti < %s | FileChe...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74405


More information about the llvm-commits mailing list