[llvm] fceb3e3 - [ARM] MVE VADDLV lowering

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 19 03:08:18 PST 2020


Author: David Green
Date: 2020-02-19T11:07:20Z
New Revision: fceb3e3b4aec635d4e31dda618c7e4c17516cdb9

URL: https://github.com/llvm/llvm-project/commit/fceb3e3b4aec635d4e31dda618c7e4c17516cdb9
DIFF: https://github.com/llvm/llvm-project/commit/fceb3e3b4aec635d4e31dda618c7e4c17516cdb9.diff

LOG: [ARM] MVE VADDLV lowering

Following on from the extra VADDV lowering, this extends things to
handle VADDLV which allows summing values into a pair of i32 registers,
together treated as a i64. This needs to be done in DAGCombine too as
the types are otherwise illegal, which is a fairly simple addition on
top of the existing code.

There is also a VADDLVA instruction handled here, that adds the incoming
values from the two general purpose registers. As opposed to the
non-long version where we could just add patterns for add(x, VADDV), the
long version needs to handle this early before the i64 has being split
into too many pieces.

Differential Revision: https://reviews.llvm.org/D74224

Added: 
    

Modified: 
    llvm/lib/Target/ARM/ARMISelLowering.cpp
    llvm/lib/Target/ARM/ARMISelLowering.h
    llvm/lib/Target/ARM/ARMInstrMVE.td
    llvm/test/CodeGen/Thumb2/mve-vecreduce-add.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 40815692c71d..f07c0f022b7f 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -943,6 +943,7 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN);
     setTargetDAGCombine(ISD::INTRINSIC_VOID);
     setTargetDAGCombine(ISD::VECREDUCE_ADD);
+    setTargetDAGCombine(ISD::ADD);
   }
 
   if (!Subtarget->hasFP64()) {
@@ -1656,6 +1657,10 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
   case ARMISD::VMULLu:        return "ARMISD::VMULLu";
   case ARMISD::VADDVs:        return "ARMISD::VADDVs";
   case ARMISD::VADDVu:        return "ARMISD::VADDVu";
+  case ARMISD::VADDLVs:       return "ARMISD::VADDLVs";
+  case ARMISD::VADDLVu:       return "ARMISD::VADDLVu";
+  case ARMISD::VADDLVAs:      return "ARMISD::VADDLVAs";
+  case ARMISD::VADDLVAu:      return "ARMISD::VADDLVAu";
   case ARMISD::UMAAL:         return "ARMISD::UMAAL";
   case ARMISD::UMLAL:         return "ARMISD::UMLAL";
   case ARMISD::SMLAL:         return "ARMISD::SMLAL";
@@ -11731,6 +11736,53 @@ static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
   return SDValue();
 }
 
+static SDValue PerformADDVecReduce(SDNode *N,
+                                   TargetLowering::DAGCombinerInfo &DCI,
+                                   const ARMSubtarget *Subtarget) {
+  if (!Subtarget->hasMVEIntegerOps() || N->getValueType(0) != MVT::i64)
+    return SDValue();
+
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+
+  // We are looking for a i64 add of a VADDLVx. Due to these being i64's, this
+  // will look like:
+  //   t1: i32,i32 = ARMISD::VADDLVs x
+  //   t2: i64 = build_pair t1, t1:1
+  //   t3: i64 = add t2, y
+  // We also need to check for sext / zext and commutitive adds.
+  auto MakeVecReduce = [&](unsigned Opcode, unsigned OpcodeA, SDValue NA,
+                           SDValue NB) {
+    if (NB->getOpcode() != ISD::BUILD_PAIR)
+      return SDValue();
+    SDValue VecRed = NB->getOperand(0);
+    if (VecRed->getOpcode() != Opcode || VecRed.getResNo() != 0 ||
+        NB->getOperand(1) != SDValue(VecRed.getNode(), 1))
+      return SDValue();
+
+    SDLoc dl(N);
+    SDValue Lo = DCI.DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, NA,
+                                 DCI.DAG.getConstant(0, dl, MVT::i32));
+    SDValue Hi = DCI.DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, NA,
+                                 DCI.DAG.getConstant(1, dl, MVT::i32));
+    SDValue Red =
+        DCI.DAG.getNode(OpcodeA, dl, DCI.DAG.getVTList({MVT::i32, MVT::i32}),
+                        Lo, Hi, VecRed->getOperand(0));
+    return DCI.DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Red,
+                           SDValue(Red.getNode(), 1));
+  };
+
+  if (SDValue M = MakeVecReduce(ARMISD::VADDLVs, ARMISD::VADDLVAs, N0, N1))
+    return M;
+  if (SDValue M = MakeVecReduce(ARMISD::VADDLVu, ARMISD::VADDLVAu, N0, N1))
+    return M;
+  if (SDValue M = MakeVecReduce(ARMISD::VADDLVs, ARMISD::VADDLVAs, N1, N0))
+    return M;
+  if (SDValue M = MakeVecReduce(ARMISD::VADDLVu, ARMISD::VADDLVAu, N1, N0))
+    return M;
+  return SDValue();
+}
+
 bool
 ARMTargetLowering::isDesirableToCommuteWithShift(const SDNode *N,
                                                  CombineLevel Level) const {
@@ -11902,6 +11954,9 @@ static SDValue PerformADDCombine(SDNode *N,
   if (SDValue Result = PerformSHLSimplify(N, DCI, Subtarget))
     return Result;
 
+  if (SDValue Result = PerformADDVecReduce(N, DCI, Subtarget))
+    return Result;
+
   // First try with the default operand order.
   if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget))
     return Result;
@@ -13945,6 +14000,7 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
 
   // Cases:
   //   VADDV u/s 8/16/32
+  //   VADDLV u/s 32
 
   auto IsVADDV = [&](MVT RetTy, unsigned ExtendCode, ArrayRef<MVT> ExtTypes) {
     if (ResVT != RetTy || N0->getOpcode() != ExtendCode)
@@ -13954,11 +14010,19 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
       return A;
     return SDValue();
   };
+  auto Create64bitNode = [&](unsigned Opcode, ArrayRef<SDValue> Ops) {
+    SDValue Node = DAG.getNode(Opcode, dl, {MVT::i32, MVT::i32}, Ops);
+    return DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Node, SDValue(Node.getNode(), 1));
+  };
 
   if (SDValue A = IsVADDV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}))
     return DAG.getNode(ARMISD::VADDVs, dl, ResVT, A);
   if (SDValue A = IsVADDV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8}))
     return DAG.getNode(ARMISD::VADDVu, dl, ResVT, A);
+  if (SDValue A = IsVADDV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v4i32}))
+    return Create64bitNode(ARMISD::VADDLVs, {A});
+  if (SDValue A = IsVADDV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v4i32}))
+    return Create64bitNode(ARMISD::VADDLVu, {A});
 
   return SDValue();
 }

diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h
index a12d7299ff8e..c635622ee8d4 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.h
+++ b/llvm/lib/Target/ARM/ARMISelLowering.h
@@ -209,6 +209,10 @@ class VectorType;
       // MVE reductions
       VADDVs,
       VADDVu,
+      VADDLVs,
+      VADDLVu,
+      VADDLVAs,
+      VADDLVAu,
 
       SMULWB,       // Signed multiply word by half word, bottom
       SMULWT,       // Signed multiply word by half word, top

diff  --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td
index 2b0704fd28f2..e4c73f824f94 100644
--- a/llvm/lib/Target/ARM/ARMInstrMVE.td
+++ b/llvm/lib/Target/ARM/ARMInstrMVE.td
@@ -691,6 +691,30 @@ multiclass MVE_VADDLV_A<string suffix, bit U, list<dag> pattern=[]> {
 defm MVE_VADDLVs32 : MVE_VADDLV_A<"s32", 0b0>;
 defm MVE_VADDLVu32 : MVE_VADDLV_A<"u32", 0b1>;
 
+def SDTVecReduceL : SDTypeProfile<2, 1, [    // VADDLV
+  SDTCisInt<0>, SDTCisInt<1>, SDTCisVec<2>
+]>;
+def SDTVecReduceLA : SDTypeProfile<2, 3, [    // VADDLVA
+  SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>, SDTCisInt<3>,
+  SDTCisVec<4>
+]>;
+def ARMVADDLVs      : SDNode<"ARMISD::VADDLVs", SDTVecReduceL>;
+def ARMVADDLVu      : SDNode<"ARMISD::VADDLVu", SDTVecReduceL>;
+def ARMVADDLVAs     : SDNode<"ARMISD::VADDLVAs", SDTVecReduceLA>;
+def ARMVADDLVAu     : SDNode<"ARMISD::VADDLVAu", SDTVecReduceLA>;
+
+let Predicates = [HasMVEInt] in {
+  def : Pat<(ARMVADDLVs (v4i32 MQPR:$val1)),
+            (MVE_VADDLVs32no_acc (v4i32 MQPR:$val1))>;
+  def : Pat<(ARMVADDLVu (v4i32 MQPR:$val1)),
+            (MVE_VADDLVu32no_acc (v4i32 MQPR:$val1))>;
+
+  def : Pat<(ARMVADDLVAs tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1)),
+            (MVE_VADDLVs32acc tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1))>;
+  def : Pat<(ARMVADDLVAu tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1)),
+            (MVE_VADDLVu32acc tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1))>;
+}
+
 class MVE_VMINMAXNMV<string iname, string suffix, bit sz,
                      bit bit_17, bit bit_7, list<dag> pattern=[]>
   : MVE_rDest<(outs rGPR:$RdaDest), (ins rGPR:$RdaSrc, MQPR:$Qm),

diff  --git a/llvm/test/CodeGen/Thumb2/mve-vecreduce-add.ll b/llvm/test/CodeGen/Thumb2/mve-vecreduce-add.ll
index 4ada1a65512e..ced01f0606c7 100644
--- a/llvm/test/CodeGen/Thumb2/mve-vecreduce-add.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-vecreduce-add.ll
@@ -14,36 +14,8 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v4i32_v4i64_zext(<4 x i32> %x) {
 ; CHECK-LABEL: add_v4i32_v4i64_zext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    adr r0, .LCPI1_0
-; CHECK-NEXT:    vmov.f32 s4, s0
-; CHECK-NEXT:    vldrw.u32 q2, [r0]
-; CHECK-NEXT:    vmov.f32 s6, s1
-; CHECK-NEXT:    vand q1, q1, q2
-; CHECK-NEXT:    vmov r2, s6
-; CHECK-NEXT:    vmov r3, s4
-; CHECK-NEXT:    vmov r0, s7
-; CHECK-NEXT:    vmov r1, s5
-; CHECK-NEXT:    vmov.f32 s4, s2
-; CHECK-NEXT:    vmov.f32 s6, s3
-; CHECK-NEXT:    vand q0, q1, q2
-; CHECK-NEXT:    adds r2, r2, r3
-; CHECK-NEXT:    vmov r3, s0
-; CHECK-NEXT:    adcs r0, r1
-; CHECK-NEXT:    vmov r1, s1
-; CHECK-NEXT:    adds r2, r2, r3
-; CHECK-NEXT:    vmov r3, s3
-; CHECK-NEXT:    adcs r1, r0
-; CHECK-NEXT:    vmov r0, s2
-; CHECK-NEXT:    adds r0, r0, r2
-; CHECK-NEXT:    adcs r1, r3
+; CHECK-NEXT:    vaddlv.u32 r0, r1, q0
 ; CHECK-NEXT:    bx lr
-; CHECK-NEXT:    .p2align 4
-; CHECK-NEXT:  @ %bb.1:
-; CHECK-NEXT:  .LCPI1_0:
-; CHECK-NEXT:    .long 4294967295 @ 0xffffffff
-; CHECK-NEXT:    .long 0 @ 0x0
-; CHECK-NEXT:    .long 4294967295 @ 0xffffffff
-; CHECK-NEXT:    .long 0 @ 0x0
 entry:
   %xx = zext <4 x i32> %x to <4 x i64>
   %z = call i64 @llvm.experimental.vector.reduce.add.v4i64(<4 x i64> %xx)
@@ -53,29 +25,7 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v4i32_v4i64_sext(<4 x i32> %x) {
 ; CHECK-LABEL: add_v4i32_v4i64_sext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.f32 s4, s0
-; CHECK-NEXT:    vmov.f32 s6, s1
-; CHECK-NEXT:    vmov r0, s4
-; CHECK-NEXT:    vmov.32 q2[0], r0
-; CHECK-NEXT:    asrs r0, r0, #31
-; CHECK-NEXT:    vmov.32 q2[1], r0
-; CHECK-NEXT:    vmov r0, s6
-; CHECK-NEXT:    vmov.32 q2[2], r0
-; CHECK-NEXT:    vmov.f32 s4, s2
-; CHECK-NEXT:    vmov.f32 s6, s3
-; CHECK-NEXT:    asrs r1, r0, #31
-; CHECK-NEXT:    vmov.32 q2[3], r1
-; CHECK-NEXT:    vmov r2, s10
-; CHECK-NEXT:    vmov r3, s8
-; CHECK-NEXT:    vmov r1, s9
-; CHECK-NEXT:    adds r2, r2, r3
-; CHECK-NEXT:    vmov r3, s6
-; CHECK-NEXT:    adc.w r0, r1, r0, asr #31
-; CHECK-NEXT:    vmov r1, s4
-; CHECK-NEXT:    adds r2, r2, r1
-; CHECK-NEXT:    adc.w r1, r0, r1, asr #31
-; CHECK-NEXT:    adds r0, r2, r3
-; CHECK-NEXT:    adc.w r1, r1, r3, asr #31
+; CHECK-NEXT:    vaddlv.s32 r0, r1, q0
 ; CHECK-NEXT:    bx lr
 entry:
   %xx = sext <4 x i32> %x to <4 x i64>
@@ -856,40 +806,8 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v4i32_v4i64_acc_zext(<4 x i32> %x, i64 %a) {
 ; CHECK-LABEL: add_v4i32_v4i64_acc_zext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .save {r4, lr}
-; CHECK-NEXT:    push {r4, lr}
-; CHECK-NEXT:    adr r2, .LCPI29_0
-; CHECK-NEXT:    vmov.f32 s4, s0
-; CHECK-NEXT:    vldrw.u32 q2, [r2]
-; CHECK-NEXT:    vmov.f32 s6, s1
-; CHECK-NEXT:    vand q1, q1, q2
-; CHECK-NEXT:    vmov r2, s6
-; CHECK-NEXT:    vmov r3, s4
-; CHECK-NEXT:    vmov r12, s7
-; CHECK-NEXT:    vmov lr, s5
-; CHECK-NEXT:    vmov.f32 s4, s2
-; CHECK-NEXT:    vmov.f32 s6, s3
-; CHECK-NEXT:    vand q0, q1, q2
-; CHECK-NEXT:    adds r4, r3, r2
-; CHECK-NEXT:    vmov r3, s0
-; CHECK-NEXT:    vmov r2, s1
-; CHECK-NEXT:    adc.w r12, r12, lr
-; CHECK-NEXT:    adds r3, r3, r4
-; CHECK-NEXT:    vmov r4, s3
-; CHECK-NEXT:    adc.w r12, r12, r2
-; CHECK-NEXT:    vmov r2, s2
-; CHECK-NEXT:    adds r2, r2, r3
-; CHECK-NEXT:    adc.w r3, r12, r4
-; CHECK-NEXT:    adds r0, r0, r2
-; CHECK-NEXT:    adcs r1, r3
-; CHECK-NEXT:    pop {r4, pc}
-; CHECK-NEXT:    .p2align 4
-; CHECK-NEXT:  @ %bb.1:
-; CHECK-NEXT:  .LCPI29_0:
-; CHECK-NEXT:    .long 4294967295 @ 0xffffffff
-; CHECK-NEXT:    .long 0 @ 0x0
-; CHECK-NEXT:    .long 4294967295 @ 0xffffffff
-; CHECK-NEXT:    .long 0 @ 0x0
+; CHECK-NEXT:    vaddlva.u32 r0, r1, q0
+; CHECK-NEXT:    bx lr
 entry:
   %xx = zext <4 x i32> %x to <4 x i64>
   %z = call i64 @llvm.experimental.vector.reduce.add.v4i64(<4 x i64> %xx)
@@ -900,34 +818,8 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v4i32_v4i64_acc_sext(<4 x i32> %x, i64 %a) {
 ; CHECK-LABEL: add_v4i32_v4i64_acc_sext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .save {r7, lr}
-; CHECK-NEXT:    push {r7, lr}
-; CHECK-NEXT:    vmov.f32 s4, s0
-; CHECK-NEXT:    vmov.f32 s6, s1
-; CHECK-NEXT:    vmov r2, s4
-; CHECK-NEXT:    vmov.32 q2[0], r2
-; CHECK-NEXT:    asrs r2, r2, #31
-; CHECK-NEXT:    vmov.32 q2[1], r2
-; CHECK-NEXT:    vmov r2, s6
-; CHECK-NEXT:    vmov.32 q2[2], r2
-; CHECK-NEXT:    vmov.f32 s4, s2
-; CHECK-NEXT:    vmov.f32 s6, s3
-; CHECK-NEXT:    asrs r3, r2, #31
-; CHECK-NEXT:    vmov.32 q2[3], r3
-; CHECK-NEXT:    vmov lr, s10
-; CHECK-NEXT:    vmov r3, s8
-; CHECK-NEXT:    vmov r12, s9
-; CHECK-NEXT:    adds.w r3, r3, lr
-; CHECK-NEXT:    adc.w r12, r12, r2, asr #31
-; CHECK-NEXT:    vmov r2, s4
-; CHECK-NEXT:    adds r3, r3, r2
-; CHECK-NEXT:    adc.w r12, r12, r2, asr #31
-; CHECK-NEXT:    vmov r2, s6
-; CHECK-NEXT:    adds r3, r3, r2
-; CHECK-NEXT:    adc.w r2, r12, r2, asr #31
-; CHECK-NEXT:    adds r0, r0, r3
-; CHECK-NEXT:    adcs r1, r2
-; CHECK-NEXT:    pop {r7, pc}
+; CHECK-NEXT:    vaddlva.s32 r0, r1, q0
+; CHECK-NEXT:    bx lr
 entry:
   %xx = sext <4 x i32> %x to <4 x i64>
   %z = call i64 @llvm.experimental.vector.reduce.add.v4i64(<4 x i64> %xx)


        


More information about the llvm-commits mailing list