[llvm] aab6f7d - [AArch64][SVE] Add lowering for llvm fabs

Muhammad Asif Manzoor via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 1 16:41:51 PDT 2020


Author: Muhammad Asif Manzoor
Date: 2020-10-01T19:41:25-04:00
New Revision: aab6f7db471d577d313f334cba37667c35158420

URL: https://github.com/llvm/llvm-project/commit/aab6f7db471d577d313f334cba37667c35158420
DIFF: https://github.com/llvm/llvm-project/commit/aab6f7db471d577d313f334cba37667c35158420.diff

LOG: [AArch64][SVE] Add lowering for llvm fabs

Add the functionality to lower fabs for passthru variant

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D88679

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/lib/Target/AArch64/SVEInstrFormats.td
    llvm/test/CodeGen/AArch64/sve-fp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fb70b2d801da..d7d326fa019d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -191,6 +191,7 @@ static bool isMergePassthruOpcode(unsigned Opc) {
   case AArch64ISD::FCVTZS_MERGE_PASSTHRU:
   case AArch64ISD::FSQRT_MERGE_PASSTHRU:
   case AArch64ISD::FRECPX_MERGE_PASSTHRU:
+  case AArch64ISD::FABS_MERGE_PASSTHRU:
     return true;
   }
 }
@@ -1054,6 +1055,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::FROUNDEVEN, VT, Custom);
       setOperationAction(ISD::FTRUNC, VT, Custom);
       setOperationAction(ISD::FSQRT, VT, Custom);
+      setOperationAction(ISD::FABS, VT, Custom);
       setOperationAction(ISD::FP_EXTEND, VT, Custom);
       setOperationAction(ISD::FP_ROUND, VT, Custom);
     }
@@ -1592,6 +1594,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::FCVTZS_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::FRECPX_MERGE_PASSTHRU)
+    MAKE_CASE(AArch64ISD::FABS_MERGE_PASSTHRU)
     MAKE_CASE(AArch64ISD::SETCC_MERGE_ZERO)
     MAKE_CASE(AArch64ISD::ADC)
     MAKE_CASE(AArch64ISD::SBC)
@@ -3521,6 +3524,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   case Intrinsic::aarch64_sve_frecpx:
     return DAG.getNode(AArch64ISD::FRECPX_MERGE_PASSTHRU, dl, Op.getValueType(),
                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+  case Intrinsic::aarch64_sve_fabs:
+    return DAG.getNode(AArch64ISD::FABS_MERGE_PASSTHRU, dl, Op.getValueType(),
+                       Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
   case Intrinsic::aarch64_sve_convert_to_svbool: {
     EVT OutVT = Op.getValueType();
     EVT InVT = Op.getOperand(1).getValueType();
@@ -3834,6 +3840,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
   case ISD::FSQRT:
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
+  case ISD::FABS:
+    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FABS_MERGE_PASSTHRU);
   case ISD::FP_ROUND:
   case ISD::STRICT_FP_ROUND:
     return LowerFP_ROUND(Op, DAG);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 1b8f62e427db..dc23fb838f97 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -95,6 +95,7 @@ enum NodeType : unsigned {
 
   // Predicated instructions with the result of inactive lanes provided by the
   // last operand.
+  FABS_MERGE_PASSTHRU,
   FCEIL_MERGE_PASSTHRU,
   FFLOOR_MERGE_PASSTHRU,
   FNEARBYINT_MERGE_PASSTHRU,

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index e2c8eb9115cf..d0b526ee4755 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -201,9 +201,10 @@ def SDT_AArch64IntExtend : SDTypeProfile<1, 4, [
 ]>;
 
 // Predicated operations with the result of inactive lanes provided by the last operand.
-def AArch64fneg_mt : SDNode<"AArch64ISD::FNEG_MERGE_PASSTHRU", SDT_AArch64Arith>;
-def AArch64sxt_mt  : SDNode<"AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU", SDT_AArch64IntExtend>;
-def AArch64uxt_mt  : SDNode<"AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU", SDT_AArch64IntExtend>;
+def AArch64fneg_mt   : SDNode<"AArch64ISD::FNEG_MERGE_PASSTHRU", SDT_AArch64Arith>;
+def AArch64fabs_mt   : SDNode<"AArch64ISD::FABS_MERGE_PASSTHRU", SDT_AArch64Arith>;
+def AArch64sxt_mt    : SDNode<"AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU", SDT_AArch64IntExtend>;
+def AArch64uxt_mt    : SDNode<"AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU", SDT_AArch64IntExtend>;
 def AArch64frintp_mt : SDNode<"AArch64ISD::FCEIL_MERGE_PASSTHRU", SDT_AArch64Arith>;
 def AArch64frintm_mt : SDNode<"AArch64ISD::FFLOOR_MERGE_PASSTHRU", SDT_AArch64Arith>;
 def AArch64frinti_mt : SDNode<"AArch64ISD::FNEARBYINT_MERGE_PASSTHRU", SDT_AArch64Arith>;
@@ -211,7 +212,7 @@ def AArch64frintx_mt : SDNode<"AArch64ISD::FRINT_MERGE_PASSTHRU", SDT_AArch64Ari
 def AArch64frinta_mt : SDNode<"AArch64ISD::FROUND_MERGE_PASSTHRU", SDT_AArch64Arith>;
 def AArch64frintn_mt : SDNode<"AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU", SDT_AArch64Arith>;
 def AArch64frintz_mt : SDNode<"AArch64ISD::FTRUNC_MERGE_PASSTHRU", SDT_AArch64Arith>;
-def AArch64fsqrt_mt : SDNode<"AArch64ISD::FSQRT_MERGE_PASSTHRU", SDT_AArch64Arith>;
+def AArch64fsqrt_mt  : SDNode<"AArch64ISD::FSQRT_MERGE_PASSTHRU", SDT_AArch64Arith>;
 def AArch64frecpx_mt : SDNode<"AArch64ISD::FRECPX_MERGE_PASSTHRU", SDT_AArch64Arith>;
 
 def SDT_AArch64FCVT : SDTypeProfile<1, 3, [
@@ -378,8 +379,8 @@ let Predicates = [HasSVE] in {
 
   defm CNOT_ZPmZ : sve_int_un_pred_arit_1<   0b011, "cnot", int_aarch64_sve_cnot>;
   defm NOT_ZPmZ  : sve_int_un_pred_arit_1<   0b110, "not",  int_aarch64_sve_not>;
-  defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", int_aarch64_sve_fabs, null_frag>;
-  defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", null_frag, AArch64fneg_mt>;
+  defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", AArch64fabs_mt>;
+  defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>;
 
   defm SMAX_ZPmZ : sve_int_bin_pred_arit_1<0b000, "smax", "SMAX_ZPZZ", int_aarch64_sve_smax, DestructiveBinaryComm>;
   defm UMAX_ZPmZ : sve_int_bin_pred_arit_1<0b001, "umax", "UMAX_ZPZZ", int_aarch64_sve_umax, DestructiveBinaryComm>;

diff  --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 45a712c897a4..7d5a0695035e 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -3802,24 +3802,17 @@ multiclass sve_int_un_pred_arit_1<bits<3> opc, string asm,
   def : SVE_3_Op_Pat<nxv2i64, op, nxv2i64, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
 }
 
-// TODO: Remove int_op once its last use is converted to ir_op.
-multiclass sve_int_un_pred_arit_1_fp<bits<3> opc, string asm,
-                                     SDPatternOperator int_op,
-                                     SDPatternOperator ir_op> {
+multiclass sve_int_un_pred_arit_1_fp<bits<3> opc, string asm, SDPatternOperator op> {
   def _H : sve_int_un_pred_arit<0b01, { opc, 0b1 }, asm, ZPR16>;
   def _S : sve_int_un_pred_arit<0b10, { opc, 0b1 }, asm, ZPR32>;
   def _D : sve_int_un_pred_arit<0b11, { opc, 0b1 }, asm, ZPR64>;
 
-  def : SVE_3_Op_Pat<nxv8f16, int_op, nxv8f16, nxv8i1, nxv8f16, !cast<Instruction>(NAME # _H)>;
-  def : SVE_3_Op_Pat<nxv4f32, int_op, nxv4f32, nxv4i1, nxv4f32, !cast<Instruction>(NAME # _S)>;
-  def : SVE_3_Op_Pat<nxv2f64, int_op, nxv2f64, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
-
-  def : SVE_1_Op_Passthru_Pat<nxv8f16, ir_op, nxv8i1, nxv8f16, !cast<Instruction>(NAME # _H)>;
-  def : SVE_1_Op_Passthru_Pat<nxv4f16, ir_op, nxv4i1, nxv4f16, !cast<Instruction>(NAME # _H)>;
-  def : SVE_1_Op_Passthru_Pat<nxv2f16, ir_op, nxv2i1, nxv2f16, !cast<Instruction>(NAME # _H)>;
-  def : SVE_1_Op_Passthru_Pat<nxv4f32, ir_op, nxv4i1, nxv4f32, !cast<Instruction>(NAME # _S)>;
-  def : SVE_1_Op_Passthru_Pat<nxv2f32, ir_op, nxv2i1, nxv2f32, !cast<Instruction>(NAME # _S)>;
-  def : SVE_1_Op_Passthru_Pat<nxv2f64, ir_op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
+  def : SVE_1_Op_Passthru_Pat<nxv8f16, op, nxv8i1, nxv8f16, !cast<Instruction>(NAME # _H)>;
+  def : SVE_1_Op_Passthru_Pat<nxv4f16, op, nxv4i1, nxv4f16, !cast<Instruction>(NAME # _H)>;
+  def : SVE_1_Op_Passthru_Pat<nxv2f16, op, nxv2i1, nxv2f16, !cast<Instruction>(NAME # _H)>;
+  def : SVE_1_Op_Passthru_Pat<nxv4f32, op, nxv4i1, nxv4f32, !cast<Instruction>(NAME # _S)>;
+  def : SVE_1_Op_Passthru_Pat<nxv2f32, op, nxv2i1, nxv2f32, !cast<Instruction>(NAME # _S)>;
+  def : SVE_1_Op_Passthru_Pat<nxv2f64, op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/test/CodeGen/AArch64/sve-fp.ll b/llvm/test/CodeGen/AArch64/sve-fp.ll
index 5334e66b22f7..7ca1fdee7f32 100644
--- a/llvm/test/CodeGen/AArch64/sve-fp.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fp.ll
@@ -542,6 +542,68 @@ define <vscale x 2 x double> @fsqrt_nxv2f64(<vscale x 2 x double> %a) {
   ret <vscale x 2 x double> %res
 }
 
+; FABS
+
+define <vscale x 8 x half> @fabs_nxv8f16(<vscale x 8 x half> %a) {
+; CHECK-LABEL: fabs_nxv8f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    fabs z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %res = call <vscale x 8 x half> @llvm.fabs.nxv8f16(<vscale x 8 x half> %a)
+  ret <vscale x 8 x half> %res
+}
+
+define <vscale x 4 x half> @fabs_nxv4f16(<vscale x 4 x half> %a) {
+; CHECK-LABEL: fabs_nxv4f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fabs z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %res = call <vscale x 4 x half> @llvm.fabs.nxv4f16(<vscale x 4 x half> %a)
+  ret <vscale x 4 x half> %res
+}
+
+define <vscale x 2 x half> @fabs_nxv2f16(<vscale x 2 x half> %a) {
+; CHECK-LABEL: fabs_nxv2f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fabs z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %res = call <vscale x 2 x half> @llvm.fabs.nxv2f16(<vscale x 2 x half> %a)
+  ret <vscale x 2 x half> %res
+}
+
+define <vscale x 4 x float> @fabs_nxv4f32(<vscale x 4 x float> %a) {
+; CHECK-LABEL: fabs_nxv4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fabs z0.s, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %res = call <vscale x 4 x float> @llvm.fabs.nxv4f32(<vscale x 4 x float> %a)
+  ret <vscale x 4 x float> %res
+}
+
+define <vscale x 2 x float> @fabs_nxv2f32(<vscale x 2 x float> %a) {
+; CHECK-LABEL: fabs_nxv2f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fabs z0.s, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %res = call <vscale x 2 x float> @llvm.fabs.nxv2f32(<vscale x 2 x float> %a)
+  ret <vscale x 2 x float> %res
+}
+
+define <vscale x 2 x double> @fabs_nxv2f64(<vscale x 2 x double> %a) {
+; CHECK-LABEL: fabs_nxv2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fabs z0.d, p0/m, z0.d
+; CHECK-NEXT:    ret
+  %res = call <vscale x 2 x double> @llvm.fabs.nxv2f64(<vscale x 2 x double> %a)
+  ret <vscale x 2 x double> %res
+}
+
 declare <vscale x 8 x half> @llvm.aarch64.sve.frecps.x.nxv8f16(<vscale x 8 x half>, <vscale x 8 x half>)
 declare <vscale x 4 x float>  @llvm.aarch64.sve.frecps.x.nxv4f32(<vscale x 4 x float> , <vscale x 4 x float>)
 declare <vscale x 2 x double> @llvm.aarch64.sve.frecps.x.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
@@ -564,5 +626,12 @@ declare <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float>)
 declare <vscale x 2 x float> @llvm.sqrt.nxv2f32(<vscale x 2 x float>)
 declare <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double>)
 
+declare <vscale x 8 x half> @llvm.fabs.nxv8f16( <vscale x 8 x half>)
+declare <vscale x 4 x half> @llvm.fabs.nxv4f16( <vscale x 4 x half>)
+declare <vscale x 2 x half> @llvm.fabs.nxv2f16( <vscale x 2 x half>)
+declare <vscale x 4 x float> @llvm.fabs.nxv4f32(<vscale x 4 x float>)
+declare <vscale x 2 x float> @llvm.fabs.nxv2f32(<vscale x 2 x float>)
+declare <vscale x 2 x double> @llvm.fabs.nxv2f64(<vscale x 2 x double>)
+
 ; Function Attrs: nounwind readnone
 declare double @llvm.aarch64.sve.faddv.nxv2f64(<vscale x 2 x i1>, <vscale x 2 x double>) #2


        


More information about the llvm-commits mailing list