[llvm] 672b62e - [AArch64][SVE] Custom lowering of floating-point reductions
Cullen Rhodes via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 30 03:24:22 PDT 2020
Author: Cullen Rhodes
Date: 2020-04-30T10:18:40Z
New Revision: 672b62ea21dfe5f9bfb2b0362785f2685be830a0
URL: https://github.com/llvm/llvm-project/commit/672b62ea21dfe5f9bfb2b0362785f2685be830a0
DIFF: https://github.com/llvm/llvm-project/commit/672b62ea21dfe5f9bfb2b0362785f2685be830a0.diff
LOG: [AArch64][SVE] Custom lowering of floating-point reductions
Summary:
This patch implements custom floating-point reduction ISD nodes that
have vector results, which are used to lower the following intrinsics:
* llvm.aarch64.sve.fadda
* llvm.aarch64.sve.faddv
* llvm.aarch64.sve.fmaxv
* llvm.aarch64.sve.fmaxnmv
* llvm.aarch64.sve.fminv
* llvm.aarch64.sve.fminnmv
SVE reduction instructions keep their result within a vector register,
with all other bits set to zero.
Changes in this patch were implemented by Paul Walker and Sander de
Smalen.
Reviewers: sdesmalen, efriedma, rengolin
Reviewed By: efriedma
Differential Revision: https://reviews.llvm.org/D78723
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
llvm/lib/Target/AArch64/SVEInstrFormats.td
llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d8151dd0e99..e5f3a4e3e8da 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1366,6 +1366,12 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
case AArch64ISD::REV: return "AArch64ISD::REV";
case AArch64ISD::REINTERPRET_CAST: return "AArch64ISD::REINTERPRET_CAST";
case AArch64ISD::TBL: return "AArch64ISD::TBL";
+ case AArch64ISD::FADDA_PRED: return "AArch64ISD::FADDA_PRED";
+ case AArch64ISD::FADDV_PRED: return "AArch64ISD::FADDV_PRED";
+ case AArch64ISD::FMAXV_PRED: return "AArch64ISD::FMAXV_PRED";
+ case AArch64ISD::FMAXNMV_PRED: return "AArch64ISD::FMAXNMV_PRED";
+ case AArch64ISD::FMINV_PRED: return "AArch64ISD::FMINV_PRED";
+ case AArch64ISD::FMINNMV_PRED: return "AArch64ISD::FMINNMV_PRED";
case AArch64ISD::NOT: return "AArch64ISD::NOT";
case AArch64ISD::BIT: return "AArch64ISD::BIT";
case AArch64ISD::CBZ: return "AArch64ISD::CBZ";
@@ -11308,6 +11314,46 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
return DAG.getZExtOrTrunc(Res, DL, VT);
}
+static SDValue combineSVEReductionFP(SDNode *N, unsigned Opc,
+ SelectionDAG &DAG) {
+ SDLoc DL(N);
+
+ SDValue Pred = N->getOperand(1);
+ SDValue VecToReduce = N->getOperand(2);
+
+ EVT ReduceVT = VecToReduce.getValueType();
+ SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce);
+
+ // SVE reductions set the whole vector register with the first element
+ // containing the reduction result, which we'll now extract.
+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
+ Zero);
+}
+
+static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
+ SelectionDAG &DAG) {
+ SDLoc DL(N);
+
+ SDValue Pred = N->getOperand(1);
+ SDValue InitVal = N->getOperand(2);
+ SDValue VecToReduce = N->getOperand(3);
+ EVT ReduceVT = VecToReduce.getValueType();
+
+ // Ordered reductions use the first lane of the result vector as the
+ // reduction's initial value.
+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
+ InitVal = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ReduceVT,
+ DAG.getUNDEF(ReduceVT), InitVal, Zero);
+
+ SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, InitVal, VecToReduce);
+
+ // SVE reductions set the whole vector register with the first element
+ // containing the reduction result, which we'll now extract.
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
+ Zero);
+}
+
static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
@@ -11391,6 +11437,18 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::aarch64_sve_udiv:
return DAG.getNode(AArch64ISD::UDIV_PRED, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2), N->getOperand(3));
+ case Intrinsic::aarch64_sve_fadda:
+ return combineSVEReductionOrderedFP(N, AArch64ISD::FADDA_PRED, DAG);
+ case Intrinsic::aarch64_sve_faddv:
+ return combineSVEReductionFP(N, AArch64ISD::FADDV_PRED, DAG);
+ case Intrinsic::aarch64_sve_fmaxnmv:
+ return combineSVEReductionFP(N, AArch64ISD::FMAXNMV_PRED, DAG);
+ case Intrinsic::aarch64_sve_fmaxv:
+ return combineSVEReductionFP(N, AArch64ISD::FMAXV_PRED, DAG);
+ case Intrinsic::aarch64_sve_fminnmv:
+ return combineSVEReductionFP(N, AArch64ISD::FMINNMV_PRED, DAG);
+ case Intrinsic::aarch64_sve_fminv:
+ return combineSVEReductionFP(N, AArch64ISD::FMINV_PRED, DAG);
case Intrinsic::aarch64_sve_sel:
return DAG.getNode(ISD::VSELECT, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2), N->getOperand(3));
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index fe67d75fab17..c5afa9bea9e2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -215,6 +215,14 @@ enum NodeType : unsigned {
REV,
TBL,
+ // Floating-point reductions.
+ FADDA_PRED,
+ FADDV_PRED,
+ FMAXV_PRED,
+ FMAXNMV_PRED,
+ FMINV_PRED,
+ FMINNMV_PRED,
+
INSR,
PTEST,
PTRUE,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index ed790633e266..7c964dad4277 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -134,16 +134,20 @@ def sve_cntw_imm_neg : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, -4>">;
def sve_cntd_imm_neg : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, -2>">;
def SDT_AArch64Reduce : SDTypeProfile<1, 2, [SDTCisVec<1>, SDTCisVec<2>]>;
-
-def AArch64smaxv_pred : SDNode<"AArch64ISD::SMAXV_PRED", SDT_AArch64Reduce>;
-def AArch64umaxv_pred : SDNode<"AArch64ISD::UMAXV_PRED", SDT_AArch64Reduce>;
-def AArch64sminv_pred : SDNode<"AArch64ISD::SMINV_PRED", SDT_AArch64Reduce>;
-def AArch64uminv_pred : SDNode<"AArch64ISD::UMINV_PRED", SDT_AArch64Reduce>;
-def AArch64orv_pred : SDNode<"AArch64ISD::ORV_PRED", SDT_AArch64Reduce>;
-def AArch64eorv_pred : SDNode<"AArch64ISD::EORV_PRED", SDT_AArch64Reduce>;
-def AArch64andv_pred : SDNode<"AArch64ISD::ANDV_PRED", SDT_AArch64Reduce>;
-def AArch64lasta : SDNode<"AArch64ISD::LASTA", SDT_AArch64Reduce>;
-def AArch64lastb : SDNode<"AArch64ISD::LASTB", SDT_AArch64Reduce>;
+def AArch64faddv_pred : SDNode<"AArch64ISD::FADDV_PRED", SDT_AArch64Reduce>;
+def AArch64fmaxv_pred : SDNode<"AArch64ISD::FMAXV_PRED", SDT_AArch64Reduce>;
+def AArch64fmaxnmv_pred : SDNode<"AArch64ISD::FMAXNMV_PRED", SDT_AArch64Reduce>;
+def AArch64fminv_pred : SDNode<"AArch64ISD::FMINV_PRED", SDT_AArch64Reduce>;
+def AArch64fminnmv_pred : SDNode<"AArch64ISD::FMINNMV_PRED", SDT_AArch64Reduce>;
+def AArch64smaxv_pred : SDNode<"AArch64ISD::SMAXV_PRED", SDT_AArch64Reduce>;
+def AArch64umaxv_pred : SDNode<"AArch64ISD::UMAXV_PRED", SDT_AArch64Reduce>;
+def AArch64sminv_pred : SDNode<"AArch64ISD::SMINV_PRED", SDT_AArch64Reduce>;
+def AArch64uminv_pred : SDNode<"AArch64ISD::UMINV_PRED", SDT_AArch64Reduce>;
+def AArch64orv_pred : SDNode<"AArch64ISD::ORV_PRED", SDT_AArch64Reduce>;
+def AArch64eorv_pred : SDNode<"AArch64ISD::EORV_PRED", SDT_AArch64Reduce>;
+def AArch64andv_pred : SDNode<"AArch64ISD::ANDV_PRED", SDT_AArch64Reduce>;
+def AArch64lasta : SDNode<"AArch64ISD::LASTA", SDT_AArch64Reduce>;
+def AArch64lastb : SDNode<"AArch64ISD::LASTB", SDT_AArch64Reduce>;
def SDT_AArch64DIV : SDTypeProfile<1, 3, [
SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>,
@@ -156,6 +160,7 @@ def AArch64udiv_pred : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64DIV>;
def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>;
def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>;
def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>;
+def AArch64fadda_pred : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;
def SDT_AArch64Rev : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
def AArch64rev : SDNode<"AArch64ISD::REV", SDT_AArch64Rev>;
@@ -352,12 +357,21 @@ let Predicates = [HasSVE] in {
defm FMUL_ZZZI : sve_fp_fmul_by_indexed_elem<"fmul", int_aarch64_sve_fmul_lane>;
// SVE floating point reductions.
- defm FADDA_VPZ : sve_fp_2op_p_vd<0b000, "fadda", int_aarch64_sve_fadda>;
- defm FADDV_VPZ : sve_fp_fast_red<0b000, "faddv", int_aarch64_sve_faddv>;
- defm FMAXNMV_VPZ : sve_fp_fast_red<0b100, "fmaxnmv", int_aarch64_sve_fmaxnmv>;
- defm FMINNMV_VPZ : sve_fp_fast_red<0b101, "fminnmv", int_aarch64_sve_fminnmv>;
- defm FMAXV_VPZ : sve_fp_fast_red<0b110, "fmaxv", int_aarch64_sve_fmaxv>;
- defm FMINV_VPZ : sve_fp_fast_red<0b111, "fminv", int_aarch64_sve_fminv>;
+ defm FADDA_VPZ : sve_fp_2op_p_vd<0b000, "fadda", AArch64fadda_pred>;
+ defm FADDV_VPZ : sve_fp_fast_red<0b000, "faddv", AArch64faddv_pred>;
+ defm FMAXNMV_VPZ : sve_fp_fast_red<0b100, "fmaxnmv", AArch64fmaxnmv_pred>;
+ defm FMINNMV_VPZ : sve_fp_fast_red<0b101, "fminnmv", AArch64fminnmv_pred>;
+ defm FMAXV_VPZ : sve_fp_fast_red<0b110, "fmaxv", AArch64fmaxv_pred>;
+ defm FMINV_VPZ : sve_fp_fast_red<0b111, "fminv", AArch64fminv_pred>;
+
+ // Use more efficient NEON instructions to extract elements within the NEON
+ // part (first 128bits) of an SVE register.
+ def : Pat<(vector_extract (nxv8f16 ZPR:$Zs), (i64 0)),
+ (f16 (EXTRACT_SUBREG (v8f16 (EXTRACT_SUBREG ZPR:$Zs, zsub)), hsub))>;
+ def : Pat<(vector_extract (nxv4f32 ZPR:$Zs), (i64 0)),
+ (f32 (EXTRACT_SUBREG (v4f32 (EXTRACT_SUBREG ZPR:$Zs, zsub)), ssub))>;
+ def : Pat<(vector_extract (nxv2f64 ZPR:$Zs), (i64 0)),
+ (f64 (EXTRACT_SUBREG (v2f64 (EXTRACT_SUBREG ZPR:$Zs, zsub)), dsub))>;
// Splat immediate (unpredicated)
defm DUP_ZI : sve_int_dup_imm<"dup">;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 54fab60e33fd..a75bf47fa39b 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -4444,8 +4444,8 @@ multiclass sve2_int_while_rr<bits<1> rw, string asm, string op> {
//===----------------------------------------------------------------------===//
class sve_fp_fast_red<bits<2> sz, bits<3> opc, string asm,
- ZPRRegOp zprty, RegisterClass dstRegClass>
-: I<(outs dstRegClass:$Vd), (ins PPR3bAny:$Pg, zprty:$Zn),
+ ZPRRegOp zprty, FPRasZPROperand dstOpType>
+: I<(outs dstOpType:$Vd), (ins PPR3bAny:$Pg, zprty:$Zn),
asm, "\t$Vd, $Pg, $Zn",
"",
[]>, Sched<[]> {
@@ -4463,13 +4463,13 @@ class sve_fp_fast_red<bits<2> sz, bits<3> opc, string asm,
}
multiclass sve_fp_fast_red<bits<3> opc, string asm, SDPatternOperator op> {
- def _H : sve_fp_fast_red<0b01, opc, asm, ZPR16, FPR16>;
- def _S : sve_fp_fast_red<0b10, opc, asm, ZPR32, FPR32>;
- def _D : sve_fp_fast_red<0b11, opc, asm, ZPR64, FPR64>;
+ def _H : sve_fp_fast_red<0b01, opc, asm, ZPR16, FPR16asZPR>;
+ def _S : sve_fp_fast_red<0b10, opc, asm, ZPR32, FPR32asZPR>;
+ def _D : sve_fp_fast_red<0b11, opc, asm, ZPR64, FPR64asZPR>;
- def : SVE_2_Op_Pat<f16, op, nxv8i1, nxv8f16, !cast<Instruction>(NAME # _H)>;
- def : SVE_2_Op_Pat<f32, op, nxv4i1, nxv4f32, !cast<Instruction>(NAME # _S)>;
- def : SVE_2_Op_Pat<f64, op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
+ def : SVE_2_Op_Pat<nxv8f16, op, nxv8i1, nxv8f16, !cast<Instruction>(NAME # _H)>;
+ def : SVE_2_Op_Pat<nxv4f32, op, nxv4i1, nxv4f32, !cast<Instruction>(NAME # _S)>;
+ def : SVE_2_Op_Pat<nxv2f64, op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
}
@@ -4478,8 +4478,8 @@ multiclass sve_fp_fast_red<bits<3> opc, string asm, SDPatternOperator op> {
//===----------------------------------------------------------------------===//
class sve_fp_2op_p_vd<bits<2> sz, bits<3> opc, string asm,
- ZPRRegOp zprty, RegisterClass dstRegClass>
-: I<(outs dstRegClass:$Vdn), (ins PPR3bAny:$Pg, dstRegClass:$_Vdn, zprty:$Zm),
+ ZPRRegOp zprty, FPRasZPROperand dstOpType>
+: I<(outs dstOpType:$Vdn), (ins PPR3bAny:$Pg, dstOpType:$_Vdn, zprty:$Zm),
asm, "\t$Vdn, $Pg, $_Vdn, $Zm",
"",
[]>,
@@ -4500,13 +4500,13 @@ class sve_fp_2op_p_vd<bits<2> sz, bits<3> opc, string asm,
}
multiclass sve_fp_2op_p_vd<bits<3> opc, string asm, SDPatternOperator op> {
- def _H : sve_fp_2op_p_vd<0b01, opc, asm, ZPR16, FPR16>;
- def _S : sve_fp_2op_p_vd<0b10, opc, asm, ZPR32, FPR32>;
- def _D : sve_fp_2op_p_vd<0b11, opc, asm, ZPR64, FPR64>;
+ def _H : sve_fp_2op_p_vd<0b01, opc, asm, ZPR16, FPR16asZPR>;
+ def _S : sve_fp_2op_p_vd<0b10, opc, asm, ZPR32, FPR32asZPR>;
+ def _D : sve_fp_2op_p_vd<0b11, opc, asm, ZPR64, FPR64asZPR>;
- def : SVE_3_Op_Pat<f16, op, nxv8i1, f16, nxv8f16, !cast<Instruction>(NAME # _H)>;
- def : SVE_3_Op_Pat<f32, op, nxv4i1, f32, nxv4f32, !cast<Instruction>(NAME # _S)>;
- def : SVE_3_Op_Pat<f64, op, nxv2i1, f64, nxv2f64, !cast<Instruction>(NAME # _D)>;
+ def : SVE_3_Op_Pat<nxv8f16, op, nxv8i1, nxv8f16, nxv8f16, !cast<Instruction>(NAME # _H)>;
+ def : SVE_3_Op_Pat<nxv4f32, op, nxv4i1, nxv4f32, nxv4f32, !cast<Instruction>(NAME # _S)>;
+ def : SVE_3_Op_Pat<nxv2f64, op, nxv2i1, nxv2f64, nxv2f64, !cast<Instruction>(NAME # _D)>;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll
index 083a7d35439c..c933c2eab40d 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll
@@ -1,4 +1,4 @@
-; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve -asm-verbose=0 < %s | FileCheck %s
;
; FADDA
More information about the llvm-commits
mailing list