[llvm] [AArch64][SVE] Fold zero-extend into add reduction. (PR #102325)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 7 08:47:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

<details>
<summary>Changes</summary>

The original pull-request #<!-- -->97339 from @<!-- -->dtemirbulatov got reverted due to some regressions with NEON. This PR removes the additional type legalisation for vectors that are too wide, which should make this change less contentious.

---
Full diff: https://github.com/llvm/llvm-project/pull/102325.diff


5 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+43-2) 
- (modified) llvm/test/CodeGen/AArch64/sve-doublereduct.ll (+11-17) 
- (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-int-reduce.ll (+28) 
- (modified) llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll (+1-6) 
- (modified) llvm/test/CodeGen/AArch64/sve-int-reduce.ll (+38) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2e869f11b84314..3b6a177add91d9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17858,6 +17858,44 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
   return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
 }
 
+// Turn vecreduce_add(zext/sext(...)) into SVE's  [US]ADDV instruction.
+static SDValue
+performVecReduceAddExtCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                              const AArch64TargetLowering &TLI) {
+  if (N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND &&
+      N->getOperand(0).getOpcode() != ISD::SIGN_EXTEND)
+    return SDValue();
+
+  SDValue VecOp = N->getOperand(0).getOperand(0);
+  EVT VecOpVT = VecOp.getValueType();
+  if (VecOpVT.getScalarType() == MVT::i1 || !TLI.isTypeLegal(VecOpVT) ||
+      (VecOpVT.isFixedLengthVector() &&
+       !TLI.useSVEForFixedLengthVectorVT(VecOpVT, /*OverrideNEON=*/true)))
+    return SDValue();
+
+  SDLoc DL(N);
+  SelectionDAG &DAG = DCI.DAG;
+
+  // The input type is legal so map VECREDUCE_ADD to UADDV/SADDV, e.g.
+  // i32 (vecreduce_add (zext nxv16i8 %op to nxv16i32))
+  // ->
+  // i32 (UADDV nxv16i8:%op)
+  EVT ElemType = N->getValueType(0);
+  SDValue Pg = getPredicateForVector(DAG, DL, VecOpVT);
+  if (VecOpVT.isFixedLengthVector()) {
+    EVT ContainerVT = getContainerForFixedLengthVector(DAG, VecOpVT);
+    VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
+  }
+  bool IsSigned = N->getOperand(0).getOpcode() == ISD::SIGN_EXTEND;
+  SDValue Res =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
+                  DAG.getConstant(IsSigned ? Intrinsic::aarch64_sve_saddv
+                                           : Intrinsic::aarch64_sve_uaddv,
+                                  DL, MVT::i64),
+                  Pg, VecOp);
+  return DAG.getAnyExtOrTrunc(Res, DL, ElemType);
+}
+
 // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
 //   vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
 //   vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
@@ -25546,8 +25584,11 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performInsertVectorEltCombine(N, DCI);
   case ISD::EXTRACT_VECTOR_ELT:
     return performExtractVectorEltCombine(N, DCI, Subtarget);
-  case ISD::VECREDUCE_ADD:
-    return performVecReduceAddCombine(N, DCI.DAG, Subtarget);
+  case ISD::VECREDUCE_ADD: {
+    if (SDValue Val = performVecReduceAddCombine(N, DCI.DAG, Subtarget))
+      return Val;
+    return performVecReduceAddExtCombine(N, DCI, *this);
+  }
   case AArch64ISD::UADDV:
     return performUADDVCombine(N, DAG);
   case AArch64ISD::SMULL:
diff --git a/llvm/test/CodeGen/AArch64/sve-doublereduct.ll b/llvm/test/CodeGen/AArch64/sve-doublereduct.ll
index 7bc31d44bb6547..84a60f853ce440 100644
--- a/llvm/test/CodeGen/AArch64/sve-doublereduct.ll
+++ b/llvm/test/CodeGen/AArch64/sve-doublereduct.ll
@@ -103,17 +103,12 @@ define i32 @add_i32(<vscale x 8 x i32> %a, <vscale x 4 x i32> %b) {
 define i16 @add_ext_i16(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: add_ext_i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    uunpkhi z2.h, z0.b
-; CHECK-NEXT:    uunpklo z0.h, z0.b
-; CHECK-NEXT:    uunpkhi z3.h, z1.b
-; CHECK-NEXT:    uunpklo z1.h, z1.b
-; CHECK-NEXT:    ptrue p0.h
-; CHECK-NEXT:    add z0.h, z0.h, z2.h
-; CHECK-NEXT:    add z1.h, z1.h, z3.h
-; CHECK-NEXT:    add z0.h, z0.h, z1.h
-; CHECK-NEXT:    uaddv d0, p0, z0.h
-; CHECK-NEXT:    fmov x0, d0
-; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    uaddv d0, p0, z0.b
+; CHECK-NEXT:    uaddv d1, p0, z1.b
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    fmov w9, s1
+; CHECK-NEXT:    add w0, w8, w9
 ; CHECK-NEXT:    ret
   %ae = zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
   %be = zext <vscale x 16 x i8> %b to <vscale x 16 x i16>
@@ -130,17 +125,16 @@ define i16 @add_ext_v32i16(<vscale x 32 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-NEXT:    uunpklo z4.h, z0.b
 ; CHECK-NEXT:    uunpkhi z1.h, z1.b
 ; CHECK-NEXT:    uunpkhi z0.h, z0.b
-; CHECK-NEXT:    uunpkhi z5.h, z2.b
-; CHECK-NEXT:    uunpklo z2.h, z2.b
 ; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    ptrue p1.b
 ; CHECK-NEXT:    add z0.h, z0.h, z1.h
 ; CHECK-NEXT:    add z1.h, z4.h, z3.h
 ; CHECK-NEXT:    add z0.h, z1.h, z0.h
-; CHECK-NEXT:    add z1.h, z2.h, z5.h
-; CHECK-NEXT:    add z0.h, z0.h, z1.h
+; CHECK-NEXT:    uaddv d1, p1, z2.b
 ; CHECK-NEXT:    uaddv d0, p0, z0.h
-; CHECK-NEXT:    fmov x0, d0
-; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    fmov w9, s1
+; CHECK-NEXT:    fmov x8, d0
+; CHECK-NEXT:    add w0, w8, w9
 ; CHECK-NEXT:    ret
   %ae = zext <vscale x 32 x i8> %a to <vscale x 32 x i16>
   %be = zext <vscale x 16 x i8> %b to <vscale x 16 x i16>
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-int-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-int-reduce.ll
index 752c2cd34bfe48..350f219d5446d7 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-int-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-int-reduce.ll
@@ -364,6 +364,34 @@ define i64 @uaddv_v32i64(ptr %a) vscale_range(16,0) #0 {
   ret i64 %res
 }
 
+define i64 @uaddv_zext_v32i8_v32i64(ptr %a) vscale_range(2,0) #0 {
+; CHECK-LABEL: uaddv_zext_v32i8_v32i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.b, vl32
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT:    uaddv d0, p0, z0.b
+; CHECK-NEXT:    fmov x0, d0
+; CHECK-NEXT:    ret
+  %op = load <32 x i8>, ptr %a
+  %op.zext = zext <32 x i8> %op to <32 x i64>
+  %res = call i64 @llvm.vector.reduce.add.v32i64(<32 x i64> %op.zext)
+  ret i64 %res
+}
+
+define i64 @uaddv_zext_v64i8_v64i64(ptr %a) vscale_range(4,0) #0 {
+; CHECK-LABEL: uaddv_zext_v64i8_v64i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.b, vl64
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT:    uaddv d0, p0, z0.b
+; CHECK-NEXT:    fmov x0, d0
+; CHECK-NEXT:    ret
+  %op = load <64 x i8>, ptr %a
+  %op.zext = zext <64 x i8> %op to <64 x i64>
+  %res = call i64 @llvm.vector.reduce.add.v64i64(<64 x i64> %op.zext)
+  ret i64 %res
+}
+
 ;
 ; SMAXV
 ;
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll b/llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll
index 1ab2589bccd5ba..f8923ffc443155 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll
@@ -1,4 +1,3 @@
-
 ; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-v1 -O3 -aarch64-sve-vector-bits-min=256 -verify-machineinstrs | FileCheck %s --check-prefixes=SVE256
 ; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-v1 -O3 -aarch64-sve-vector-bits-min=128 -verify-machineinstrs | FileCheck %s --check-prefixes=NEON
 ; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-n1 -O3 -verify-machineinstrs | FileCheck %s --check-prefixes=NEON
@@ -9,11 +8,7 @@ define internal i32 @test(ptr nocapture readonly %p1, i32 %i1, ptr nocapture rea
 ; SVE256:       ld1b    { z0.h }, p0/z,
 ; SVE256:       ld1b    { z1.h }, p0/z,
 ; SVE256:       sub z0.h, z0.h, z1.h
-; SVE256-NEXT:  sunpklo z1.s, z0.h
-; SVE256-NEXT:  ext z0.b, z0.b, z0.b, #16
-; SVE256-NEXT:  sunpklo z0.s, z0.h
-; SVE256-NEXT:  add z0.s, z1.s, z0.s
-; SVE256-NEXT:  uaddv   d0, p1, z0.s
+; SVE256-NEXT:  saddv d0, p0, z0.h
 
 ; NEON-LABEL: test:
 ; NEON:       ldr q0, [x0, w9, sxtw]
diff --git a/llvm/test/CodeGen/AArch64/sve-int-reduce.ll b/llvm/test/CodeGen/AArch64/sve-int-reduce.ll
index 8c1b5225b7f257..c11e8bf5d10c37 100644
--- a/llvm/test/CodeGen/AArch64/sve-int-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-int-reduce.ll
@@ -188,6 +188,44 @@ define i64 @uaddv_nxv2i64(<vscale x 2 x i64> %a) {
   ret i64 %res
 }
 
+define i16 @uaddv_nxv16i8_nxv16i16(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: uaddv_nxv16i8_nxv16i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    uaddv d0, p0, z0.b
+; CHECK-NEXT:    fmov x0, d0
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+  %1 = zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
+  %2 = call i16 @llvm.vector.reduce.add.nxv16i16(<vscale x 16 x i16> %1)
+  ret i16 %2
+}
+
+define i32 @uaddv_nxv16i8_nxv16i32(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: uaddv_nxv16i8_nxv16i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    uaddv d0, p0, z0.b
+; CHECK-NEXT:    fmov x0, d0
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+  %1 = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %2 = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> %1)
+  ret i32 %2
+}
+
+define i64 @uaddv_nxv16i8_nxv16i64(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: uaddv_nxv16i8_nxv16i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    uaddv d0, p0, z0.b
+; CHECK-NEXT:    fmov x0, d0
+; CHECK-NEXT:    ret
+  %1 = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %2 = call i64 @llvm.vector.reduce.add.nxv16i64(<vscale x 16 x i64> %1)
+  ret i64 %2
+}
+
 ; UMINV
 
 define i8 @umin_nxv16i8(<vscale x 16 x i8> %a) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/102325


More information about the llvm-commits mailing list