[llvm] 04e809a - [DAG] Add TargetLowering::expandABD and convert X86 lowering to use it

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri May 5 07:13:39 PDT 2023


Author: Simon Pilgrim
Date: 2023-05-05T15:13:23+01:00
New Revision: 04e809ab90161799a9973429a9af0b25ee1e3261

URL: https://github.com/llvm/llvm-project/commit/04e809ab90161799a9973429a9af0b25ee1e3261
DIFF: https://github.com/llvm/llvm-project/commit/04e809ab90161799a9973429a9af0b25ee1e3261.diff

LOG: [DAG] Add TargetLowering::expandABD and convert X86 lowering to use it

Scalar widening cases are still custom lowered in the X86 backend - we still need to add promotion/legalization support to handle these

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 56a067a7bbae1..4c5e18192ea21 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5033,6 +5033,11 @@ class TargetLowering : public TargetLoweringBase {
   SDValue expandABS(SDNode *N, SelectionDAG &DAG,
                     bool IsNegative = false) const;
 
+  /// Expand ABDS/ABDU nodes. Expands vector/scalar ABDS/ABDU nodes.
+  /// \param N Node to expand
+  /// \returns The expansion result or SDValue() if it fails.
+  SDValue expandABD(SDNode *N, SelectionDAG &DAG) const;
+
   /// Expand BSWAP nodes. Expands scalar/vector BSWAP nodes with i16/i32/i64
   /// scalar types. Returns SDValue() if expand fails.
   /// \param N Node to expand

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 7a370db299e46..b3c4d07eff057 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -2696,6 +2696,11 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     if ((Tmp1 = TLI.expandABS(Node, DAG)))
       Results.push_back(Tmp1);
     break;
+  case ISD::ABDS:
+  case ISD::ABDU:
+    if ((Tmp1 = TLI.expandABD(Node, DAG)))
+      Results.push_back(Tmp1);
+    break;
   case ISD::CTPOP:
     if ((Tmp1 = TLI.expandCTPOP(Node, DAG)))
       Results.push_back(Tmp1);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 5d5157142f0ac..faf3453ca3bb6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -795,6 +795,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
       return;
     }
     break;
+  case ISD::ABDS:
+  case ISD::ABDU:
+    if (SDValue Expanded = TLI.expandABD(Node, DAG)) {
+      Results.push_back(Expanded);
+      return;
+    }
+    break;
   case ISD::BITREVERSE:
     ExpandBITREVERSE(Node, Results);
     return;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index b18f2d7e9aa71..20b80f882c276 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -8627,6 +8627,38 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(ISD::SUB, dl, VT, Shift, Xor);
 }
 
+SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
+  SDLoc dl(N);
+  EVT VT = N->getValueType(0);
+  SDValue LHS = DAG.getFreeze(N->getOperand(0));
+  SDValue RHS = DAG.getFreeze(N->getOperand(1));
+  bool IsSigned = N->getOpcode() == ISD::ABDS;
+
+  // abds(lhs, rhs) -> sub(smax(lhs,rhs), smin(lhs,rhs))
+  // abdu(lhs, rhs) -> sub(umax(lhs,rhs), umin(lhs,rhs))
+  unsigned MaxOpc = IsSigned ? ISD::SMAX : ISD::UMAX;
+  unsigned MinOpc = IsSigned ? ISD::SMIN : ISD::UMIN;
+  if (isOperationLegal(MaxOpc, VT) && isOperationLegal(MinOpc, VT)) {
+    SDValue Max = DAG.getNode(MaxOpc, dl, VT, LHS, RHS);
+    SDValue Min = DAG.getNode(MinOpc, dl, VT, LHS, RHS);
+    return DAG.getNode(ISD::SUB, dl, VT, Max, Min);
+  }
+
+  // abdu(lhs, rhs) -> or(usubsat(lhs,rhs), usubsat(rhs,lhs))
+  if (!IsSigned && isOperationLegal(ISD::USUBSAT, VT))
+    return DAG.getNode(ISD::OR, dl, VT,
+                       DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
+                       DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
+
+  // abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
+  // abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
+  EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+  ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
+  SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
+  return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),
+                       DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
+}
+
 SDValue TargetLowering::expandBSWAP(SDNode *N, SelectionDAG &DAG) const {
   SDLoc dl(N);
   EVT VT = N->getValueType(0);

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 5f9aec6bc5e47..03880f2fdddae 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -30375,29 +30375,11 @@ static SDValue LowerABD(SDValue Op, const X86Subtarget &Subtarget,
   if ((VT == MVT::v32i16 || VT == MVT::v64i8) && !Subtarget.useBWIRegs())
     return splitVectorIntBinary(Op, DAG);
 
-  // TODO: Add TargetLowering expandABD() support.
   SDLoc dl(Op);
   bool IsSigned = Op.getOpcode() == ISD::ABDS;
-  SDValue LHS = DAG.getFreeze(Op.getOperand(0));
-  SDValue RHS = DAG.getFreeze(Op.getOperand(1));
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
-  // abds(lhs, rhs) -> sub(smax(lhs,rhs), smin(lhs,rhs))
-  // abdu(lhs, rhs) -> sub(umax(lhs,rhs), umin(lhs,rhs))
-  unsigned MaxOpc = IsSigned ? ISD::SMAX : ISD::UMAX;
-  unsigned MinOpc = IsSigned ? ISD::SMIN : ISD::UMIN;
-  if (TLI.isOperationLegal(MaxOpc, VT) && TLI.isOperationLegal(MinOpc, VT)) {
-    SDValue Max = DAG.getNode(MaxOpc, dl, VT, LHS, RHS);
-    SDValue Min = DAG.getNode(MinOpc, dl, VT, LHS, RHS);
-    return DAG.getNode(ISD::SUB, dl, VT, Max, Min);
-  }
-
-  // abdu(lhs, rhs) -> or(usubsat(lhs,rhs), usubsat(rhs,lhs))
-  if (!IsSigned && TLI.isOperationLegal(ISD::USUBSAT, VT))
-    return DAG.getNode(ISD::OR, dl, VT,
-                       DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
-                       DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
-
+  // TODO: Move to TargetLowering expandABD() once we have ABD promotion.
   if (VT.isScalarInteger()) {
     unsigned WideBits = std::max<unsigned>(2 * VT.getScalarSizeInBits(), 32u);
     MVT WideVT = MVT::getIntegerVT(WideBits);
@@ -30405,6 +30387,8 @@ static SDValue LowerABD(SDValue Op, const X86Subtarget &Subtarget,
       // abds(lhs, rhs) -> trunc(abs(sub(sext(lhs), sext(rhs))))
       // abdu(lhs, rhs) -> trunc(abs(sub(zext(lhs), zext(rhs))))
       unsigned ExtOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+      SDValue LHS = DAG.getFreeze(Op.getOperand(0));
+      SDValue RHS = DAG.getFreeze(Op.getOperand(1));
       LHS = DAG.getNode(ExtOpc, dl, WideVT, LHS);
       RHS = DAG.getNode(ExtOpc, dl, WideVT, RHS);
       SDValue Diff = DAG.getNode(ISD::SUB, dl, WideVT, LHS, RHS);
@@ -30413,13 +30397,8 @@ static SDValue LowerABD(SDValue Op, const X86Subtarget &Subtarget,
     }
   }
 
-  // abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
-  // abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
-  EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
-  ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
-  SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
-  return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),
-                       DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
+  // Default to expand.
+  return SDValue();
 }
 
 static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,


        


More information about the llvm-commits mailing list