[llvm] d417488 - [AArch64][SVE] Add lowering for llvm fsqrt

Danilo C. Grael via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 15 12:28:54 PDT 2020


Author: Muhammad Asif Manzoor
Date: 2020-09-15T15:26:17-04:00
New Revision: d417488ef5a6cd1089900defcd6d5ae5a1d47fd4

URL: https://github.com/llvm/llvm-project/commit/d417488ef5a6cd1089900defcd6d5ae5a1d47fd4
DIFF: https://github.com/llvm/llvm-project/commit/d417488ef5a6cd1089900defcd6d5ae5a1d47fd4.diff

LOG: [AArch64][SVE] Add lowering for llvm fsqrt

Add the functionality to lower fsqrt for passthru variant

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D87707

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/test/CodeGen/AArch64/sve-fp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 820661454783..b961e5a30cd0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -145,6 +145,7 @@ static bool isMergePassthruOpcode(unsigned Opc) {
   case AArch64ISD::FROUND_MERGE_PASSTHRU:
   case AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU:
   case AArch64ISD::FTRUNC_MERGE_PASSTHRU:
+  case AArch64ISD::FSQRT_MERGE_PASSTHRU:
     return true;
   }
 }
@@ -990,6 +991,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         setOperationAction(ISD::FROUND, VT, Custom);
         setOperationAction(ISD::FROUNDEVEN, VT, Custom);
         setOperationAction(ISD::FTRUNC, VT, Custom);
+        setOperationAction(ISD::FSQRT, VT, Custom);
       }
     }
 
@@ -1502,6 +1504,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::FROUND_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::FTRUNC_MERGE_PASSTHRU)
+    MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::SETCC_MERGE_ZERO)
     MAKE_CASE(AArch64ISD::ADC)
     MAKE_CASE(AArch64ISD::SBC)
@@ -3385,6 +3388,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   case Intrinsic::aarch64_sve_frintz:
     return DAG.getNode(AArch64ISD::FTRUNC_MERGE_PASSTHRU, dl, Op.getValueType(),
                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+  case Intrinsic::aarch64_sve_fsqrt:
+    return DAG.getNode(AArch64ISD::FSQRT_MERGE_PASSTHRU, dl, Op.getValueType(),
+                       Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
   case Intrinsic::aarch64_sve_convert_to_svbool: {
     EVT OutVT = Op.getValueType();
     EVT InVT = Op.getOperand(1).getValueType();
@@ -3696,6 +3702,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU);
   case ISD::FTRUNC:
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
+  case ISD::FSQRT:
+    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
   case ISD::FP_ROUND:
   case ISD::STRICT_FP_ROUND:
     return LowerFP_ROUND(Op, DAG);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d6e511891752..e34caacd272d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -102,6 +102,7 @@ enum NodeType : unsigned {
   FRINT_MERGE_PASSTHRU,
   FROUND_MERGE_PASSTHRU,
   FROUNDEVEN_MERGE_PASSTHRU,
+  FSQRT_MERGE_PASSTHRU,
   FTRUNC_MERGE_PASSTHRU,
   SIGN_EXTEND_INREG_MERGE_PASSTHRU,
   ZERO_EXTEND_INREG_MERGE_PASSTHRU,

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index e01a34242a8d..63545d30b2d1 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -209,6 +209,7 @@ def AArch64frintx_mt : SDNode<"AArch64ISD::FRINT_MERGE_PASSTHRU", SDT_AArch64Ari
 def AArch64frinta_mt : SDNode<"AArch64ISD::FROUND_MERGE_PASSTHRU", SDT_AArch64Arith>;
 def AArch64frintn_mt : SDNode<"AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU", SDT_AArch64Arith>;
 def AArch64frintz_mt : SDNode<"AArch64ISD::FTRUNC_MERGE_PASSTHRU", SDT_AArch64Arith>;
+def AArch64fsqrt_mt : SDNode<"AArch64ISD::FSQRT_MERGE_PASSTHRU", SDT_AArch64Arith>;
 
 def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>;
 def AArch64clasta_n   : SDNode<"AArch64ISD::CLASTA_N",   SDT_AArch64ReduceWithInit>;
@@ -1430,7 +1431,7 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
   defm FRINTX_ZPmZ : sve_fp_2op_p_zd_HSD<0b00110, "frintx", null_frag, AArch64frintx_mt>;
   defm FRINTI_ZPmZ : sve_fp_2op_p_zd_HSD<0b00111, "frinti", null_frag, AArch64frinti_mt>;
   defm FRECPX_ZPmZ : sve_fp_2op_p_zd_HSD<0b01100, "frecpx", int_aarch64_sve_frecpx>;
-  defm FSQRT_ZPmZ  : sve_fp_2op_p_zd_HSD<0b01101, "fsqrt",  int_aarch64_sve_fsqrt>;
+  defm FSQRT_ZPmZ  : sve_fp_2op_p_zd_HSD<0b01101, "fsqrt",  null_frag, AArch64fsqrt_mt>;
 
   let Predicates = [HasBF16, HasSVE] in {
     defm BFDOT_ZZZ    : sve_bfloat_dot<"bfdot", int_aarch64_sve_bfdot>;

diff  --git a/llvm/test/CodeGen/AArch64/sve-fp.ll b/llvm/test/CodeGen/AArch64/sve-fp.ll
index e4aea2847bc4..5334e66b22f7 100644
--- a/llvm/test/CodeGen/AArch64/sve-fp.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fp.ll
@@ -480,6 +480,68 @@ define void @float_copy(<vscale x 4 x float>* %P1, <vscale x 4 x float>* %P2) {
   ret void
 }
 
+; FSQRT
+
+define <vscale x 8 x half> @fsqrt_nxv8f16(<vscale x 8 x half> %a) {
+; CHECK-LABEL: fsqrt_nxv8f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    fsqrt z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %res = call <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half> %a)
+  ret <vscale x 8 x half> %res
+}
+
+define <vscale x 4 x half> @fsqrt_nxv4f16(<vscale x 4 x half> %a) {
+; CHECK-LABEL: fsqrt_nxv4f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fsqrt z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %res = call <vscale x 4 x half> @llvm.sqrt.nxv4f16(<vscale x 4 x half> %a)
+  ret <vscale x 4 x half> %res
+}
+
+define <vscale x 2 x half> @fsqrt_nxv2f16(<vscale x 2 x half> %a) {
+; CHECK-LABEL: fsqrt_nxv2f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fsqrt z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %res = call <vscale x 2 x half> @llvm.sqrt.nxv2f16(<vscale x 2 x half> %a)
+  ret <vscale x 2 x half> %res
+}
+
+define <vscale x 4 x float> @fsqrt_nxv4f32(<vscale x 4 x float> %a) {
+; CHECK-LABEL: fsqrt_nxv4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fsqrt z0.s, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %res = call <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float> %a)
+  ret <vscale x 4 x float> %res
+}
+
+define <vscale x 2 x float> @fsqrt_nxv2f32(<vscale x 2 x float> %a) {
+; CHECK-LABEL: fsqrt_nxv2f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fsqrt z0.s, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %res = call <vscale x 2 x float> @llvm.sqrt.nxv2f32(<vscale x 2 x float> %a)
+  ret <vscale x 2 x float> %res
+}
+
+define <vscale x 2 x double> @fsqrt_nxv2f64(<vscale x 2 x double> %a) {
+; CHECK-LABEL: fsqrt_nxv2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fsqrt z0.d, p0/m, z0.d
+; CHECK-NEXT:    ret
+  %res = call <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double> %a)
+  ret <vscale x 2 x double> %res
+}
+
 declare <vscale x 8 x half> @llvm.aarch64.sve.frecps.x.nxv8f16(<vscale x 8 x half>, <vscale x 8 x half>)
 declare <vscale x 4 x float>  @llvm.aarch64.sve.frecps.x.nxv4f32(<vscale x 4 x float> , <vscale x 4 x float>)
 declare <vscale x 2 x double> @llvm.aarch64.sve.frecps.x.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
@@ -495,5 +557,12 @@ declare <vscale x 8 x half> @llvm.fma.nxv8f16(<vscale x 8 x half>, <vscale x 8 x
 declare <vscale x 4 x half> @llvm.fma.nxv4f16(<vscale x 4 x half>, <vscale x 4 x half>, <vscale x 4 x half>)
 declare <vscale x 2 x half> @llvm.fma.nxv2f16(<vscale x 2 x half>, <vscale x 2 x half>, <vscale x 2 x half>)
 
+declare <vscale x 8 x half> @llvm.sqrt.nxv8f16( <vscale x 8 x half>)
+declare <vscale x 4 x half> @llvm.sqrt.nxv4f16( <vscale x 4 x half>)
+declare <vscale x 2 x half> @llvm.sqrt.nxv2f16( <vscale x 2 x half>)
+declare <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float>)
+declare <vscale x 2 x float> @llvm.sqrt.nxv2f32(<vscale x 2 x float>)
+declare <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double>)
+
 ; Function Attrs: nounwind readnone
 declare double @llvm.aarch64.sve.faddv.nxv2f64(<vscale x 2 x i1>, <vscale x 2 x double>) #2


        


More information about the llvm-commits mailing list