[llvm] 56ae2ce - [AArch64][SVE] Add lowering for llvm.fma.

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 9 16:13:20 PDT 2020


Author: Eli Friedman
Date: 2020-07-09T16:12:41-07:00
New Revision: 56ae2cebcdf7884470212ed2a04c1bce73d5c996

URL: https://github.com/llvm/llvm-project/commit/56ae2cebcdf7884470212ed2a04c1bce73d5c996
DIFF: https://github.com/llvm/llvm-project/commit/56ae2cebcdf7884470212ed2a04c1bce73d5c996.diff

LOG: [AArch64][SVE] Add lowering for llvm.fma.

This is currently bare-bones; we aren't taking advantage of any of the
FMA variant instructions.  But it's enough to at least generate
code.

Differential Revision: https://reviews.llvm.org/D83444

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/test/CodeGen/AArch64/sve-fp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c3735c8784ca..1a3bbaf1832d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -948,6 +948,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
         setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
         setOperationAction(ISD::SELECT, VT, Custom);
+        setOperationAction(ISD::FMA, VT, Custom);
       }
     }
 
@@ -1470,6 +1471,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::FADD_PRED)
     MAKE_CASE(AArch64ISD::FADDA_PRED)
     MAKE_CASE(AArch64ISD::FADDV_PRED)
+    MAKE_CASE(AArch64ISD::FMA_PRED)
     MAKE_CASE(AArch64ISD::FMAXV_PRED)
     MAKE_CASE(AArch64ISD::FMAXNMV_PRED)
     MAKE_CASE(AArch64ISD::FMINV_PRED)
@@ -3455,6 +3457,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerF128Call(Op, DAG, RTLIB::SUB_F128);
   case ISD::FMUL:
     return LowerF128Call(Op, DAG, RTLIB::MUL_F128);
+  case ISD::FMA:
+    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
   case ISD::FDIV:
     return LowerF128Call(Op, DAG, RTLIB::DIV_F128);
   case ISD::FP_ROUND:

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 4b395acd816d..1be44797aac7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -77,6 +77,7 @@ enum NodeType : unsigned {
   FADD_PRED,
   SDIV_PRED,
   UDIV_PRED,
+  FMA_PRED,
   SMIN_MERGE_OP1,
   UMIN_MERGE_OP1,
   SMAX_MERGE_OP1,

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 825082230d1f..28a54e6f7d79 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -167,9 +167,15 @@ def SDT_AArch64Arith : SDTypeProfile<1, 3, [
   SDTCVecEltisVT<1,i1>, SDTCisSameAs<2,3>
 ]>;
 
+def SDT_AArch64FMA : SDTypeProfile<1, 4, [
+  SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>, SDTCisVec<4>,
+  SDTCVecEltisVT<1,i1>, SDTCisSameAs<2,3>, SDTCisSameAs<3,4>
+]>;
+
 // Predicated operations with the result of inactive lanes being unspecified.
 def AArch64add_p  : SDNode<"AArch64ISD::ADD_PRED",  SDT_AArch64Arith>;
 def AArch64fadd_p : SDNode<"AArch64ISD::FADD_PRED", SDT_AArch64Arith>;
+def AArch64fma_p  : SDNode<"AArch64ISD::FMA_PRED",  SDT_AArch64FMA>;
 def AArch64sdiv_p : SDNode<"AArch64ISD::SDIV_PRED", SDT_AArch64Arith>;
 def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>;
 
@@ -393,6 +399,16 @@ let Predicates = [HasSVE] in {
   defm FNMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b10, "fnmad", int_aarch64_sve_fnmad>;
   defm FNMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b11, "fnmsb", int_aarch64_sve_fnmsb>;
 
+  // Add patterns for FMA where disabled lanes are undef.
+  // FIXME: Implement a pseudo so we can choose a better instruction after
+  // regalloc.
+  def : Pat<(nxv8f16 (AArch64fma_p nxv8i1:$P, nxv8f16:$Op1, nxv8f16:$Op2, nxv8f16:$Op3)),
+            (FMLA_ZPmZZ_H $P, $Op3, $Op1, $Op2)>;
+  def : Pat<(nxv4f32 (AArch64fma_p nxv4i1:$P, nxv4f32:$Op1, nxv4f32:$Op2, nxv4f32:$Op3)),
+            (FMLA_ZPmZZ_S $P, $Op3, $Op1, $Op2)>;
+  def : Pat<(nxv2f64 (AArch64fma_p nxv2i1:$P, nxv2f64:$Op1, nxv2f64:$Op2, nxv2f64:$Op3)),
+            (FMLA_ZPmZZ_D $P, $Op3, $Op1, $Op2)>;
+
   defm FTMAD_ZZI : sve_fp_ftmad<"ftmad", int_aarch64_sve_ftmad_x>;
 
   defm FMLA_ZZZI : sve_fp_fma_by_indexed_elem<0b0, "fmla", int_aarch64_sve_fmla_lane>;

diff  --git a/llvm/test/CodeGen/AArch64/sve-fp.ll b/llvm/test/CodeGen/AArch64/sve-fp.ll
index c7cf917b2e64..e3c0ba72bda1 100644
--- a/llvm/test/CodeGen/AArch64/sve-fp.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fp.ll
@@ -85,6 +85,56 @@ define <vscale x 2 x double> @fmul_d(<vscale x 2 x double> %a, <vscale x 2 x dou
   ret <vscale x 2 x double> %res
 }
 
+define <vscale x 8 x half> @fma_half(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) {
+; CHECK-LABEL: fma_half:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    fmla z2.h, p0/m, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+  %r = call <vscale x 8 x half> @llvm.fma.nxv8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
+  ret <vscale x 8 x half> %r
+}
+define <vscale x 4 x float> @fma_float(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c) {
+; CHECK-LABEL: fma_float:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fmla z2.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+  %r = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c)
+  ret <vscale x 4 x float> %r
+}
+define <vscale x 2 x double> @fma_double_1(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fma_double_1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fmla z2.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+  %r = call <vscale x 2 x double> @llvm.fma.nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c)
+  ret <vscale x 2 x double> %r
+}
+define <vscale x 2 x double> @fma_double_2(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fma_double_2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fmla z2.d, p0/m, z1.d, z0.d
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+  %r = call <vscale x 2 x double> @llvm.fma.nxv2f64(<vscale x 2 x double> %b, <vscale x 2 x double> %a, <vscale x 2 x double> %c)
+  ret <vscale x 2 x double> %r
+}
+define <vscale x 2 x double> @fma_double_3(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fma_double_3:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fmla z0.d, p0/m, z2.d, z1.d
+; CHECK-NEXT:    ret
+  %r = call <vscale x 2 x double> @llvm.fma.nxv2f64(<vscale x 2 x double> %c, <vscale x 2 x double> %b, <vscale x 2 x double> %a)
+  ret <vscale x 2 x double> %r
+}
+
 define <vscale x 8 x half> @frecps_h(<vscale x 8 x half> %a, <vscale x 8 x half> %b) {
 ; CHECK-LABEL: frecps_h:
 ; CHECK:       // %bb.0:
@@ -166,5 +216,9 @@ declare <vscale x 8 x half> @llvm.aarch64.sve.frsqrts.x.nxv8f16(<vscale x 8 x ha
 declare <vscale x 4 x float> @llvm.aarch64.sve.frsqrts.x.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>)
 declare <vscale x 2 x double> @llvm.aarch64.sve.frsqrts.x.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
 
+declare <vscale x 2 x double> @llvm.fma.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>)
+declare <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
+declare <vscale x 8 x half> @llvm.fma.nxv8f16(<vscale x 8 x half>, <vscale x 8 x half>, <vscale x 8 x half>)
+
 ; Function Attrs: nounwind readnone
 declare double @llvm.aarch64.sve.faddv.nxv2f64(<vscale x 2 x i1>, <vscale x 2 x double>) #2


        


More information about the llvm-commits mailing list