[llvm] 1f49b71 - [SVE][CodeGen] Enable reciprocal estimates for scalable fdiv/fsqrt

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 25 03:31:19 PDT 2021


Author: Kerry McLaughlin
Date: 2021-10-25T11:30:44+01:00
New Revision: 1f49b71fe5fa061065a30c89e6e95b7d123e4bb5

URL: https://github.com/llvm/llvm-project/commit/1f49b71fe5fa061065a30c89e6e95b7d123e4bb5
DIFF: https://github.com/llvm/llvm-project/commit/1f49b71fe5fa061065a30c89e6e95b7d123e4bb5.diff

LOG: [SVE][CodeGen] Enable reciprocal estimates for scalable fdiv/fsqrt

This patch enables the use of reciprocal estimates for SVE
when both the -Ofast and -mrecip flags are used.

Reviewed By: david-arm, paulwalker-arm

Differential Revision: https://reviews.llvm.org/D111657

Added: 
    llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 560effa39f0f..6abb991fe228 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4130,6 +4130,18 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   case Intrinsic::aarch64_sve_frecpx:
     return DAG.getNode(AArch64ISD::FRECPX_MERGE_PASSTHRU, dl, Op.getValueType(),
                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+  case Intrinsic::aarch64_sve_frecpe_x:
+    return DAG.getNode(AArch64ISD::FRECPE, dl, Op.getValueType(),
+                       Op.getOperand(1));
+  case Intrinsic::aarch64_sve_frecps_x:
+    return DAG.getNode(AArch64ISD::FRECPS, dl, Op.getValueType(),
+                       Op.getOperand(1), Op.getOperand(2));
+  case Intrinsic::aarch64_sve_frsqrte_x:
+    return DAG.getNode(AArch64ISD::FRSQRTE, dl, Op.getValueType(),
+                       Op.getOperand(1));
+  case Intrinsic::aarch64_sve_frsqrts_x:
+    return DAG.getNode(AArch64ISD::FRSQRTS, dl, Op.getValueType(),
+                       Op.getOperand(1), Op.getOperand(2));
   case Intrinsic::aarch64_sve_fabs:
     return DAG.getNode(AArch64ISD::FABS_MERGE_PASSTHRU, dl, Op.getValueType(),
                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
@@ -8235,10 +8247,12 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
                            SDValue Operand, SelectionDAG &DAG,
                            int &ExtraSteps) {
   EVT VT = Operand.getValueType();
-  if (ST->hasNEON() &&
-      (VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 ||
-       VT == MVT::f32 || VT == MVT::v1f32 ||
-       VT == MVT::v2f32 || VT == MVT::v4f32)) {
+  if ((ST->hasNEON() &&
+       (VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 ||
+        VT == MVT::f32 || VT == MVT::v1f32 || VT == MVT::v2f32 ||
+        VT == MVT::v4f32)) ||
+      (ST->hasSVE() &&
+       (VT == MVT::nxv8f16 || VT == MVT::nxv4f32 || VT == MVT::nxv2f64))) {
     if (ExtraSteps == TargetLoweringBase::ReciprocalEstimate::Unspecified)
       // For the reciprocal estimates, convergence is quadratic, so the number
       // of digits is doubled after each iteration.  In ARMv8, the accuracy of

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 8879789e83f7..0ac0eb2999d2 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -402,8 +402,8 @@ let Predicates = [HasSVEorStreamingSVE] in {
   defm SMIN_ZPZZ : sve_int_bin_pred_bhsd<AArch64smin_p>;
   defm UMIN_ZPZZ : sve_int_bin_pred_bhsd<AArch64umin_p>;
 
-  defm FRECPE_ZZ  : sve_fp_2op_u_zd<0b110, "frecpe",  int_aarch64_sve_frecpe_x>;
-  defm FRSQRTE_ZZ : sve_fp_2op_u_zd<0b111, "frsqrte", int_aarch64_sve_frsqrte_x>;
+  defm FRECPE_ZZ  : sve_fp_2op_u_zd<0b110, "frecpe",  AArch64frecpe>;
+  defm FRSQRTE_ZZ : sve_fp_2op_u_zd<0b111, "frsqrte", AArch64frsqrte>;
 
   defm FADD_ZPmI    : sve_fp_2op_i_p_zds<0b000, "fadd", "FADD_ZPZI", sve_fpimm_half_one, fpimm_half, fpimm_one, int_aarch64_sve_fadd>;
   defm FSUB_ZPmI    : sve_fp_2op_i_p_zds<0b001, "fsub", "FSUB_ZPZI", sve_fpimm_half_one, fpimm_half, fpimm_one, int_aarch64_sve_fsub>;
@@ -484,8 +484,8 @@ let Predicates = [HasSVE] in {
 } // End HasSVE
 
 let Predicates = [HasSVEorStreamingSVE] in {
-  defm FRECPS_ZZZ  : sve_fp_3op_u_zd<0b110, "frecps",  int_aarch64_sve_frecps_x>;
-  defm FRSQRTS_ZZZ : sve_fp_3op_u_zd<0b111, "frsqrts", int_aarch64_sve_frsqrts_x>;
+  defm FRECPS_ZZZ  : sve_fp_3op_u_zd<0b110, "frecps",  AArch64frecps>;
+  defm FRSQRTS_ZZZ : sve_fp_3op_u_zd<0b111, "frsqrts", AArch64frsqrts>;
 } // End HasSVEorStreamingSVE
 
 let Predicates = [HasSVE] in {

diff  --git a/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll b/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll
new file mode 100644
index 000000000000..2385436cfe58
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll
@@ -0,0 +1,179 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
+
+; FDIV
+
+define <vscale x 8 x half> @fdiv_8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b) {
+; CHECK-LABEL: fdiv_8f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    fdiv z0.h, p0/m, z0.h, z1.h
+; CHECK-NEXT:    ret
+  %fdiv = fdiv fast <vscale x 8 x half> %a, %b
+  ret <vscale x 8 x half> %fdiv
+}
+
+define <vscale x 8 x half> @fdiv_recip_8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b) #0 {
+; CHECK-LABEL: fdiv_recip_8f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    frecpe z2.h, z1.h
+; CHECK-NEXT:    frecps z3.h, z1.h, z2.h
+; CHECK-NEXT:    fmul z2.h, z2.h, z3.h
+; CHECK-NEXT:    frecps z1.h, z1.h, z2.h
+; CHECK-NEXT:    fmul z1.h, z2.h, z1.h
+; CHECK-NEXT:    fmul z0.h, z1.h, z0.h
+; CHECK-NEXT:    ret
+  %fdiv = fdiv fast <vscale x 8 x half> %a, %b
+  ret <vscale x 8 x half> %fdiv
+}
+
+define <vscale x 4 x float> @fdiv_4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
+; CHECK-LABEL: fdiv_4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fdiv z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    ret
+  %fdiv = fdiv fast <vscale x 4 x float> %a, %b
+  ret <vscale x 4 x float> %fdiv
+}
+
+define <vscale x 4 x float> @fdiv_recip_4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 {
+; CHECK-LABEL: fdiv_recip_4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    frecpe z2.s, z1.s
+; CHECK-NEXT:    frecps z3.s, z1.s, z2.s
+; CHECK-NEXT:    fmul z2.s, z2.s, z3.s
+; CHECK-NEXT:    frecps z1.s, z1.s, z2.s
+; CHECK-NEXT:    fmul z1.s, z2.s, z1.s
+; CHECK-NEXT:    fmul z0.s, z1.s, z0.s
+; CHECK-NEXT:    ret
+  %fdiv = fdiv fast <vscale x 4 x float> %a, %b
+  ret <vscale x 4 x float> %fdiv
+}
+
+define <vscale x 2 x double> @fdiv_2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b) {
+; CHECK-LABEL: fdiv_2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fdiv z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %fdiv = fdiv fast <vscale x 2 x double> %a, %b
+  ret <vscale x 2 x double> %fdiv
+}
+
+define <vscale x 2 x double> @fdiv_recip_2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b) #0 {
+; CHECK-LABEL: fdiv_recip_2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    frecpe z2.d, z1.d
+; CHECK-NEXT:    frecps z3.d, z1.d, z2.d
+; CHECK-NEXT:    fmul z2.d, z2.d, z3.d
+; CHECK-NEXT:    frecps z3.d, z1.d, z2.d
+; CHECK-NEXT:    fmul z2.d, z2.d, z3.d
+; CHECK-NEXT:    frecps z1.d, z1.d, z2.d
+; CHECK-NEXT:    fmul z1.d, z2.d, z1.d
+; CHECK-NEXT:    fmul z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+  %fdiv = fdiv fast <vscale x 2 x double> %a, %b
+  ret <vscale x 2 x double> %fdiv
+}
+
+; FSQRT
+
+define <vscale x 8 x half> @fsqrt_8f16(<vscale x 8 x half> %a) {
+; CHECK-LABEL: fsqrt_8f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    fsqrt z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %fsqrt = call fast <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half> %a)
+  ret <vscale x 8 x half> %fsqrt
+}
+
+define <vscale x 8 x half> @fsqrt_recip_8f16(<vscale x 8 x half> %a) #0 {
+; CHECK-LABEL: fsqrt_recip_8f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    frsqrte z1.h, z0.h
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    fmul z2.h, z1.h, z1.h
+; CHECK-NEXT:    fcmeq p0.h, p0/z, z0.h, #0.0
+; CHECK-NEXT:    frsqrts z2.h, z0.h, z2.h
+; CHECK-NEXT:    fmul z1.h, z1.h, z2.h
+; CHECK-NEXT:    fmul z2.h, z1.h, z1.h
+; CHECK-NEXT:    frsqrts z2.h, z0.h, z2.h
+; CHECK-NEXT:    fmul z1.h, z1.h, z2.h
+; CHECK-NEXT:    fmul z1.h, z0.h, z1.h
+; CHECK-NEXT:    sel z0.h, p0, z0.h, z1.h
+; CHECK-NEXT:    ret
+  %fsqrt = call fast <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half> %a)
+  ret <vscale x 8 x half> %fsqrt
+}
+
+define <vscale x 4 x float> @fsqrt_4f32(<vscale x 4 x float> %a) {
+; CHECK-LABEL: fsqrt_4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fsqrt z0.s, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %fsqrt = call fast <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float> %a)
+  ret <vscale x 4 x float> %fsqrt
+}
+
+define <vscale x 4 x float> @fsqrt_recip_4f32(<vscale x 4 x float> %a) #0 {
+; CHECK-LABEL: fsqrt_recip_4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    frsqrte z1.s, z0.s
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fmul z2.s, z1.s, z1.s
+; CHECK-NEXT:    fcmeq p0.s, p0/z, z0.s, #0.0
+; CHECK-NEXT:    frsqrts z2.s, z0.s, z2.s
+; CHECK-NEXT:    fmul z1.s, z1.s, z2.s
+; CHECK-NEXT:    fmul z2.s, z1.s, z1.s
+; CHECK-NEXT:    frsqrts z2.s, z0.s, z2.s
+; CHECK-NEXT:    fmul z1.s, z1.s, z2.s
+; CHECK-NEXT:    fmul z1.s, z0.s, z1.s
+; CHECK-NEXT:    sel z0.s, p0, z0.s, z1.s
+; CHECK-NEXT:    ret
+  %fsqrt = call fast <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float> %a)
+  ret <vscale x 4 x float> %fsqrt
+}
+
+define <vscale x 2 x double> @fsqrt_2f64(<vscale x 2 x double> %a) {
+; CHECK-LABEL: fsqrt_2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fsqrt z0.d, p0/m, z0.d
+; CHECK-NEXT:    ret
+  %fsqrt = call fast <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double> %a)
+  ret <vscale x 2 x double> %fsqrt
+}
+
+define <vscale x 2 x double> @fsqrt_recip_2f64(<vscale x 2 x double> %a) #0 {
+; CHECK-LABEL: fsqrt_recip_2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    frsqrte z1.d, z0.d
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fmul z2.d, z1.d, z1.d
+; CHECK-NEXT:    fcmeq p0.d, p0/z, z0.d, #0.0
+; CHECK-NEXT:    frsqrts z2.d, z0.d, z2.d
+; CHECK-NEXT:    fmul z1.d, z1.d, z2.d
+; CHECK-NEXT:    fmul z2.d, z1.d, z1.d
+; CHECK-NEXT:    frsqrts z2.d, z0.d, z2.d
+; CHECK-NEXT:    fmul z1.d, z1.d, z2.d
+; CHECK-NEXT:    fmul z2.d, z1.d, z1.d
+; CHECK-NEXT:    frsqrts z2.d, z0.d, z2.d
+; CHECK-NEXT:    fmul z1.d, z1.d, z2.d
+; CHECK-NEXT:    fmul z1.d, z0.d, z1.d
+; CHECK-NEXT:    sel z0.d, p0, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %fsqrt = call fast <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double> %a)
+  ret <vscale x 2 x double> %fsqrt
+}
+
+declare <vscale x 2 x half> @llvm.sqrt.nxv2f16(<vscale x 2 x half>)
+declare <vscale x 4 x half> @llvm.sqrt.nxv4f16(<vscale x 4 x half>)
+declare <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half>)
+declare <vscale x 2 x float> @llvm.sqrt.nxv2f32(<vscale x 2 x float>)
+declare <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float>)
+declare <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double>)
+
+attributes #0 = { "reciprocal-estimates"="all" }


        


More information about the llvm-commits mailing list