[llvm] a9f5e43 - [AArch64] Use faddp to implement fadd reductions.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 6 01:37:41 PST 2021


Author: Sander de Smalen
Date: 2021-01-06T09:36:51Z
New Revision: a9f5e4375b36e5316b8d6f9731be6bfa5a70e276

URL: https://github.com/llvm/llvm-project/commit/a9f5e4375b36e5316b8d6f9731be6bfa5a70e276
DIFF: https://github.com/llvm/llvm-project/commit/a9f5e4375b36e5316b8d6f9731be6bfa5a70e276.diff

LOG: [AArch64] Use faddp to implement fadd reductions.

Custom-expand legal VECREDUCE_FADD SDNodes
to benefit from pair-wise faddp instructions.

Reviewed By: dmgreen

Differential Revision: https://reviews.llvm.org/D59259

Added: 
    

Modified: 
    llvm/include/llvm/Target/TargetSelectionDAG.td
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.td
    llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization.ll
    llvm/test/CodeGen/AArch64/vecreduce-fadd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index d5b8aeb1055d..0c6eef939ea4 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -250,6 +250,10 @@ def SDTVecInsert : SDTypeProfile<1, 3, [    // vector insert
 def SDTVecReduce : SDTypeProfile<1, 1, [    // vector reduction
   SDTCisInt<0>, SDTCisVec<1>
 ]>;
+def SDTFPVecReduce : SDTypeProfile<1, 1, [  // FP vector reduction
+  SDTCisFP<0>, SDTCisVec<1>
+]>;
+
 
 def SDTSubVecExtract : SDTypeProfile<1, 2, [// subvector extract
   SDTCisSubVecOfVec<0,1>, SDTCisInt<2>
@@ -439,6 +443,7 @@ def vecreduce_smax  : SDNode<"ISD::VECREDUCE_SMAX", SDTVecReduce>;
 def vecreduce_umax  : SDNode<"ISD::VECREDUCE_UMAX", SDTVecReduce>;
 def vecreduce_smin  : SDNode<"ISD::VECREDUCE_SMIN", SDTVecReduce>;
 def vecreduce_umin  : SDNode<"ISD::VECREDUCE_UMIN", SDTVecReduce>;
+def vecreduce_fadd  : SDNode<"ISD::VECREDUCE_FADD", SDTFPVecReduce>;
 
 def fadd       : SDNode<"ISD::FADD"       , SDTFPBinOp, [SDNPCommutative]>;
 def fsub       : SDNode<"ISD::FSUB"       , SDTFPBinOp>;

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index faed7c64a15e..2b9dc84a06cc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -999,6 +999,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                     MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
       setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
       setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
+
+      if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16())
+        setOperationAction(ISD::VECREDUCE_FADD, VT, Legal);
     }
     for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
                     MVT::v16i8, MVT::v8i16, MVT::v4i32 }) {

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 4d70fb334828..7e9f2fb95188 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -4989,6 +4989,26 @@ defm FMAXNMP : SIMDFPPairwiseScalar<0, 0b01100, "fmaxnmp">;
 defm FMAXP   : SIMDFPPairwiseScalar<0, 0b01111, "fmaxp">;
 defm FMINNMP : SIMDFPPairwiseScalar<1, 0b01100, "fminnmp">;
 defm FMINP   : SIMDFPPairwiseScalar<1, 0b01111, "fminp">;
+
+let Predicates = [HasFullFP16] in {
+def : Pat<(f16 (vecreduce_fadd (v8f16 V128:$Rn))),
+            (FADDPv2i16p
+              (EXTRACT_SUBREG
+                 (FADDPv8f16 (FADDPv8f16 V128:$Rn, (v8f16 (IMPLICIT_DEF))), (v8f16 (IMPLICIT_DEF))),
+               dsub))>;
+def : Pat<(f16 (vecreduce_fadd (v4f16 V64:$Rn))),
+          (FADDPv2i16p (FADDPv4f16 V64:$Rn, (v4f16 (IMPLICIT_DEF))))>;
+}
+def : Pat<(f32 (vecreduce_fadd (v4f32 V128:$Rn))),
+          (FADDPv2i32p
+            (EXTRACT_SUBREG
+              (FADDPv4f32 V128:$Rn, (v4f32 (IMPLICIT_DEF))),
+             dsub))>;
+def : Pat<(f32 (vecreduce_fadd (v2f32 V64:$Rn))),
+          (FADDPv2i32p V64:$Rn)>;
+def : Pat<(f64 (vecreduce_fadd (v2f64 V128:$Rn))),
+          (FADDPv2i64p V128:$Rn)>;
+
 def : Pat<(v2i64 (AArch64saddv V128:$Rn)),
           (INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (ADDPv2i64p V128:$Rn), dsub)>;
 def : Pat<(v2i64 (AArch64uaddv V128:$Rn)),

diff  --git a/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization.ll b/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization.ll
index 69b9c3e22d7a..2b5fcd4b839a 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization.ll
@@ -51,8 +51,7 @@ define float @test_v3f32(<3 x float> %a) nounwind {
 ; CHECK-NEXT:    mov w8, #-2147483648
 ; CHECK-NEXT:    fmov s1, w8
 ; CHECK-NEXT:    mov v0.s[3], v1.s[0]
-; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    fadd v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    faddp v0.4s, v0.4s, v0.4s
 ; CHECK-NEXT:    faddp s0, v0.2s
 ; CHECK-NEXT:    ret
   %b = call reassoc float @llvm.vector.reduce.fadd.f32.v3f32(float -0.0, <3 x float> %a)
@@ -73,8 +72,7 @@ define float @test_v5f32(<5 x float> %a) nounwind {
 ; CHECK-NEXT:    mov v0.s[3], v3.s[0]
 ; CHECK-NEXT:    mov v5.s[0], v4.s[0]
 ; CHECK-NEXT:    fadd v0.4s, v0.4s, v5.4s
-; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    fadd v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    faddp v0.4s, v0.4s, v0.4s
 ; CHECK-NEXT:    faddp s0, v0.2s
 ; CHECK-NEXT:    ret
   %b = call reassoc float @llvm.vector.reduce.fadd.f32.v5f32(float -0.0, <5 x float> %a)
@@ -95,8 +93,7 @@ define float @test_v16f32(<16 x float> %a) nounwind {
 ; CHECK-NEXT:    fadd v1.4s, v1.4s, v3.4s
 ; CHECK-NEXT:    fadd v0.4s, v0.4s, v2.4s
 ; CHECK-NEXT:    fadd v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    fadd v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    faddp v0.4s, v0.4s, v0.4s
 ; CHECK-NEXT:    faddp s0, v0.2s
 ; CHECK-NEXT:    ret
   %b = call reassoc float @llvm.vector.reduce.fadd.f32.v16f32(float -0.0, <16 x float> %a)

diff  --git a/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll b/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll
index 12e8be962f79..5bb36432a91b 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll
@@ -5,13 +5,11 @@
 define float @add_HalfS(<2 x float> %bin.rdx)  {
 ; CHECK-LABEL: add_HalfS:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
 ; CHECK-NEXT:    faddp s0, v0.2s
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: add_HalfS:
 ; CHECKNOFP16:       // %bb.0:
-; CHECKNOFP16-NEXT:    // kill: def $d0 killed $d0 def $q0
 ; CHECKNOFP16-NEXT:    faddp s0, v0.2s
 ; CHECKNOFP16-NEXT:    ret
   %r = call fast float @llvm.vector.reduce.fadd.f32.v2f32(float -0.0, <2 x float> %bin.rdx)
@@ -21,12 +19,8 @@ define float @add_HalfS(<2 x float> %bin.rdx)  {
 define half @add_HalfH(<4 x half> %bin.rdx)  {
 ; CHECK-LABEL: add_HalfH:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NEXT:    mov h1, v0.h[3]
-; CHECK-NEXT:    mov h2, v0.h[2]
+; CHECK-NEXT:    faddp v0.4h, v0.4h, v0.4h
 ; CHECK-NEXT:    faddp h0, v0.2h
-; CHECK-NEXT:    fadd h0, h0, h2
-; CHECK-NEXT:    fadd h0, h0, h1
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: add_HalfH:
@@ -56,13 +50,9 @@ define half @add_HalfH(<4 x half> %bin.rdx)  {
 define half @add_H(<8 x half> %bin.rdx)  {
 ; CHECK-LABEL: add_H:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    fadd v0.4h, v0.4h, v1.4h
-; CHECK-NEXT:    mov h1, v0.h[2]
-; CHECK-NEXT:    faddp h2, v0.2h
-; CHECK-NEXT:    fadd h1, h2, h1
-; CHECK-NEXT:    mov h0, v0.h[3]
-; CHECK-NEXT:    fadd h0, h1, h0
+; CHECK-NEXT:    faddp v0.8h, v0.8h, v0.8h
+; CHECK-NEXT:    faddp v0.8h, v0.8h, v0.8h
+; CHECK-NEXT:    faddp h0, v0.2h
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: add_H:
@@ -110,15 +100,13 @@ define half @add_H(<8 x half> %bin.rdx)  {
 define float @add_S(<4 x float> %bin.rdx)  {
 ; CHECK-LABEL: add_S:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    fadd v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    faddp v0.4s, v0.4s, v0.4s
 ; CHECK-NEXT:    faddp s0, v0.2s
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: add_S:
 ; CHECKNOFP16:       // %bb.0:
-; CHECKNOFP16-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECKNOFP16-NEXT:    fadd v0.2s, v0.2s, v1.2s
+; CHECKNOFP16-NEXT:    faddp v0.4s, v0.4s, v0.4s
 ; CHECKNOFP16-NEXT:    faddp s0, v0.2s
 ; CHECKNOFP16-NEXT:    ret
   %r = call fast float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %bin.rdx)
@@ -143,13 +131,9 @@ define half @add_2H(<16 x half> %bin.rdx)  {
 ; CHECK-LABEL: add_2H:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fadd v0.8h, v0.8h, v1.8h
-; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    fadd v0.4h, v0.4h, v1.4h
-; CHECK-NEXT:    mov h1, v0.h[2]
-; CHECK-NEXT:    faddp h2, v0.2h
-; CHECK-NEXT:    fadd h1, h2, h1
-; CHECK-NEXT:    mov h0, v0.h[3]
-; CHECK-NEXT:    fadd h0, h1, h0
+; CHECK-NEXT:    faddp v0.8h, v0.8h, v0.8h
+; CHECK-NEXT:    faddp v0.8h, v0.8h, v0.8h
+; CHECK-NEXT:    faddp h0, v0.2h
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: add_2H:
@@ -237,16 +221,14 @@ define float @add_2S(<8 x float> %bin.rdx)  {
 ; CHECK-LABEL: add_2S:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fadd v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    fadd v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    faddp v0.4s, v0.4s, v0.4s
 ; CHECK-NEXT:    faddp s0, v0.2s
 ; CHECK-NEXT:    ret
 ;
 ; CHECKNOFP16-LABEL: add_2S:
 ; CHECKNOFP16:       // %bb.0:
 ; CHECKNOFP16-NEXT:    fadd v0.4s, v0.4s, v1.4s
-; CHECKNOFP16-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECKNOFP16-NEXT:    fadd v0.2s, v0.2s, v1.2s
+; CHECKNOFP16-NEXT:    faddp v0.4s, v0.4s, v0.4s
 ; CHECKNOFP16-NEXT:    faddp s0, v0.2s
 ; CHECKNOFP16-NEXT:    ret
   %r = call fast float @llvm.vector.reduce.fadd.f32.v8f32(float -0.0, <8 x float> %bin.rdx)
@@ -269,6 +251,29 @@ define double @add_2D(<4 x double> %bin.rdx)  {
   ret double %r
 }
 
+; Added at least one test where the start value is not -0.0.
+define float @add_S_init_42(<4 x float> %bin.rdx)  {
+; CHECK-LABEL: add_S_init_42:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    faddp v0.4s, v0.4s, v0.4s
+; CHECK-NEXT:    mov w8, #1109917696
+; CHECK-NEXT:    faddp s0, v0.2s
+; CHECK-NEXT:    fmov s1, w8
+; CHECK-NEXT:    fadd s0, s0, s1
+; CHECK-NEXT:    ret
+;
+; CHECKNOFP16-LABEL: add_S_init_42:
+; CHECKNOFP16:       // %bb.0:
+; CHECKNOFP16-NEXT:    faddp v0.4s, v0.4s, v0.4s
+; CHECKNOFP16-NEXT:    mov w8, #1109917696
+; CHECKNOFP16-NEXT:    faddp s0, v0.2s
+; CHECKNOFP16-NEXT:    fmov s1, w8
+; CHECKNOFP16-NEXT:    fadd s0, s0, s1
+; CHECKNOFP16-NEXT:    ret
+  %r = call fast float @llvm.vector.reduce.fadd.f32.v4f32(float 42.0, <4 x float> %bin.rdx)
+  ret float %r
+}
+
 ; Function Attrs: nounwind readnone
 declare half @llvm.vector.reduce.fadd.f16.v4f16(half, <4 x half>)
 declare half @llvm.vector.reduce.fadd.f16.v8f16(half, <8 x half>)


        


More information about the llvm-commits mailing list