[llvm] d5fd3d9 - [AArch64] Match pairwise add/fadd pattern

Sanne Wouda via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 17 08:27:47 PDT 2020


Author: Sanne Wouda
Date: 2020-09-17T16:27:01+01:00
New Revision: d5fd3d9b903ef6d96c6b3b82434dd0461faaba55

URL: https://github.com/llvm/llvm-project/commit/d5fd3d9b903ef6d96c6b3b82434dd0461faaba55
DIFF: https://github.com/llvm/llvm-project/commit/d5fd3d9b903ef6d96c6b3b82434dd0461faaba55.diff

LOG: [AArch64] Match pairwise add/fadd pattern

D75689 turns the faddp pattern into a shuffle with vector add.

Match this new pattern in target-specific DAG combine, rather than ISel,
because legalization (for v2f32) turns it into a bit of a mess.

- extended to cover f16, f32, f64 and i64

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.td
    llvm/test/CodeGen/AArch64/faddp-half.ll
    llvm/test/CodeGen/AArch64/faddp.ll
    llvm/test/CodeGen/AArch64/vecreduce-fadd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c4f02d36c7a7..7b5cf792a332 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -745,6 +745,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(ISD::INTRINSIC_VOID);
   setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN);
   setTargetDAGCombine(ISD::INSERT_VECTOR_ELT);
+  setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
 
   setTargetDAGCombine(ISD::GlobalAddress);
 
@@ -11602,6 +11603,60 @@ performVectorTruncateCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   return ResultHADD;
 }
 
+static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
+  switch (Opcode) {
+  case ISD::FADD:
+    return (FullFP16 && VT == MVT::f16) || VT == MVT::f32 || VT == MVT::f64;
+  case ISD::ADD:
+    return VT == MVT::i64;
+  default:
+    return false;
+  }
+}
+
+static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) {
+  SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
+  ConstantSDNode *ConstantN1 = dyn_cast<ConstantSDNode>(N1);
+
+  EVT VT = N->getValueType(0);
+  const bool FullFP16 =
+      static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16();
+
+  // Rewrite for pairwise fadd pattern
+  //   (f32 (extract_vector_elt
+  //           (fadd (vXf32 Other)
+  //                 (vector_shuffle (vXf32 Other) undef <1,X,...> )) 0))
+  // ->
+  //   (f32 (fadd (extract_vector_elt (vXf32 Other) 0)
+  //              (extract_vector_elt (vXf32 Other) 1))
+  if (ConstantN1 && ConstantN1->getZExtValue() == 0 &&
+      hasPairwiseAdd(N0->getOpcode(), VT, FullFP16)) {
+    SDLoc DL(N0);
+    SDValue N00 = N0->getOperand(0);
+    SDValue N01 = N0->getOperand(1);
+
+    ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(N01);
+    SDValue Other = N00;
+
+    // And handle the commutative case.
+    if (!Shuffle) {
+      Shuffle = dyn_cast<ShuffleVectorSDNode>(N00);
+      Other = N01;
+    }
+
+    if (Shuffle && Shuffle->getMaskElt(0) == 1 &&
+        Other == Shuffle->getOperand(0)) {
+      return DAG.getNode(N0->getOpcode(), DL, VT,
+                         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
+                                     DAG.getConstant(0, DL, MVT::i64)),
+                         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
+                                     DAG.getConstant(1, DL, MVT::i64)));
+    }
+  }
+
+  return SDValue();
+}
+
 static SDValue performConcatVectorsCombine(SDNode *N,
                                            TargetLowering::DAGCombinerInfo &DCI,
                                            SelectionDAG &DAG) {
@@ -14425,6 +14480,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performUzpCombine(N, DAG);
   case ISD::INSERT_VECTOR_ELT:
     return performPostLD1Combine(N, DCI, true);
+  case ISD::EXTRACT_VECTOR_ELT:
+    return performExtractVectorEltCombine(N, DAG);
   case ISD::INTRINSIC_VOID:
   case ISD::INTRINSIC_W_CHAIN:
     switch (cast<ConstantSDNode>(N->getOperand(1))->getZExtValue()) {

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 6a0bb14f5514..06e88b7b2045 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -7482,6 +7482,9 @@ def : Pat<(f64 (fadd (vector_extract (v2f64 FPR128:$Rn), (i64 0)),
 def : Pat<(fadd (vector_extract (v4f32 FPR128:$Rn), (i64 0)),
                 (vector_extract (v4f32 FPR128:$Rn), (i64 1))),
           (f32 (FADDPv2i32p (EXTRACT_SUBREG FPR128:$Rn, dsub)))>;
+def : Pat<(fadd (vector_extract (v8f16 FPR128:$Rn), (i64 0)),
+                (vector_extract (v8f16 FPR128:$Rn), (i64 1))),
+          (f16 (FADDPv2i16p (EXTRACT_SUBREG FPR128:$Rn, dsub)))>;
 
 // Scalar 64-bit shifts in FPR64 registers.
 def : Pat<(i64 (int_aarch64_neon_sshl (i64 FPR64:$Rn), (i64 FPR64:$Rm))),

diff  --git a/llvm/test/CodeGen/AArch64/faddp-half.ll b/llvm/test/CodeGen/AArch64/faddp-half.ll
index d89205d3ac5f..449b9a5b8c92 100644
--- a/llvm/test/CodeGen/AArch64/faddp-half.ll
+++ b/llvm/test/CodeGen/AArch64/faddp-half.ll
@@ -6,9 +6,7 @@ define half @faddp_2xhalf(<2 x half> %a) {
 ; CHECK-LABEL: faddp_2xhalf:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NEXT:    dup v1.4h, v0.h[1]
-; CHECK-NEXT:    fadd v0.4h, v0.4h, v1.4h
-; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $q0
+; CHECK-NEXT:    faddp h0, v0.2h
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: faddp_2xhalf:
@@ -32,9 +30,7 @@ define half @faddp_2xhalf_commute(<2 x half> %a) {
 ; CHECK-LABEL: faddp_2xhalf_commute:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NEXT:    dup v1.4h, v0.h[1]
-; CHECK-NEXT:    fadd v0.4h, v1.4h, v0.4h
-; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $q0
+; CHECK-NEXT:    faddp h0, v0.2h
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: faddp_2xhalf_commute:
@@ -58,9 +54,7 @@ define half @faddp_4xhalf(<4 x half> %a) {
 ; CHECK-LABEL: faddp_4xhalf:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NEXT:    dup v1.4h, v0.h[1]
-; CHECK-NEXT:    fadd v0.4h, v0.4h, v1.4h
-; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $q0
+; CHECK-NEXT:    faddp h0, v0.2h
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: faddp_4xhalf:
@@ -84,9 +78,7 @@ define half @faddp_4xhalf_commute(<4 x half> %a) {
 ; CHECK-LABEL: faddp_4xhalf_commute:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NEXT:    dup v1.4h, v0.h[1]
-; CHECK-NEXT:    fadd v0.4h, v1.4h, v0.4h
-; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $q0
+; CHECK-NEXT:    faddp h0, v0.2h
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: faddp_4xhalf_commute:
@@ -109,9 +101,7 @@ entry:
 define half @faddp_8xhalf(<8 x half> %a) {
 ; CHECK-LABEL: faddp_8xhalf:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    dup v1.8h, v0.h[1]
-; CHECK-NEXT:    fadd v0.8h, v0.8h, v1.8h
-; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $q0
+; CHECK-NEXT:    faddp h0, v0.2h
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: faddp_8xhalf:
@@ -132,9 +122,7 @@ entry:
 define half @faddp_8xhalf_commute(<8 x half> %a) {
 ; CHECK-LABEL: faddp_8xhalf_commute:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    dup v1.8h, v0.h[1]
-; CHECK-NEXT:    fadd v0.8h, v1.8h, v0.8h
-; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $q0
+; CHECK-NEXT:    faddp h0, v0.2h
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: faddp_8xhalf_commute:

diff  --git a/llvm/test/CodeGen/AArch64/faddp.ll b/llvm/test/CodeGen/AArch64/faddp.ll
index 299ff08b513f..06e976136c37 100644
--- a/llvm/test/CodeGen/AArch64/faddp.ll
+++ b/llvm/test/CodeGen/AArch64/faddp.ll
@@ -5,9 +5,7 @@ define float @faddp_2xfloat(<2 x float> %a) {
 ; CHECK-LABEL: faddp_2xfloat:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NEXT:    dup v1.2s, v0.s[1]
-; CHECK-NEXT:    fadd v0.2s, v0.2s, v1.2s
-; CHECK-NEXT:    // kill: def $s0 killed $s0 killed $q0
+; CHECK-NEXT:    faddp s0, v0.2s
 ; CHECK-NEXT:    ret
 entry:
   %shift = shufflevector <2 x float> %a, <2 x float> undef, <2 x i32> <i32 1, i32 undef>
@@ -19,9 +17,7 @@ entry:
 define float @faddp_4xfloat(<4 x float> %a) {
 ; CHECK-LABEL: faddp_4xfloat:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    dup v1.4s, v0.s[1]
-; CHECK-NEXT:    fadd v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    // kill: def $s0 killed $s0 killed $q0
+; CHECK-NEXT:    faddp s0, v0.2s
 ; CHECK-NEXT:    ret
 entry:
   %shift = shufflevector <4 x float> %a, <4 x float> undef, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
@@ -33,9 +29,7 @@ entry:
 define float @faddp_4xfloat_commute(<4 x float> %a) {
 ; CHECK-LABEL: faddp_4xfloat_commute:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    dup v1.4s, v0.s[1]
-; CHECK-NEXT:    fadd v0.4s, v1.4s, v0.4s
-; CHECK-NEXT:    // kill: def $s0 killed $s0 killed $q0
+; CHECK-NEXT:    faddp s0, v0.2s
 ; CHECK-NEXT:    ret
 entry:
   %shift = shufflevector <4 x float> %a, <4 x float> undef, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
@@ -48,9 +42,7 @@ define float @faddp_2xfloat_commute(<2 x float> %a) {
 ; CHECK-LABEL: faddp_2xfloat_commute:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NEXT:    dup v1.2s, v0.s[1]
-; CHECK-NEXT:    fadd v0.2s, v1.2s, v0.2s
-; CHECK-NEXT:    // kill: def $s0 killed $s0 killed $q0
+; CHECK-NEXT:    faddp s0, v0.2s
 ; CHECK-NEXT:    ret
 entry:
   %shift = shufflevector <2 x float> %a, <2 x float> undef, <2 x i32> <i32 1, i32 undef>
@@ -62,9 +54,7 @@ entry:
 define double @faddp_2xdouble(<2 x double> %a) {
 ; CHECK-LABEL: faddp_2xdouble:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    dup v1.2d, v0.d[1]
-; CHECK-NEXT:    fadd v0.2d, v0.2d, v1.2d
-; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    faddp d0, v0.2d
 ; CHECK-NEXT:    ret
 entry:
   %shift = shufflevector <2 x double> %a, <2 x double> undef, <2 x i32> <i32 1, i32 undef>
@@ -76,9 +66,7 @@ entry:
 define double @faddp_2xdouble_commute(<2 x double> %a) {
 ; CHECK-LABEL: faddp_2xdouble_commute:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    dup v1.2d, v0.d[1]
-; CHECK-NEXT:    fadd v0.2d, v1.2d, v0.2d
-; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
+; CHECK-NEXT:    faddp d0, v0.2d
 ; CHECK-NEXT:    ret
 entry:
   %shift = shufflevector <2 x double> %a, <2 x double> undef, <2 x i32> <i32 1, i32 undef>
@@ -90,8 +78,7 @@ entry:
 define i64 @addp_2xi64(<2 x i64> %a) {
 ; CHECK-LABEL: addp_2xi64:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    dup v1.2d, v0.d[1]
-; CHECK-NEXT:    add v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    addp d0, v0.2d
 ; CHECK-NEXT:    fmov x0, d0
 ; CHECK-NEXT:    ret
 entry:
@@ -104,8 +91,7 @@ entry:
 define i64 @addp_2xi64_commute(<2 x i64> %a) {
 ; CHECK-LABEL: addp_2xi64_commute:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    dup v1.2d, v0.d[1]
-; CHECK-NEXT:    add v0.2d, v1.2d, v0.2d
+; CHECK-NEXT:    addp d0, v0.2d
 ; CHECK-NEXT:    fmov x0, d0
 ; CHECK-NEXT:    ret
 entry:

diff  --git a/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll b/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll
index 9552f4d0575c..90367377fb4a 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll
@@ -22,10 +22,9 @@ define half @add_HalfH(<4 x half> %bin.rdx)  {
 ; CHECK-LABEL: add_HalfH:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NEXT:    mov h3, v0.h[1]
 ; CHECK-NEXT:    mov h1, v0.h[3]
 ; CHECK-NEXT:    mov h2, v0.h[2]
-; CHECK-NEXT:    fadd h0, h0, h3
+; CHECK-NEXT:    faddp h0, v0.2h
 ; CHECK-NEXT:    fadd h0, h0, h2
 ; CHECK-NEXT:    fadd h0, h0, h1
 ; CHECK-NEXT:    ret
@@ -59,10 +58,9 @@ define half @add_H(<8 x half> %bin.rdx)  {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
 ; CHECK-NEXT:    fadd v0.4h, v0.4h, v1.4h
-; CHECK-NEXT:    mov h1, v0.h[1]
-; CHECK-NEXT:    mov h2, v0.h[2]
-; CHECK-NEXT:    fadd h1, h0, h1
-; CHECK-NEXT:    fadd h1, h1, h2
+; CHECK-NEXT:    mov h1, v0.h[2]
+; CHECK-NEXT:    faddp h2, v0.2h
+; CHECK-NEXT:    fadd h1, h2, h1
 ; CHECK-NEXT:    mov h0, v0.h[3]
 ; CHECK-NEXT:    fadd h0, h1, h0
 ; CHECK-NEXT:    ret
@@ -105,7 +103,6 @@ define half @add_H(<8 x half> %bin.rdx)  {
 ; CHECKNOFP16-NEXT:    fadd s0, s0, s1
 ; CHECKNOFP16-NEXT:    fcvt h0, s0
 ; CHECKNOFP16-NEXT:    ret
-
   %r = call fast half @llvm.experimental.vector.reduce.v2.fadd.f16.v8f16(half 0.0, <8 x half> %bin.rdx)
   ret half %r
 }
@@ -148,10 +145,9 @@ define half @add_2H(<16 x half> %bin.rdx)  {
 ; CHECK-NEXT:    fadd v0.8h, v0.8h, v1.8h
 ; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
 ; CHECK-NEXT:    fadd v0.4h, v0.4h, v1.4h
-; CHECK-NEXT:    mov h1, v0.h[1]
-; CHECK-NEXT:    mov h2, v0.h[2]
-; CHECK-NEXT:    fadd h1, h0, h1
-; CHECK-NEXT:    fadd h1, h1, h2
+; CHECK-NEXT:    mov h1, v0.h[2]
+; CHECK-NEXT:    faddp h2, v0.2h
+; CHECK-NEXT:    fadd h1, h2, h1
 ; CHECK-NEXT:    mov h0, v0.h[3]
 ; CHECK-NEXT:    fadd h0, h1, h0
 ; CHECK-NEXT:    ret


        


More information about the llvm-commits mailing list