[llvm] 04e809a - [DAG] Add TargetLowering::expandABD and convert X86 lowering to use it
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri May 5 07:13:39 PDT 2023
Author: Simon Pilgrim
Date: 2023-05-05T15:13:23+01:00
New Revision: 04e809ab90161799a9973429a9af0b25ee1e3261
URL: https://github.com/llvm/llvm-project/commit/04e809ab90161799a9973429a9af0b25ee1e3261
DIFF: https://github.com/llvm/llvm-project/commit/04e809ab90161799a9973429a9af0b25ee1e3261.diff
LOG: [DAG] Add TargetLowering::expandABD and convert X86 lowering to use it
Scalar widening cases are still custom lowered in the X86 backend - we still need to add promotion/legalization support to handle these
Added:
Modified:
llvm/include/llvm/CodeGen/TargetLowering.h
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/lib/Target/X86/X86ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 56a067a7bbae1..4c5e18192ea21 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5033,6 +5033,11 @@ class TargetLowering : public TargetLoweringBase {
SDValue expandABS(SDNode *N, SelectionDAG &DAG,
bool IsNegative = false) const;
+ /// Expand ABDS/ABDU nodes. Expands vector/scalar ABDS/ABDU nodes.
+ /// \param N Node to expand
+ /// \returns The expansion result or SDValue() if it fails.
+ SDValue expandABD(SDNode *N, SelectionDAG &DAG) const;
+
/// Expand BSWAP nodes. Expands scalar/vector BSWAP nodes with i16/i32/i64
/// scalar types. Returns SDValue() if expand fails.
/// \param N Node to expand
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 7a370db299e46..b3c4d07eff057 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -2696,6 +2696,11 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
if ((Tmp1 = TLI.expandABS(Node, DAG)))
Results.push_back(Tmp1);
break;
+ case ISD::ABDS:
+ case ISD::ABDU:
+ if ((Tmp1 = TLI.expandABD(Node, DAG)))
+ Results.push_back(Tmp1);
+ break;
case ISD::CTPOP:
if ((Tmp1 = TLI.expandCTPOP(Node, DAG)))
Results.push_back(Tmp1);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 5d5157142f0ac..faf3453ca3bb6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -795,6 +795,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
return;
}
break;
+ case ISD::ABDS:
+ case ISD::ABDU:
+ if (SDValue Expanded = TLI.expandABD(Node, DAG)) {
+ Results.push_back(Expanded);
+ return;
+ }
+ break;
case ISD::BITREVERSE:
ExpandBITREVERSE(Node, Results);
return;
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index b18f2d7e9aa71..20b80f882c276 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -8627,6 +8627,38 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::SUB, dl, VT, Shift, Xor);
}
+SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
+ SDLoc dl(N);
+ EVT VT = N->getValueType(0);
+ SDValue LHS = DAG.getFreeze(N->getOperand(0));
+ SDValue RHS = DAG.getFreeze(N->getOperand(1));
+ bool IsSigned = N->getOpcode() == ISD::ABDS;
+
+ // abds(lhs, rhs) -> sub(smax(lhs,rhs), smin(lhs,rhs))
+ // abdu(lhs, rhs) -> sub(umax(lhs,rhs), umin(lhs,rhs))
+ unsigned MaxOpc = IsSigned ? ISD::SMAX : ISD::UMAX;
+ unsigned MinOpc = IsSigned ? ISD::SMIN : ISD::UMIN;
+ if (isOperationLegal(MaxOpc, VT) && isOperationLegal(MinOpc, VT)) {
+ SDValue Max = DAG.getNode(MaxOpc, dl, VT, LHS, RHS);
+ SDValue Min = DAG.getNode(MinOpc, dl, VT, LHS, RHS);
+ return DAG.getNode(ISD::SUB, dl, VT, Max, Min);
+ }
+
+ // abdu(lhs, rhs) -> or(usubsat(lhs,rhs), usubsat(rhs,lhs))
+ if (!IsSigned && isOperationLegal(ISD::USUBSAT, VT))
+ return DAG.getNode(ISD::OR, dl, VT,
+ DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
+ DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
+
+ // abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
+ // abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
+ EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+ ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
+ SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
+ return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),
+ DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
+}
+
SDValue TargetLowering::expandBSWAP(SDNode *N, SelectionDAG &DAG) const {
SDLoc dl(N);
EVT VT = N->getValueType(0);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 5f9aec6bc5e47..03880f2fdddae 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -30375,29 +30375,11 @@ static SDValue LowerABD(SDValue Op, const X86Subtarget &Subtarget,
if ((VT == MVT::v32i16 || VT == MVT::v64i8) && !Subtarget.useBWIRegs())
return splitVectorIntBinary(Op, DAG);
- // TODO: Add TargetLowering expandABD() support.
SDLoc dl(Op);
bool IsSigned = Op.getOpcode() == ISD::ABDS;
- SDValue LHS = DAG.getFreeze(Op.getOperand(0));
- SDValue RHS = DAG.getFreeze(Op.getOperand(1));
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- // abds(lhs, rhs) -> sub(smax(lhs,rhs), smin(lhs,rhs))
- // abdu(lhs, rhs) -> sub(umax(lhs,rhs), umin(lhs,rhs))
- unsigned MaxOpc = IsSigned ? ISD::SMAX : ISD::UMAX;
- unsigned MinOpc = IsSigned ? ISD::SMIN : ISD::UMIN;
- if (TLI.isOperationLegal(MaxOpc, VT) && TLI.isOperationLegal(MinOpc, VT)) {
- SDValue Max = DAG.getNode(MaxOpc, dl, VT, LHS, RHS);
- SDValue Min = DAG.getNode(MinOpc, dl, VT, LHS, RHS);
- return DAG.getNode(ISD::SUB, dl, VT, Max, Min);
- }
-
- // abdu(lhs, rhs) -> or(usubsat(lhs,rhs), usubsat(rhs,lhs))
- if (!IsSigned && TLI.isOperationLegal(ISD::USUBSAT, VT))
- return DAG.getNode(ISD::OR, dl, VT,
- DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
- DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
-
+ // TODO: Move to TargetLowering expandABD() once we have ABD promotion.
if (VT.isScalarInteger()) {
unsigned WideBits = std::max<unsigned>(2 * VT.getScalarSizeInBits(), 32u);
MVT WideVT = MVT::getIntegerVT(WideBits);
@@ -30405,6 +30387,8 @@ static SDValue LowerABD(SDValue Op, const X86Subtarget &Subtarget,
// abds(lhs, rhs) -> trunc(abs(sub(sext(lhs), sext(rhs))))
// abdu(lhs, rhs) -> trunc(abs(sub(zext(lhs), zext(rhs))))
unsigned ExtOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+ SDValue LHS = DAG.getFreeze(Op.getOperand(0));
+ SDValue RHS = DAG.getFreeze(Op.getOperand(1));
LHS = DAG.getNode(ExtOpc, dl, WideVT, LHS);
RHS = DAG.getNode(ExtOpc, dl, WideVT, RHS);
SDValue Diff = DAG.getNode(ISD::SUB, dl, WideVT, LHS, RHS);
@@ -30413,13 +30397,8 @@ static SDValue LowerABD(SDValue Op, const X86Subtarget &Subtarget,
}
}
- // abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
- // abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
- EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
- ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
- SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
- return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),
- DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
+ // Default to expand.
+ return SDValue();
}
static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
More information about the llvm-commits
mailing list