[llvm] 71dc3de - [ARM] Improve min/max vector reductions on Arm

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 22 09:00:26 PDT 2023


Author: Caleb Zulawski
Date: 2023-03-22T16:00:19Z
New Revision: 71dc3de533b9247223c083a3b058859c9759099c

URL: https://github.com/llvm/llvm-project/commit/71dc3de533b9247223c083a3b058859c9759099c
DIFF: https://github.com/llvm/llvm-project/commit/71dc3de533b9247223c083a3b058859c9759099c.diff

LOG: [ARM] Improve min/max vector reductions on Arm

This patch adds some more efficient lowering for vecreduce.min/max under NEON,
using sequences of pairwise vpmin/vpmax to reduce to a single value.

This nearly resolves issues such as #50466, #40981, #38190.

Differential Revision: https://reviews.llvm.org/D146404

Added: 
    llvm/test/CodeGen/ARM/vecreduce-minmax.ll

Modified: 
    llvm/lib/Target/ARM/ARMISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 126bbc61a7d30..9c5f0df4d9468 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -1007,6 +1007,14 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
         setLoadExtAction(ISD::SEXTLOAD, VT, Ty, Legal);
       }
     }
+
+    for (auto VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
+                    MVT::v4i32}) {
+      setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
+      setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
+      setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
+      setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
+    }
   }
 
   if (Subtarget->hasNEON() || Subtarget->hasMVEIntegerOps()) {
@@ -10271,6 +10279,80 @@ static SDValue LowerVecReduceF(SDValue Op, SelectionDAG &DAG,
   return LowerVecReduce(Op, DAG, ST);
 }
 
+static SDValue LowerVecReduceMinMax(SDValue Op, SelectionDAG &DAG,
+                                    const ARMSubtarget *ST) {
+  if (!ST->hasNEON())
+    return SDValue();
+
+  SDLoc dl(Op);
+  SDValue Op0 = Op->getOperand(0);
+  EVT VT = Op0.getValueType();
+  EVT EltVT = VT.getVectorElementType();
+
+  unsigned PairwiseIntrinsic = 0;
+  switch (Op->getOpcode()) {
+  default:
+    llvm_unreachable("Expected VECREDUCE opcode");
+  case ISD::VECREDUCE_UMIN:
+    PairwiseIntrinsic = Intrinsic::arm_neon_vpminu;
+    break;
+  case ISD::VECREDUCE_UMAX:
+    PairwiseIntrinsic = Intrinsic::arm_neon_vpmaxu;
+    break;
+  case ISD::VECREDUCE_SMIN:
+    PairwiseIntrinsic = Intrinsic::arm_neon_vpmins;
+    break;
+  case ISD::VECREDUCE_SMAX:
+    PairwiseIntrinsic = Intrinsic::arm_neon_vpmaxs;
+    break;
+  }
+  SDValue PairwiseOp = DAG.getConstant(PairwiseIntrinsic, dl, MVT::i32);
+
+  unsigned NumElts = VT.getVectorNumElements();
+  unsigned NumActiveLanes = NumElts;
+
+  assert((NumActiveLanes == 16 || NumActiveLanes == 8 || NumActiveLanes == 4 ||
+          NumActiveLanes == 2) &&
+         "Only expected a power 2 vector size");
+
+  // Split 128-bit vectors, since vpmin/max takes 2 64-bit vectors.
+  if (VT.is128BitVector()) {
+    SDValue Lo, Hi;
+    std::tie(Lo, Hi) = DAG.SplitVector(Op0, dl);
+    VT = Lo.getValueType();
+    Op0 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, {PairwiseOp, Lo, Hi});
+    NumActiveLanes /= 2;
+  }
+
+  // Use pairwise reductions until one lane remains
+  while (NumActiveLanes > 1) {
+    Op0 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, {PairwiseOp, Op0, Op0});
+    NumActiveLanes /= 2;
+  }
+
+  SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Op0,
+                            DAG.getConstant(0, dl, MVT::i32));
+
+  // Result type may be wider than element type.
+  if (EltVT != Op.getValueType()) {
+    unsigned Extend = 0;
+    switch (Op->getOpcode()) {
+    default:
+      llvm_unreachable("Expected VECREDUCE opcode");
+    case ISD::VECREDUCE_UMIN:
+    case ISD::VECREDUCE_UMAX:
+      Extend = ISD::ZERO_EXTEND;
+      break;
+    case ISD::VECREDUCE_SMIN:
+    case ISD::VECREDUCE_SMAX:
+      Extend = ISD::SIGN_EXTEND;
+      break;
+    }
+    Res = DAG.getNode(Extend, dl, Op.getValueType(), Res);
+  }
+  return Res;
+}
+
 static SDValue LowerAtomicLoadStore(SDValue Op, SelectionDAG &DAG) {
   if (isStrongerThanMonotonic(cast<AtomicSDNode>(Op)->getSuccessOrdering()))
     // Acquire/Release load/store is not legal for targets without a dmb or
@@ -10502,6 +10584,11 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::VECREDUCE_FMIN:
   case ISD::VECREDUCE_FMAX:
     return LowerVecReduceF(Op, DAG, Subtarget);
+  case ISD::VECREDUCE_UMIN:
+  case ISD::VECREDUCE_UMAX:
+  case ISD::VECREDUCE_SMIN:
+  case ISD::VECREDUCE_SMAX:
+    return LowerVecReduceMinMax(Op, DAG, Subtarget);
   case ISD::ATOMIC_LOAD:
   case ISD::ATOMIC_STORE:  return LowerAtomicLoadStore(Op, DAG);
   case ISD::FSINCOS:       return LowerFSINCOS(Op, DAG);

diff  --git a/llvm/test/CodeGen/ARM/vecreduce-minmax.ll b/llvm/test/CodeGen/ARM/vecreduce-minmax.ll
new file mode 100644
index 0000000000000..c392e6ca6bfa6
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/vecreduce-minmax.ll
@@ -0,0 +1,219 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=armv7-none-eabi -float-abi=hard -mattr=+neon -verify-machineinstrs | FileCheck %s
+
+define i8 @test_umin_v8i8(<8 x i8> %x) {
+; CHECK-LABEL: test_umin_v8i8:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmin.u8 d16, d0, d0
+; CHECK-NEXT:    vpmin.u8 d16, d16, d16
+; CHECK-NEXT:    vpmin.u8 d16, d16, d16
+; CHECK-NEXT:    vmov.u8 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i8 @llvm.vector.reduce.umin.v8i8(<8 x i8> %x)
+  ret i8 %z
+}
+
+define i8 @test_smin_v8i8(<8 x i8> %x) {
+; CHECK-LABEL: test_smin_v8i8:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmin.s8 d16, d0, d0
+; CHECK-NEXT:    vpmin.s8 d16, d16, d16
+; CHECK-NEXT:    vpmin.s8 d16, d16, d16
+; CHECK-NEXT:    vmov.s8 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i8 @llvm.vector.reduce.smin.v8i8(<8 x i8> %x)
+  ret i8 %z
+}
+
+define i8 @test_umax_v8i8(<8 x i8> %x) {
+; CHECK-LABEL: test_umax_v8i8:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmax.u8 d16, d0, d0
+; CHECK-NEXT:    vpmax.u8 d16, d16, d16
+; CHECK-NEXT:    vpmax.u8 d16, d16, d16
+; CHECK-NEXT:    vmov.u8 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i8 @llvm.vector.reduce.umax.v8i8(<8 x i8> %x)
+  ret i8 %z
+}
+
+define i8 @test_smax_v8i8(<8 x i8> %x) {
+; CHECK-LABEL: test_smax_v8i8:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmax.s8 d16, d0, d0
+; CHECK-NEXT:    vpmax.s8 d16, d16, d16
+; CHECK-NEXT:    vpmax.s8 d16, d16, d16
+; CHECK-NEXT:    vmov.s8 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i8 @llvm.vector.reduce.smax.v8i8(<8 x i8> %x)
+  ret i8 %z
+}
+
+define i16 @test_umin_v4i16(<4 x i16> %x) {
+; CHECK-LABEL: test_umin_v4i16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmin.u16 d16, d0, d0
+; CHECK-NEXT:    vpmin.u16 d16, d16, d16
+; CHECK-NEXT:    vmov.u16 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i16 @llvm.vector.reduce.umin.v4i16(<4 x i16> %x)
+  ret i16 %z
+}
+
+define i16 @test_smin_v4i16(<4 x i16> %x) {
+; CHECK-LABEL: test_smin_v4i16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmin.s16 d16, d0, d0
+; CHECK-NEXT:    vpmin.s16 d16, d16, d16
+; CHECK-NEXT:    vmov.s16 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i16 @llvm.vector.reduce.smin.v4i16(<4 x i16> %x)
+  ret i16 %z
+}
+
+define i16 @test_umax_v4i16(<4 x i16> %x) {
+; CHECK-LABEL: test_umax_v4i16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmax.u16 d16, d0, d0
+; CHECK-NEXT:    vpmax.u16 d16, d16, d16
+; CHECK-NEXT:    vmov.u16 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i16 @llvm.vector.reduce.umax.v4i16(<4 x i16> %x)
+  ret i16 %z
+}
+
+define i16 @test_smax_v4i16(<4 x i16> %x) {
+; CHECK-LABEL: test_smax_v4i16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmax.s16 d16, d0, d0
+; CHECK-NEXT:    vpmax.s16 d16, d16, d16
+; CHECK-NEXT:    vmov.s16 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i16 @llvm.vector.reduce.smax.v4i16(<4 x i16> %x)
+  ret i16 %z
+}
+
+define i32 @test_umin_v2i32(<2 x i32> %x) {
+; CHECK-LABEL: test_umin_v2i32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmin.u32 d16, d0, d0
+; CHECK-NEXT:    vmov.32 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i32 @llvm.vector.reduce.umin.v2i32(<2 x i32> %x)
+  ret i32 %z
+}
+
+define i32 @test_smin_v2i32(<2 x i32> %x) {
+; CHECK-LABEL: test_smin_v2i32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmin.s32 d16, d0, d0
+; CHECK-NEXT:    vmov.32 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i32 @llvm.vector.reduce.smin.v2i32(<2 x i32> %x)
+  ret i32 %z
+}
+
+define i32 @test_umax_v2i32(<2 x i32> %x) {
+; CHECK-LABEL: test_umax_v2i32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmax.u32 d16, d0, d0
+; CHECK-NEXT:    vmov.32 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i32 @llvm.vector.reduce.umax.v2i32(<2 x i32> %x)
+  ret i32 %z
+}
+
+define i32 @test_smax_v2i32(<2 x i32> %x) {
+; CHECK-LABEL: test_smax_v2i32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmax.s32 d16, d0, d0
+; CHECK-NEXT:    vmov.32 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i32 @llvm.vector.reduce.smax.v2i32(<2 x i32> %x)
+  ret i32 %z
+}
+
+define i8 @test_umin_v16i8(<16 x i8> %x) {
+; CHECK-LABEL: test_umin_v16i8:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmin.u8 d16, d0, d1
+; CHECK-NEXT:    vpmin.u8 d16, d16, d16
+; CHECK-NEXT:    vpmin.u8 d16, d16, d16
+; CHECK-NEXT:    vpmin.u8 d16, d16, d16
+; CHECK-NEXT:    vmov.u8 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i8 @llvm.vector.reduce.umin.v16i8(<16 x i8> %x)
+  ret i8 %z
+}
+
+define i16 @test_smin_v8i16(<8 x i16> %x) {
+; CHECK-LABEL: test_smin_v8i16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmin.s16 d16, d0, d1
+; CHECK-NEXT:    vpmin.s16 d16, d16, d16
+; CHECK-NEXT:    vpmin.s16 d16, d16, d16
+; CHECK-NEXT:    vmov.s16 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i16 @llvm.vector.reduce.smin.v8i16(<8 x i16> %x)
+  ret i16 %z
+}
+
+define i32 @test_umax_v4i32(<4 x i32> %x) {
+; CHECK-LABEL: test_umax_v4i32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vpmax.u32 d16, d0, d1
+; CHECK-NEXT:    vpmax.u32 d16, d16, d16
+; CHECK-NEXT:    vmov.32 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i32 @llvm.vector.reduce.umax.v4i32(<4 x i32> %x)
+  ret i32 %z
+}
+
+define i8 @test_umin_v32i8(<32 x i8> %x) {
+; CHECK-LABEL: test_umin_v32i8:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmin.u8 q8, q0, q1
+; CHECK-NEXT:    vpmin.u8 d16, d16, d17
+; CHECK-NEXT:    vpmin.u8 d16, d16, d16
+; CHECK-NEXT:    vpmin.u8 d16, d16, d16
+; CHECK-NEXT:    vpmin.u8 d16, d16, d16
+; CHECK-NEXT:    vmov.u8 r0, d16[0]
+; CHECK-NEXT:    bx lr
+entry:
+  %z = call i8 @llvm.vector.reduce.umin.v32i8(<32 x i8> %x)
+  ret i8 %z
+}
+
+declare i8 @llvm.vector.reduce.umin.v8i8(<8 x i8>)
+declare i8 @llvm.vector.reduce.smin.v8i8(<8 x i8>)
+declare i8 @llvm.vector.reduce.umax.v8i8(<8 x i8>)
+declare i8 @llvm.vector.reduce.smax.v8i8(<8 x i8>)
+declare i16 @llvm.vector.reduce.umin.v4i16(<4 x i16>)
+declare i16 @llvm.vector.reduce.smin.v4i16(<4 x i16>)
+declare i16 @llvm.vector.reduce.umax.v4i16(<4 x i16>)
+declare i16 @llvm.vector.reduce.smax.v4i16(<4 x i16>)
+declare i32 @llvm.vector.reduce.umin.v2i32(<2 x i32>)
+declare i32 @llvm.vector.reduce.smin.v2i32(<2 x i32>)
+declare i32 @llvm.vector.reduce.umax.v2i32(<2 x i32>)
+declare i32 @llvm.vector.reduce.smax.v2i32(<2 x i32>)
+
+declare i8 @llvm.vector.reduce.umin.v16i8(<16 x i8>)
+declare i16 @llvm.vector.reduce.smin.v8i16(<8 x i16>)
+declare i32 @llvm.vector.reduce.umax.v4i32(<4 x i32>)
+
+declare i8 @llvm.vector.reduce.umin.v32i8(<32 x i8>)


        


More information about the llvm-commits mailing list