[llvm] 05d424d - [AArch64][SVE] Fold fadda(ptrue, x, select(mask, y, -0.0)) into fadda(mask, x, y)

Rosie Sumpter via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 19 00:39:18 PDT 2022


Author: Rosie Sumpter
Date: 2022-07-19T08:31:51+01:00
New Revision: 05d424d165631aee9c27a5ca18a534d0b7823fe7

URL: https://github.com/llvm/llvm-project/commit/05d424d165631aee9c27a5ca18a534d0b7823fe7
DIFF: https://github.com/llvm/llvm-project/commit/05d424d165631aee9c27a5ca18a534d0b7823fe7.diff

LOG: [AArch64][SVE] Fold fadda(ptrue, x, select(mask, y, -0.0)) into fadda(mask, x, y)

This patch adds an SVE pattern to recognize the use of a select with an
fadda in the form fadda(ptrue, x, select(mask, y, -0.0)). In this case
the select can be folded away, with the select mask used as the
predicate for fadda. This improves the codegen when vectorizing loops
with ordered fp reductions.

Differential Revision: https://reviews.llvm.org/D129623

Added: 
    llvm/test/CodeGen/AArch64/sve-fadda-select.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64InstrFormats.td
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 02fa36a1df4bb..7df9831d598fa 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -1234,6 +1234,10 @@ def fpimm0 : FPImmLeaf<fAny, [{
   return Imm.isExactlyValue(+0.0);
 }]>;
 
+def fpimm_minus0 : FPImmLeaf<fAny, [{
+  return Imm.isExactlyValue(-0.0);
+}]>;
+
 def fpimm_half : FPImmLeaf<fAny, [{
   return Imm.isExactlyValue(+0.5);
 }]>;

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index c66f9cfd9c226..e49b5b8901d77 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -278,10 +278,18 @@ def AArch64scvtf_mt  : SDNode<"AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU", SDT_AArch
 def AArch64fcvtzu_mt : SDNode<"AArch64ISD::FCVTZU_MERGE_PASSTHRU", SDT_AArch64FCVT>;
 def AArch64fcvtzs_mt : SDNode<"AArch64ISD::FCVTZS_MERGE_PASSTHRU", SDT_AArch64FCVT>;
 
-def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>;
-def AArch64clasta_n   : SDNode<"AArch64ISD::CLASTA_N",   SDT_AArch64ReduceWithInit>;
-def AArch64clastb_n   : SDNode<"AArch64ISD::CLASTB_N",   SDT_AArch64ReduceWithInit>;
-def AArch64fadda_p    : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;
+def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3,
+   [SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisVec<3>, SDTCisSameNumEltsAs<1,3>]>;
+def AArch64clasta_n     : SDNode<"AArch64ISD::CLASTA_N",   SDT_AArch64ReduceWithInit>;
+def AArch64clastb_n     : SDNode<"AArch64ISD::CLASTB_N",   SDT_AArch64ReduceWithInit>;
+def AArch64fadda_p_node : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;
+
+def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
+    [(AArch64fadda_p_node node:$op1, node:$op2, node:$op3),
+     (AArch64fadda_p_node (SVEAllActive), node:$op2,
+             (vselect node:$op1, node:$op3, (splat_vector (f32 fpimm_minus0)))),
+     (AArch64fadda_p_node (SVEAllActive), node:$op2,
+             (vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;
 
 def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
 def AArch64ptest     : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;

diff  --git a/llvm/test/CodeGen/AArch64/sve-fadda-select.ll b/llvm/test/CodeGen/AArch64/sve-fadda-select.ll
new file mode 100644
index 0000000000000..e02984072325c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-fadda-select.ll
@@ -0,0 +1,112 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
+
+; Fold fadda(ptrue, x, select(mask, y, -0.0)) -> fadda(mask, x, y)
+
+define float @pred_fadda_nxv2f32(float %x, <vscale x 2 x float> %y, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: pred_fadda_nxv2f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $s0 killed $s0 def $z0
+; CHECK-NEXT:    fadda s0, p0, s0, z1.s
+; CHECK-NEXT:    // kill: def $s0 killed $s0 killed $z0
+; CHECK-NEXT:    ret
+  %i = insertelement <vscale x 2 x float> poison, float -0.000000e+00, i32 0
+  %minus0 = shufflevector <vscale x 2 x float> %i, <vscale x 2 x float> poison, <vscale x 2 x i32> zeroinitializer
+  %sel = select <vscale x 2 x i1> %mask, <vscale x 2 x float> %y, <vscale x 2 x float> %minus0
+  %fadda = call float @llvm.vector.reduce.fadd.nxv2f32(float %x, <vscale x 2 x float> %sel)
+  ret float %fadda
+}
+
+define float @pred_fadda_nxv4f32(float %x, <vscale x 4 x float> %y, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: pred_fadda_nxv4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $s0 killed $s0 def $z0
+; CHECK-NEXT:    fadda s0, p0, s0, z1.s
+; CHECK-NEXT:    // kill: def $s0 killed $s0 killed $z0
+; CHECK-NEXT:    ret
+  %i = insertelement <vscale x 4 x float> poison, float -0.000000e+00, i32 0
+  %minus0 = shufflevector <vscale x 4 x float> %i, <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
+  %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %y, <vscale x 4 x float> %minus0
+  %fadda = call float @llvm.vector.reduce.fadd.nxv4f32(float %x, <vscale x 4 x float> %sel)
+  ret float %fadda
+}
+
+define double @pred_fadda_nxv2f64(double %x, <vscale x 2 x double> %y, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: pred_fadda_nxv2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT:    fadda d0, p0, d0, z1.d
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT:    ret
+  %i = insertelement <vscale x 2 x double> poison, double -0.000000e+00, i32 0
+  %minus0 = shufflevector <vscale x 2 x double> %i, <vscale x 2 x double> poison, <vscale x 2 x i32> zeroinitializer
+  %sel = select <vscale x 2 x i1> %mask, <vscale x 2 x double> %y, <vscale x 2 x double> %minus0
+  %fadda = call double @llvm.vector.reduce.fadd.nxv2f64(double %x, <vscale x 2 x double> %sel)
+  ret double %fadda
+}
+
+; Currently the folding doesn't work for f16 element types, since -0.0 is not treated as a legal f16 immediate.
+
+define half @pred_fadda_nxv2f16(half %x, <vscale x 2 x half> %y, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: pred_fadda_nxv2f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    adrp x8, .LCPI3_0
+; CHECK-NEXT:    add x8, x8, :lo12:.LCPI3_0
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $z0
+; CHECK-NEXT:    ld1rh { z2.d }, p1/z, [x8]
+; CHECK-NEXT:    sel z1.d, p0, z1.d, z2.d
+; CHECK-NEXT:    fadda h0, p1, h0, z1.h
+; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $z0
+; CHECK-NEXT:    ret
+  %i = insertelement <vscale x 2 x half> poison, half -0.000000e+00, i32 0
+  %minus0 = shufflevector <vscale x 2 x half> %i, <vscale x 2 x half> poison, <vscale x 2 x i32> zeroinitializer
+  %sel = select <vscale x 2 x i1> %mask, <vscale x 2 x half> %y, <vscale x 2 x half> %minus0
+  %fadda = call half @llvm.vector.reduce.fadd.nxv2f16(half %x, <vscale x 2 x half> %sel)
+  ret half %fadda
+}
+
+define half @pred_fadda_nxv4f16(half %x, <vscale x 4 x half> %y, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: pred_fadda_nxv4f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    adrp x8, .LCPI4_0
+; CHECK-NEXT:    add x8, x8, :lo12:.LCPI4_0
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $z0
+; CHECK-NEXT:    ld1rh { z2.s }, p1/z, [x8]
+; CHECK-NEXT:    sel z1.s, p0, z1.s, z2.s
+; CHECK-NEXT:    fadda h0, p1, h0, z1.h
+; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $z0
+; CHECK-NEXT:    ret
+  %i = insertelement <vscale x 4 x half> poison, half -0.000000e+00, i32 0
+  %minus0 = shufflevector <vscale x 4 x half> %i, <vscale x 4 x half> poison, <vscale x 4 x i32> zeroinitializer
+  %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x half> %y, <vscale x 4 x half> %minus0
+  %fadda = call half @llvm.vector.reduce.fadd.nxv4f16(half %x, <vscale x 4 x half> %sel)
+  ret half %fadda
+}
+
+define half @pred_fadda_nxv8f16(half %x, <vscale x 8 x half> %y, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: pred_fadda_nxv8f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    adrp x8, .LCPI5_0
+; CHECK-NEXT:    add x8, x8, :lo12:.LCPI5_0
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    // kill: def $h0 killed $h0 def $z0
+; CHECK-NEXT:    ld1rh { z2.h }, p1/z, [x8]
+; CHECK-NEXT:    sel z1.h, p0, z1.h, z2.h
+; CHECK-NEXT:    fadda h0, p1, h0, z1.h
+; CHECK-NEXT:    // kill: def $h0 killed $h0 killed $z0
+; CHECK-NEXT:    ret
+  %i = insertelement <vscale x 8 x half> poison, half -0.000000e+00, i32 0
+  %minus0 = shufflevector <vscale x 8 x half> %i, <vscale x 8 x half> poison, <vscale x 8 x i32> zeroinitializer
+  %sel = select <vscale x 8 x i1> %mask, <vscale x 8 x half> %y, <vscale x 8 x half> %minus0
+  %fadda = call half @llvm.vector.reduce.fadd.nxv8f16(half %x, <vscale x 8 x half> %sel)
+  ret half %fadda
+}
+
+declare float @llvm.vector.reduce.fadd.nxv2f32(float, <vscale x 2 x float>)
+declare float @llvm.vector.reduce.fadd.nxv4f32(float, <vscale x 4 x float>)
+declare double @llvm.vector.reduce.fadd.nxv2f64(double, <vscale x 2 x double>)
+declare half @llvm.vector.reduce.fadd.nxv2f16(half, <vscale x 2 x half>)
+declare half @llvm.vector.reduce.fadd.nxv4f16(half, <vscale x 4 x half>)
+declare half @llvm.vector.reduce.fadd.nxv8f16(half, <vscale x 8 x half>)


        


More information about the llvm-commits mailing list