[llvm] 08ce52e - [AArch64] Improve SAD pattern

Jingu Kang via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 14 07:53:44 PDT 2021


Author: Jingu Kang
Date: 2021-06-14T15:48:51+01:00
New Revision: 08ce52ef5e6b879216f8018b920ef5c0621e797d

URL: https://github.com/llvm/llvm-project/commit/08ce52ef5e6b879216f8018b920ef5c0621e797d
DIFF: https://github.com/llvm/llvm-project/commit/08ce52ef5e6b879216f8018b920ef5c0621e797d.diff

LOG: [AArch64] Improve SAD pattern

Given a vecreduce_add node, detect the below pattern and convert it to the node
sequence with UABDL, [S|U]ADB and UADDLP.

i32 vecreduce_add(
 v16i32 abs(
   v16i32 sub(
    v16i32 [sign|zero]_extend(v16i8 a), v16i32 [sign|zero]_extend(v16i8 b))))
=================>
i32 vecreduce_add(
  v4i32 UADDLP(
    v8i16 add(
      v8i16 zext(
        v8i8 [S|U]ABD low8:v16i8 a, low8:v16i8 b
      v8i16 zext(
        v8i8 [S|U]ABD high8:v16i8 a, high8:v16i8 b

Differential Revision: https://reviews.llvm.org/D104042

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64InstrInfo.td
    llvm/test/CodeGen/AArch64/arm64-vabs.ll
    llvm/test/CodeGen/AArch64/neon-sad.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0f89bffd9f211..9b0735a6a9866 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2110,6 +2110,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::INDEX_VECTOR)
     MAKE_CASE(AArch64ISD::UABD)
     MAKE_CASE(AArch64ISD::SABD)
+    MAKE_CASE(AArch64ISD::UADDLP)
     MAKE_CASE(AArch64ISD::CALL_RVMARKER)
   }
 #undef MAKE_CASE
@@ -4078,6 +4079,10 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
                        Op.getOperand(2));
   }
+  case Intrinsic::aarch64_neon_uaddlp: {
+    unsigned Opcode = AArch64ISD::UADDLP;
+    return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1));
+  }
   case Intrinsic::aarch64_neon_sdot:
   case Intrinsic::aarch64_neon_udot:
   case Intrinsic::aarch64_sve_sdot:
@@ -11981,13 +11986,106 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0));
 }
 
+// Given a vecreduce_add node, detect the below pattern and convert it to the
+// node sequence with UABDL, [S|U]ADB and UADDLP.
+//
+// i32 vecreduce_add(
+//  v16i32 abs(
+//    v16i32 sub(
+//     v16i32 [sign|zero]_extend(v16i8 a), v16i32 [sign|zero]_extend(v16i8 b))))
+// =================>
+// i32 vecreduce_add(
+//   v4i32 UADDLP(
+//     v8i16 add(
+//       v8i16 zext(
+//         v8i8 [S|U]ABD low8:v16i8 a, low8:v16i8 b
+//       v8i16 zext(
+//         v8i8 [S|U]ABD high8:v16i8 a, high8:v16i8 b
+static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
+                                                    SelectionDAG &DAG) {
+  // Assumed i32 vecreduce_add
+  if (N->getValueType(0) != MVT::i32)
+    return SDValue();
+
+  SDValue VecReduceOp0 = N->getOperand(0);
+  unsigned Opcode = VecReduceOp0.getOpcode();
+  // Assumed v16i32 abs
+  if (Opcode != ISD::ABS || VecReduceOp0->getValueType(0) != MVT::v16i32)
+    return SDValue();
+
+  SDValue ABS = VecReduceOp0;
+  // Assumed v16i32 sub
+  if (ABS->getOperand(0)->getOpcode() != ISD::SUB ||
+      ABS->getOperand(0)->getValueType(0) != MVT::v16i32)
+    return SDValue();
+
+  SDValue SUB = ABS->getOperand(0);
+  unsigned Opcode0 = SUB->getOperand(0).getOpcode();
+  unsigned Opcode1 = SUB->getOperand(1).getOpcode();
+  // Assumed v16i32 type
+  if (SUB->getOperand(0)->getValueType(0) != MVT::v16i32 ||
+      SUB->getOperand(1)->getValueType(0) != MVT::v16i32)
+    return SDValue();
+
+  // Assumed zext or sext
+  bool IsZExt = false;
+  if (Opcode0 == ISD::ZERO_EXTEND && Opcode1 == ISD::ZERO_EXTEND) {
+    IsZExt = true;
+  } else if (Opcode0 == ISD::SIGN_EXTEND && Opcode1 == ISD::SIGN_EXTEND) {
+    IsZExt = false;
+  } else
+    return SDValue();
+
+  SDValue EXT0 = SUB->getOperand(0);
+  SDValue EXT1 = SUB->getOperand(1);
+  // Assumed zext's operand has v16i8 type
+  if (EXT0->getOperand(0)->getValueType(0) != MVT::v16i8 ||
+      EXT1->getOperand(0)->getValueType(0) != MVT::v16i8)
+    return SDValue();
+
+  // Pattern is dectected. Let's convert it to sequence of nodes.
+  SDLoc DL(N);
+
+  // First, create the node pattern of UABD/SABD.
+  SDValue UABDHigh8Op0 =
+      DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0),
+                  DAG.getConstant(8, DL, MVT::i64));
+  SDValue UABDHigh8Op1 =
+      DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
+                  DAG.getConstant(8, DL, MVT::i64));
+  SDValue UABDHigh8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
+                                  DL, MVT::v8i8, UABDHigh8Op0, UABDHigh8Op1);
+  SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8);
+
+  // Second, create the node pattern of UABAL.
+  SDValue UABDLo8Op0 =
+      DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0),
+                  DAG.getConstant(0, DL, MVT::i64));
+  SDValue UABDLo8Op1 =
+      DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
+                  DAG.getConstant(0, DL, MVT::i64));
+  SDValue UABDLo8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
+                                DL, MVT::v8i8, UABDLo8Op0, UABDLo8Op1);
+  SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8);
+  SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD);
+
+  // Third, create the node of UADDLP.
+  SDValue UADDLP = DAG.getNode(AArch64ISD::UADDLP, DL, MVT::v4i32, UABAL);
+
+  // Fourth, create the node of VECREDUCE_ADD.
+  return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
+}
+
 // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
 //   vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
 //   vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
 static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
                                           const AArch64Subtarget *ST) {
+  if (!ST->hasDotProd())
+    return performVecReduceAddCombineWithUADDLP(N, DAG);
+
   SDValue Op0 = N->getOperand(0);
-  if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32 ||
+  if (N->getValueType(0) != MVT::i32 ||
       Op0.getValueType().getVectorElementType() != MVT::i32)
     return SDValue();
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index c2ada6f35accc..20872b454e0c0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -240,6 +240,9 @@ enum NodeType : unsigned {
   UABD,
   SABD,
 
+  // Unsigned Add Long Pairwise
+  UADDLP,
+
   // udot/sdot instructions
   UDOT,
   SDOT,

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 33bd0be43f5de..c303d87c838b6 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -271,6 +271,8 @@ def SDT_AArch64ITOF  : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisSameAs<0,1>]>;
 def SDT_AArch64TLSDescCall : SDTypeProfile<0, -2, [SDTCisPtrTy<0>,
                                                  SDTCisPtrTy<1>]>;
 
+def SDT_AArch64uaddlp : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
+
 def SDT_AArch64ldp : SDTypeProfile<2, 1, [SDTCisVT<0, i64>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>;
 def SDT_AArch64stp : SDTypeProfile<0, 3, [SDTCisVT<0, i64>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>;
 def SDT_AArch64stnp : SDTypeProfile<0, 3, [SDTCisVT<0, v4i32>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>;
@@ -587,6 +589,11 @@ def AArch64sabd     : PatFrags<(ops node:$lhs, node:$rhs),
                                [(AArch64sabd_n node:$lhs, node:$rhs),
                                 (int_aarch64_neon_sabd node:$lhs, node:$rhs)]>;
 
+def AArch64uaddlp_n : SDNode<"AArch64ISD::UADDLP", SDT_AArch64uaddlp>;
+def AArch64uaddlp   : PatFrags<(ops node:$src),
+                               [(AArch64uaddlp_n node:$src),
+                                (int_aarch64_neon_uaddlp node:$src)]>;
+
 def SDT_AArch64SETTAG : SDTypeProfile<0, 2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>;
 def AArch64stg : SDNode<"AArch64ISD::STG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
 def AArch64stzg : SDNode<"AArch64ISD::STZG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
@@ -4178,9 +4185,8 @@ defm SQXTN  : SIMDMixedTwoVector<0, 0b10100, "sqxtn", int_aarch64_neon_sqxtn>;
 defm SQXTUN : SIMDMixedTwoVector<1, 0b10010, "sqxtun", int_aarch64_neon_sqxtun>;
 defm SUQADD : SIMDTwoVectorBHSDTied<0, 0b00011, "suqadd",int_aarch64_neon_suqadd>;
 defm UADALP : SIMDLongTwoVectorTied<1, 0b00110, "uadalp",
-       BinOpFrag<(add node:$LHS, (int_aarch64_neon_uaddlp node:$RHS))> >;
-defm UADDLP : SIMDLongTwoVector<1, 0b00010, "uaddlp",
-                    int_aarch64_neon_uaddlp>;
+       BinOpFrag<(add node:$LHS, (AArch64uaddlp node:$RHS))> >;
+defm UADDLP : SIMDLongTwoVector<1, 0b00010, "uaddlp", AArch64uaddlp>;
 defm UCVTF  : SIMDTwoVectorIntToFP<1, 0, 0b11101, "ucvtf", uint_to_fp>;
 defm UQXTN  : SIMDMixedTwoVector<1, 0b10100, "uqxtn", int_aarch64_neon_uqxtn>;
 defm URECPE : SIMDTwoVectorS<0, 1, 0b11100, "urecpe", int_aarch64_neon_urecpe>;

diff  --git a/llvm/test/CodeGen/AArch64/arm64-vabs.ll b/llvm/test/CodeGen/AArch64/arm64-vabs.ll
index a5945bb7ac766..4d792bcec0458 100644
--- a/llvm/test/CodeGen/AArch64/arm64-vabs.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-vabs.ll
@@ -218,12 +218,9 @@ define i16 @uabd16b_rdx(<16 x i8>* %a, <16 x i8>* %b) {
 define i32 @uabd16b_rdx_i32(<16 x i8> %a, <16 x i8> %b) {
 ; CHECK-LABEL: uabd16b_rdx_i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    uabd.16b v0, v0, v1
-; CHECK-NEXT:    ushll2.8h v1, v0, #0
-; CHECK-NEXT:    ushll.8h v0, v0, #0
-; CHECK-NEXT:    uaddl2.4s v2, v0, v1
-; CHECK-NEXT:    uaddl.4s v0, v0, v1
-; CHECK-NEXT:    add.4s v0, v0, v2
+; CHECK-NEXT:    uabdl.8h v2, v0, v1
+; CHECK-NEXT:    uabal2.8h v2, v0, v1
+; CHECK-NEXT:    uaddlp.4s v0, v2
 ; CHECK-NEXT:    addv.4s s0, v0
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
@@ -240,12 +237,9 @@ define i32 @uabd16b_rdx_i32(<16 x i8> %a, <16 x i8> %b) {
 define i32 @sabd16b_rdx_i32(<16 x i8> %a, <16 x i8> %b) {
 ; CHECK-LABEL: sabd16b_rdx_i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    sabd.16b v0, v0, v1
-; CHECK-NEXT:    ushll2.8h v1, v0, #0
-; CHECK-NEXT:    ushll.8h v0, v0, #0
-; CHECK-NEXT:    uaddl2.4s v2, v0, v1
-; CHECK-NEXT:    uaddl.4s v0, v0, v1
-; CHECK-NEXT:    add.4s v0, v0, v2
+; CHECK-NEXT:    sabdl.8h v2, v0, v1
+; CHECK-NEXT:    sabal2.8h v2, v0, v1
+; CHECK-NEXT:    uaddlp.4s v0, v2
 ; CHECK-NEXT:    addv.4s s0, v0
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret

diff  --git a/llvm/test/CodeGen/AArch64/neon-sad.ll b/llvm/test/CodeGen/AArch64/neon-sad.ll
index c5372a2424d0d..cfd9712efdc33 100644
--- a/llvm/test/CodeGen/AArch64/neon-sad.ll
+++ b/llvm/test/CodeGen/AArch64/neon-sad.ll
@@ -9,12 +9,9 @@ define i32 @test_sad_v16i8_zext(i8* nocapture readonly %a, i8* nocapture readonl
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ldr q1, [x1]
-; CHECK-NEXT:    uabd v0.16b, v1.16b, v0.16b
-; CHECK-NEXT:    ushll2 v1.8h, v0.16b, #0
-; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-NEXT:    uaddl2 v2.4s, v0.8h, v1.8h
-; CHECK-NEXT:    uaddl v0.4s, v0.4h, v1.4h
-; CHECK-NEXT:    add v0.4s, v0.4s, v2.4s
+; CHECK-NEXT:    uabdl v2.8h, v1.8b, v0.8b
+; CHECK-NEXT:    uabal2 v2.8h, v1.16b, v0.16b
+; CHECK-NEXT:    uaddlp v0.4s, v2.8h
 ; CHECK-NEXT:    addv s0, v0.4s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
@@ -36,12 +33,9 @@ define i32 @test_sad_v16i8_sext(i8* nocapture readonly %a, i8* nocapture readonl
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ldr q1, [x1]
-; CHECK-NEXT:    sabd v0.16b, v1.16b, v0.16b
-; CHECK-NEXT:    ushll2 v1.8h, v0.16b, #0
-; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-NEXT:    uaddl2 v2.4s, v0.8h, v1.8h
-; CHECK-NEXT:    uaddl v0.4s, v0.4h, v1.4h
-; CHECK-NEXT:    add v0.4s, v0.4s, v2.4s
+; CHECK-NEXT:    sabdl v2.8h, v1.8b, v0.8b
+; CHECK-NEXT:    sabal2 v2.8h, v1.16b, v0.16b
+; CHECK-NEXT:    uaddlp v0.4s, v2.8h
 ; CHECK-NEXT:    addv s0, v0.4s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret


        


More information about the llvm-commits mailing list