[llvm] 672b62e - [AArch64][SVE] Custom lowering of floating-point reductions

Cullen Rhodes via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 30 03:24:22 PDT 2020


Author: Cullen Rhodes
Date: 2020-04-30T10:18:40Z
New Revision: 672b62ea21dfe5f9bfb2b0362785f2685be830a0

URL: https://github.com/llvm/llvm-project/commit/672b62ea21dfe5f9bfb2b0362785f2685be830a0
DIFF: https://github.com/llvm/llvm-project/commit/672b62ea21dfe5f9bfb2b0362785f2685be830a0.diff

LOG: [AArch64][SVE] Custom lowering of floating-point reductions

Summary:
This patch implements custom floating-point reduction ISD nodes that
have vector results, which are used to lower the following intrinsics:

    * llvm.aarch64.sve.fadda
    * llvm.aarch64.sve.faddv
    * llvm.aarch64.sve.fmaxv
    * llvm.aarch64.sve.fmaxnmv
    * llvm.aarch64.sve.fminv
    * llvm.aarch64.sve.fminnmv

SVE reduction instructions keep their result within a vector register,
with all other bits set to zero.

Changes in this patch were implemented by Paul Walker and Sander de
Smalen.

Reviewers: sdesmalen, efriedma, rengolin

Reviewed By: efriedma

Differential Revision: https://reviews.llvm.org/D78723

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/lib/Target/AArch64/SVEInstrFormats.td
    llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d8151dd0e99..e5f3a4e3e8da 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1366,6 +1366,12 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
   case AArch64ISD::REV:               return "AArch64ISD::REV";
   case AArch64ISD::REINTERPRET_CAST:  return "AArch64ISD::REINTERPRET_CAST";
   case AArch64ISD::TBL:               return "AArch64ISD::TBL";
+  case AArch64ISD::FADDA_PRED:        return "AArch64ISD::FADDA_PRED";
+  case AArch64ISD::FADDV_PRED:        return "AArch64ISD::FADDV_PRED";
+  case AArch64ISD::FMAXV_PRED:        return "AArch64ISD::FMAXV_PRED";
+  case AArch64ISD::FMAXNMV_PRED:      return "AArch64ISD::FMAXNMV_PRED";
+  case AArch64ISD::FMINV_PRED:        return "AArch64ISD::FMINV_PRED";
+  case AArch64ISD::FMINNMV_PRED:      return "AArch64ISD::FMINNMV_PRED";
   case AArch64ISD::NOT:               return "AArch64ISD::NOT";
   case AArch64ISD::BIT:               return "AArch64ISD::BIT";
   case AArch64ISD::CBZ:               return "AArch64ISD::CBZ";
@@ -11308,6 +11314,46 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
   return DAG.getZExtOrTrunc(Res, DL, VT);
 }
 
+static SDValue combineSVEReductionFP(SDNode *N, unsigned Opc,
+                                     SelectionDAG &DAG) {
+  SDLoc DL(N);
+
+  SDValue Pred = N->getOperand(1);
+  SDValue VecToReduce = N->getOperand(2);
+
+  EVT ReduceVT = VecToReduce.getValueType();
+  SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce);
+
+  // SVE reductions set the whole vector register with the first element
+  // containing the reduction result, which we'll now extract.
+  SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
+  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
+                     Zero);
+}
+
+static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
+                                            SelectionDAG &DAG) {
+  SDLoc DL(N);
+
+  SDValue Pred = N->getOperand(1);
+  SDValue InitVal = N->getOperand(2);
+  SDValue VecToReduce = N->getOperand(3);
+  EVT ReduceVT = VecToReduce.getValueType();
+
+  // Ordered reductions use the first lane of the result vector as the
+  // reduction's initial value.
+  SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
+  InitVal = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ReduceVT,
+                        DAG.getUNDEF(ReduceVT), InitVal, Zero);
+
+  SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, InitVal, VecToReduce);
+
+  // SVE reductions set the whole vector register with the first element
+  // containing the reduction result, which we'll now extract.
+  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
+                     Zero);
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -11391,6 +11437,18 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_udiv:
      return DAG.getNode(AArch64ISD::UDIV_PRED, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
+  case Intrinsic::aarch64_sve_fadda:
+    return combineSVEReductionOrderedFP(N, AArch64ISD::FADDA_PRED, DAG);
+  case Intrinsic::aarch64_sve_faddv:
+    return combineSVEReductionFP(N, AArch64ISD::FADDV_PRED, DAG);
+  case Intrinsic::aarch64_sve_fmaxnmv:
+    return combineSVEReductionFP(N, AArch64ISD::FMAXNMV_PRED, DAG);
+  case Intrinsic::aarch64_sve_fmaxv:
+    return combineSVEReductionFP(N, AArch64ISD::FMAXV_PRED, DAG);
+  case Intrinsic::aarch64_sve_fminnmv:
+    return combineSVEReductionFP(N, AArch64ISD::FMINNMV_PRED, DAG);
+  case Intrinsic::aarch64_sve_fminv:
+    return combineSVEReductionFP(N, AArch64ISD::FMINV_PRED, DAG);
   case Intrinsic::aarch64_sve_sel:
     return DAG.getNode(ISD::VSELECT, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2), N->getOperand(3));

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index fe67d75fab17..c5afa9bea9e2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -215,6 +215,14 @@ enum NodeType : unsigned {
   REV,
   TBL,
 
+  // Floating-point reductions.
+  FADDA_PRED,
+  FADDV_PRED,
+  FMAXV_PRED,
+  FMAXNMV_PRED,
+  FMINV_PRED,
+  FMINNMV_PRED,
+
   INSR,
   PTEST,
   PTRUE,

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index ed790633e266..7c964dad4277 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -134,16 +134,20 @@ def sve_cntw_imm_neg : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, -4>">;
 def sve_cntd_imm_neg : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, -2>">;
 
 def SDT_AArch64Reduce : SDTypeProfile<1, 2, [SDTCisVec<1>, SDTCisVec<2>]>;
-
-def AArch64smaxv_pred      :  SDNode<"AArch64ISD::SMAXV_PRED", SDT_AArch64Reduce>;
-def AArch64umaxv_pred      :  SDNode<"AArch64ISD::UMAXV_PRED", SDT_AArch64Reduce>;
-def AArch64sminv_pred      :  SDNode<"AArch64ISD::SMINV_PRED", SDT_AArch64Reduce>;
-def AArch64uminv_pred      :  SDNode<"AArch64ISD::UMINV_PRED", SDT_AArch64Reduce>;
-def AArch64orv_pred        :  SDNode<"AArch64ISD::ORV_PRED", SDT_AArch64Reduce>;
-def AArch64eorv_pred       :  SDNode<"AArch64ISD::EORV_PRED", SDT_AArch64Reduce>;
-def AArch64andv_pred       :  SDNode<"AArch64ISD::ANDV_PRED", SDT_AArch64Reduce>;
-def AArch64lasta           :  SDNode<"AArch64ISD::LASTA",     SDT_AArch64Reduce>;
-def AArch64lastb           :  SDNode<"AArch64ISD::LASTB",     SDT_AArch64Reduce>;
+def AArch64faddv_pred   : SDNode<"AArch64ISD::FADDV_PRED",   SDT_AArch64Reduce>;
+def AArch64fmaxv_pred   : SDNode<"AArch64ISD::FMAXV_PRED",   SDT_AArch64Reduce>;
+def AArch64fmaxnmv_pred : SDNode<"AArch64ISD::FMAXNMV_PRED", SDT_AArch64Reduce>;
+def AArch64fminv_pred   : SDNode<"AArch64ISD::FMINV_PRED",   SDT_AArch64Reduce>;
+def AArch64fminnmv_pred : SDNode<"AArch64ISD::FMINNMV_PRED", SDT_AArch64Reduce>;
+def AArch64smaxv_pred   : SDNode<"AArch64ISD::SMAXV_PRED",   SDT_AArch64Reduce>;
+def AArch64umaxv_pred   : SDNode<"AArch64ISD::UMAXV_PRED",   SDT_AArch64Reduce>;
+def AArch64sminv_pred   : SDNode<"AArch64ISD::SMINV_PRED",   SDT_AArch64Reduce>;
+def AArch64uminv_pred   : SDNode<"AArch64ISD::UMINV_PRED",   SDT_AArch64Reduce>;
+def AArch64orv_pred     : SDNode<"AArch64ISD::ORV_PRED",     SDT_AArch64Reduce>;
+def AArch64eorv_pred    : SDNode<"AArch64ISD::EORV_PRED",    SDT_AArch64Reduce>;
+def AArch64andv_pred    : SDNode<"AArch64ISD::ANDV_PRED",    SDT_AArch64Reduce>;
+def AArch64lasta        : SDNode<"AArch64ISD::LASTA",        SDT_AArch64Reduce>;
+def AArch64lastb        : SDNode<"AArch64ISD::LASTB",        SDT_AArch64Reduce>;
 
 def SDT_AArch64DIV : SDTypeProfile<1, 3, [
   SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>,
@@ -156,6 +160,7 @@ def AArch64udiv_pred      :  SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64DIV>;
 def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>;
 def AArch64clasta_n   : SDNode<"AArch64ISD::CLASTA_N",   SDT_AArch64ReduceWithInit>;
 def AArch64clastb_n   : SDNode<"AArch64ISD::CLASTB_N",   SDT_AArch64ReduceWithInit>;
+def AArch64fadda_pred : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;
 
 def SDT_AArch64Rev   : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
 def AArch64rev       : SDNode<"AArch64ISD::REV", SDT_AArch64Rev>;
@@ -352,12 +357,21 @@ let Predicates = [HasSVE] in {
   defm FMUL_ZZZI   : sve_fp_fmul_by_indexed_elem<"fmul", int_aarch64_sve_fmul_lane>;
 
   // SVE floating point reductions.
-  defm FADDA_VPZ   : sve_fp_2op_p_vd<0b000, "fadda", int_aarch64_sve_fadda>;
-  defm FADDV_VPZ   : sve_fp_fast_red<0b000, "faddv", int_aarch64_sve_faddv>;
-  defm FMAXNMV_VPZ : sve_fp_fast_red<0b100, "fmaxnmv", int_aarch64_sve_fmaxnmv>;
-  defm FMINNMV_VPZ : sve_fp_fast_red<0b101, "fminnmv", int_aarch64_sve_fminnmv>;
-  defm FMAXV_VPZ   : sve_fp_fast_red<0b110, "fmaxv", int_aarch64_sve_fmaxv>;
-  defm FMINV_VPZ   : sve_fp_fast_red<0b111, "fminv", int_aarch64_sve_fminv>;
+  defm FADDA_VPZ   : sve_fp_2op_p_vd<0b000, "fadda",   AArch64fadda_pred>;
+  defm FADDV_VPZ   : sve_fp_fast_red<0b000, "faddv",   AArch64faddv_pred>;
+  defm FMAXNMV_VPZ : sve_fp_fast_red<0b100, "fmaxnmv", AArch64fmaxnmv_pred>;
+  defm FMINNMV_VPZ : sve_fp_fast_red<0b101, "fminnmv", AArch64fminnmv_pred>;
+  defm FMAXV_VPZ   : sve_fp_fast_red<0b110, "fmaxv",   AArch64fmaxv_pred>;
+  defm FMINV_VPZ   : sve_fp_fast_red<0b111, "fminv",   AArch64fminv_pred>;
+
+  // Use more efficient NEON instructions to extract elements within the NEON
+  // part (first 128bits) of an SVE register.
+  def : Pat<(vector_extract (nxv8f16 ZPR:$Zs), (i64 0)),
+            (f16 (EXTRACT_SUBREG (v8f16 (EXTRACT_SUBREG ZPR:$Zs, zsub)), hsub))>;
+  def : Pat<(vector_extract (nxv4f32 ZPR:$Zs), (i64 0)),
+            (f32 (EXTRACT_SUBREG (v4f32 (EXTRACT_SUBREG ZPR:$Zs, zsub)), ssub))>;
+  def : Pat<(vector_extract (nxv2f64 ZPR:$Zs), (i64 0)),
+            (f64 (EXTRACT_SUBREG (v2f64 (EXTRACT_SUBREG ZPR:$Zs, zsub)), dsub))>;
 
   // Splat immediate (unpredicated)
   defm DUP_ZI   : sve_int_dup_imm<"dup">;

diff  --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 54fab60e33fd..a75bf47fa39b 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -4444,8 +4444,8 @@ multiclass sve2_int_while_rr<bits<1> rw, string asm, string op> {
 //===----------------------------------------------------------------------===//
 
 class sve_fp_fast_red<bits<2> sz, bits<3> opc, string asm,
-                      ZPRRegOp zprty, RegisterClass dstRegClass>
-: I<(outs dstRegClass:$Vd), (ins PPR3bAny:$Pg, zprty:$Zn),
+                      ZPRRegOp zprty, FPRasZPROperand dstOpType>
+: I<(outs dstOpType:$Vd), (ins PPR3bAny:$Pg, zprty:$Zn),
   asm, "\t$Vd, $Pg, $Zn",
   "",
   []>, Sched<[]> {
@@ -4463,13 +4463,13 @@ class sve_fp_fast_red<bits<2> sz, bits<3> opc, string asm,
 }
 
 multiclass sve_fp_fast_red<bits<3> opc, string asm, SDPatternOperator op> {
-  def _H : sve_fp_fast_red<0b01, opc, asm, ZPR16, FPR16>;
-  def _S : sve_fp_fast_red<0b10, opc, asm, ZPR32, FPR32>;
-  def _D : sve_fp_fast_red<0b11, opc, asm, ZPR64, FPR64>;
+  def _H : sve_fp_fast_red<0b01, opc, asm, ZPR16, FPR16asZPR>;
+  def _S : sve_fp_fast_red<0b10, opc, asm, ZPR32, FPR32asZPR>;
+  def _D : sve_fp_fast_red<0b11, opc, asm, ZPR64, FPR64asZPR>;
 
-  def : SVE_2_Op_Pat<f16, op, nxv8i1, nxv8f16, !cast<Instruction>(NAME # _H)>;
-  def : SVE_2_Op_Pat<f32, op, nxv4i1, nxv4f32, !cast<Instruction>(NAME # _S)>;
-  def : SVE_2_Op_Pat<f64, op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
+  def : SVE_2_Op_Pat<nxv8f16, op, nxv8i1, nxv8f16, !cast<Instruction>(NAME # _H)>;
+  def : SVE_2_Op_Pat<nxv4f32, op, nxv4i1, nxv4f32, !cast<Instruction>(NAME # _S)>;
+  def : SVE_2_Op_Pat<nxv2f64, op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
 }
 
 
@@ -4478,8 +4478,8 @@ multiclass sve_fp_fast_red<bits<3> opc, string asm, SDPatternOperator op> {
 //===----------------------------------------------------------------------===//
 
 class sve_fp_2op_p_vd<bits<2> sz, bits<3> opc, string asm,
-                      ZPRRegOp zprty, RegisterClass dstRegClass>
-: I<(outs dstRegClass:$Vdn), (ins PPR3bAny:$Pg, dstRegClass:$_Vdn, zprty:$Zm),
+                      ZPRRegOp zprty, FPRasZPROperand dstOpType>
+: I<(outs dstOpType:$Vdn), (ins PPR3bAny:$Pg, dstOpType:$_Vdn, zprty:$Zm),
   asm, "\t$Vdn, $Pg, $_Vdn, $Zm",
   "",
   []>,
@@ -4500,13 +4500,13 @@ class sve_fp_2op_p_vd<bits<2> sz, bits<3> opc, string asm,
 }
 
 multiclass sve_fp_2op_p_vd<bits<3> opc, string asm, SDPatternOperator op> {
-  def _H : sve_fp_2op_p_vd<0b01, opc, asm, ZPR16, FPR16>;
-  def _S : sve_fp_2op_p_vd<0b10, opc, asm, ZPR32, FPR32>;
-  def _D : sve_fp_2op_p_vd<0b11, opc, asm, ZPR64, FPR64>;
+  def _H : sve_fp_2op_p_vd<0b01, opc, asm, ZPR16, FPR16asZPR>;
+  def _S : sve_fp_2op_p_vd<0b10, opc, asm, ZPR32, FPR32asZPR>;
+  def _D : sve_fp_2op_p_vd<0b11, opc, asm, ZPR64, FPR64asZPR>;
 
-  def : SVE_3_Op_Pat<f16, op, nxv8i1, f16, nxv8f16, !cast<Instruction>(NAME # _H)>;
-  def : SVE_3_Op_Pat<f32, op, nxv4i1, f32, nxv4f32, !cast<Instruction>(NAME # _S)>;
-  def : SVE_3_Op_Pat<f64, op, nxv2i1, f64, nxv2f64, !cast<Instruction>(NAME # _D)>;
+  def : SVE_3_Op_Pat<nxv8f16, op, nxv8i1, nxv8f16, nxv8f16, !cast<Instruction>(NAME # _H)>;
+  def : SVE_3_Op_Pat<nxv4f32, op, nxv4i1, nxv4f32, nxv4f32, !cast<Instruction>(NAME # _S)>;
+  def : SVE_3_Op_Pat<nxv2f64, op, nxv2i1, nxv2f64, nxv2f64, !cast<Instruction>(NAME # _D)>;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll
index 083a7d35439c..c933c2eab40d 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll
@@ -1,4 +1,4 @@
-; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve -asm-verbose=0 < %s | FileCheck %s
 
 ;
 ; FADDA


        


More information about the llvm-commits mailing list