[llvm] 71dc3de - [ARM] Improve min/max vector reductions on Arm
David Green via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 22 09:00:26 PDT 2023
Author: Caleb Zulawski
Date: 2023-03-22T16:00:19Z
New Revision: 71dc3de533b9247223c083a3b058859c9759099c
URL: https://github.com/llvm/llvm-project/commit/71dc3de533b9247223c083a3b058859c9759099c
DIFF: https://github.com/llvm/llvm-project/commit/71dc3de533b9247223c083a3b058859c9759099c.diff
LOG: [ARM] Improve min/max vector reductions on Arm
This patch adds some more efficient lowering for vecreduce.min/max under NEON,
using sequences of pairwise vpmin/vpmax to reduce to a single value.
This nearly resolves issues such as #50466, #40981, #38190.
Differential Revision: https://reviews.llvm.org/D146404
Added:
llvm/test/CodeGen/ARM/vecreduce-minmax.ll
Modified:
llvm/lib/Target/ARM/ARMISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 126bbc61a7d30..9c5f0df4d9468 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -1007,6 +1007,14 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
setLoadExtAction(ISD::SEXTLOAD, VT, Ty, Legal);
}
}
+
+ for (auto VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
+ MVT::v4i32}) {
+ setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
+ }
}
if (Subtarget->hasNEON() || Subtarget->hasMVEIntegerOps()) {
@@ -10271,6 +10279,80 @@ static SDValue LowerVecReduceF(SDValue Op, SelectionDAG &DAG,
return LowerVecReduce(Op, DAG, ST);
}
+static SDValue LowerVecReduceMinMax(SDValue Op, SelectionDAG &DAG,
+ const ARMSubtarget *ST) {
+ if (!ST->hasNEON())
+ return SDValue();
+
+ SDLoc dl(Op);
+ SDValue Op0 = Op->getOperand(0);
+ EVT VT = Op0.getValueType();
+ EVT EltVT = VT.getVectorElementType();
+
+ unsigned PairwiseIntrinsic = 0;
+ switch (Op->getOpcode()) {
+ default:
+ llvm_unreachable("Expected VECREDUCE opcode");
+ case ISD::VECREDUCE_UMIN:
+ PairwiseIntrinsic = Intrinsic::arm_neon_vpminu;
+ break;
+ case ISD::VECREDUCE_UMAX:
+ PairwiseIntrinsic = Intrinsic::arm_neon_vpmaxu;
+ break;
+ case ISD::VECREDUCE_SMIN:
+ PairwiseIntrinsic = Intrinsic::arm_neon_vpmins;
+ break;
+ case ISD::VECREDUCE_SMAX:
+ PairwiseIntrinsic = Intrinsic::arm_neon_vpmaxs;
+ break;
+ }
+ SDValue PairwiseOp = DAG.getConstant(PairwiseIntrinsic, dl, MVT::i32);
+
+ unsigned NumElts = VT.getVectorNumElements();
+ unsigned NumActiveLanes = NumElts;
+
+ assert((NumActiveLanes == 16 || NumActiveLanes == 8 || NumActiveLanes == 4 ||
+ NumActiveLanes == 2) &&
+ "Only expected a power 2 vector size");
+
+ // Split 128-bit vectors, since vpmin/max takes 2 64-bit vectors.
+ if (VT.is128BitVector()) {
+ SDValue Lo, Hi;
+ std::tie(Lo, Hi) = DAG.SplitVector(Op0, dl);
+ VT = Lo.getValueType();
+ Op0 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, {PairwiseOp, Lo, Hi});
+ NumActiveLanes /= 2;
+ }
+
+ // Use pairwise reductions until one lane remains
+ while (NumActiveLanes > 1) {
+ Op0 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, {PairwiseOp, Op0, Op0});
+ NumActiveLanes /= 2;
+ }
+
+ SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Op0,
+ DAG.getConstant(0, dl, MVT::i32));
+
+ // Result type may be wider than element type.
+ if (EltVT != Op.getValueType()) {
+ unsigned Extend = 0;
+ switch (Op->getOpcode()) {
+ default:
+ llvm_unreachable("Expected VECREDUCE opcode");
+ case ISD::VECREDUCE_UMIN:
+ case ISD::VECREDUCE_UMAX:
+ Extend = ISD::ZERO_EXTEND;
+ break;
+ case ISD::VECREDUCE_SMIN:
+ case ISD::VECREDUCE_SMAX:
+ Extend = ISD::SIGN_EXTEND;
+ break;
+ }
+ Res = DAG.getNode(Extend, dl, Op.getValueType(), Res);
+ }
+ return Res;
+}
+
static SDValue LowerAtomicLoadStore(SDValue Op, SelectionDAG &DAG) {
if (isStrongerThanMonotonic(cast<AtomicSDNode>(Op)->getSuccessOrdering()))
// Acquire/Release load/store is not legal for targets without a dmb or
@@ -10502,6 +10584,11 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMAX:
return LowerVecReduceF(Op, DAG, Subtarget);
+ case ISD::VECREDUCE_UMIN:
+ case ISD::VECREDUCE_UMAX:
+ case ISD::VECREDUCE_SMIN:
+ case ISD::VECREDUCE_SMAX:
+ return LowerVecReduceMinMax(Op, DAG, Subtarget);
case ISD::ATOMIC_LOAD:
case ISD::ATOMIC_STORE: return LowerAtomicLoadStore(Op, DAG);
case ISD::FSINCOS: return LowerFSINCOS(Op, DAG);
diff --git a/llvm/test/CodeGen/ARM/vecreduce-minmax.ll b/llvm/test/CodeGen/ARM/vecreduce-minmax.ll
new file mode 100644
index 0000000000000..c392e6ca6bfa6
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/vecreduce-minmax.ll
@@ -0,0 +1,219 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=armv7-none-eabi -float-abi=hard -mattr=+neon -verify-machineinstrs | FileCheck %s
+
+define i8 @test_umin_v8i8(<8 x i8> %x) {
+; CHECK-LABEL: test_umin_v8i8:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmin.u8 d16, d0, d0
+; CHECK-NEXT: vpmin.u8 d16, d16, d16
+; CHECK-NEXT: vpmin.u8 d16, d16, d16
+; CHECK-NEXT: vmov.u8 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i8 @llvm.vector.reduce.umin.v8i8(<8 x i8> %x)
+ ret i8 %z
+}
+
+define i8 @test_smin_v8i8(<8 x i8> %x) {
+; CHECK-LABEL: test_smin_v8i8:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmin.s8 d16, d0, d0
+; CHECK-NEXT: vpmin.s8 d16, d16, d16
+; CHECK-NEXT: vpmin.s8 d16, d16, d16
+; CHECK-NEXT: vmov.s8 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i8 @llvm.vector.reduce.smin.v8i8(<8 x i8> %x)
+ ret i8 %z
+}
+
+define i8 @test_umax_v8i8(<8 x i8> %x) {
+; CHECK-LABEL: test_umax_v8i8:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmax.u8 d16, d0, d0
+; CHECK-NEXT: vpmax.u8 d16, d16, d16
+; CHECK-NEXT: vpmax.u8 d16, d16, d16
+; CHECK-NEXT: vmov.u8 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i8 @llvm.vector.reduce.umax.v8i8(<8 x i8> %x)
+ ret i8 %z
+}
+
+define i8 @test_smax_v8i8(<8 x i8> %x) {
+; CHECK-LABEL: test_smax_v8i8:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmax.s8 d16, d0, d0
+; CHECK-NEXT: vpmax.s8 d16, d16, d16
+; CHECK-NEXT: vpmax.s8 d16, d16, d16
+; CHECK-NEXT: vmov.s8 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i8 @llvm.vector.reduce.smax.v8i8(<8 x i8> %x)
+ ret i8 %z
+}
+
+define i16 @test_umin_v4i16(<4 x i16> %x) {
+; CHECK-LABEL: test_umin_v4i16:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmin.u16 d16, d0, d0
+; CHECK-NEXT: vpmin.u16 d16, d16, d16
+; CHECK-NEXT: vmov.u16 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i16 @llvm.vector.reduce.umin.v4i16(<4 x i16> %x)
+ ret i16 %z
+}
+
+define i16 @test_smin_v4i16(<4 x i16> %x) {
+; CHECK-LABEL: test_smin_v4i16:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmin.s16 d16, d0, d0
+; CHECK-NEXT: vpmin.s16 d16, d16, d16
+; CHECK-NEXT: vmov.s16 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i16 @llvm.vector.reduce.smin.v4i16(<4 x i16> %x)
+ ret i16 %z
+}
+
+define i16 @test_umax_v4i16(<4 x i16> %x) {
+; CHECK-LABEL: test_umax_v4i16:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmax.u16 d16, d0, d0
+; CHECK-NEXT: vpmax.u16 d16, d16, d16
+; CHECK-NEXT: vmov.u16 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i16 @llvm.vector.reduce.umax.v4i16(<4 x i16> %x)
+ ret i16 %z
+}
+
+define i16 @test_smax_v4i16(<4 x i16> %x) {
+; CHECK-LABEL: test_smax_v4i16:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmax.s16 d16, d0, d0
+; CHECK-NEXT: vpmax.s16 d16, d16, d16
+; CHECK-NEXT: vmov.s16 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i16 @llvm.vector.reduce.smax.v4i16(<4 x i16> %x)
+ ret i16 %z
+}
+
+define i32 @test_umin_v2i32(<2 x i32> %x) {
+; CHECK-LABEL: test_umin_v2i32:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmin.u32 d16, d0, d0
+; CHECK-NEXT: vmov.32 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i32 @llvm.vector.reduce.umin.v2i32(<2 x i32> %x)
+ ret i32 %z
+}
+
+define i32 @test_smin_v2i32(<2 x i32> %x) {
+; CHECK-LABEL: test_smin_v2i32:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmin.s32 d16, d0, d0
+; CHECK-NEXT: vmov.32 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i32 @llvm.vector.reduce.smin.v2i32(<2 x i32> %x)
+ ret i32 %z
+}
+
+define i32 @test_umax_v2i32(<2 x i32> %x) {
+; CHECK-LABEL: test_umax_v2i32:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmax.u32 d16, d0, d0
+; CHECK-NEXT: vmov.32 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i32 @llvm.vector.reduce.umax.v2i32(<2 x i32> %x)
+ ret i32 %z
+}
+
+define i32 @test_smax_v2i32(<2 x i32> %x) {
+; CHECK-LABEL: test_smax_v2i32:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmax.s32 d16, d0, d0
+; CHECK-NEXT: vmov.32 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i32 @llvm.vector.reduce.smax.v2i32(<2 x i32> %x)
+ ret i32 %z
+}
+
+define i8 @test_umin_v16i8(<16 x i8> %x) {
+; CHECK-LABEL: test_umin_v16i8:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmin.u8 d16, d0, d1
+; CHECK-NEXT: vpmin.u8 d16, d16, d16
+; CHECK-NEXT: vpmin.u8 d16, d16, d16
+; CHECK-NEXT: vpmin.u8 d16, d16, d16
+; CHECK-NEXT: vmov.u8 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i8 @llvm.vector.reduce.umin.v16i8(<16 x i8> %x)
+ ret i8 %z
+}
+
+define i16 @test_smin_v8i16(<8 x i16> %x) {
+; CHECK-LABEL: test_smin_v8i16:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmin.s16 d16, d0, d1
+; CHECK-NEXT: vpmin.s16 d16, d16, d16
+; CHECK-NEXT: vpmin.s16 d16, d16, d16
+; CHECK-NEXT: vmov.s16 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i16 @llvm.vector.reduce.smin.v8i16(<8 x i16> %x)
+ ret i16 %z
+}
+
+define i32 @test_umax_v4i32(<4 x i32> %x) {
+; CHECK-LABEL: test_umax_v4i32:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vpmax.u32 d16, d0, d1
+; CHECK-NEXT: vpmax.u32 d16, d16, d16
+; CHECK-NEXT: vmov.32 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i32 @llvm.vector.reduce.umax.v4i32(<4 x i32> %x)
+ ret i32 %z
+}
+
+define i8 @test_umin_v32i8(<32 x i8> %x) {
+; CHECK-LABEL: test_umin_v32i8:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vmin.u8 q8, q0, q1
+; CHECK-NEXT: vpmin.u8 d16, d16, d17
+; CHECK-NEXT: vpmin.u8 d16, d16, d16
+; CHECK-NEXT: vpmin.u8 d16, d16, d16
+; CHECK-NEXT: vpmin.u8 d16, d16, d16
+; CHECK-NEXT: vmov.u8 r0, d16[0]
+; CHECK-NEXT: bx lr
+entry:
+ %z = call i8 @llvm.vector.reduce.umin.v32i8(<32 x i8> %x)
+ ret i8 %z
+}
+
+declare i8 @llvm.vector.reduce.umin.v8i8(<8 x i8>)
+declare i8 @llvm.vector.reduce.smin.v8i8(<8 x i8>)
+declare i8 @llvm.vector.reduce.umax.v8i8(<8 x i8>)
+declare i8 @llvm.vector.reduce.smax.v8i8(<8 x i8>)
+declare i16 @llvm.vector.reduce.umin.v4i16(<4 x i16>)
+declare i16 @llvm.vector.reduce.smin.v4i16(<4 x i16>)
+declare i16 @llvm.vector.reduce.umax.v4i16(<4 x i16>)
+declare i16 @llvm.vector.reduce.smax.v4i16(<4 x i16>)
+declare i32 @llvm.vector.reduce.umin.v2i32(<2 x i32>)
+declare i32 @llvm.vector.reduce.smin.v2i32(<2 x i32>)
+declare i32 @llvm.vector.reduce.umax.v2i32(<2 x i32>)
+declare i32 @llvm.vector.reduce.smax.v2i32(<2 x i32>)
+
+declare i8 @llvm.vector.reduce.umin.v16i8(<16 x i8>)
+declare i16 @llvm.vector.reduce.smin.v8i16(<8 x i16>)
+declare i32 @llvm.vector.reduce.umax.v4i32(<4 x i32>)
+
+declare i8 @llvm.vector.reduce.umin.v32i8(<32 x i8>)
More information about the llvm-commits
mailing list