[llvm] DAG: Implement promotion for strict_fp_round (PR #74332)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 5 20:39:25 PST 2023


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 0c568c2535848d1596a612c15248f299ec8c42be 66316f8c8ffbe72b0aa0433bed2dbcbcb6a8a70d -- llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp llvm/lib/Target/AMDGPU/SIISelLowering.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index f7f49c917c..ad23660cf0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2439,7 +2439,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
 
     case ISD::SINT_TO_FP:
     case ISD::UINT_TO_FP: R = PromoteFloatRes_XINT_TO_FP(N); break;
-    case ISD::STRICT_SINT_TO_FP: R = PromoteFloatRes_STRICT_XINT_TO_FP(N); break;
+    case ISD::STRICT_SINT_TO_FP:
+      R = PromoteFloatRes_STRICT_XINT_TO_FP(N);
+      break;
     case ISD::UNDEF:      R = PromoteFloatRes_UNDEF(N); break;
     case ISD::ATOMIC_SWAP: R = BitcastToInt_ATOMIC_SWAP(N); break;
     case ISD::VECREDUCE_FADD:
@@ -2723,13 +2725,15 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_XINT_TO_FP(SDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
   SDVTList NVTs = DAG.getVTList(NVT, MVT::Other);
 
-  SDValue NV = DAG.getNode(N->getOpcode(), DL, NVTs, N->getOperand(0), N->getOperand(1));
+  SDValue NV =
+      DAG.getNode(N->getOpcode(), DL, NVTs, N->getOperand(0), N->getOperand(1));
 
   // Round the value to the desired precision (that of the source type).
-  SDValue Rounded = DAG.getNode(ISD::STRICT_FP_ROUND, DL, N->getVTList(), NV.getValue(1), NV,
-                                DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
-  return DAG.getNode(
-    ISD::STRICT_FP_EXTEND, DL, NVTs, Rounded.getValue(1), Rounded.getValue(0));
+  SDValue Rounded =
+      DAG.getNode(ISD::STRICT_FP_ROUND, DL, N->getVTList(), NV.getValue(1), NV,
+                  DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+  return DAG.getNode(ISD::STRICT_FP_EXTEND, DL, NVTs, Rounded.getValue(1),
+                     Rounded.getValue(0));
 }
 
 SDValue DAGTypeLegalizer::PromoteFloatRes_UNDEF(SDNode *N) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 7d33421e11..def6b70fbe 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -420,9 +420,9 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
   setOperationAction({ISD::MULHU, ISD::MULHS}, MVT::i16, Expand);
 
   setOperationAction({ISD::MUL, ISD::MULHU, ISD::MULHS}, MVT::i64, Expand);
-  setOperationAction(
-    {ISD::UINT_TO_FP, ISD::SINT_TO_FP, ISD::STRICT_SINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
-      MVT::i64, Custom);
+  setOperationAction({ISD::UINT_TO_FP, ISD::SINT_TO_FP, ISD::STRICT_SINT_TO_FP,
+                      ISD::FP_TO_SINT, ISD::FP_TO_UINT},
+                     MVT::i64, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::i64, Expand);
 
   setOperationAction({ISD::SMIN, ISD::UMIN, ISD::SMAX, ISD::UMAX}, MVT::i32,
@@ -438,17 +438,23 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
 
   for (MVT VT : VectorIntTypes) {
     // Expand the following operations for the current type by default.
-    setOperationAction({ISD::ADD,        ISD::AND,     ISD::FP_TO_SINT,
-                        ISD::FP_TO_UINT, ISD::MUL,     ISD::MULHU,
-                        ISD::MULHS,      ISD::OR,      ISD::SHL,
-                        ISD::SRA,        ISD::SRL,     ISD::ROTL,
-                        ISD::ROTR,       ISD::SUB,     ISD::SINT_TO_FP, ISD::STRICT_SINT_TO_FP,
-                        ISD::UINT_TO_FP, ISD::SDIV,    ISD::UDIV,
-                        ISD::SREM,       ISD::UREM,    ISD::SMUL_LOHI,
-                        ISD::UMUL_LOHI,  ISD::SDIVREM, ISD::UDIVREM,
-                        ISD::SELECT,     ISD::VSELECT, ISD::SELECT_CC,
-                        ISD::XOR,        ISD::BSWAP,   ISD::CTPOP,
-                        ISD::CTTZ,       ISD::CTLZ,    ISD::VECTOR_SHUFFLE,
+    setOperationAction({ISD::ADD,        ISD::AND,
+                        ISD::FP_TO_SINT, ISD::FP_TO_UINT,
+                        ISD::MUL,        ISD::MULHU,
+                        ISD::MULHS,      ISD::OR,
+                        ISD::SHL,        ISD::SRA,
+                        ISD::SRL,        ISD::ROTL,
+                        ISD::ROTR,       ISD::SUB,
+                        ISD::SINT_TO_FP, ISD::STRICT_SINT_TO_FP,
+                        ISD::UINT_TO_FP, ISD::SDIV,
+                        ISD::UDIV,       ISD::SREM,
+                        ISD::UREM,       ISD::SMUL_LOHI,
+                        ISD::UMUL_LOHI,  ISD::SDIVREM,
+                        ISD::UDIVREM,    ISD::SELECT,
+                        ISD::VSELECT,    ISD::SELECT_CC,
+                        ISD::XOR,        ISD::BSWAP,
+                        ISD::CTPOP,      ISD::CTTZ,
+                        ISD::CTLZ,       ISD::VECTOR_SHUFFLE,
                         ISD::SETCC},
                        VT, Expand);
   }
@@ -3242,17 +3248,18 @@ SDValue AMDGPUTargetLowering::LowerINT_TO_FP64(SDValue Op, SelectionDAG &DAG,
 
   if (IsStrict) {
     SDVTList VTs = Op->getVTList();
-    SDValue CvtHi = DAG.getNode(Signed ? ISD::STRICT_SINT_TO_FP : ISD::STRICT_UINT_TO_FP,
-                                SL, VTs, Op.getOperand(0), Hi);
+    SDValue CvtHi =
+        DAG.getNode(Signed ? ISD::STRICT_SINT_TO_FP : ISD::STRICT_UINT_TO_FP,
+                    SL, VTs, Op.getOperand(0), Hi);
 
-    SDValue CvtLo = DAG.getNode(ISD::STRICT_UINT_TO_FP, SL, VTs, CvtHi.getValue(1), Lo);
+    SDValue CvtLo =
+        DAG.getNode(ISD::STRICT_UINT_TO_FP, SL, VTs, CvtHi.getValue(1), Lo);
 
-    SDValue LdExp = DAG.getNode(ISD::STRICT_FLDEXP, SL, VTs,
-                                CvtLo.getValue(1),
-                                CvtHi,
-                                DAG.getConstant(32, SL, MVT::i32));
+    SDValue LdExp = DAG.getNode(ISD::STRICT_FLDEXP, SL, VTs, CvtLo.getValue(1),
+                                CvtHi, DAG.getConstant(32, SL, MVT::i32));
     // TODO: Should this propagate fast-math-flags?
-    return DAG.getNode(ISD::STRICT_FADD, SL, VTs, LdExp.getValue(1), LdExp, CvtLo);
+    return DAG.getNode(ISD::STRICT_FADD, SL, VTs, LdExp.getValue(1), LdExp,
+                       CvtLo);
   }
 
   SDValue CvtHi = DAG.getNode(Signed ? ISD::SINT_TO_FP : ISD::UINT_TO_FP,
@@ -3288,8 +3295,8 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
   if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
     SDLoc DL(Op);
 
-    SDValue IntToFp32 = DAG.getNode(Op.getOpcode(), DL, {MVT::f32, MVT::Other},
-                                    Src);
+    SDValue IntToFp32 =
+        DAG.getNode(Op.getOpcode(), DL, {MVT::f32, MVT::Other}, Src);
     SDValue FPRoundFlag =
         DAG.getIntPtrConstant(0, SDLoc(Op), /*isTarget=*/true);
     SDValue FPRound =
@@ -3320,7 +3327,8 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
     // Promote src to i32
     SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i32, Src);
     if (IsStrict)
-      return DAG.getNode(ISD::STRICT_SINT_TO_FP, DL, Op->getVTList(), Op.getOperand(0), Ext);
+      return DAG.getNode(ISD::STRICT_SINT_TO_FP, DL, Op->getVTList(),
+                         Op.getOperand(0), Ext);
     return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
   }
 
@@ -3332,21 +3340,21 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
     SDLoc DL(Op);
     SDValue Src = Op.getOperand(IsStrict ? 1 : 0);
 
-    SDValue FPRoundFlag =
-      DAG.getIntPtrConstant(0, DL, /*isTarget=*/true);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, DL, /*isTarget=*/true);
 
     if (IsStrict) {
-      SDValue IntToFp32 = DAG.getNode(Op.getOpcode(), DL, DAG.getVTList(MVT::f32, MVT::Other),
-                                      Op.getOperand(0), Src);
+      SDValue IntToFp32 =
+          DAG.getNode(Op.getOpcode(), DL, DAG.getVTList(MVT::f32, MVT::Other),
+                      Op.getOperand(0), Src);
       SDValue FPRound =
-        DAG.getNode(ISD::STRICT_FP_ROUND, DL, Op->getVTList(), IntToFp32.getValue(1),
-                    IntToFp32, FPRoundFlag);
+          DAG.getNode(ISD::STRICT_FP_ROUND, DL, Op->getVTList(),
+                      IntToFp32.getValue(1), IntToFp32, FPRoundFlag);
       return FPRound;
     }
 
     SDValue IntToFp32 = DAG.getNode(Op.getOpcode(), DL, MVT::f32, Src);
     SDValue FPRound =
-      DAG.getNode(ISD::FP_ROUND, DL, MVT::f16, IntToFp32, FPRoundFlag);
+        DAG.getNode(ISD::FP_ROUND, DL, MVT::f16, IntToFp32, FPRoundFlag);
 
     return FPRound;
   }
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index db329e3cbe..09484a3ac1 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -539,12 +539,13 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                         ISD::FSIN, ISD::FROUND, ISD::FPTRUNC_ROUND},
                        MVT::f16, Custom);
 
-    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::STRICT_SINT_TO_FP}, MVT::i16, Custom);
-
     setOperationAction(
-        {ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP, ISD::STRICT_SINT_TO_FP,
-         ISD::UINT_TO_FP},
-        MVT::f16, Promote);
+        {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::STRICT_SINT_TO_FP}, MVT::i16,
+        Custom);
+
+    setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP,
+                        ISD::STRICT_SINT_TO_FP, ISD::UINT_TO_FP},
+                       MVT::f16, Promote);
 
     setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i32, Legal);
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/74332


More information about the llvm-commits mailing list