[llvm] 2887f14 - [ISel] Port AArch64 SABD and UABD to DAGCombine

David Green via llvm-commits llvm-commits at lists.llvm.org
Sat Jun 26 11:34:41 PDT 2021


Author: David Green
Date: 2021-06-26T19:34:16+01:00
New Revision: 2887f1463930044a6093f111dc8eba5594144c33

URL: https://github.com/llvm/llvm-project/commit/2887f1463930044a6093f111dc8eba5594144c33
DIFF: https://github.com/llvm/llvm-project/commit/2887f1463930044a6093f111dc8eba5594144c33.diff

LOG: [ISel] Port AArch64 SABD and UABD to DAGCombine

This ports the AArch64 SABD and USBD over to DAG Combine, where they can be
used by more backends (notably MVE in a follow-up patch). The matching code
has changed very little, just to handle legal operations and types
differently. It selects from (ABS (SUB (EXTEND a), (EXTEND b))), producing
a ubds/abdu which is zexted to the original type.

Differential Revision: https://reviews.llvm.org/D91937

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/ISDOpcodes.h
    llvm/include/llvm/Target/TargetSelectionDAG.td
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
    llvm/lib/CodeGen/TargetLoweringBase.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64InstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index adad8c18e5583..6eb70ab477089 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -611,6 +611,13 @@ enum NodeType {
   MULHU,
   MULHS,
 
+  // ABDS/ABDU - Absolute 
diff erence - Return the absolute 
diff erence between
+  // two numbers interpreted as signed/unsigned.
+  // i.e trunc(abs(sext(Op0) - sext(Op1))) becomes abds(Op0, Op1)
+  //  or trunc(abs(zext(Op0) - zext(Op1))) becomes abdu(Op0, Op1)
+  ABDS,
+  ABDU,
+
   /// [US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned
   /// integers.
   SMIN,

diff  --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 1913396609fd0..c7f22bf173910 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -369,6 +369,8 @@ def mul        : SDNode<"ISD::MUL"       , SDTIntBinOp,
                         [SDNPCommutative, SDNPAssociative]>;
 def mulhs      : SDNode<"ISD::MULHS"     , SDTIntBinOp, [SDNPCommutative]>;
 def mulhu      : SDNode<"ISD::MULHU"     , SDTIntBinOp, [SDNPCommutative]>;
+def abds       : SDNode<"ISD::ABDS"      , SDTIntBinOp, [SDNPCommutative]>;
+def abdu       : SDNode<"ISD::ABDU"      , SDTIntBinOp, [SDNPCommutative]>;
 def smullohi   : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
 def umullohi   : SDNode<"ISD::UMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
 def sdiv       : SDNode<"ISD::SDIV"      , SDTIntBinOp>;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 63c979c4dbd73..5ea3de9d0db66 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -9071,6 +9071,40 @@ SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
   return SDValue();
 }
 
+// Given a ABS node, detect the following pattern:
+// (ABS (SUB (EXTEND a), (EXTEND b))).
+// Generates UABD/SABD instruction.
+static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG,
+                               const TargetLowering &TLI) {
+  SDValue AbsOp1 = N->getOperand(0);
+  SDValue Op0, Op1;
+
+  if (AbsOp1.getOpcode() != ISD::SUB)
+    return SDValue();
+
+  Op0 = AbsOp1.getOperand(0);
+  Op1 = AbsOp1.getOperand(1);
+
+  unsigned Opc0 = Op0.getOpcode();
+  // Check if the operands of the sub are (zero|sign)-extended.
+  if (Opc0 != Op1.getOpcode() ||
+      (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
+    return SDValue();
+
+  EVT VT1 = Op0.getOperand(0).getValueType();
+  EVT VT2 = Op1.getOperand(0).getValueType();
+  // Check if the operands are of same type and valid size.
+  unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
+  if (VT1 != VT2 || !TLI.isOperationLegalOrCustom(ABDOpcode, VT1))
+    return SDValue();
+
+  Op0 = Op0.getOperand(0);
+  Op1 = Op1.getOperand(0);
+  SDValue ABD =
+      DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
+  return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
+}
+
 SDValue DAGCombiner::visitABS(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
@@ -9084,6 +9118,10 @@ SDValue DAGCombiner::visitABS(SDNode *N) {
   // fold (abs x) -> x iff not-negative
   if (DAG.SignBitIsZero(N0))
     return N0;
+
+  if (SDValue ABD = combineABSToABD(N, DAG, TLI))
+    return ABD;
+
   return SDValue();
 }
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 73c207e589fba..40083c614a6c7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -231,6 +231,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::MUL:                        return "mul";
   case ISD::MULHU:                      return "mulhu";
   case ISD::MULHS:                      return "mulhs";
+  case ISD::ABDS:                       return "abds";
+  case ISD::ABDU:                       return "abdu";
   case ISD::SDIV:                       return "sdiv";
   case ISD::UDIV:                       return "udiv";
   case ISD::SREM:                       return "srem";

diff  --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index d2c291f2ae72b..ebac779984ec5 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -813,6 +813,10 @@ void TargetLoweringBase::initActions() {
     setOperationAction(ISD::SUBC, VT, Expand);
     setOperationAction(ISD::SUBE, VT, Expand);
 
+    // Absolute 
diff erence
+    setOperationAction(ISD::ABDS, VT, Expand);
+    setOperationAction(ISD::ABDU, VT, Expand);
+
     // These default to Expand so they will be expanded to CTLZ/CTTZ by default.
     setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Expand);
     setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Expand);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fd5c9e05a8a25..9886d6374665b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1050,6 +1050,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::USUBSAT, VT, Legal);
     }
 
+    for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
+                   MVT::v4i32}) {
+      setOperationAction(ISD::ABDS, VT, Legal);
+      setOperationAction(ISD::ABDU, VT, Legal);
+    }
+
     // Vector reductions
     for (MVT VT : { MVT::v4f16, MVT::v2f32,
                     MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
@@ -2116,8 +2122,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::INDEX_VECTOR)
-    MAKE_CASE(AArch64ISD::UABD)
-    MAKE_CASE(AArch64ISD::SABD)
     MAKE_CASE(AArch64ISD::UADDLP)
     MAKE_CASE(AArch64ISD::CALL_RVMARKER)
   }
@@ -4082,8 +4086,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   }
   case Intrinsic::aarch64_neon_sabd:
   case Intrinsic::aarch64_neon_uabd: {
-    unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? AArch64ISD::UABD
-                                                            : AArch64ISD::SABD;
+    unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? ISD::ABDU
+                                                            : ISD::ABDS;
     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
                        Op.getOperand(2));
   }
@@ -12099,8 +12103,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
   SDValue UABDHigh8Op1 =
       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
                   DAG.getConstant(8, DL, MVT::i64));
-  SDValue UABDHigh8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
-                                  DL, MVT::v8i8, UABDHigh8Op0, UABDHigh8Op1);
+  SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
+                                  UABDHigh8Op0, UABDHigh8Op1);
   SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8);
 
   // Second, create the node pattern of UABAL.
@@ -12110,8 +12114,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
   SDValue UABDLo8Op1 =
       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
                   DAG.getConstant(0, DL, MVT::i64));
-  SDValue UABDLo8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
-                                DL, MVT::v8i8, UABDLo8Op0, UABDLo8Op1);
+  SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
+                                UABDLo8Op0, UABDLo8Op1);
   SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8);
   SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD);
 
@@ -12170,48 +12174,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
 }
 
-// Given a ABS node, detect the following pattern:
-// (ABS (SUB (EXTEND a), (EXTEND b))).
-// Generates UABD/SABD instruction.
-static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG,
-                                 TargetLowering::DAGCombinerInfo &DCI,
-                                 const AArch64Subtarget *Subtarget) {
-  SDValue AbsOp1 = N->getOperand(0);
-  SDValue Op0, Op1;
-
-  if (AbsOp1.getOpcode() != ISD::SUB)
-    return SDValue();
-
-  Op0 = AbsOp1.getOperand(0);
-  Op1 = AbsOp1.getOperand(1);
-
-  unsigned Opc0 = Op0.getOpcode();
-  // Check if the operands of the sub are (zero|sign)-extended.
-  if (Opc0 != Op1.getOpcode() ||
-      (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
-    return SDValue();
-
-  EVT VectorT1 = Op0.getOperand(0).getValueType();
-  EVT VectorT2 = Op1.getOperand(0).getValueType();
-  // Check if vectors are of same type and valid size.
-  uint64_t Size = VectorT1.getFixedSizeInBits();
-  if (VectorT1 != VectorT2 || (Size != 64 && Size != 128))
-    return SDValue();
-
-  // Check if vector element types are valid.
-  EVT VT1 = VectorT1.getVectorElementType();
-  if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32)
-    return SDValue();
-
-  Op0 = Op0.getOperand(0);
-  Op1 = Op1.getOperand(0);
-  unsigned ABDOpcode =
-      (Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD;
-  SDValue ABD =
-      DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
-  return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
-}
-
 static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  const AArch64Subtarget *Subtarget) {
@@ -14377,8 +14339,8 @@ static SDValue performExtendCombine(SDNode *N,
   // helps the backend to decide that an sabdl2 would be useful, saving a real
   // extract_high operation.
   if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
-      (N->getOperand(0).getOpcode() == AArch64ISD::UABD ||
-       N->getOperand(0).getOpcode() == AArch64ISD::SABD)) {
+      (N->getOperand(0).getOpcode() == ISD::ABDU ||
+       N->getOperand(0).getOpcode() == ISD::ABDS)) {
     SDNode *ABDNode = N->getOperand(0).getNode();
     SDValue NewABD =
         tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
@@ -16344,8 +16306,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   default:
     LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
     break;
-  case ISD::ABS:
-    return performABSCombine(N, DAG, DCI, Subtarget);
   case ISD::ADD:
   case ISD::SUB:
     return performAddSubCombine(N, DCI, DAG);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index ced2607862123..7daa61996739f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -236,10 +236,6 @@ enum NodeType : unsigned {
   SRHADD,
   URHADD,
 
-  // Absolute 
diff erence
-  UABD,
-  SABD,
-
   // Unsigned Add Long Pairwise
   UADDLP,
 

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 301f1ed69638d..7802144fb2c98 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -579,14 +579,11 @@ def AArch64urhadd   : SDNode<"AArch64ISD::URHADD", SDT_AArch64binvec>;
 def AArch64shadd   : SDNode<"AArch64ISD::SHADD", SDT_AArch64binvec>;
 def AArch64uhadd   : SDNode<"AArch64ISD::UHADD", SDT_AArch64binvec>;
 
-def AArch64uabd_n   : SDNode<"AArch64ISD::UABD", SDT_AArch64binvec>;
-def AArch64sabd_n   : SDNode<"AArch64ISD::SABD", SDT_AArch64binvec>;
-
 def AArch64uabd     : PatFrags<(ops node:$lhs, node:$rhs),
-                               [(AArch64uabd_n node:$lhs, node:$rhs),
+                               [(abdu node:$lhs, node:$rhs),
                                 (int_aarch64_neon_uabd node:$lhs, node:$rhs)]>;
 def AArch64sabd     : PatFrags<(ops node:$lhs, node:$rhs),
-                               [(AArch64sabd_n node:$lhs, node:$rhs),
+                               [(abds node:$lhs, node:$rhs),
                                 (int_aarch64_neon_sabd node:$lhs, node:$rhs)]>;
 
 def AArch64uaddlp_n : SDNode<"AArch64ISD::UADDLP", SDT_AArch64uaddlp>;


        


More information about the llvm-commits mailing list