[llvm] c3c2e1e - [AArch64][SVE] Add codegen support for partial reduction lowering to wide add instructions (#114406)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 05:53:40 PST 2024


Author: James Chesterman
Date: 2024-11-12T13:53:35Z
New Revision: c3c2e1e161b4f11a2070966453067584223427de

URL: https://github.com/llvm/llvm-project/commit/c3c2e1e161b4f11a2070966453067584223427de
DIFF: https://github.com/llvm/llvm-project/commit/c3c2e1e161b4f11a2070966453067584223427de.diff

LOG: [AArch64][SVE] Add codegen support for partial reduction lowering to wide add instructions (#114406)

For partial reductions in the situation of the number of elements
being halved, a pair of wide add instructions can be used.

Added: 
    llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 069aab274d3126..e7923ff02de704 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2039,8 +2039,13 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     return true;
 
   EVT VT = EVT::getEVT(I->getType());
-  return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
-         VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
+  auto Op1 = I->getOperand(1);
+  EVT Op1VT = EVT::getEVT(Op1->getType());
+  if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
+      (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
+       VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
+    return false;
+  return true;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21784,6 +21789,55 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 }
 
+SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
+                                          const AArch64Subtarget *Subtarget,
+                                          SelectionDAG &DAG) {
+
+  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+         getIntrinsicID(N) ==
+             Intrinsic::experimental_vector_partial_reduce_add &&
+         "Expected a partial reduction node");
+
+  if (!Subtarget->isSVEorStreamingSVEAvailable())
+    return SDValue();
+
+  SDLoc DL(N);
+
+  auto Acc = N->getOperand(1);
+  auto ExtInput = N->getOperand(2);
+
+  EVT AccVT = Acc.getValueType();
+  EVT AccElemVT = AccVT.getVectorElementType();
+
+  if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
+    return SDValue();
+
+  unsigned ExtInputOpcode = ExtInput->getOpcode();
+  if (!ISD::isExtOpcode(ExtInputOpcode))
+    return SDValue();
+
+  auto Input = ExtInput->getOperand(0);
+  EVT InputVT = Input.getValueType();
+
+  if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
+      !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
+      !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
+    return SDValue();
+
+  bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
+  auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb
+                                       : Intrinsic::aarch64_sve_uaddwb;
+  auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
+                                    : Intrinsic::aarch64_sve_uaddwt;
+
+  auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT);
+  auto BottomNode =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input);
+  auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT);
+  return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode,
+                     Input);
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -21795,6 +21849,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::experimental_vector_partial_reduce_add: {
     if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
       return Dot;
+    if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
+      return WideAdd;
     return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
                                    N->getOperand(1), N->getOperand(2));
   }

diff  --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
new file mode 100644
index 00000000000000..1d05649964670d
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -0,0 +1,141 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+
+define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
+; CHECK-LABEL: signed_wide_add_nxv4i32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    saddwb z0.d, z0.d, z1.s
+; CHECK-NEXT:    saddwt z0.d, z0.d, z1.s
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
+    %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
+    ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv4i32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uaddwb z0.d, z0.d, z1.s
+; CHECK-NEXT:    uaddwt z0.d, z0.d, z1.s
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
+    %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
+    ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
+; CHECK-LABEL: signed_wide_add_nxv8i16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    saddwb z0.s, z0.s, z1.h
+; CHECK-NEXT:    saddwt z0.s, z0.s, z1.h
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
+    %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
+    ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv8i16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uaddwb z0.s, z0.s, z1.h
+; CHECK-NEXT:    uaddwt z0.s, z0.s, z1.h
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
+    %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
+    ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
+; CHECK-LABEL: signed_wide_add_nxv16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    saddwb z0.h, z0.h, z1.b
+; CHECK-NEXT:    saddwt z0.h, z0.h, z1.b
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
+    %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
+    ret <vscale x 8 x i16> %partial.reduce
+}
+
+define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uaddwb z0.h, z0.h, z1.b
+; CHECK-NEXT:    uaddwt z0.h, z0.h, z1.b
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
+    %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
+    ret <vscale x 8 x i16> %partial.reduce
+}
+
+define <vscale x 2 x i32> @signed_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
+; CHECK-LABEL: signed_wide_add_nxv4i16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    sxth z1.s, p0/m, z1.s
+; CHECK-NEXT:    uunpklo z2.d, z1.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = sext <vscale x 4 x i16> %input to <vscale x 4 x i32>
+    %partial.reduce = tail call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32(<vscale x 2 x i32> %acc, <vscale x 4 x i32> %input.wide)
+    ret <vscale x 2 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i32> @unsigned_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv4i16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEXT:    uunpklo z2.d, z1.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = zext <vscale x 4 x i16> %input to <vscale x 4 x i32>
+    %partial.reduce = tail call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32(<vscale x 2 x i32> %acc, <vscale x 4 x i32> %input.wide)
+    ret <vscale x 2 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i64> @signed_wide_add_nxv8i32(<vscale x 4 x i64> %acc, <vscale x 8 x i32> %input){
+; CHECK-LABEL: signed_wide_add_nxv8i32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sunpkhi z4.d, z2.s
+; CHECK-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NEXT:    sunpkhi z5.d, z3.s
+; CHECK-NEXT:    sunpklo z3.d, z3.s
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    add z1.d, z1.d, z4.d
+; CHECK-NEXT:    add z0.d, z3.d, z0.d
+; CHECK-NEXT:    add z1.d, z5.d, z1.d
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = sext <vscale x 8 x i32> %input to <vscale x 8 x i64>
+    %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64(<vscale x 4 x i64> %acc, <vscale x 8 x i64> %input.wide)
+    ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @unsigned_wide_add_nxv8i32(<vscale x 4 x i64> %acc, <vscale x 8 x i32> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv8i32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uunpkhi z4.d, z2.s
+; CHECK-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NEXT:    uunpkhi z5.d, z3.s
+; CHECK-NEXT:    uunpklo z3.d, z3.s
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    add z1.d, z1.d, z4.d
+; CHECK-NEXT:    add z0.d, z3.d, z0.d
+; CHECK-NEXT:    add z1.d, z5.d, z1.d
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = zext <vscale x 8 x i32> %input to <vscale x 8 x i64>
+    %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64(<vscale x 4 x i64> %acc, <vscale x 8 x i64> %input.wide)
+    ret <vscale x 4 x i64> %partial.reduce
+}


        


More information about the llvm-commits mailing list