[llvm] 3931734 - [AArch64][SVE] Add initial backend support for FP splat_vector

Cameron McInally via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 19 08:19:25 PST 2020


Author: Cameron McInally
Date: 2020-02-19T10:19:11-06:00
New Revision: 393173499099649eb95816132889295c3b463faf

URL: https://github.com/llvm/llvm-project/commit/393173499099649eb95816132889295c3b463faf
DIFF: https://github.com/llvm/llvm-project/commit/393173499099649eb95816132889295c3b463faf.diff

LOG: [AArch64][SVE] Add initial backend support for FP splat_vector

Differential Revision: https://reviews.llvm.org/D74632

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrFormats.td
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/test/CodeGen/AArch64/sve-vector-splat.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 58a518fbe5d7..0f2b5eb29bfe 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -874,6 +874,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     }
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom);
+
+    for (MVT VT : MVT::fp_scalable_vector_valuetypes()) {
+      if (isTypeLegal(VT)) {
+        setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+      }
+    }
   }
 
   PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive();
@@ -7483,14 +7489,6 @@ SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op,
   // Extend input splat value where needed to fit into a GPR (32b or 64b only)
   // FPRs don't have this restriction.
   switch (ElemVT.getSimpleVT().SimpleTy) {
-  case MVT::i8:
-  case MVT::i16:
-  case MVT::i32:
-    SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32);
-    return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal);
-  case MVT::i64:
-    SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64);
-    return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal);
   case MVT::i1: {
     // The general case of i1.  There isn't any natural way to do this,
     // so we use some trickery with whilelo.
@@ -7503,13 +7501,24 @@ SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op,
     return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID,
                        DAG.getConstant(0, dl, MVT::i64), SplatVal);
   }
-  // TODO: we can support float types, but haven't added patterns yet.
+  case MVT::i8:
+  case MVT::i16:
+  case MVT::i32:
+    SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32);
+    break;
+  case MVT::i64:
+    SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64);
+    break;
   case MVT::f16:
   case MVT::f32:
   case MVT::f64:
+    // Fine as is
+    break;
   default:
     report_fatal_error("Unsupported SPLAT_VECTOR input operand type");
   }
+
+  return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal);
 }
 
 static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits,

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 58baf67ee447..6914b7c7d40e 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -66,6 +66,7 @@ class Pseudo<dag oops, dag iops, list<dag> pattern, string cstr = "">
   dag InOperandList  = iops;
   let Pattern        = pattern;
   let isCodeGenOnly  = 1;
+  let isPseudo       = 1;
 }
 
 // Real instructions (have encoding information)

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index f09349973373..e76e4acd0893 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -296,6 +296,28 @@ let Predicates = [HasSVE] in {
   defm CPY_ZPmR : sve_int_perm_cpy_r<"cpy", AArch64dup_pred>;
   defm CPY_ZPmV : sve_int_perm_cpy_v<"cpy", AArch64dup_pred>;
 
+  // Duplicate FP scalar into all vector elements
+  def : Pat<(nxv8f16 (AArch64dup (f16 FPR16:$src))),
+            (DUP_ZZI_H (INSERT_SUBREG (IMPLICIT_DEF), FPR16:$src, hsub), 0)>;
+  def : Pat<(nxv4f16 (AArch64dup (f16 FPR16:$src))),
+            (DUP_ZZI_H (INSERT_SUBREG (IMPLICIT_DEF), FPR16:$src, hsub), 0)>;
+  def : Pat<(nxv2f16 (AArch64dup (f16 FPR16:$src))),
+            (DUP_ZZI_H (INSERT_SUBREG (IMPLICIT_DEF), FPR16:$src, hsub), 0)>;
+  def : Pat<(nxv4f32 (AArch64dup (f32 FPR32:$src))),
+            (DUP_ZZI_S (INSERT_SUBREG (IMPLICIT_DEF), FPR32:$src, ssub), 0)>;
+  def : Pat<(nxv2f32 (AArch64dup (f32 FPR32:$src))),
+            (DUP_ZZI_S (INSERT_SUBREG (IMPLICIT_DEF), FPR32:$src, ssub), 0)>;
+  def : Pat<(nxv2f64 (AArch64dup (f64 FPR64:$src))),
+            (DUP_ZZI_D (INSERT_SUBREG (IMPLICIT_DEF), FPR64:$src, dsub), 0)>;
+
+  // Duplicate +0.0 into all vector elements
+  def : Pat<(nxv8f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>;
+  def : Pat<(nxv4f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>;
+  def : Pat<(nxv2f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>;
+  def : Pat<(nxv4f32 (AArch64dup (f32 fpimm0))), (DUP_ZI_S 0, 0)>;
+  def : Pat<(nxv2f32 (AArch64dup (f32 fpimm0))), (DUP_ZI_S 0, 0)>;
+  def : Pat<(nxv2f64 (AArch64dup (f64 fpimm0))), (DUP_ZI_D 0, 0)>;
+
   // Select elements from either vector (predicated)
   defm SEL_ZPZZ    : sve_int_sel_vvv<"sel", vselect>;
 

diff  --git a/llvm/test/CodeGen/AArch64/sve-vector-splat.ll b/llvm/test/CodeGen/AArch64/sve-vector-splat.ll
index 086241c4e0a7..fdb3ee6c066c 100644
--- a/llvm/test/CodeGen/AArch64/sve-vector-splat.ll
+++ b/llvm/test/CodeGen/AArch64/sve-vector-splat.ll
@@ -133,3 +133,104 @@ define <vscale x 16 x i1> @sve_splat_16xi1(i1 %val) {
   %splat = shufflevector <vscale x 16 x i1> %ins, <vscale x 16 x i1> undef, <vscale x 16 x i32> zeroinitializer
   ret <vscale x 16 x i1> %splat
 }
+
+;; Splats of legal floating point vector types
+
+define <vscale x 8 x half> @splat_nxv8f16(half %val) {
+; CHECK-LABEL: splat_nxv8f16:
+; CHECK: mov z0.h, h0
+; CHECK-NEXT: ret
+  %1 = insertelement <vscale x 8 x half> undef, half %val, i32 0
+  %2 = shufflevector <vscale x 8 x half> %1, <vscale x 8 x half> undef, <vscale x 8 x i32> zeroinitializer
+  ret <vscale x 8 x half> %2
+}
+
+define <vscale x 4 x half> @splat_nxv4f16(half %val) {
+; CHECK-LABEL: splat_nxv4f16:
+; CHECK: mov z0.h, h0
+; CHECK-NEXT: ret
+  %1 = insertelement <vscale x 4 x half> undef, half %val, i32 0
+  %2 = shufflevector <vscale x 4 x half> %1, <vscale x 4 x half> undef, <vscale x 4 x i32> zeroinitializer
+  ret <vscale x 4 x half> %2
+}
+
+define <vscale x 2 x half> @splat_nxv2f16(half %val) {
+; CHECK-LABEL: splat_nxv2f16:
+; CHECK: mov z0.h, h0
+; CHECK-NEXT: ret
+  %1 = insertelement <vscale x 2 x half> undef, half %val, i32 0
+  %2 = shufflevector <vscale x 2 x half> %1, <vscale x 2 x half> undef, <vscale x 2 x i32> zeroinitializer
+  ret <vscale x 2 x half> %2
+}
+
+define <vscale x 4 x float> @splat_nxv4f32(float %val) {
+; CHECK-LABEL: splat_nxv4f32:
+; CHECK: mov z0.s, s0
+; CHECK-NEXT: ret
+  %1 = insertelement <vscale x 4 x float> undef, float %val, i32 0
+  %2 = shufflevector <vscale x 4 x float> %1, <vscale x 4 x float> undef, <vscale x 4 x i32> zeroinitializer
+  ret <vscale x 4 x float> %2
+}
+
+define <vscale x 2 x float> @splat_nxv2f32(float %val) {
+; CHECK-LABEL: splat_nxv2f32:
+; CHECK: mov z0.s, s0
+; CHECK-NEXT: ret
+  %1 = insertelement <vscale x 2 x float> undef, float %val, i32 0
+  %2 = shufflevector <vscale x 2 x float> %1, <vscale x 2 x float> undef, <vscale x 2 x i32> zeroinitializer
+  ret <vscale x 2 x float> %2
+}
+
+define <vscale x 2 x double> @splat_nxv2f64(double %val) {
+; CHECK-LABEL: splat_nxv2f64:
+; CHECK: mov z0.d, d0
+; CHECK-NEXT: ret
+  %1 = insertelement <vscale x 2 x double> undef, double %val, i32 0
+  %2 = shufflevector <vscale x 2 x double> %1, <vscale x 2 x double> undef, <vscale x 2 x i32> zeroinitializer
+  ret <vscale x 2 x double> %2
+}
+
+; TODO: The f16 constant should be folded into the move.
+define <vscale x 8 x half> @splat_nxv8f16_zero() {
+; CHECK-LABEL: splat_nxv8f16_zero:
+; CHECK: mov z0.h, h0
+; CHECK-NEXT: ret
+  ret <vscale x 8 x half> zeroinitializer
+}
+
+; TODO: The f16 constant should be folded into the move.
+define <vscale x 4 x half> @splat_nxv4f16_zero() {
+; CHECK-LABEL: splat_nxv4f16_zero:
+; CHECK: mov z0.h, h0
+; CHECK-NEXT: ret
+  ret <vscale x 4 x half> zeroinitializer
+}
+
+; TODO: The f16 constant should be folded into the move.
+define <vscale x 2 x half> @splat_nxv2f16_zero() {
+; CHECK-LABEL: splat_nxv2f16_zero:
+; CHECK: mov z0.h, h0
+; CHECK-NEXT: ret
+  ret <vscale x 2 x half> zeroinitializer
+}
+
+define <vscale x 4 x float> @splat_nxv4f32_zero() {
+; CHECK-LABEL: splat_nxv4f32_zero:
+; CHECK: mov z0.s, #0
+; CHECK-NEXT: ret
+  ret <vscale x 4 x float> zeroinitializer
+}
+
+define <vscale x 2 x float> @splat_nxv2f32_zero() {
+; CHECK-LABEL: splat_nxv2f32_zero:
+; CHECK: mov z0.s, #0
+; CHECK-NEXT: ret
+  ret <vscale x 2 x float> zeroinitializer
+}
+
+define <vscale x 2 x double> @splat_nxv2f64_zero() {
+; CHECK-LABEL: splat_nxv2f64_zero:
+; CHECK: mov z0.d, #0
+; CHECK-NEXT: ret
+  ret <vscale x 2 x double> zeroinitializer
+}


        


More information about the llvm-commits mailing list