[llvm-branch-commits] [llvm] a9f5e43 - [AArch64] Use faddp to implement fadd reductions.
Sander de Smalen via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jan 6 01:42:10 PST 2021
Author: Sander de Smalen
Date: 2021-01-06T09:36:51Z
New Revision: a9f5e4375b36e5316b8d6f9731be6bfa5a70e276
URL: https://github.com/llvm/llvm-project/commit/a9f5e4375b36e5316b8d6f9731be6bfa5a70e276
DIFF: https://github.com/llvm/llvm-project/commit/a9f5e4375b36e5316b8d6f9731be6bfa5a70e276.diff
LOG: [AArch64] Use faddp to implement fadd reductions.
Custom-expand legal VECREDUCE_FADD SDNodes
to benefit from pair-wise faddp instructions.
Reviewed By: dmgreen
Differential Revision: https://reviews.llvm.org/D59259
Added:
Modified:
llvm/include/llvm/Target/TargetSelectionDAG.td
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64InstrInfo.td
llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization.ll
llvm/test/CodeGen/AArch64/vecreduce-fadd.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index d5b8aeb1055d..0c6eef939ea4 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -250,6 +250,10 @@ def SDTVecInsert : SDTypeProfile<1, 3, [ // vector insert
def SDTVecReduce : SDTypeProfile<1, 1, [ // vector reduction
SDTCisInt<0>, SDTCisVec<1>
]>;
+def SDTFPVecReduce : SDTypeProfile<1, 1, [ // FP vector reduction
+ SDTCisFP<0>, SDTCisVec<1>
+]>;
+
def SDTSubVecExtract : SDTypeProfile<1, 2, [// subvector extract
SDTCisSubVecOfVec<0,1>, SDTCisInt<2>
@@ -439,6 +443,7 @@ def vecreduce_smax : SDNode<"ISD::VECREDUCE_SMAX", SDTVecReduce>;
def vecreduce_umax : SDNode<"ISD::VECREDUCE_UMAX", SDTVecReduce>;
def vecreduce_smin : SDNode<"ISD::VECREDUCE_SMIN", SDTVecReduce>;
def vecreduce_umin : SDNode<"ISD::VECREDUCE_UMIN", SDTVecReduce>;
+def vecreduce_fadd : SDNode<"ISD::VECREDUCE_FADD", SDTFPVecReduce>;
def fadd : SDNode<"ISD::FADD" , SDTFPBinOp, [SDNPCommutative]>;
def fsub : SDNode<"ISD::FSUB" , SDTFPBinOp>;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index faed7c64a15e..2b9dc84a06cc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -999,6 +999,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
+
+ if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16())
+ setOperationAction(ISD::VECREDUCE_FADD, VT, Legal);
}
for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
MVT::v16i8, MVT::v8i16, MVT::v4i32 }) {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 4d70fb334828..7e9f2fb95188 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -4989,6 +4989,26 @@ defm FMAXNMP : SIMDFPPairwiseScalar<0, 0b01100, "fmaxnmp">;
defm FMAXP : SIMDFPPairwiseScalar<0, 0b01111, "fmaxp">;
defm FMINNMP : SIMDFPPairwiseScalar<1, 0b01100, "fminnmp">;
defm FMINP : SIMDFPPairwiseScalar<1, 0b01111, "fminp">;
+
+let Predicates = [HasFullFP16] in {
+def : Pat<(f16 (vecreduce_fadd (v8f16 V128:$Rn))),
+ (FADDPv2i16p
+ (EXTRACT_SUBREG
+ (FADDPv8f16 (FADDPv8f16 V128:$Rn, (v8f16 (IMPLICIT_DEF))), (v8f16 (IMPLICIT_DEF))),
+ dsub))>;
+def : Pat<(f16 (vecreduce_fadd (v4f16 V64:$Rn))),
+ (FADDPv2i16p (FADDPv4f16 V64:$Rn, (v4f16 (IMPLICIT_DEF))))>;
+}
+def : Pat<(f32 (vecreduce_fadd (v4f32 V128:$Rn))),
+ (FADDPv2i32p
+ (EXTRACT_SUBREG
+ (FADDPv4f32 V128:$Rn, (v4f32 (IMPLICIT_DEF))),
+ dsub))>;
+def : Pat<(f32 (vecreduce_fadd (v2f32 V64:$Rn))),
+ (FADDPv2i32p V64:$Rn)>;
+def : Pat<(f64 (vecreduce_fadd (v2f64 V128:$Rn))),
+ (FADDPv2i64p V128:$Rn)>;
+
def : Pat<(v2i64 (AArch64saddv V128:$Rn)),
(INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (ADDPv2i64p V128:$Rn), dsub)>;
def : Pat<(v2i64 (AArch64uaddv V128:$Rn)),
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization.ll b/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization.ll
index 69b9c3e22d7a..2b5fcd4b839a 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization.ll
@@ -51,8 +51,7 @@ define float @test_v3f32(<3 x float> %a) nounwind {
; CHECK-NEXT: mov w8, #-2147483648
; CHECK-NEXT: fmov s1, w8
; CHECK-NEXT: mov v0.s[3], v1.s[0]
-; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT: fadd v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: faddp v0.4s, v0.4s, v0.4s
; CHECK-NEXT: faddp s0, v0.2s
; CHECK-NEXT: ret
%b = call reassoc float @llvm.vector.reduce.fadd.f32.v3f32(float -0.0, <3 x float> %a)
@@ -73,8 +72,7 @@ define float @test_v5f32(<5 x float> %a) nounwind {
; CHECK-NEXT: mov v0.s[3], v3.s[0]
; CHECK-NEXT: mov v5.s[0], v4.s[0]
; CHECK-NEXT: fadd v0.4s, v0.4s, v5.4s
-; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT: fadd v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: faddp v0.4s, v0.4s, v0.4s
; CHECK-NEXT: faddp s0, v0.2s
; CHECK-NEXT: ret
%b = call reassoc float @llvm.vector.reduce.fadd.f32.v5f32(float -0.0, <5 x float> %a)
@@ -95,8 +93,7 @@ define float @test_v16f32(<16 x float> %a) nounwind {
; CHECK-NEXT: fadd v1.4s, v1.4s, v3.4s
; CHECK-NEXT: fadd v0.4s, v0.4s, v2.4s
; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s
-; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT: fadd v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: faddp v0.4s, v0.4s, v0.4s
; CHECK-NEXT: faddp s0, v0.2s
; CHECK-NEXT: ret
%b = call reassoc float @llvm.vector.reduce.fadd.f32.v16f32(float -0.0, <16 x float> %a)
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll b/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll
index 12e8be962f79..5bb36432a91b 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-fadd.ll
@@ -5,13 +5,11 @@
define float @add_HalfS(<2 x float> %bin.rdx) {
; CHECK-LABEL: add_HalfS:
; CHECK: // %bb.0:
-; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NEXT: faddp s0, v0.2s
; CHECK-NEXT: ret
;
; CHECKNOFP16-LABEL: add_HalfS:
; CHECKNOFP16: // %bb.0:
-; CHECKNOFP16-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECKNOFP16-NEXT: faddp s0, v0.2s
; CHECKNOFP16-NEXT: ret
%r = call fast float @llvm.vector.reduce.fadd.f32.v2f32(float -0.0, <2 x float> %bin.rdx)
@@ -21,12 +19,8 @@ define float @add_HalfS(<2 x float> %bin.rdx) {
define half @add_HalfH(<4 x half> %bin.rdx) {
; CHECK-LABEL: add_HalfH:
; CHECK: // %bb.0:
-; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NEXT: mov h1, v0.h[3]
-; CHECK-NEXT: mov h2, v0.h[2]
+; CHECK-NEXT: faddp v0.4h, v0.4h, v0.4h
; CHECK-NEXT: faddp h0, v0.2h
-; CHECK-NEXT: fadd h0, h0, h2
-; CHECK-NEXT: fadd h0, h0, h1
; CHECK-NEXT: ret
;
; CHECKNOFP16-LABEL: add_HalfH:
@@ -56,13 +50,9 @@ define half @add_HalfH(<4 x half> %bin.rdx) {
define half @add_H(<8 x half> %bin.rdx) {
; CHECK-LABEL: add_H:
; CHECK: // %bb.0:
-; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT: fadd v0.4h, v0.4h, v1.4h
-; CHECK-NEXT: mov h1, v0.h[2]
-; CHECK-NEXT: faddp h2, v0.2h
-; CHECK-NEXT: fadd h1, h2, h1
-; CHECK-NEXT: mov h0, v0.h[3]
-; CHECK-NEXT: fadd h0, h1, h0
+; CHECK-NEXT: faddp v0.8h, v0.8h, v0.8h
+; CHECK-NEXT: faddp v0.8h, v0.8h, v0.8h
+; CHECK-NEXT: faddp h0, v0.2h
; CHECK-NEXT: ret
;
; CHECKNOFP16-LABEL: add_H:
@@ -110,15 +100,13 @@ define half @add_H(<8 x half> %bin.rdx) {
define float @add_S(<4 x float> %bin.rdx) {
; CHECK-LABEL: add_S:
; CHECK: // %bb.0:
-; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT: fadd v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: faddp v0.4s, v0.4s, v0.4s
; CHECK-NEXT: faddp s0, v0.2s
; CHECK-NEXT: ret
;
; CHECKNOFP16-LABEL: add_S:
; CHECKNOFP16: // %bb.0:
-; CHECKNOFP16-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECKNOFP16-NEXT: fadd v0.2s, v0.2s, v1.2s
+; CHECKNOFP16-NEXT: faddp v0.4s, v0.4s, v0.4s
; CHECKNOFP16-NEXT: faddp s0, v0.2s
; CHECKNOFP16-NEXT: ret
%r = call fast float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %bin.rdx)
@@ -143,13 +131,9 @@ define half @add_2H(<16 x half> %bin.rdx) {
; CHECK-LABEL: add_2H:
; CHECK: // %bb.0:
; CHECK-NEXT: fadd v0.8h, v0.8h, v1.8h
-; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT: fadd v0.4h, v0.4h, v1.4h
-; CHECK-NEXT: mov h1, v0.h[2]
-; CHECK-NEXT: faddp h2, v0.2h
-; CHECK-NEXT: fadd h1, h2, h1
-; CHECK-NEXT: mov h0, v0.h[3]
-; CHECK-NEXT: fadd h0, h1, h0
+; CHECK-NEXT: faddp v0.8h, v0.8h, v0.8h
+; CHECK-NEXT: faddp v0.8h, v0.8h, v0.8h
+; CHECK-NEXT: faddp h0, v0.2h
; CHECK-NEXT: ret
;
; CHECKNOFP16-LABEL: add_2H:
@@ -237,16 +221,14 @@ define float @add_2S(<8 x float> %bin.rdx) {
; CHECK-LABEL: add_2S:
; CHECK: // %bb.0:
; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s
-; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT: fadd v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: faddp v0.4s, v0.4s, v0.4s
; CHECK-NEXT: faddp s0, v0.2s
; CHECK-NEXT: ret
;
; CHECKNOFP16-LABEL: add_2S:
; CHECKNOFP16: // %bb.0:
; CHECKNOFP16-NEXT: fadd v0.4s, v0.4s, v1.4s
-; CHECKNOFP16-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECKNOFP16-NEXT: fadd v0.2s, v0.2s, v1.2s
+; CHECKNOFP16-NEXT: faddp v0.4s, v0.4s, v0.4s
; CHECKNOFP16-NEXT: faddp s0, v0.2s
; CHECKNOFP16-NEXT: ret
%r = call fast float @llvm.vector.reduce.fadd.f32.v8f32(float -0.0, <8 x float> %bin.rdx)
@@ -269,6 +251,29 @@ define double @add_2D(<4 x double> %bin.rdx) {
ret double %r
}
+; Added at least one test where the start value is not -0.0.
+define float @add_S_init_42(<4 x float> %bin.rdx) {
+; CHECK-LABEL: add_S_init_42:
+; CHECK: // %bb.0:
+; CHECK-NEXT: faddp v0.4s, v0.4s, v0.4s
+; CHECK-NEXT: mov w8, #1109917696
+; CHECK-NEXT: faddp s0, v0.2s
+; CHECK-NEXT: fmov s1, w8
+; CHECK-NEXT: fadd s0, s0, s1
+; CHECK-NEXT: ret
+;
+; CHECKNOFP16-LABEL: add_S_init_42:
+; CHECKNOFP16: // %bb.0:
+; CHECKNOFP16-NEXT: faddp v0.4s, v0.4s, v0.4s
+; CHECKNOFP16-NEXT: mov w8, #1109917696
+; CHECKNOFP16-NEXT: faddp s0, v0.2s
+; CHECKNOFP16-NEXT: fmov s1, w8
+; CHECKNOFP16-NEXT: fadd s0, s0, s1
+; CHECKNOFP16-NEXT: ret
+ %r = call fast float @llvm.vector.reduce.fadd.f32.v4f32(float 42.0, <4 x float> %bin.rdx)
+ ret float %r
+}
+
; Function Attrs: nounwind readnone
declare half @llvm.vector.reduce.fadd.f16.v4f16(half, <4 x half>)
declare half @llvm.vector.reduce.fadd.f16.v8f16(half, <8 x half>)
More information about the llvm-branch-commits
mailing list