[llvm] [AArch64][SelectionDAG] Add type legalization for partial reduce wide adds (PR #141075)

Nicholas Guy via llvm-commits llvm-commits at lists.llvm.org
Thu May 29 04:56:12 PDT 2025


https://github.com/NickGuy-Arm updated https://github.com/llvm/llvm-project/pull/141075

>From 4ffe5e5cb8f33bfc0f5eceb62b75ea1288373ff9 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Wed, 28 May 2025 16:33:04 +0100
Subject: [PATCH 1/7] [AArch64][SelectionDAG] Add type legalization for partial
 reduce wide adds

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  12 +-
 .../Target/AArch64/AArch64ISelLowering.cpp    |  35 ++
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  13 +
 .../AArch64/sve-partial-reduce-dot-product.ll | 206 +++++++----
 .../AArch64/sve-partial-reduce-wide-add.ll    | 322 +++++++++++++-----
 5 files changed, 426 insertions(+), 162 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 9e418329d15be..af504df596615 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12676,6 +12676,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
+  EVT ResultVT = N->getValueType(0);
+
   bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
   unsigned NewOpcode =
       ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
@@ -12689,7 +12691,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
         (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
       return SDValue();
 
-    return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+    return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp,
                        DAG.getConstant(CTrunc, DL, LHSExtOpVT));
   }
 
@@ -12710,8 +12712,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
-                     RHSExtOp);
+  return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, RHSExtOp);
 }
 
 // partial.reduce.umla(acc, zext(op), splat(1))
@@ -12735,7 +12736,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
 
   SDValue UnextOp1 = Op1.getOperand(0);
   EVT UnextOp1VT = UnextOp1.getValueType();
-  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+  auto *Context = DAG.getContext();
+  if (!TLI.isPartialReduceMLALegalOrCustom(
+          TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+          TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
     return SDValue();
 
   bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a817ed5f0e917..fecdfe95a082d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1885,6 +1885,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
 
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+
+    // Wide add types
+    if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
+      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Custom);
+      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Custom);
+      setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Custom);
+    }
   }
 
   // Handle operations that are only available in non-streaming SVE mode.
@@ -29230,6 +29237,34 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
   SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
 
+  // Recognise Op as a wide add, if it is then we leave it as-is
+  // Base: nxv2i64, Subdivision: nxv4i32
+  auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
+    assert(Base.isVector() && Subdivision.isVector());
+    assert(Base.isScalableVector() == Subdivision.isScalableVector());
+
+    ElementCount BaseCount = Base.getVectorElementCount();
+    ElementCount SubCount = Subdivision.getVectorElementCount();
+    if (BaseCount * 2 != SubCount)
+      return false;
+
+    uint64_t BaseScalarSize = Base.getScalarSizeInBits();
+    uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
+    if (BaseScalarSize != SubScalarSize * 2)
+      return false;
+
+    return true;
+  };
+  if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
+    // If it looks like a real wide add, we can leave it as-is and treat it as
+    // Legal
+    APInt C;
+    if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
+      return Op;
+    // If it doesn't, then we need to expand it.
+    return SDValue();
+  }
+
   assert((Scalable && ResultVT == MVT::nxv2i64 &&
           LHS.getValueType() == MVT::nxv16i8) ||
          (!Scalable && ResultVT == MVT::v2i64 &&
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index a40ef56f30486..1b1a24394e1f1 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3813,6 +3813,19 @@ let Predicates = [HasSVE2_or_SME] in {
   defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>;
   defm USUBWT_ZZZ : sve2_wide_int_arith_wide<0b111, "usubwt", int_aarch64_sve_usubwt>;
 
+  def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+
   // SVE2 integer multiply long
   defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>;
   defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 809a45045b0db..a45b8b710c63a 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -561,31 +561,34 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
 ; CHECK-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
-; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT:    udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT:    udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
   %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
   ret <vscale x 4 x i64> %partial.reduce
@@ -603,31 +606,34 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
 ; CHECK-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
-; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT:    sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT:    sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
   %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
   ret <vscale x 4 x i64> %partial.reduce
@@ -647,18 +653,44 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: not_udot:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    and z1.h, z1.h, #0xff
-; CHECK-NEWLOWERING-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
   %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
@@ -681,18 +713,44 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: not_udot_wide:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    and z1.s, z1.s, #0xffff
-; CHECK-NEWLOWERING-NEXT:    and z2.s, z2.s, #0xffff
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z4.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
   %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index 5148d3da6c737..8f9f26a5d5b23 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -1,7 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE2
-; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK-SVE2
+; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE2
 
 define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
 ; CHECK-SVE2-LABEL: signed_wide_add_nxv4i32:
@@ -18,13 +19,19 @@ define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vsc
 ; CHECK-SVE-NEXT:    add z0.d, z0.d, z1.d
 ; CHECK-SVE-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv4i32:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z1.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
     %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
@@ -46,13 +53,19 @@ define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <v
 ; CHECK-SVE-NEXT:    add z0.d, z0.d, z1.d
 ; CHECK-SVE-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv4i32:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z1.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
     %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
@@ -74,13 +87,19 @@ define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vsc
 ; CHECK-SVE-NEXT:    add z0.s, z0.s, z1.s
 ; CHECK-SVE-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv8i16:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv8i16:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv8i16:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
     %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
@@ -102,13 +121,19 @@ define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <v
 ; CHECK-SVE-NEXT:    add z0.s, z0.s, z1.s
 ; CHECK-SVE-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv8i16:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv8i16:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv8i16:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
     %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
@@ -130,13 +155,19 @@ define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vsc
 ; CHECK-SVE-NEXT:    add z0.h, z0.h, z1.h
 ; CHECK-SVE-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv16i8:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    add z0.h, z0.h, z2.h
-; CHECK-NEWLOWERING-NEXT:    add z0.h, z0.h, z1.h
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv16i8:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.h, z0.h, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.h, z0.h, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv16i8:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
     %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
@@ -158,13 +189,19 @@ define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <v
 ; CHECK-SVE-NEXT:    add z0.h, z0.h, z1.h
 ; CHECK-SVE-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv16i8:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    add z0.h, z0.h, z2.h
-; CHECK-NEWLOWERING-NEXT:    add z0.h, z0.h, z1.h
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv16i8:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.h, z0.h, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.h, z0.h, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv16i8:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
     %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
@@ -172,15 +209,43 @@ entry:
 }
 
 define <vscale x 2 x i32> @signed_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
-; CHECK-LABEL: signed_wide_add_nxv4i16:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    sxth z1.s, p0/m, z1.s
-; CHECK-NEXT:    uunpklo z2.d, z1.s
-; CHECK-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z0.d, z1.d, z0.d
-; CHECK-NEXT:    ret
+; CHECK-SVE2-LABEL: signed_wide_add_nxv4i16:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    ptrue p0.s
+; CHECK-SVE2-NEXT:    sxth z1.s, p0/m, z1.s
+; CHECK-SVE2-NEXT:    uunpklo z2.d, z1.s
+; CHECK-SVE2-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-SVE2-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-SVE2-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-SVE-LABEL: signed_wide_add_nxv4i16:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    ptrue p0.s
+; CHECK-SVE-NEXT:    sxth z1.s, p0/m, z1.s
+; CHECK-SVE-NEXT:    uunpklo z2.d, z1.s
+; CHECK-SVE-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-SVE-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv4i16:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sxth z1.s, p0/m, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv4i16:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sxth z1.s, p0/m, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 4 x i16> %input to <vscale x 4 x i32>
     %partial.reduce = tail call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32(<vscale x 2 x i32> %acc, <vscale x 4 x i32> %input.wide)
@@ -188,14 +253,39 @@ entry:
 }
 
 define <vscale x 2 x i32> @unsigned_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
-; CHECK-LABEL: unsigned_wide_add_nxv4i16:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    and z1.s, z1.s, #0xffff
-; CHECK-NEXT:    uunpklo z2.d, z1.s
-; CHECK-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z0.d, z1.d, z0.d
-; CHECK-NEXT:    ret
+; CHECK-SVE2-LABEL: unsigned_wide_add_nxv4i16:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-SVE2-NEXT:    uunpklo z2.d, z1.s
+; CHECK-SVE2-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-SVE2-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-SVE2-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-SVE-LABEL: unsigned_wide_add_nxv4i16:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-SVE-NEXT:    uunpklo z2.d, z1.s
+; CHECK-SVE-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-SVE-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv4i16:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv4i16:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = zext <vscale x 4 x i16> %input to <vscale x 4 x i32>
     %partial.reduce = tail call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32(<vscale x 2 x i32> %acc, <vscale x 4 x i32> %input.wide)
@@ -203,17 +293,49 @@ entry:
 }
 
 define <vscale x 4 x i64> @signed_wide_add_nxv8i32(<vscale x 4 x i64> %acc, <vscale x 8 x i32> %input){
-; CHECK-LABEL: signed_wide_add_nxv8i32:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sunpklo z4.d, z3.s
-; CHECK-NEXT:    sunpklo z5.d, z2.s
-; CHECK-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEXT:    add z0.d, z0.d, z5.d
-; CHECK-NEXT:    add z1.d, z1.d, z4.d
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEXT:    ret
+; CHECK-SVE2-LABEL: signed_wide_add_nxv8i32:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    sunpklo z4.d, z3.s
+; CHECK-SVE2-NEXT:    sunpklo z5.d, z2.s
+; CHECK-SVE2-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-SVE2-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-SVE2-NEXT:    add z0.d, z0.d, z5.d
+; CHECK-SVE2-NEXT:    add z1.d, z1.d, z4.d
+; CHECK-SVE2-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-SVE2-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-SVE-LABEL: signed_wide_add_nxv8i32:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    sunpklo z4.d, z3.s
+; CHECK-SVE-NEXT:    sunpklo z5.d, z2.s
+; CHECK-SVE-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-SVE-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-SVE-NEXT:    add z0.d, z0.d, z5.d
+; CHECK-SVE-NEXT:    add z1.d, z1.d, z4.d
+; CHECK-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-SVE-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv8i32:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z4.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z5.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z5.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z1.d, z1.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv8i32:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z1.d, z1.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z1.d, z1.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 8 x i32> %input to <vscale x 8 x i64>
     %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64(<vscale x 4 x i64> %acc, <vscale x 8 x i64> %input.wide)
@@ -221,17 +343,49 @@ entry:
 }
 
 define <vscale x 4 x i64> @unsigned_wide_add_nxv8i32(<vscale x 4 x i64> %acc, <vscale x 8 x i32> %input){
-; CHECK-LABEL: unsigned_wide_add_nxv8i32:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    uunpklo z4.d, z3.s
-; CHECK-NEXT:    uunpklo z5.d, z2.s
-; CHECK-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEXT:    add z0.d, z0.d, z5.d
-; CHECK-NEXT:    add z1.d, z1.d, z4.d
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEXT:    ret
+; CHECK-SVE2-LABEL: unsigned_wide_add_nxv8i32:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    uunpklo z4.d, z3.s
+; CHECK-SVE2-NEXT:    uunpklo z5.d, z2.s
+; CHECK-SVE2-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-SVE2-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-SVE2-NEXT:    add z0.d, z0.d, z5.d
+; CHECK-SVE2-NEXT:    add z1.d, z1.d, z4.d
+; CHECK-SVE2-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-SVE2-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-SVE-LABEL: unsigned_wide_add_nxv8i32:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    uunpklo z4.d, z3.s
+; CHECK-SVE-NEXT:    uunpklo z5.d, z2.s
+; CHECK-SVE-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-SVE-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-SVE-NEXT:    add z0.d, z0.d, z5.d
+; CHECK-SVE-NEXT:    add z1.d, z1.d, z4.d
+; CHECK-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-SVE-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv8i32:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z4.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z5.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z5.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z1.d, z1.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv8i32:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z1.d, z1.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z1.d, z1.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = zext <vscale x 8 x i32> %input to <vscale x 8 x i64>
     %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64(<vscale x 4 x i64> %acc, <vscale x 8 x i64> %input.wide)

>From 6a7b359b3346c0751fb6728d3d2aa31ecaed46dd Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 27 May 2025 17:29:00 +0100
Subject: [PATCH 2/7] Replace custom lowering with tablegen patterns.

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  7 +--
 .../Target/AArch64/AArch64ISelLowering.cpp    | 63 ++++++++++---------
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 13 ++++
 .../AArch64/sve-partial-reduce-dot-product.ll | 44 ++++---------
 4 files changed, 60 insertions(+), 67 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index af504df596615..b0ce39010c97c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12676,8 +12676,6 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
-  EVT ResultVT = N->getValueType(0);
-
   bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
   unsigned NewOpcode =
       ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
@@ -12691,7 +12689,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
         (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
       return SDValue();
 
-    return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp,
+    return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
                        DAG.getConstant(CTrunc, DL, LHSExtOpVT));
   }
 
@@ -12712,7 +12710,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, RHSExtOp);
+  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+                     RHSExtOp);
 }
 
 // partial.reduce.umla(acc, zext(op), splat(1))
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fecdfe95a082d..0120eba2c894c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1888,9 +1888,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     // Wide add types
     if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
-      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Custom);
-      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Custom);
-      setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Custom);
+      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal);
+      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal);
+      setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal);
     }
   }
 
@@ -29236,34 +29236,35 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
   SDValue LHS = Op.getOperand(1);
   SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
-
-  // Recognise Op as a wide add, if it is then we leave it as-is
-  // Base: nxv2i64, Subdivision: nxv4i32
-  auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
-    assert(Base.isVector() && Subdivision.isVector());
-    assert(Base.isScalableVector() == Subdivision.isScalableVector());
-
-    ElementCount BaseCount = Base.getVectorElementCount();
-    ElementCount SubCount = Subdivision.getVectorElementCount();
-    if (BaseCount * 2 != SubCount)
-      return false;
-
-    uint64_t BaseScalarSize = Base.getScalarSizeInBits();
-    uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
-    if (BaseScalarSize != SubScalarSize * 2)
-      return false;
-
-    return true;
-  };
-  if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
-    // If it looks like a real wide add, we can leave it as-is and treat it as
-    // Legal
-    APInt C;
-    if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
-      return Op;
-    // If it doesn't, then we need to expand it.
-    return SDValue();
-  }
+  //
+  // // Recognise Op as a wide add, if it is then we leave it as-is
+  // // Base: nxv2i64, Subdivision: nxv4i32
+  // auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
+  //   assert(Base.isVector() && Subdivision.isVector());
+  //   assert(Base.isScalableVector() == Subdivision.isScalableVector());
+  //
+  //   ElementCount BaseCount = Base.getVectorElementCount();
+  //   ElementCount SubCount = Subdivision.getVectorElementCount();
+  //   if (BaseCount * 2 != SubCount)
+  //     return false;
+  //
+  //   uint64_t BaseScalarSize = Base.getScalarSizeInBits();
+  //   uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
+  //   if (BaseScalarSize != SubScalarSize * 2)
+  //     return false;
+  //
+  //   return true;
+  // };
+  // if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
+  //   // If it looks like a real wide add, we can leave it as-is and treat it
+  //   as
+  //   // Legal
+  //   APInt C;
+  //   if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
+  //     return Op;
+  //   // If it doesn't, then we need to expand it.
+  //   return SDValue();
+  // }
 
   assert((Scalable && ResultVT == MVT::nxv2i64 &&
           LHS.getValueType() == MVT::nxv16i8) ||
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 1b1a24394e1f1..487650f9ad9c0 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3826,6 +3826,19 @@ let Predicates = [HasSVE2_or_SME] in {
   def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
             (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
 
+  def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
+            (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+  def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
+            (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+  def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)),
+            (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+  def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)),
+            (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+  def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+            (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+  def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+            (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+
   // SVE2 integer multiply long
   defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>;
   defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index a45b8b710c63a..203606c8ffacc 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -668,28 +668,18 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ;
 ; CHECK-NEWLOWERING-SVE2-LABEL: not_udot:
 ; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z4.s, z1.h
-; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.s, p0/m, z4.s, z3.s
-; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT:    umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    umlalt z0.h, z1.b, z2.b
 ; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 ;
 ; CHECK-NEWLOWERING-SME-LABEL: not_udot:
 ; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-SME-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEWLOWERING-SME-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z4.s, z1.h
-; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-SME-NEXT:    mla z0.s, p0/m, z4.s, z3.s
-; CHECK-NEWLOWERING-SME-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT:    umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    umlalt z0.h, z1.b, z2.b
 ; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
@@ -728,28 +718,18 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ;
 ; CHECK-NEWLOWERING-SVE2-LABEL: not_udot_wide:
 ; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.s, z1.s, #0xffff
 ; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.s, z2.s, #0xffff
-; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z4.d, z1.s
-; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z4.d, z3.d
-; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT:    umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    umlalt z0.h, z1.b, z2.b
 ; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 ;
 ; CHECK-NEWLOWERING-SME-LABEL: not_udot_wide:
 ; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-SME-NEXT:    and z1.s, z1.s, #0xffff
 ; CHECK-NEWLOWERING-SME-NEXT:    and z2.s, z2.s, #0xffff
-; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z4.d, z1.s
-; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z4.d, z3.d
-; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SME-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT:    umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    umlalt z0.h, z1.b, z2.b
 ; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>

>From ba78d71047d5c6103ac16b112d0560ab797bff3a Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 27 May 2025 18:08:52 +0100
Subject: [PATCH 3/7] Remove dead code

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 29 -------------------
 1 file changed, 29 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0120eba2c894c..2f47f2610b78c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29236,35 +29236,6 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
   SDValue LHS = Op.getOperand(1);
   SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
-  //
-  // // Recognise Op as a wide add, if it is then we leave it as-is
-  // // Base: nxv2i64, Subdivision: nxv4i32
-  // auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
-  //   assert(Base.isVector() && Subdivision.isVector());
-  //   assert(Base.isScalableVector() == Subdivision.isScalableVector());
-  //
-  //   ElementCount BaseCount = Base.getVectorElementCount();
-  //   ElementCount SubCount = Subdivision.getVectorElementCount();
-  //   if (BaseCount * 2 != SubCount)
-  //     return false;
-  //
-  //   uint64_t BaseScalarSize = Base.getScalarSizeInBits();
-  //   uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
-  //   if (BaseScalarSize != SubScalarSize * 2)
-  //     return false;
-  //
-  //   return true;
-  // };
-  // if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
-  //   // If it looks like a real wide add, we can leave it as-is and treat it
-  //   as
-  //   // Legal
-  //   APInt C;
-  //   if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
-  //     return Op;
-  //   // If it doesn't, then we need to expand it.
-  //   return SDValue();
-  // }
 
   assert((Scalable && ResultVT == MVT::nxv2i64 &&
           LHS.getValueType() == MVT::nxv16i8) ||

>From 9aca4589b98d1f9c84dd162c693279ef395971e9 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Wed, 28 May 2025 16:34:08 +0100
Subject: [PATCH 4/7] Use correct instructions for types

---
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 24 +++++++++----------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 487650f9ad9c0..51cadc3b73c31 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3818,25 +3818,25 @@ let Predicates = [HasSVE2_or_SME] in {
   def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
             (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
   def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
-            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+            (UADDWT_ZZZ_S (UADDWB_ZZZ_S $Acc, $Input), $Input)>;
   def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
-            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+            (SADDWT_ZZZ_S (SADDWB_ZZZ_S $Acc, $Input), $Input)>;
   def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
-            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+            (UADDWT_ZZZ_H (UADDWB_ZZZ_H $Acc, $Input), $Input)>;
   def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
-            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+            (SADDWT_ZZZ_H (SADDWB_ZZZ_H $Acc, $Input), $Input)>;
 
-  def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
-            (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
-  def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
-            (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+  def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+            (UMLALT_ZZZ_D (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+  def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+            (SMLALT_ZZZ_D (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
   def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)),
-            (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+            (UMLALT_ZZZ_S (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
   def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)),
-            (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
-  def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+            (SMLALT_ZZZ_S (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+  def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
             (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
-  def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+  def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
             (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
 
   // SVE2 integer multiply long

>From 4c514eac0355e14e4ffa75885e43b1ff89f58162 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Wed, 28 May 2025 15:52:30 +0100
Subject: [PATCH 5/7] Update tests

---
 .../AArch64/sve-partial-reduce-dot-product.ll    |  8 ++++----
 .../AArch64/sve-partial-reduce-wide-add.ll       | 16 ++++++++--------
 2 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 203606c8ffacc..55c879deb6217 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -671,7 +671,7 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.h, z2.h, #0xff
 ; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEWLOWERING-SVE2-NEXT:    umlalb z0.h, z1.b, z2.b
-; CHECK-NEWLOWERING-SVE2-NEXT:    umlalt z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    umlalt z0.s, z1.h, z2.h
 ; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 ;
 ; CHECK-NEWLOWERING-SME-LABEL: not_udot:
@@ -679,7 +679,7 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEWLOWERING-SME-NEXT:    and z2.h, z2.h, #0xff
 ; CHECK-NEWLOWERING-SME-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEWLOWERING-SME-NEXT:    umlalb z0.h, z1.b, z2.b
-; CHECK-NEWLOWERING-SME-NEXT:    umlalt z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    umlalt z0.s, z1.h, z2.h
 ; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
@@ -721,7 +721,7 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.s, z2.s, #0xffff
 ; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.s, z1.s, #0xffff
 ; CHECK-NEWLOWERING-SVE2-NEXT:    umlalb z0.h, z1.b, z2.b
-; CHECK-NEWLOWERING-SVE2-NEXT:    umlalt z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    umlalt z0.d, z1.s, z2.s
 ; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 ;
 ; CHECK-NEWLOWERING-SME-LABEL: not_udot_wide:
@@ -729,7 +729,7 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEWLOWERING-SME-NEXT:    and z2.s, z2.s, #0xffff
 ; CHECK-NEWLOWERING-SME-NEXT:    and z1.s, z1.s, #0xffff
 ; CHECK-NEWLOWERING-SME-NEXT:    umlalb z0.h, z1.b, z2.b
-; CHECK-NEWLOWERING-SME-NEXT:    umlalt z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    umlalt z0.d, z1.s, z2.s
 ; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index 8f9f26a5d5b23..428dd4c3a0154 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -97,8 +97,8 @@ define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vsc
 ;
 ; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv8i16:
 ; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z1.s
-; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.s, z0.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.s, z0.s, z1.h
 ; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
@@ -131,8 +131,8 @@ define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <v
 ;
 ; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv8i16:
 ; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z1.s
-; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.s, z0.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.s, z0.s, z1.h
 ; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
@@ -165,8 +165,8 @@ define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vsc
 ;
 ; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv16i8:
 ; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z1.s
-; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.h, z0.h, z1.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.h, z0.h, z1.b
 ; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
@@ -199,8 +199,8 @@ define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <v
 ;
 ; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv16i8:
 ; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z1.s
-; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.h, z0.h, z1.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.h, z0.h, z1.b
 ; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>

>From 26c1098be6578d514d4bfa8c9b5b1e31c3dbb33f Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Wed, 28 May 2025 16:55:59 +0100
Subject: [PATCH 6/7] Update tests after rebase

---
 .../neon-partial-reduce-dot-product.ll        | 30 ++++---------------
 .../AArch64/sve-partial-reduce-dot-product.ll | 12 ++++----
 2 files changed, 12 insertions(+), 30 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 2b68c963ad319..d977d8fc9cf21 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -917,20 +917,11 @@ define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
 ;
 ; CHECK-NEWLOWERING-I8MM-LABEL: udot_no_bin_op_8to64:
 ; CHECK-NEWLOWERING-I8MM:       // %bb.0:
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v3.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v4.4s, v3.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v5.4s, v2.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v3.4s, v3.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v2.4s, v2.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    uaddw v1.2d, v1.2d, v5.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    movi v3.16b, #1
+; CHECK-NEWLOWERING-I8MM-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT:    udot v4.4s, v2.16b, v3.16b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    uaddw v0.2d, v0.2d, v4.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    uaddw2 v1.2d, v1.2d, v5.4s
 ; CHECK-NEWLOWERING-I8MM-NEXT:    uaddw2 v0.2d, v0.2d, v4.4s
-; CHECK-NEWLOWERING-I8MM-NEXT:    uaddw v1.2d, v1.2d, v2.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    uaddw v0.2d, v0.2d, v3.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    uaddw2 v1.2d, v1.2d, v2.4s
-; CHECK-NEWLOWERING-I8MM-NEXT:    uaddw2 v0.2d, v0.2d, v3.4s
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ret
   %a.wide = zext <16 x i8> %a to <16 x i64>
   %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
@@ -967,20 +958,11 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
 ;
 ; CHECK-NEWLOWERING-I8MM-LABEL: sdot_no_bin_op_8to64:
 ; CHECK-NEWLOWERING-I8MM:       // %bb.0:
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v3.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v4.4s, v3.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v5.4s, v2.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v3.4s, v3.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v2.4s, v2.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    saddw v1.2d, v1.2d, v5.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    movi v3.16b, #1
+; CHECK-NEWLOWERING-I8MM-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT:    sdot v4.4s, v2.16b, v3.16b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    saddw v0.2d, v0.2d, v4.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    saddw2 v1.2d, v1.2d, v5.4s
 ; CHECK-NEWLOWERING-I8MM-NEXT:    saddw2 v0.2d, v0.2d, v4.4s
-; CHECK-NEWLOWERING-I8MM-NEXT:    saddw v1.2d, v1.2d, v2.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    saddw v0.2d, v0.2d, v3.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    saddw2 v1.2d, v1.2d, v2.4s
-; CHECK-NEWLOWERING-I8MM-NEXT:    saddw2 v0.2d, v0.2d, v3.4s
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ret
   %a.wide = sext <16 x i8> %a to <16 x i64>
   %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 55c879deb6217..006083d843370 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -566,10 +566,10 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
 ; CHECK-NEWLOWERING-SVE-NEXT:    movi v3.2d, #0000000000000000
 ; CHECK-NEWLOWERING-SVE-NEXT:    mov z4.b, #1 // =0x1
 ; CHECK-NEWLOWERING-SVE-NEXT:    udot z3.s, z2.b, z4.b
-; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.d, z3.s
-; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z3.d, z3.s
 ; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z3.d
 ; CHECK-NEWLOWERING-SVE-NEXT:    ret
 ;
 ; CHECK-NEWLOWERING-SVE2-LABEL: udot_no_bin_op_8to64:
@@ -611,10 +611,10 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
 ; CHECK-NEWLOWERING-SVE-NEXT:    movi v3.2d, #0000000000000000
 ; CHECK-NEWLOWERING-SVE-NEXT:    mov z4.b, #1 // =0x1
 ; CHECK-NEWLOWERING-SVE-NEXT:    sdot z3.s, z2.b, z4.b
-; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z2.d, z3.s
-; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z3.d, z3.s
 ; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z3.d
 ; CHECK-NEWLOWERING-SVE-NEXT:    ret
 ;
 ; CHECK-NEWLOWERING-SVE2-LABEL: sdot_no_bin_op_8to64:

>From e1d2ed62ebf4df6226da075156654fda334ce7ec Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Thu, 29 May 2025 12:49:00 +0100
Subject: [PATCH 7/7] Use correctly-typed instructions for the lower half too

---
 llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td            | 8 ++++----
 .../CodeGen/AArch64/sve-partial-reduce-dot-product.ll     | 8 ++++----
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 51cadc3b73c31..91db6b6fc7984 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3827,13 +3827,13 @@ let Predicates = [HasSVE2_or_SME] in {
             (SADDWT_ZZZ_H (SADDWB_ZZZ_H $Acc, $Input), $Input)>;
 
   def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
-            (UMLALT_ZZZ_D (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+            (UMLALT_ZZZ_D (UMLALB_ZZZ_D $Acc, $LHS, $RHS), $LHS, $RHS)>;
   def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
-            (SMLALT_ZZZ_D (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+            (SMLALT_ZZZ_D (SMLALB_ZZZ_D $Acc, $LHS, $RHS), $LHS, $RHS)>;
   def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)),
-            (UMLALT_ZZZ_S (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+            (UMLALT_ZZZ_S (UMLALB_ZZZ_S $Acc, $LHS, $RHS), $LHS, $RHS)>;
   def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)),
-            (SMLALT_ZZZ_S (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+            (SMLALT_ZZZ_S (SMLALB_ZZZ_S $Acc, $LHS, $RHS), $LHS, $RHS)>;
   def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
             (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
   def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 006083d843370..d3ccfaaf20a22 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -670,7 +670,7 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.h, z2.h, #0xff
 ; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.h, z1.h, #0xff
-; CHECK-NEWLOWERING-SVE2-NEXT:    umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    umlalb z0.s, z1.h, z2.h
 ; CHECK-NEWLOWERING-SVE2-NEXT:    umlalt z0.s, z1.h, z2.h
 ; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 ;
@@ -678,7 +678,7 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-SME-NEXT:    and z2.h, z2.h, #0xff
 ; CHECK-NEWLOWERING-SME-NEXT:    and z1.h, z1.h, #0xff
-; CHECK-NEWLOWERING-SME-NEXT:    umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    umlalb z0.s, z1.h, z2.h
 ; CHECK-NEWLOWERING-SME-NEXT:    umlalt z0.s, z1.h, z2.h
 ; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
@@ -720,7 +720,7 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.s, z2.s, #0xffff
 ; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.s, z1.s, #0xffff
-; CHECK-NEWLOWERING-SVE2-NEXT:    umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    umlalb z0.d, z1.s, z2.s
 ; CHECK-NEWLOWERING-SVE2-NEXT:    umlalt z0.d, z1.s, z2.s
 ; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 ;
@@ -728,7 +728,7 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-SME-NEXT:    and z2.s, z2.s, #0xffff
 ; CHECK-NEWLOWERING-SME-NEXT:    and z1.s, z1.s, #0xffff
-; CHECK-NEWLOWERING-SME-NEXT:    umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    umlalb z0.d, z1.s, z2.s
 ; CHECK-NEWLOWERING-SME-NEXT:    umlalt z0.d, z1.s, z2.s
 ; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:



More information about the llvm-commits mailing list