[llvm] [SelectionDAG] Require last operand of (STRICT_)FP_ROUND to be a TargetConstant. (PR #117639)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 25 14:45:28 PST 2024


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/117639

Fix all the places I could find that did't do this. We were already mostly correct for FP_ROUND after 9a976f36615dbe15e76c12b22f711b2e597a8e51, but not STRICT_FP_ROUND.

>From 04b7a8bd6f91225e814ffb759951ffe7b9e4ae50 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Mon, 25 Nov 2024 14:26:45 -0800
Subject: [PATCH] [SelectionDAG] Require last operand of (STRICT_)FP_ROUND to
 be a TargetConstant.

Fix all the places I could find that did't do this. We were already
mostly correct for FP_ROUND, but not STRICT_FP_ROUND.
---
 llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 14 ++++---
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 13 +++----
 .../Target/AArch64/AArch64ISelLowering.cpp    | 37 ++++++++++---------
 llvm/lib/Target/AMDGPU/SIISelLowering.cpp     |  2 +-
 .../Target/Hexagon/HexagonISelLoweringHVX.cpp |  7 ++--
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  2 +-
 llvm/lib/Target/PowerPC/PPCISelLowering.cpp   | 13 ++++---
 llvm/lib/Target/X86/X86ISelLowering.cpp       |  5 ++-
 8 files changed, 51 insertions(+), 42 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index fbc96bade15f5a..639a4397688016 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -5277,7 +5277,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
       Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1);
     else
       Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1,
-                         DAG.getIntPtrConstant(0, dl));
+                         DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
     Results.push_back(Tmp1);
     break;
   }
@@ -5425,7 +5425,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     Tmp1 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
                        {Tmp3, Tmp1, Tmp2});
     Tmp1 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
-                       {Tmp1.getValue(1), Tmp1, DAG.getIntPtrConstant(0, dl)});
+                       {Tmp1.getValue(1), Tmp1,
+                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
     Results.push_back(Tmp1);
     Results.push_back(Tmp1.getValue(1));
     break;
@@ -5450,7 +5451,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     Tmp4 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
                        {Tmp4, Tmp1, Tmp2, Tmp3});
     Tmp4 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
-                       {Tmp4.getValue(1), Tmp4, DAG.getIntPtrConstant(0, dl)});
+                       {Tmp4.getValue(1), Tmp4,
+                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
     Results.push_back(Tmp4);
     Results.push_back(Tmp4.getValue(1));
     break;
@@ -5478,7 +5480,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     Tmp2 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
                        {Tmp1.getValue(1), Tmp1, Node->getOperand(2)});
     Tmp3 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
-                       {Tmp2.getValue(1), Tmp2, DAG.getIntPtrConstant(0, dl)});
+                       {Tmp2.getValue(1), Tmp2,
+                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
     Results.push_back(Tmp3);
     Results.push_back(Tmp3.getValue(1));
     break;
@@ -5562,7 +5565,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     Tmp2 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
                        {Tmp1.getValue(1), Tmp1});
     Tmp3 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
-                       {Tmp2.getValue(1), Tmp2, DAG.getIntPtrConstant(0, dl)});
+                       {Tmp2.getValue(1), Tmp2,
+                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
     Results.push_back(Tmp3);
     Results.push_back(Tmp3.getValue(1));
     break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3a8ec3c6105bc0..7c5ed04830b16a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1460,7 +1460,7 @@ SelectionDAG::getStrictFPExtendOrRound(SDValue Op, SDValue Chain,
       VT.bitsGT(Op.getValueType())
           ? getNode(ISD::STRICT_FP_EXTEND, DL, {VT, MVT::Other}, {Chain, Op})
           : getNode(ISD::STRICT_FP_ROUND, DL, {VT, MVT::Other},
-                    {Chain, Op, getIntPtrConstant(0, DL)});
+                    {Chain, Op, getIntPtrConstant(0, DL, /*isTarget=*/true)});
 
   return std::pair<SDValue, SDValue>(Res, SDValue(Res.getNode(), 1));
 }
@@ -7355,11 +7355,10 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
       return N1;
     break;
   case ISD::FP_ROUND:
-    assert(VT.isFloatingPoint() &&
-           N1.getValueType().isFloatingPoint() &&
-           VT.bitsLE(N1.getValueType()) &&
-           N2C && (N2C->getZExtValue() == 0 || N2C->getZExtValue() == 1) &&
-           "Invalid FP_ROUND!");
+    assert(VT.isFloatingPoint() && N1.getValueType().isFloatingPoint() &&
+           VT.bitsLE(N1.getValueType()) && N2C &&
+           (N2C->getZExtValue() == 0 || N2C->getZExtValue() == 1) &&
+           N2.getOpcode() == ISD::TargetConstant && "Invalid FP_ROUND!");
     if (N1.getValueType() == VT) return N1;  // noop conversion.
     break;
   case ISD::AssertSext:
@@ -10542,7 +10541,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
     assert(VTList.VTs[0].isFloatingPoint() &&
            Ops[1].getValueType().isFloatingPoint() &&
            VTList.VTs[0].bitsLT(Ops[1].getValueType()) &&
-           isa<ConstantSDNode>(Ops[2]) &&
+           Ops[2].getOpcode() == ISD::TargetConstant &&
            (Ops[2]->getAsZExtVal() == 0 || Ops[2]->getAsZExtVal() == 1) &&
            "Invalid STRICT_FP_ROUND!");
     break;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ed2d9a07cec630..e1be825fcf7bf3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4901,13 +4901,14 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
     if (IsStrict) {
       SDValue Val = DAG.getNode(Op.getOpcode(), dl, {F32, MVT::Other},
                                 {Op.getOperand(0), In});
-      return DAG.getNode(
-          ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
-          {Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
+      return DAG.getNode(ISD::STRICT_FP_ROUND, dl,
+                         {Op.getValueType(), MVT::Other},
+                         {Val.getValue(1), Val.getValue(0),
+                          DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
     }
     return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
                        DAG.getNode(Op.getOpcode(), dl, F32, In),
-                       DAG.getIntPtrConstant(0, dl));
+                       DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
   }
 
   uint64_t VTSize = VT.getFixedSizeInBits();
@@ -4919,9 +4920,9 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
     if (IsStrict) {
       In = DAG.getNode(Opc, dl, {CastVT, MVT::Other},
                        {Op.getOperand(0), In});
-      return DAG.getNode(
-          ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
-          {In.getValue(1), In.getValue(0), DAG.getIntPtrConstant(0, dl)});
+      return DAG.getNode(ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
+                         {In.getValue(1), In.getValue(0),
+                          DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
     }
     In = DAG.getNode(Opc, dl, CastVT, In);
     return DAG.getNode(ISD::FP_ROUND, dl, VT, In,
@@ -4969,13 +4970,14 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
     if (IsStrict) {
       SDValue Val = DAG.getNode(Op.getOpcode(), dl, {PromoteVT, MVT::Other},
                                 {Op.getOperand(0), SrcVal});
-      return DAG.getNode(
-          ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
-          {Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
+      return DAG.getNode(ISD::STRICT_FP_ROUND, dl,
+                         {Op.getValueType(), MVT::Other},
+                         {Val.getValue(1), Val.getValue(0),
+                          DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
     }
     return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
                        DAG.getNode(Op.getOpcode(), dl, PromoteVT, SrcVal),
-                       DAG.getIntPtrConstant(0, dl));
+                       DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
   };
 
   if (Op.getValueType() == MVT::bf16) {
@@ -5067,12 +5069,13 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
           DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, NeedsAdjustment);
       SDValue Adjusted = DAG.getNode(ISD::BITCAST, DL, MVT::f64, AdjustedBits);
       return IsStrict
-                 ? DAG.getNode(ISD::STRICT_FP_ROUND, DL,
-                               {Op.getValueType(), MVT::Other},
-                               {Rounded.getValue(1), Adjusted,
-                                DAG.getIntPtrConstant(0, DL)})
+                 ? DAG.getNode(
+                       ISD::STRICT_FP_ROUND, DL,
+                       {Op.getValueType(), MVT::Other},
+                       {Rounded.getValue(1), Adjusted,
+                        DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)})
                  : DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), Adjusted,
-                               DAG.getIntPtrConstant(0, DL, true));
+                               DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
     }
   }
 
@@ -7109,7 +7112,7 @@ static SDValue LowerFLDEXP(SDValue Op, SelectionDAG &DAG) {
       DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, X.getValueType(), FScale, Zero);
   if (X.getValueType() != XScalarTy)
     Final = DAG.getNode(ISD::FP_ROUND, DL, XScalarTy, Final,
-                        DAG.getIntPtrConstant(1, SDLoc(Op)));
+                        DAG.getIntPtrConstant(1, SDLoc(Op), /*isTarget=*/true));
   return Final;
 }
 
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index d35bb15ac6566a..f326416a324178 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -10756,7 +10756,7 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
   Tmp = DAG.getNode(ISD::BITCAST, SL, MVT::f32, TmpCast);
   Quot = DAG.getNode(ISD::FADD, SL, MVT::f32, Tmp, Quot, Op->getFlags());
   SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot,
-                             DAG.getConstant(0, SL, MVT::i32));
+                             DAG.getTargetConstant(0, SL, MVT::i32));
   return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, RDst, RHS, LHS,
                      Op->getFlags());
 }
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
index ab0f41343ce211..816e063f8dbbe5 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
@@ -1575,9 +1575,10 @@ HexagonTargetLowering::resizeToWidth(SDValue VecV, MVT ResTy, bool Signed,
   unsigned ResWidth = ResTy.getSizeInBits();
 
   if (InpTy.isFloatingPoint()) {
-    return InpWidth < ResWidth ? DAG.getNode(ISD::FP_EXTEND, dl, ResTy, VecV)
-                               : DAG.getNode(ISD::FP_ROUND, dl, ResTy, VecV,
-                                             getZero(dl, MVT::i32, DAG));
+    return InpWidth < ResWidth
+               ? DAG.getNode(ISD::FP_EXTEND, dl, ResTy, VecV)
+               : DAG.getNode(ISD::FP_ROUND, dl, ResTy, VecV,
+                             DAG.getTargetConstant(0, dl, MVT::i32));
   }
 
   assert(InpTy.isInteger());
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b9003ddbd3187c..62647b31285188 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2786,7 +2786,7 @@ SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
     return DAG.getNode(
         ISD::FP_ROUND, Loc, MVT::bf16,
         DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
-        DAG.getIntPtrConstant(0, Loc));
+        DAG.getIntPtrConstant(0, Loc, /*isTarget=*/true));
   }
 
   // Everything else is considered legal.
diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
index 87a4ad3752c649..f4d3668726164b 100644
--- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
+++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
@@ -8963,9 +8963,10 @@ SDValue PPCTargetLowering::LowerINT_TO_FP(SDValue Op,
 
     if (Op.getValueType() == MVT::f32 && !Subtarget.hasFPCVT()) {
       if (IsStrict)
-        FP = DAG.getNode(ISD::STRICT_FP_ROUND, dl,
-                         DAG.getVTList(MVT::f32, MVT::Other),
-                         {Chain, FP, DAG.getIntPtrConstant(0, dl)}, Flags);
+        FP = DAG.getNode(
+            ISD::STRICT_FP_ROUND, dl, DAG.getVTList(MVT::f32, MVT::Other),
+            {Chain, FP, DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)},
+            Flags);
       else
         FP = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, FP,
                          DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
@@ -9044,9 +9045,9 @@ SDValue PPCTargetLowering::LowerINT_TO_FP(SDValue Op,
     Chain = FP.getValue(1);
   if (Op.getValueType() == MVT::f32 && !Subtarget.hasFPCVT()) {
     if (IsStrict)
-      FP = DAG.getNode(ISD::STRICT_FP_ROUND, dl,
-                       DAG.getVTList(MVT::f32, MVT::Other),
-                       {Chain, FP, DAG.getIntPtrConstant(0, dl)}, Flags);
+      FP = DAG.getNode(
+          ISD::STRICT_FP_ROUND, dl, DAG.getVTList(MVT::f32, MVT::Other),
+          {Chain, FP, DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)}, Flags);
     else
       FP = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, FP,
                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 9048d1d83f1874..868be4721f9f46 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -19595,7 +19595,7 @@ static SDValue promoteXINT_TO_FP(SDValue Op, const SDLoc &dl,
   MVT VT = Op.getSimpleValueType();
   MVT NVT = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
 
-  SDValue Rnd = DAG.getIntPtrConstant(0, dl);
+  SDValue Rnd = DAG.getIntPtrConstant(0, dl, /*isTarget=*/true);
   if (IsStrict)
     return DAG.getNode(
         ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
@@ -20266,7 +20266,8 @@ SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op,
     if (DstVT == MVT::f80)
       return Add;
     return DAG.getNode(ISD::STRICT_FP_ROUND, dl, {DstVT, MVT::Other},
-                       {Add.getValue(1), Add, DAG.getIntPtrConstant(0, dl)});
+                       {Add.getValue(1), Add,
+                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
   }
   unsigned Opc = ISD::FADD;
   // Windows needs the precision control changed to 80bits around this add.



More information about the llvm-commits mailing list