[llvm] [SelectionDAG] Require last operand of (STRICT_)FP_ROUND to be a TargetConstant. (PR #117639)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 25 14:45:28 PST 2024
https://github.com/topperc created https://github.com/llvm/llvm-project/pull/117639
Fix all the places I could find that did't do this. We were already mostly correct for FP_ROUND after 9a976f36615dbe15e76c12b22f711b2e597a8e51, but not STRICT_FP_ROUND.
>From 04b7a8bd6f91225e814ffb759951ffe7b9e4ae50 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Mon, 25 Nov 2024 14:26:45 -0800
Subject: [PATCH] [SelectionDAG] Require last operand of (STRICT_)FP_ROUND to
be a TargetConstant.
Fix all the places I could find that did't do this. We were already
mostly correct for FP_ROUND, but not STRICT_FP_ROUND.
---
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 14 ++++---
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 13 +++----
.../Target/AArch64/AArch64ISelLowering.cpp | 37 ++++++++++---------
llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 2 +-
.../Target/Hexagon/HexagonISelLoweringHVX.cpp | 7 ++--
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 2 +-
llvm/lib/Target/PowerPC/PPCISelLowering.cpp | 13 ++++---
llvm/lib/Target/X86/X86ISelLowering.cpp | 5 ++-
8 files changed, 51 insertions(+), 42 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index fbc96bade15f5a..639a4397688016 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -5277,7 +5277,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1);
else
Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1,
- DAG.getIntPtrConstant(0, dl));
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
Results.push_back(Tmp1);
break;
}
@@ -5425,7 +5425,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp1 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
{Tmp3, Tmp1, Tmp2});
Tmp1 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
- {Tmp1.getValue(1), Tmp1, DAG.getIntPtrConstant(0, dl)});
+ {Tmp1.getValue(1), Tmp1,
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
Results.push_back(Tmp1);
Results.push_back(Tmp1.getValue(1));
break;
@@ -5450,7 +5451,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp4 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
{Tmp4, Tmp1, Tmp2, Tmp3});
Tmp4 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
- {Tmp4.getValue(1), Tmp4, DAG.getIntPtrConstant(0, dl)});
+ {Tmp4.getValue(1), Tmp4,
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
Results.push_back(Tmp4);
Results.push_back(Tmp4.getValue(1));
break;
@@ -5478,7 +5480,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp2 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
{Tmp1.getValue(1), Tmp1, Node->getOperand(2)});
Tmp3 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
- {Tmp2.getValue(1), Tmp2, DAG.getIntPtrConstant(0, dl)});
+ {Tmp2.getValue(1), Tmp2,
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
Results.push_back(Tmp3);
Results.push_back(Tmp3.getValue(1));
break;
@@ -5562,7 +5565,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp2 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
{Tmp1.getValue(1), Tmp1});
Tmp3 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
- {Tmp2.getValue(1), Tmp2, DAG.getIntPtrConstant(0, dl)});
+ {Tmp2.getValue(1), Tmp2,
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
Results.push_back(Tmp3);
Results.push_back(Tmp3.getValue(1));
break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3a8ec3c6105bc0..7c5ed04830b16a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1460,7 +1460,7 @@ SelectionDAG::getStrictFPExtendOrRound(SDValue Op, SDValue Chain,
VT.bitsGT(Op.getValueType())
? getNode(ISD::STRICT_FP_EXTEND, DL, {VT, MVT::Other}, {Chain, Op})
: getNode(ISD::STRICT_FP_ROUND, DL, {VT, MVT::Other},
- {Chain, Op, getIntPtrConstant(0, DL)});
+ {Chain, Op, getIntPtrConstant(0, DL, /*isTarget=*/true)});
return std::pair<SDValue, SDValue>(Res, SDValue(Res.getNode(), 1));
}
@@ -7355,11 +7355,10 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return N1;
break;
case ISD::FP_ROUND:
- assert(VT.isFloatingPoint() &&
- N1.getValueType().isFloatingPoint() &&
- VT.bitsLE(N1.getValueType()) &&
- N2C && (N2C->getZExtValue() == 0 || N2C->getZExtValue() == 1) &&
- "Invalid FP_ROUND!");
+ assert(VT.isFloatingPoint() && N1.getValueType().isFloatingPoint() &&
+ VT.bitsLE(N1.getValueType()) && N2C &&
+ (N2C->getZExtValue() == 0 || N2C->getZExtValue() == 1) &&
+ N2.getOpcode() == ISD::TargetConstant && "Invalid FP_ROUND!");
if (N1.getValueType() == VT) return N1; // noop conversion.
break;
case ISD::AssertSext:
@@ -10542,7 +10541,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
assert(VTList.VTs[0].isFloatingPoint() &&
Ops[1].getValueType().isFloatingPoint() &&
VTList.VTs[0].bitsLT(Ops[1].getValueType()) &&
- isa<ConstantSDNode>(Ops[2]) &&
+ Ops[2].getOpcode() == ISD::TargetConstant &&
(Ops[2]->getAsZExtVal() == 0 || Ops[2]->getAsZExtVal() == 1) &&
"Invalid STRICT_FP_ROUND!");
break;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ed2d9a07cec630..e1be825fcf7bf3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4901,13 +4901,14 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
if (IsStrict) {
SDValue Val = DAG.getNode(Op.getOpcode(), dl, {F32, MVT::Other},
{Op.getOperand(0), In});
- return DAG.getNode(
- ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
- {Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
+ return DAG.getNode(ISD::STRICT_FP_ROUND, dl,
+ {Op.getValueType(), MVT::Other},
+ {Val.getValue(1), Val.getValue(0),
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
}
return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
DAG.getNode(Op.getOpcode(), dl, F32, In),
- DAG.getIntPtrConstant(0, dl));
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
}
uint64_t VTSize = VT.getFixedSizeInBits();
@@ -4919,9 +4920,9 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
if (IsStrict) {
In = DAG.getNode(Opc, dl, {CastVT, MVT::Other},
{Op.getOperand(0), In});
- return DAG.getNode(
- ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
- {In.getValue(1), In.getValue(0), DAG.getIntPtrConstant(0, dl)});
+ return DAG.getNode(ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
+ {In.getValue(1), In.getValue(0),
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
}
In = DAG.getNode(Opc, dl, CastVT, In);
return DAG.getNode(ISD::FP_ROUND, dl, VT, In,
@@ -4969,13 +4970,14 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
if (IsStrict) {
SDValue Val = DAG.getNode(Op.getOpcode(), dl, {PromoteVT, MVT::Other},
{Op.getOperand(0), SrcVal});
- return DAG.getNode(
- ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
- {Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
+ return DAG.getNode(ISD::STRICT_FP_ROUND, dl,
+ {Op.getValueType(), MVT::Other},
+ {Val.getValue(1), Val.getValue(0),
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
}
return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
DAG.getNode(Op.getOpcode(), dl, PromoteVT, SrcVal),
- DAG.getIntPtrConstant(0, dl));
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
};
if (Op.getValueType() == MVT::bf16) {
@@ -5067,12 +5069,13 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, NeedsAdjustment);
SDValue Adjusted = DAG.getNode(ISD::BITCAST, DL, MVT::f64, AdjustedBits);
return IsStrict
- ? DAG.getNode(ISD::STRICT_FP_ROUND, DL,
- {Op.getValueType(), MVT::Other},
- {Rounded.getValue(1), Adjusted,
- DAG.getIntPtrConstant(0, DL)})
+ ? DAG.getNode(
+ ISD::STRICT_FP_ROUND, DL,
+ {Op.getValueType(), MVT::Other},
+ {Rounded.getValue(1), Adjusted,
+ DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)})
: DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), Adjusted,
- DAG.getIntPtrConstant(0, DL, true));
+ DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
}
}
@@ -7109,7 +7112,7 @@ static SDValue LowerFLDEXP(SDValue Op, SelectionDAG &DAG) {
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, X.getValueType(), FScale, Zero);
if (X.getValueType() != XScalarTy)
Final = DAG.getNode(ISD::FP_ROUND, DL, XScalarTy, Final,
- DAG.getIntPtrConstant(1, SDLoc(Op)));
+ DAG.getIntPtrConstant(1, SDLoc(Op), /*isTarget=*/true));
return Final;
}
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index d35bb15ac6566a..f326416a324178 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -10756,7 +10756,7 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
Tmp = DAG.getNode(ISD::BITCAST, SL, MVT::f32, TmpCast);
Quot = DAG.getNode(ISD::FADD, SL, MVT::f32, Tmp, Quot, Op->getFlags());
SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot,
- DAG.getConstant(0, SL, MVT::i32));
+ DAG.getTargetConstant(0, SL, MVT::i32));
return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, RDst, RHS, LHS,
Op->getFlags());
}
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
index ab0f41343ce211..816e063f8dbbe5 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
@@ -1575,9 +1575,10 @@ HexagonTargetLowering::resizeToWidth(SDValue VecV, MVT ResTy, bool Signed,
unsigned ResWidth = ResTy.getSizeInBits();
if (InpTy.isFloatingPoint()) {
- return InpWidth < ResWidth ? DAG.getNode(ISD::FP_EXTEND, dl, ResTy, VecV)
- : DAG.getNode(ISD::FP_ROUND, dl, ResTy, VecV,
- getZero(dl, MVT::i32, DAG));
+ return InpWidth < ResWidth
+ ? DAG.getNode(ISD::FP_EXTEND, dl, ResTy, VecV)
+ : DAG.getNode(ISD::FP_ROUND, dl, ResTy, VecV,
+ DAG.getTargetConstant(0, dl, MVT::i32));
}
assert(InpTy.isInteger());
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b9003ddbd3187c..62647b31285188 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2786,7 +2786,7 @@ SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
return DAG.getNode(
ISD::FP_ROUND, Loc, MVT::bf16,
DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
- DAG.getIntPtrConstant(0, Loc));
+ DAG.getIntPtrConstant(0, Loc, /*isTarget=*/true));
}
// Everything else is considered legal.
diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
index 87a4ad3752c649..f4d3668726164b 100644
--- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
+++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
@@ -8963,9 +8963,10 @@ SDValue PPCTargetLowering::LowerINT_TO_FP(SDValue Op,
if (Op.getValueType() == MVT::f32 && !Subtarget.hasFPCVT()) {
if (IsStrict)
- FP = DAG.getNode(ISD::STRICT_FP_ROUND, dl,
- DAG.getVTList(MVT::f32, MVT::Other),
- {Chain, FP, DAG.getIntPtrConstant(0, dl)}, Flags);
+ FP = DAG.getNode(
+ ISD::STRICT_FP_ROUND, dl, DAG.getVTList(MVT::f32, MVT::Other),
+ {Chain, FP, DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)},
+ Flags);
else
FP = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, FP,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
@@ -9044,9 +9045,9 @@ SDValue PPCTargetLowering::LowerINT_TO_FP(SDValue Op,
Chain = FP.getValue(1);
if (Op.getValueType() == MVT::f32 && !Subtarget.hasFPCVT()) {
if (IsStrict)
- FP = DAG.getNode(ISD::STRICT_FP_ROUND, dl,
- DAG.getVTList(MVT::f32, MVT::Other),
- {Chain, FP, DAG.getIntPtrConstant(0, dl)}, Flags);
+ FP = DAG.getNode(
+ ISD::STRICT_FP_ROUND, dl, DAG.getVTList(MVT::f32, MVT::Other),
+ {Chain, FP, DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)}, Flags);
else
FP = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, FP,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 9048d1d83f1874..868be4721f9f46 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -19595,7 +19595,7 @@ static SDValue promoteXINT_TO_FP(SDValue Op, const SDLoc &dl,
MVT VT = Op.getSimpleValueType();
MVT NVT = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
- SDValue Rnd = DAG.getIntPtrConstant(0, dl);
+ SDValue Rnd = DAG.getIntPtrConstant(0, dl, /*isTarget=*/true);
if (IsStrict)
return DAG.getNode(
ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
@@ -20266,7 +20266,8 @@ SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op,
if (DstVT == MVT::f80)
return Add;
return DAG.getNode(ISD::STRICT_FP_ROUND, dl, {DstVT, MVT::Other},
- {Add.getValue(1), Add, DAG.getIntPtrConstant(0, dl)});
+ {Add.getValue(1), Add,
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)});
}
unsigned Opc = ISD::FADD;
// Windows needs the precision control changed to 80bits around this add.
More information about the llvm-commits
mailing list