[llvm] [SelectionDAG] Do partial reduction lowering using ISD nodes (PR #129268)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 28 09:00:43 PST 2025


https://github.com/JamesChesterman created https://github.com/llvm/llvm-project/pull/129268

Series of patches that together produce a new method of lowering partial reduction intrinsics. They transform them into ISD nodes, which are then combined and lowered appropriately.

>From 5053db67347bfa6328a5dca2a3a9ff0bd041d464 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 13 Feb 2025 15:35:55 +0000
Subject: [PATCH 1/9] [DAGCombiner] Add generic DAG combine for
 ISD::PARTIAL_REDUCE_MLA

Add generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA nodes.
Transforms the DAG from:
PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1))
to
PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS).
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  48 +++++++++
 .../neon-partial-reduce-dot-product.ll        |  75 +++++++------
 .../AArch64/sve-partial-reduce-dot-product.ll | 100 +++++++++---------
 3 files changed, 138 insertions(+), 85 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bc7cdf38dbc2a..8f35675ff1509 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -545,6 +545,7 @@ namespace {
     SDValue visitMGATHER(SDNode *N);
     SDValue visitMSCATTER(SDNode *N);
     SDValue visitMHISTOGRAM(SDNode *N);
+    SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
     SDValue visitVPGATHER(SDNode *N);
     SDValue visitVPSCATTER(SDNode *N);
     SDValue visitVP_STRIDED_LOAD(SDNode *N);
@@ -1972,6 +1973,9 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::MSCATTER:           return visitMSCATTER(N);
   case ISD::MSTORE:             return visitMSTORE(N);
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
+  case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_UMLA:
+                                return visitPARTIAL_REDUCE_MLA(N);
   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
@@ -12497,6 +12501,50 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+  // Makes PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1))
+  // into PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS)
+  SDLoc DL(N);
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Op2 = N->getOperand(2);
+
+  if (Op1->getOpcode() != ISD::MUL)
+    return SDValue();
+
+  SDValue ExtMulOpLHS = Op1->getOperand(0);
+  SDValue ExtMulOpRHS = Op1->getOperand(1);
+  unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+  unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+  if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+      !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+    return SDValue();
+
+  SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+  SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+  EVT MulOpLHSVT = MulOpLHS.getValueType();
+  if (MulOpLHSVT != MulOpRHS.getValueType())
+    return SDValue();
+
+  if (!TLI.isTypeLegal(MulOpLHSVT) || !TLI.isTypeLegal(N->getValueType(0)))
+    return SDValue();
+
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
+      !ConstantOne.isOne())
+    return SDValue();
+
+  bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+  bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+  if (LHSIsSigned != RHSIsSigned)
+    return SDValue();
+
+  unsigned NewOpcode =
+      LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  return DAG.getNode(NewOpcode, DL, Op0->getValueType(0), Op0, MulOpLHS,
+                     MulOpRHS);
+}
+
 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
   auto *SLD = cast<VPStridedLoadSDNode>(N);
   EVT EltVT = SLD->getValueType(0).getVectorElementType();
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 40daf8ffb63ea..7ec166aa8ed36 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -12,13 +12,15 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: udot:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    umull v3.8h, v2.8b, v1.8b
-; CHECK-NODOT-NEXT:    umull2 v1.8h, v2.16b, v1.16b
-; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v3.4h
-; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v2.4s, v3.8h
-; CHECK-NODOT-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
-; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT:    umlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT:    umull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    umlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = zext <16 x i8> %s to <16 x i32>
@@ -35,17 +37,19 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: udot_narrow:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    ushll v2.8h, v2.8b, #0
 ; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT:    ushll2 v3.4s, v1.8h, #0
-; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT:    umull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    umull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    umlal v0.4s, v2.4h, v1.4h
 ; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
-; CHECK-NODOT-NEXT:    uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT:    umlal v3.4s, v6.4h, v5.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
@@ -62,13 +66,15 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: sdot:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    smull v3.8h, v2.8b, v1.8b
-; CHECK-NODOT-NEXT:    smull2 v1.8h, v2.16b, v1.16b
-; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v3.4h
-; CHECK-NODOT-NEXT:    saddw2 v2.4s, v2.4s, v3.8h
-; CHECK-NODOT-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
-; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    sshll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT:    smull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = sext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
@@ -85,17 +91,19 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: sdot_narrow:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    smull v1.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    sshll v2.8h, v2.8b, #0
 ; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT:    sshll2 v3.4s, v1.8h, #0
-; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    smlal v0.4s, v2.4h, v1.4h
 ; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
-; CHECK-NODOT-NEXT:    saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT:    smlal v3.4s, v6.4h, v5.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
@@ -531,9 +539,10 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
 define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-LABEL: not_udot:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull v1.8h, v2.8b, v1.8b
-; CHECK-NEXT:    uaddw v0.4s, v0.4s, v1.4h
-; CHECK-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; CHECK-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
 ; CHECK-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 455231dd37be6..c6dc0ed5651ec 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -11,24 +11,23 @@ define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: udot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z1.b
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z4.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT:    mad z1.s, p0/m, z2.s, z3.s
 ; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -47,24 +46,23 @@ define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z1.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z7.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    mul z3.d, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -83,24 +81,23 @@ define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z1.b
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z4.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT:    mad z1.s, p0/m, z2.s, z3.s
 ; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -119,24 +116,23 @@ define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z1.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    mul z3.d, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:

>From 406041faa99b841fb753bbb98cd873056601509e Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Wed, 19 Feb 2025 15:14:03 +0000
Subject: [PATCH 2/9] Split the DAG combine into two.

Also make sure the DAG combine is only done when the action for
partial reductions have a type combination which is either Legal
or Custom.
This ensures that the combines are not performed only for the
resulting DAG to be expanded, as this leads to worse Code Gen.
---
 llvm/include/llvm/CodeGen/TargetLowering.h    |  35 ++++++
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  55 +++++++---
 .../SelectionDAG/LegalizeVectorOps.cpp        |   7 +-
 llvm/lib/CodeGen/TargetLoweringBase.cpp       |   5 +-
 .../neon-partial-reduce-dot-product.ll        |  75 ++++++-------
 .../AArch64/sve-partial-reduce-dot-product.ll | 100 +++++++++---------
 6 files changed, 165 insertions(+), 112 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a4c3d042fe3a4..52e57365dceab 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1639,6 +1639,25 @@ class TargetLoweringBase {
            getCondCodeAction(CC, VT) == Custom;
   }
 
+  /// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
+  /// InputVT should be treated. Either it's legal, needs to be promoted to a
+  /// larger size, needs to be expanded to some other code sequence, or the
+  /// target has a custom expander for it.
+  LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
+    unsigned AccI = (unsigned)AccVT.getSimpleVT().SimpleTy;
+    unsigned InputI = (unsigned)InputVT.getSimpleVT().SimpleTy;
+    assert(AccI < MVT::VALUETYPE_SIZE && InputI < MVT::VALUETYPE_SIZE &&
+           "Table isn't big enough!");
+    return PartialReduceMLAActions[AccI][InputI];
+  }
+
+  /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
+  /// legal or custom for this target.
+  bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
+    return getPartialReduceMLAAction(AccVT, InputVT) == Legal ||
+           getPartialReduceMLAAction(AccVT, InputVT) == Custom;
+  }
+
   /// If the action for this operation is to promote, this method returns the
   /// ValueType to promote to.
   MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
@@ -2704,6 +2723,16 @@ class TargetLoweringBase {
       setCondCodeAction(CCs, VT, Action);
   }
 
+  /// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
+  /// type InputVT should be treated by the target. Either it's legal, needs to
+  /// be promoted to a larger size, needs to be expanded to some other code
+  /// sequence, or the target has a custom expander for it.
+  void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
+                                 LegalizeAction Action) {
+    assert(AccVT.isValid() && InputVT.isValid() && "Table isn't big enough!");
+    PartialReduceMLAActions[AccVT.SimpleTy][InputVT.SimpleTy] = Action;
+  }
+
   /// If Opc/OrigVT is specified as being promoted, the promotion code defaults
   /// to trying a larger integer/fp until it can find one that works. If that
   /// default is insufficient, this method can be used by the target to override
@@ -3650,6 +3679,12 @@ class TargetLoweringBase {
   /// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
   uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
 
+  /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
+  /// nodes, keep a LegalizeAction which indicates how instruction selection
+  /// should deal with this operation.
+  LegalizeAction PartialReduceMLAActions[MVT::VALUETYPE_SIZE]
+                                        [MVT::VALUETYPE_SIZE];
+
   ValueTypeActionImpl ValueTypeActions;
 
 private:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 8f35675ff1509..223260c43a38e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -622,6 +622,8 @@ namespace {
     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
     SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
                                  const TargetLowering &TLI);
+    SDValue foldMulPARTIAL_REDUCE_MLA(SDNode *N);
+    SDValue foldExtendPARTIAL_REDUCE_MLA(SDNode *N);
 
     SDValue CombineExtLoad(SDNode *N);
     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12502,18 +12504,45 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
 }
 
 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
-  // Makes PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1))
-  // into PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS)
+  // Only perform the DAG combine if there is custom lowering provided by the
+  // target.
+  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0),
+                                           N->getOperand(1).getValueType()))
+    return SDValue();
+
+  if (SDValue Res = foldMulPARTIAL_REDUCE_MLA(N))
+    return Res;
+  if (SDValue Res = foldExtendPARTIAL_REDUCE_MLA(N))
+    return Res;
+  return SDValue();
+}
+
+SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
+  // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(MulOpLHS, MulOpRHS), Splat(1)) into
+  // PARTIAL_REDUCE_*MLA(Acc, MulOpLHS, MulOpRHS)
   SDLoc DL(N);
-  SDValue Op0 = N->getOperand(0);
-  SDValue Op1 = N->getOperand(1);
-  SDValue Op2 = N->getOperand(2);
 
+  SDValue Op1 = N->getOperand(1);
   if (Op1->getOpcode() != ISD::MUL)
     return SDValue();
 
-  SDValue ExtMulOpLHS = Op1->getOperand(0);
-  SDValue ExtMulOpRHS = Op1->getOperand(1);
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(N->getOperand(2).getNode(), ConstantOne) ||
+      !ConstantOne.isOne())
+    return SDValue();
+
+  return DAG.getNode(N->getOpcode(), DL, N->getValueType(0), N->getOperand(0),
+                     Op1->getOperand(0), Op1->getOperand(1));
+}
+
+SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
+  // Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(MulOpLHS), ZEXT(MulOpRHS)) into
+  // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS) and
+  // PARTIAL_REDUCE_*MLA(Acc, SEXT(MulOpLHS), SEXT(MulOpRHS)) into
+  // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS)
+  SDLoc DL(N);
+  SDValue ExtMulOpLHS = N->getOperand(1);
+  SDValue ExtMulOpRHS = N->getOperand(2);
   unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
   unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
   if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
@@ -12526,14 +12555,6 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   if (MulOpLHSVT != MulOpRHS.getValueType())
     return SDValue();
 
-  if (!TLI.isTypeLegal(MulOpLHSVT) || !TLI.isTypeLegal(N->getValueType(0)))
-    return SDValue();
-
-  APInt ConstantOne;
-  if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
-    return SDValue();
-
   bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
   bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
   if (LHSIsSigned != RHSIsSigned)
@@ -12541,8 +12562,8 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
 
   unsigned NewOpcode =
       LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-  return DAG.getNode(NewOpcode, DL, Op0->getValueType(0), Op0, MulOpLHS,
-                     MulOpRHS);
+  return DAG.getNode(NewOpcode, DL, N->getValueType(0), N->getOperand(0),
+                     MulOpLHS, MulOpRHS);
 }
 
 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index de4447fb0cf1a..e43b14a47e565 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -469,8 +469,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::VECTOR_COMPRESS:
   case ISD::SCMP:
   case ISD::UCMP:
-  case ISD::PARTIAL_REDUCE_UMLA:
-  case ISD::PARTIAL_REDUCE_SMLA:
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
   case ISD::SMULFIX:
@@ -524,6 +522,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
     break;
   }
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
+                                           Node->getOperand(1).getValueType());
+    break;
 
 #define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...)                          \
   case ISD::VPID: {                                                            \
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f5ea3c0b47d6a..af97ce20fdb10 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -836,9 +836,8 @@ void TargetLoweringBase::initActions() {
     setOperationAction(ISD::SET_FPENV, VT, Expand);
     setOperationAction(ISD::RESET_FPENV, VT, Expand);
 
-    // PartialReduceMLA operations default to expand.
-    setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
-                       Expand);
+    for (MVT InputVT : MVT::all_valuetypes())
+      setPartialReduceMLAAction(VT, InputVT, Expand);
   }
 
   // Most targets ignore the @llvm.prefetch intrinsic.
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 7ec166aa8ed36..40daf8ffb63ea 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -12,15 +12,13 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: udot:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    ushll v3.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    ushll v4.8h, v2.8b, #0
-; CHECK-NODOT-NEXT:    ushll2 v1.8h, v1.16b, #0
-; CHECK-NODOT-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    umlal v0.4s, v4.4h, v3.4h
-; CHECK-NODOT-NEXT:    umull v5.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    umlal2 v5.4s, v4.8h, v3.8h
-; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT:    umull v3.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT:    umull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOT-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = zext <16 x i8> %s to <16 x i32>
@@ -37,19 +35,17 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: udot_narrow:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    umull v1.8h, v2.8b, v1.8b
 ; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT:    umull v3.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    umull2 v4.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v1.4h
 ; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
-; CHECK-NODOT-NEXT:    umlal v3.4s, v6.4h, v5.4h
-; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
 ; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
@@ -66,15 +62,13 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: sdot:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    sshll v3.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    sshll v4.8h, v2.8b, #0
-; CHECK-NODOT-NEXT:    sshll2 v1.8h, v1.16b, #0
-; CHECK-NODOT-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    smlal v0.4s, v4.4h, v3.4h
-; CHECK-NODOT-NEXT:    smull v5.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
-; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT:    smull v3.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT:    smull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOT-NEXT:    saddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOT-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = sext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
@@ -91,19 +85,17 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: sdot_narrow:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    sshll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    smull v1.8h, v2.8b, v1.8b
 ; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT:    smull v3.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    smull2 v4.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    sshll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v1.4h
 ; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
-; CHECK-NODOT-NEXT:    smlal v3.4s, v6.4h, v5.4h
-; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
 ; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
@@ -539,10 +531,9 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
 define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-LABEL: not_udot:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
-; CHECK-NEXT:    umlal v0.4s, v2.4h, v1.4h
-; CHECK-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
 ; CHECK-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index c6dc0ed5651ec..455231dd37be6 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -11,23 +11,24 @@ define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: udot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    mul z3.s, z4.s, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
-; CHECK-NEWLOWERING-NEXT:    mad z1.s, p0/m, z2.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
 ; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -46,23 +47,24 @@ define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z2.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z24.d
 ; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -81,23 +83,24 @@ define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    mul z3.s, z4.s, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
-; CHECK-NEWLOWERING-NEXT:    mad z1.s, p0/m, z2.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
 ; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -116,23 +119,24 @@ define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z2.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z24.d
 ; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:

>From bd7d333aaca6d82b13fef4acedcb50c671d6e3d7 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 20 Feb 2025 11:41:30 +0000
Subject: [PATCH 3/9] Make DAG combine one function again

This is so the MUL fold does not happen unless the extend fold can
be performed.
As otherwise a lot of code would need to be repeated to check that
it can happen.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 46 +++++++------------
 1 file changed, 16 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 223260c43a38e..9073f814e7e4f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12504,25 +12504,17 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
 }
 
 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
-  // Only perform the DAG combine if there is custom lowering provided by the
-  // target.
-  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0),
-                                           N->getOperand(1).getValueType()))
-    return SDValue();
-
-  if (SDValue Res = foldMulPARTIAL_REDUCE_MLA(N))
-    return Res;
-  if (SDValue Res = foldExtendPARTIAL_REDUCE_MLA(N))
-    return Res;
-  return SDValue();
-}
-
-SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
-  // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(MulOpLHS, MulOpRHS), Splat(1)) into
-  // PARTIAL_REDUCE_*MLA(Acc, MulOpLHS, MulOpRHS)
+  // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), ZEXT(MulOpRHS)),
+  // Splat(1)) into
+  // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS).
+  // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), SEXT(MulOpRHS)),
+  // Splat(1)) into
+  // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS).
   SDLoc DL(N);
 
+  SDValue Op0 = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
+
   if (Op1->getOpcode() != ISD::MUL)
     return SDValue();
 
@@ -12531,18 +12523,8 @@ SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
       !ConstantOne.isOne())
     return SDValue();
 
-  return DAG.getNode(N->getOpcode(), DL, N->getValueType(0), N->getOperand(0),
-                     Op1->getOperand(0), Op1->getOperand(1));
-}
-
-SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
-  // Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(MulOpLHS), ZEXT(MulOpRHS)) into
-  // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS) and
-  // PARTIAL_REDUCE_*MLA(Acc, SEXT(MulOpLHS), SEXT(MulOpRHS)) into
-  // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS)
-  SDLoc DL(N);
-  SDValue ExtMulOpLHS = N->getOperand(1);
-  SDValue ExtMulOpRHS = N->getOperand(2);
+  SDValue ExtMulOpLHS = Op1->getOperand(0);
+  SDValue ExtMulOpRHS = Op1->getOperand(1);
   unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
   unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
   if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
@@ -12554,6 +12536,10 @@ SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
   EVT MulOpLHSVT = MulOpLHS.getValueType();
   if (MulOpLHSVT != MulOpRHS.getValueType())
     return SDValue();
+  // Only perform the DAG combine if there is custom lowering provided by the
+  // target
+  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), MulOpLHSVT))
+    return SDValue();
 
   bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
   bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
@@ -12562,8 +12548,8 @@ SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
 
   unsigned NewOpcode =
       LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-  return DAG.getNode(NewOpcode, DL, N->getValueType(0), N->getOperand(0),
-                     MulOpLHS, MulOpRHS);
+  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Op0, MulOpLHS,
+                     MulOpRHS);
 }
 
 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {

>From 793115a81e1eca816410426d7c5defdd67edc8ec Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 21 Feb 2025 14:28:35 +0000
Subject: [PATCH 4/9] Remove isLegalOrCustom check from DAG combine

This makes it so the changes are reflected in the tests, so that we can tell the DAG combine is actually happening.
It has been replaced with a FIXME note saying to potentially add it back in when the rest of the implementation is complete.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |   7 +-
 .../neon-partial-reduce-dot-product.ll        | 139 ++++---
 .../AArch64/sve-partial-reduce-dot-product.ll | 386 +++++++++---------
 3 files changed, 277 insertions(+), 255 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 9073f814e7e4f..111ecb61c2c07 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12536,10 +12536,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   EVT MulOpLHSVT = MulOpLHS.getValueType();
   if (MulOpLHSVT != MulOpRHS.getValueType())
     return SDValue();
-  // Only perform the DAG combine if there is custom lowering provided by the
-  // target
-  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), MulOpLHSVT))
-    return SDValue();
+
+  // FIXME: Add a check to only perform the DAG combine if there is lowering
+  // provided by the target
 
   bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
   bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 40daf8ffb63ea..3938a57d0152c 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -12,13 +12,15 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: udot:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    umull v3.8h, v2.8b, v1.8b
-; CHECK-NODOT-NEXT:    umull2 v1.8h, v2.16b, v1.16b
-; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v3.4h
-; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v2.4s, v3.8h
-; CHECK-NODOT-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
-; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT:    umlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT:    umull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    umlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = zext <16 x i8> %s to <16 x i32>
@@ -35,17 +37,19 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: udot_narrow:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    ushll v2.8h, v2.8b, #0
 ; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT:    ushll2 v3.4s, v1.8h, #0
-; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT:    umull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    umull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    umlal v0.4s, v2.4h, v1.4h
 ; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
-; CHECK-NODOT-NEXT:    uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT:    umlal v3.4s, v6.4h, v5.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
@@ -62,13 +66,15 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: sdot:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    smull v3.8h, v2.8b, v1.8b
-; CHECK-NODOT-NEXT:    smull2 v1.8h, v2.16b, v1.16b
-; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v3.4h
-; CHECK-NODOT-NEXT:    saddw2 v2.4s, v2.4s, v3.8h
-; CHECK-NODOT-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
-; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    sshll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT:    smull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = sext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
@@ -85,17 +91,19 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: sdot_narrow:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    smull v1.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    sshll v2.8h, v2.8b, #0
 ; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT:    sshll2 v3.4s, v1.8h, #0
-; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    smlal v0.4s, v2.4h, v1.4h
 ; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
-; CHECK-NODOT-NEXT:    saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT:    smlal v3.4s, v6.4h, v5.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
@@ -223,19 +231,27 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
 ;
 ; CHECK-NODOT-LABEL: udot_8to64:
 ; CHECK-NODOT:       // %bb.0: // %entry
-; CHECK-NODOT-NEXT:    umull v4.8h, v2.8b, v3.8b
-; CHECK-NODOT-NEXT:    umull2 v2.8h, v2.16b, v3.16b
-; CHECK-NODOT-NEXT:    ushll v3.4s, v4.4h, #0
-; CHECK-NODOT-NEXT:    ushll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    ushll v4.8h, v3.8b, #0
+; CHECK-NODOT-NEXT:    ushll v5.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    ushll2 v3.8h, v3.16b, #0
+; CHECK-NODOT-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT:    ushll v6.4s, v4.4h, #0
+; CHECK-NODOT-NEXT:    ushll v7.4s, v5.4h, #0
 ; CHECK-NODOT-NEXT:    ushll2 v4.4s, v4.8h, #0
-; CHECK-NODOT-NEXT:    ushll2 v2.4s, v2.8h, #0
-; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v3.4s
-; CHECK-NODOT-NEXT:    uaddw v0.2d, v0.2d, v3.2s
-; CHECK-NODOT-NEXT:    uaddl2 v3.2d, v4.4s, v5.4s
-; CHECK-NODOT-NEXT:    uaddl v4.2d, v4.2s, v5.2s
-; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v2.4s
-; CHECK-NODOT-NEXT:    uaddw v0.2d, v0.2d, v2.2s
-; CHECK-NODOT-NEXT:    add v1.2d, v3.2d, v1.2d
+; CHECK-NODOT-NEXT:    ushll2 v5.4s, v5.8h, #0
+; CHECK-NODOT-NEXT:    ushll2 v16.4s, v3.8h, #0
+; CHECK-NODOT-NEXT:    ushll2 v17.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    ushll v3.4s, v3.4h, #0
+; CHECK-NODOT-NEXT:    ushll v2.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    umlal2 v1.2d, v7.4s, v6.4s
+; CHECK-NODOT-NEXT:    umlal v0.2d, v7.2s, v6.2s
+; CHECK-NODOT-NEXT:    umull2 v18.2d, v5.4s, v4.4s
+; CHECK-NODOT-NEXT:    umull v4.2d, v5.2s, v4.2s
+; CHECK-NODOT-NEXT:    umlal2 v1.2d, v17.4s, v16.4s
+; CHECK-NODOT-NEXT:    umlal v0.2d, v17.2s, v16.2s
+; CHECK-NODOT-NEXT:    umlal2 v18.2d, v2.4s, v3.4s
+; CHECK-NODOT-NEXT:    umlal v4.2d, v2.2s, v3.2s
+; CHECK-NODOT-NEXT:    add v1.2d, v18.2d, v1.2d
 ; CHECK-NODOT-NEXT:    add v0.2d, v4.2d, v0.2d
 ; CHECK-NODOT-NEXT:    ret
 entry:
@@ -258,19 +274,27 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
 ;
 ; CHECK-NODOT-LABEL: sdot_8to64:
 ; CHECK-NODOT:       // %bb.0: // %entry
-; CHECK-NODOT-NEXT:    smull v4.8h, v2.8b, v3.8b
-; CHECK-NODOT-NEXT:    smull2 v2.8h, v2.16b, v3.16b
-; CHECK-NODOT-NEXT:    sshll v3.4s, v4.4h, #0
-; CHECK-NODOT-NEXT:    sshll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    sshll v4.8h, v3.8b, #0
+; CHECK-NODOT-NEXT:    sshll v5.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    sshll2 v3.8h, v3.16b, #0
+; CHECK-NODOT-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT:    sshll v6.4s, v4.4h, #0
+; CHECK-NODOT-NEXT:    sshll v7.4s, v5.4h, #0
 ; CHECK-NODOT-NEXT:    sshll2 v4.4s, v4.8h, #0
-; CHECK-NODOT-NEXT:    sshll2 v2.4s, v2.8h, #0
-; CHECK-NODOT-NEXT:    saddw2 v1.2d, v1.2d, v3.4s
-; CHECK-NODOT-NEXT:    saddw v0.2d, v0.2d, v3.2s
-; CHECK-NODOT-NEXT:    saddl2 v3.2d, v4.4s, v5.4s
-; CHECK-NODOT-NEXT:    saddl v4.2d, v4.2s, v5.2s
-; CHECK-NODOT-NEXT:    saddw2 v1.2d, v1.2d, v2.4s
-; CHECK-NODOT-NEXT:    saddw v0.2d, v0.2d, v2.2s
-; CHECK-NODOT-NEXT:    add v1.2d, v3.2d, v1.2d
+; CHECK-NODOT-NEXT:    sshll2 v5.4s, v5.8h, #0
+; CHECK-NODOT-NEXT:    sshll2 v16.4s, v3.8h, #0
+; CHECK-NODOT-NEXT:    sshll2 v17.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    sshll v3.4s, v3.4h, #0
+; CHECK-NODOT-NEXT:    sshll v2.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    smlal2 v1.2d, v7.4s, v6.4s
+; CHECK-NODOT-NEXT:    smlal v0.2d, v7.2s, v6.2s
+; CHECK-NODOT-NEXT:    smull2 v18.2d, v5.4s, v4.4s
+; CHECK-NODOT-NEXT:    smull v4.2d, v5.2s, v4.2s
+; CHECK-NODOT-NEXT:    smlal2 v1.2d, v17.4s, v16.4s
+; CHECK-NODOT-NEXT:    smlal v0.2d, v17.2s, v16.2s
+; CHECK-NODOT-NEXT:    smlal2 v18.2d, v2.4s, v3.4s
+; CHECK-NODOT-NEXT:    smlal v4.2d, v2.2s, v3.2s
+; CHECK-NODOT-NEXT:    add v1.2d, v18.2d, v1.2d
 ; CHECK-NODOT-NEXT:    add v0.2d, v4.2d, v0.2d
 ; CHECK-NODOT-NEXT:    ret
 entry:
@@ -531,9 +555,10 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
 define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-LABEL: not_udot:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull v1.8h, v2.8b, v1.8b
-; CHECK-NEXT:    uaddw v0.4s, v0.4s, v1.4h
-; CHECK-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; CHECK-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
 ; CHECK-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 455231dd37be6..d7bab3297cf29 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -11,24 +11,23 @@ define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: udot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z1.b
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z4.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT:    mad z1.s, p0/m, z2.s, z3.s
 ; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -47,24 +46,23 @@ define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z1.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z7.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    mul z3.d, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -83,24 +81,23 @@ define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z1.b
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z4.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT:    mad z1.s, p0/m, z2.s, z3.s
 ; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -119,24 +116,23 @@ define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z1.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    mul z3.d, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -278,59 +274,46 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_8to64:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #-2
-; CHECK-NEWLOWERING-NEXT:    str z9, [sp] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    str z8, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
-; CHECK-NEWLOWERING-NEXT:    .cfi_offset w29, -16
-; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
-; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z24.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z25.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z4.s
 ; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z5.s
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z30.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z31.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z8.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z9.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mul z27.d, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z26.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z6.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mul z27.d, z29.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z25.d, z25.s
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mul z4.d, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z7.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z9.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z2, z27
-; CHECK-NEWLOWERING-NEXT:    mla z2.d, p0/m, z24.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    ldr z9, [sp] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z31.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z3, z4
-; CHECK-NEWLOWERING-NEXT:    mla z3.d, p0/m, z30.d, z8.d
-; CHECK-NEWLOWERING-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z2.d, z0.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z3.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #2
-; CHECK-NEWLOWERING-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    mul z4.d, z5.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z26.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z5, z27
+; CHECK-NEWLOWERING-NEXT:    mla z5.d, p0/m, z28.d, z7.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z25.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    mad z2.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z5.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z2.d, z1.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -354,59 +337,46 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_8to64:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #-2
-; CHECK-NEWLOWERING-NEXT:    str z9, [sp] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    str z8, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
-; CHECK-NEWLOWERING-NEXT:    .cfi_offset w29, -16
-; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
-; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z24.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z25.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z4.s
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z5.s
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z30.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z31.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z8.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z9.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mul z27.d, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z26.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z6.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mul z27.d, z29.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z25.d, z25.s
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mul z4.d, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z7.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z9.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z2, z27
-; CHECK-NEWLOWERING-NEXT:    mla z2.d, p0/m, z24.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    ldr z9, [sp] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z31.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    movprfx z3, z4
-; CHECK-NEWLOWERING-NEXT:    mla z3.d, p0/m, z30.d, z8.d
-; CHECK-NEWLOWERING-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z2.d, z0.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z3.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #2
-; CHECK-NEWLOWERING-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    mul z4.d, z5.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z26.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z5, z27
+; CHECK-NEWLOWERING-NEXT:    mla z5.d, p0/m, z28.d, z7.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z25.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    mad z2.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z5.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z2.d, z1.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -875,11 +845,11 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEXT:    and z2.h, z2.h, #0xff
 ; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    uunpklo z3.s, z1.h
-; CHECK-NEXT:    uunpklo z4.s, z2.h
-; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEXT:    uunpklo z4.s, z1.h
 ; CHECK-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEXT:    mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    mla z0.s, p0/m, z4.s, z3.s
 ; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
 ;
@@ -888,11 +858,11 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEWLOWERING-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEWLOWERING-NEXT:    and z2.h, z2.h, #0xff
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z1.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z4.s, z3.s
 ; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -909,11 +879,11 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEXT:    and z1.s, z1.s, #0xffff
 ; CHECK-NEXT:    and z2.s, z2.s, #0xffff
 ; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    uunpklo z3.d, z1.s
-; CHECK-NEXT:    uunpklo z4.d, z2.s
-; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEXT:    uunpklo z4.d, z1.s
 ; CHECK-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEXT:    mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    mla z0.d, p0/m, z4.d, z3.d
 ; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
 ;
@@ -922,11 +892,11 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEWLOWERING-NEXT:    and z1.s, z1.s, #0xffff
 ; CHECK-NEWLOWERING-NEXT:    and z2.s, z2.s, #0xffff
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.d, z1.s
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z4.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
@@ -1278,34 +1248,48 @@ define <vscale x 2 x i16> @udot_nxv8i8_promote (<vscale x 2 x i16> %acc, <vscale
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEXT:    mul z1.h, z1.h, z2.h
-; CHECK-NEXT:    uunpklo z2.s, z1.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEXT:    uunpkhi z2.s, z2.h
 ; CHECK-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEXT:    uunpklo z3.d, z2.s
-; CHECK-NEXT:    uunpklo z4.d, z1.s
-; CHECK-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEXT:    add z0.d, z0.d, z3.d
-; CHECK-NEXT:    add z2.d, z2.d, z4.d
+; CHECK-NEXT:    uunpklo z5.d, z3.s
+; CHECK-NEXT:    uunpklo z6.d, z4.s
+; CHECK-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEXT:    uunpkhi z5.d, z2.s
+; CHECK-NEXT:    uunpkhi z6.d, z1.s
+; CHECK-NEXT:    mul z3.d, z4.d, z3.d
+; CHECK-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NEXT:    uunpklo z1.d, z1.s
+; CHECK-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEXT:    mad z1.d, p0/m, z2.d, z3.d
 ; CHECK-NEXT:    add z0.d, z1.d, z0.d
-; CHECK-NEXT:    add z0.d, z2.d, z0.d
 ; CHECK-NEXT:    ret
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_nxv8i8_promote:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEWLOWERING-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-NEXT:    mul z1.h, z1.h, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
 ; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z2.d, z2.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z2.d, z0.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i16>
@@ -1321,17 +1305,24 @@ define <vscale x 2 x i16> @sdot_nxv8i8_promote (<vscale x 2 x i16> %acc, <vscale
 ; CHECK-NEXT:    ptrue p0.h
 ; CHECK-NEXT:    sxtb z1.h, p0/m, z1.h
 ; CHECK-NEXT:    sxtb z2.h, p0/m, z2.h
-; CHECK-NEXT:    mul z1.h, z1.h, z2.h
-; CHECK-NEXT:    uunpklo z2.s, z1.h
-; CHECK-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEXT:    uunpklo z3.d, z2.s
-; CHECK-NEXT:    uunpklo z4.d, z1.s
-; CHECK-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEXT:    add z0.d, z0.d, z3.d
-; CHECK-NEXT:    add z2.d, z2.d, z4.d
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    sunpklo z5.d, z3.s
+; CHECK-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEXT:    sunpkhi z5.d, z2.s
+; CHECK-NEXT:    sunpkhi z6.d, z1.s
+; CHECK-NEXT:    mul z3.d, z4.d, z3.d
+; CHECK-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NEXT:    sunpklo z1.d, z1.s
+; CHECK-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEXT:    mad z1.d, p0/m, z2.d, z3.d
 ; CHECK-NEXT:    add z0.d, z1.d, z0.d
-; CHECK-NEXT:    add z0.d, z2.d, z0.d
 ; CHECK-NEXT:    ret
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_nxv8i8_promote:
@@ -1339,17 +1330,24 @@ define <vscale x 2 x i16> @sdot_nxv8i8_promote (<vscale x 2 x i16> %acc, <vscale
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.h
 ; CHECK-NEWLOWERING-NEXT:    sxtb z1.h, p0/m, z1.h
 ; CHECK-NEWLOWERING-NEXT:    sxtb z2.h, p0/m, z2.h
-; CHECK-NEWLOWERING-NEXT:    mul z1.h, z1.h, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z2.d, z2.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z2.d, z0.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i8> %a to <vscale x 8 x i16>

>From 9908933f772a2ac33fed3836b8326a39848322f2 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 21 Feb 2025 15:03:34 +0000
Subject: [PATCH 5/9] Remove unnecessary functions for getting PR operation
 action

---
 llvm/include/llvm/CodeGen/TargetLowering.h    | 35 -------------------
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  2 --
 .../SelectionDAG/LegalizeVectorOps.cpp        |  7 ++--
 llvm/lib/CodeGen/TargetLoweringBase.cpp       |  5 +--
 4 files changed, 5 insertions(+), 44 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 52e57365dceab..a4c3d042fe3a4 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1639,25 +1639,6 @@ class TargetLoweringBase {
            getCondCodeAction(CC, VT) == Custom;
   }
 
-  /// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
-  /// InputVT should be treated. Either it's legal, needs to be promoted to a
-  /// larger size, needs to be expanded to some other code sequence, or the
-  /// target has a custom expander for it.
-  LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
-    unsigned AccI = (unsigned)AccVT.getSimpleVT().SimpleTy;
-    unsigned InputI = (unsigned)InputVT.getSimpleVT().SimpleTy;
-    assert(AccI < MVT::VALUETYPE_SIZE && InputI < MVT::VALUETYPE_SIZE &&
-           "Table isn't big enough!");
-    return PartialReduceMLAActions[AccI][InputI];
-  }
-
-  /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
-  /// legal or custom for this target.
-  bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
-    return getPartialReduceMLAAction(AccVT, InputVT) == Legal ||
-           getPartialReduceMLAAction(AccVT, InputVT) == Custom;
-  }
-
   /// If the action for this operation is to promote, this method returns the
   /// ValueType to promote to.
   MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
@@ -2723,16 +2704,6 @@ class TargetLoweringBase {
       setCondCodeAction(CCs, VT, Action);
   }
 
-  /// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
-  /// type InputVT should be treated by the target. Either it's legal, needs to
-  /// be promoted to a larger size, needs to be expanded to some other code
-  /// sequence, or the target has a custom expander for it.
-  void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
-                                 LegalizeAction Action) {
-    assert(AccVT.isValid() && InputVT.isValid() && "Table isn't big enough!");
-    PartialReduceMLAActions[AccVT.SimpleTy][InputVT.SimpleTy] = Action;
-  }
-
   /// If Opc/OrigVT is specified as being promoted, the promotion code defaults
   /// to trying a larger integer/fp until it can find one that works. If that
   /// default is insufficient, this method can be used by the target to override
@@ -3679,12 +3650,6 @@ class TargetLoweringBase {
   /// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
   uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
 
-  /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
-  /// nodes, keep a LegalizeAction which indicates how instruction selection
-  /// should deal with this operation.
-  LegalizeAction PartialReduceMLAActions[MVT::VALUETYPE_SIZE]
-                                        [MVT::VALUETYPE_SIZE];
-
   ValueTypeActionImpl ValueTypeActions;
 
 private:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 111ecb61c2c07..a25ce015be45e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -622,8 +622,6 @@ namespace {
     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
     SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
                                  const TargetLowering &TLI);
-    SDValue foldMulPARTIAL_REDUCE_MLA(SDNode *N);
-    SDValue foldExtendPARTIAL_REDUCE_MLA(SDNode *N);
 
     SDValue CombineExtLoad(SDNode *N);
     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index e43b14a47e565..de4447fb0cf1a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -469,6 +469,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::VECTOR_COMPRESS:
   case ISD::SCMP:
   case ISD::UCMP:
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
   case ISD::SMULFIX:
@@ -522,11 +524,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
     break;
   }
-  case ISD::PARTIAL_REDUCE_UMLA:
-  case ISD::PARTIAL_REDUCE_SMLA:
-    Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
-                                           Node->getOperand(1).getValueType());
-    break;
 
 #define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...)                          \
   case ISD::VPID: {                                                            \
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index af97ce20fdb10..f5ea3c0b47d6a 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -836,8 +836,9 @@ void TargetLoweringBase::initActions() {
     setOperationAction(ISD::SET_FPENV, VT, Expand);
     setOperationAction(ISD::RESET_FPENV, VT, Expand);
 
-    for (MVT InputVT : MVT::all_valuetypes())
-      setPartialReduceMLAAction(VT, InputVT, Expand);
+    // PartialReduceMLA operations default to expand.
+    setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
+                       Expand);
   }
 
   // Most targets ignore the @llvm.prefetch intrinsic.

>From 5ca85da2cf4b596d3826ab1bade94d81d5402b18 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 28 Feb 2025 08:56:22 +0000
Subject: [PATCH 6/9] Rename variables and move a comment.

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 39 +++++++++----------
 1 file changed, 19 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a25ce015be45e..03e31376fc5c6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12501,13 +12501,13 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
+// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), ZEXT(MulOpRHS)),
+// Splat(1)) into
+// PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS).
+// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), SEXT(MulOpRHS)),
+// Splat(1)) into
+// PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS).
 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
-  // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), ZEXT(MulOpRHS)),
-  // Splat(1)) into
-  // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS).
-  // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), SEXT(MulOpRHS)),
-  // Splat(1)) into
-  // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS).
   SDLoc DL(N);
 
   SDValue Op0 = N->getOperand(0);
@@ -12521,32 +12521,31 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
       !ConstantOne.isOne())
     return SDValue();
 
-  SDValue ExtMulOpLHS = Op1->getOperand(0);
-  SDValue ExtMulOpRHS = Op1->getOperand(1);
-  unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
-  unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
-  if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
-      !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+  SDValue LHS = Op1->getOperand(0);
+  SDValue RHS = Op1->getOperand(1);
+  unsigned LHSOpcode = LHS->getOpcode();
+  unsigned RHSOpcode = RHS->getOpcode();
+  if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
     return SDValue();
 
-  SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
-  SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
-  EVT MulOpLHSVT = MulOpLHS.getValueType();
-  if (MulOpLHSVT != MulOpRHS.getValueType())
+  SDValue LHSExtOp = LHS->getOperand(0);
+  SDValue RHSExtOp = RHS->getOperand(0);
+  EVT LHSExtOpVT = LHSExtOp.getValueType();
+  if (LHSExtOpVT != RHSExtOp.getValueType())
     return SDValue();
 
   // FIXME: Add a check to only perform the DAG combine if there is lowering
   // provided by the target
 
-  bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
-  bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+  bool LHSIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+  bool RHSIsSigned = RHSOpcode == ISD::SIGN_EXTEND;
   if (LHSIsSigned != RHSIsSigned)
     return SDValue();
 
   unsigned NewOpcode =
       LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Op0, MulOpLHS,
-                     MulOpRHS);
+  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Op0, LHSExtOp,
+                     RHSExtOp);
 }
 
 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {

>From ce0b0984b6e8b579248cb34a78e5e998a5df1151 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 28 Feb 2025 15:44:13 +0000
Subject: [PATCH 7/9] Add an additional check into the DAG combine

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 03e31376fc5c6..5955700d747ef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12510,14 +12510,15 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   SDLoc DL(N);
 
-  SDValue Op0 = N->getOperand(0);
+  SDValue Acc = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
+  SDValue Op2 = N->getOperand(2);
 
   if (Op1->getOpcode() != ISD::MUL)
     return SDValue();
 
   APInt ConstantOne;
-  if (!ISD::isConstantSplatVector(N->getOperand(2).getNode(), ConstantOne) ||
+  if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
       !ConstantOne.isOne())
     return SDValue();
 
@@ -12542,9 +12543,16 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   if (LHSIsSigned != RHSIsSigned)
     return SDValue();
 
+  bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+  EVT AccElemVT = Acc.getValueType().getVectorElementType();
+  if (LHSIsSigned != NodeIsSigned &&
+      (Op1.getValueType().getVectorElementType() != AccElemVT ||
+       Op2.getValueType().getVectorElementType() != AccElemVT))
+    return SDValue();
+
   unsigned NewOpcode =
       LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Op0, LHSExtOp,
+  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
                      RHSExtOp);
 }
 

>From e694bcf17271b0a209d89de25db54f750e0d0c34 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 28 Feb 2025 15:53:43 +0000
Subject: [PATCH 8/9] Rename variables to match those in the function

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 5955700d747ef..4bcc39d27f886 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12501,12 +12501,12 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), ZEXT(MulOpRHS)),
+// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
 // Splat(1)) into
-// PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS).
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), SEXT(MulOpRHS)),
+// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
+// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
 // Splat(1)) into
-// PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS).
+// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   SDLoc DL(N);
 

>From 24a613f84a30d6341e0743dd034bc69f9fc78dd6 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 28 Feb 2025 16:56:11 +0000
Subject: [PATCH 9/9] [AArch64][SVE] Add dot product lowering for
 PARTIAL_REDUCE_MLA node

Add lowering in tablegen for PARTIAL_REDUCE_U/SMLA ISD nodes.
Only happens when the combine has been performed on the ISD node.
Also adds in check to only do the DAG combine when the node can
then eventually be lowered, so changes neon tests too.
---
 llvm/include/llvm/CodeGen/TargetLowering.h    |  35 ++++
 .../include/llvm/Target/TargetSelectionDAG.td |   9 +
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |   6 +-
 .../SelectionDAG/LegalizeVectorOps.cpp        |   7 +-
 llvm/lib/CodeGen/TargetLoweringBase.cpp       |   5 +-
 .../Target/AArch64/AArch64ISelLowering.cpp    |  15 ++
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |   3 +
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  11 ++
 .../neon-partial-reduce-dot-product.ll        | 139 ++++++-------
 .../AArch64/sve-partial-reduce-dot-product.ll | 186 ++++--------------
 10 files changed, 176 insertions(+), 240 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a4c3d042fe3a4..52e57365dceab 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1639,6 +1639,25 @@ class TargetLoweringBase {
            getCondCodeAction(CC, VT) == Custom;
   }
 
+  /// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
+  /// InputVT should be treated. Either it's legal, needs to be promoted to a
+  /// larger size, needs to be expanded to some other code sequence, or the
+  /// target has a custom expander for it.
+  LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
+    unsigned AccI = (unsigned)AccVT.getSimpleVT().SimpleTy;
+    unsigned InputI = (unsigned)InputVT.getSimpleVT().SimpleTy;
+    assert(AccI < MVT::VALUETYPE_SIZE && InputI < MVT::VALUETYPE_SIZE &&
+           "Table isn't big enough!");
+    return PartialReduceMLAActions[AccI][InputI];
+  }
+
+  /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
+  /// legal or custom for this target.
+  bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
+    return getPartialReduceMLAAction(AccVT, InputVT) == Legal ||
+           getPartialReduceMLAAction(AccVT, InputVT) == Custom;
+  }
+
   /// If the action for this operation is to promote, this method returns the
   /// ValueType to promote to.
   MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
@@ -2704,6 +2723,16 @@ class TargetLoweringBase {
       setCondCodeAction(CCs, VT, Action);
   }
 
+  /// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
+  /// type InputVT should be treated by the target. Either it's legal, needs to
+  /// be promoted to a larger size, needs to be expanded to some other code
+  /// sequence, or the target has a custom expander for it.
+  void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
+                                 LegalizeAction Action) {
+    assert(AccVT.isValid() && InputVT.isValid() && "Table isn't big enough!");
+    PartialReduceMLAActions[AccVT.SimpleTy][InputVT.SimpleTy] = Action;
+  }
+
   /// If Opc/OrigVT is specified as being promoted, the promotion code defaults
   /// to trying a larger integer/fp until it can find one that works. If that
   /// default is insufficient, this method can be used by the target to override
@@ -3650,6 +3679,12 @@ class TargetLoweringBase {
   /// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
   uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
 
+  /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
+  /// nodes, keep a LegalizeAction which indicates how instruction selection
+  /// should deal with this operation.
+  LegalizeAction PartialReduceMLAActions[MVT::VALUETYPE_SIZE]
+                                        [MVT::VALUETYPE_SIZE];
+
   ValueTypeActionImpl ValueTypeActions;
 
 private:
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 42a5fbec95174..64c27dbace397 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -313,6 +313,10 @@ def SDTSubVecInsert : SDTypeProfile<1, 3, [ // subvector insert
   SDTCisSubVecOfVec<2, 1>, SDTCisSameAs<0,1>, SDTCisInt<3>
 ]>;
 
+def SDTPartialReduceMLA : SDTypeProfile<1, 3, [ // partial reduce mla
+  SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>
+]>;
+
 def SDTPrefetch : SDTypeProfile<0, 4, [     // prefetch
   SDTCisPtrTy<0>, SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>, SDTCisInt<1>
 ]>;
@@ -513,6 +517,11 @@ def vecreduce_fmax  : SDNode<"ISD::VECREDUCE_FMAX", SDTFPVecReduce>;
 def vecreduce_fminimum : SDNode<"ISD::VECREDUCE_FMINIMUM", SDTFPVecReduce>;
 def vecreduce_fmaximum : SDNode<"ISD::VECREDUCE_FMAXIMUM", SDTFPVecReduce>;
 
+def partial_reduce_umla : SDNode<"ISD::PARTIAL_REDUCE_UMLA",
+                                 SDTPartialReduceMLA>;
+def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
+                                 SDTPartialReduceMLA>;
+
 def fadd       : SDNode<"ISD::FADD"       , SDTFPBinOp, [SDNPCommutative]>;
 def fsub       : SDNode<"ISD::FSUB"       , SDTFPBinOp>;
 def fmul       : SDNode<"ISD::FMUL"       , SDTFPBinOp, [SDNPCommutative]>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 4bcc39d27f886..c80cd41aed783 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12535,8 +12535,10 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   if (LHSExtOpVT != RHSExtOp.getValueType())
     return SDValue();
 
-  // FIXME: Add a check to only perform the DAG combine if there is lowering
-  // provided by the target
+  // Only perform the DAG combine if there is custom lowering provided by the
+  // target
+  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), LHSExtOpVT))
+    return SDValue();
 
   bool LHSIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
   bool RHSIsSigned = RHSOpcode == ISD::SIGN_EXTEND;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index de4447fb0cf1a..e43b14a47e565 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -469,8 +469,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::VECTOR_COMPRESS:
   case ISD::SCMP:
   case ISD::UCMP:
-  case ISD::PARTIAL_REDUCE_UMLA:
-  case ISD::PARTIAL_REDUCE_SMLA:
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
   case ISD::SMULFIX:
@@ -524,6 +522,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
     break;
   }
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
+                                           Node->getOperand(1).getValueType());
+    break;
 
 #define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...)                          \
   case ISD::VPID: {                                                            \
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f5ea3c0b47d6a..af97ce20fdb10 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -836,9 +836,8 @@ void TargetLoweringBase::initActions() {
     setOperationAction(ISD::SET_FPENV, VT, Expand);
     setOperationAction(ISD::RESET_FPENV, VT, Expand);
 
-    // PartialReduceMLA operations default to expand.
-    setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
-                       Expand);
+    for (MVT InputVT : MVT::all_valuetypes())
+      setPartialReduceMLAAction(VT, InputVT, Expand);
   }
 
   // Most targets ignore the @llvm.prefetch intrinsic.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 50be082777835..b9923ce603a13 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1581,6 +1581,21 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::MSTORE, VT, Custom);
     }
 
+    for (MVT VT : MVT::integer_scalable_vector_valuetypes()) {
+      if (!EnablePartialReduceNodes)
+        break;
+      for (MVT InnerVT : MVT::integer_scalable_vector_valuetypes()) {
+        ElementCount VTElemCount = VT.getVectorElementCount();
+        if (VTElemCount.getKnownMinValue() == 1)
+          continue;
+        if (VTElemCount * 4 == InnerVT.getVectorElementCount())
+          setPartialReduceMLAAction(VT, InnerVT, Custom);
+        if (InnerVT.getVectorElementType().getSizeInBits() * 4 ==
+            VT.getVectorElementType().getSizeInBits())
+          setPartialReduceMLAAction(VT, InnerVT, Legal);
+      }
+    }
+
     // Firstly, exclude all scalable vector extending loads/truncating stores,
     // include both integer and floating scalable vector.
     for (MVT VT : MVT::scalable_vector_valuetypes()) {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 93a6100ce54e9..deb3bb0d4dbd9 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -143,6 +143,9 @@ def HasFuseAES       : Predicate<"Subtarget->hasFuseAES()">,
                                  "fuse-aes">;
 def HasSVE           : Predicate<"Subtarget->isSVEAvailable()">,
                                  AssemblerPredicateWithAll<(all_of FeatureSVE), "sve">;
+def HasSVEorStreamingSVE 
+                     : Predicate<"Subtarget->isSVEorStreamingSVEAvailable()">, 
+                                 AssemblerPredicateWithAll<(all_of FeatureSVE), "sve">;
 def HasSVEB16B16     : Predicate<"Subtarget->isSVEorStreamingSVEAvailable() && Subtarget->hasSVEB16B16()">,
                                  AssemblerPredicateWithAll<(all_of FeatureSVEB16B16), "sve-b16b16">;
 def HasSVE2          : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2()">,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 28aecd14e33fa..f8ee2d7c1de0c 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -655,6 +655,17 @@ let Predicates = [HasSVE_or_SME] in {
   defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>;
   defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>;
 
+  let Predicates = [HasSVEorStreamingSVE] in {
+    def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)),
+              (UDOT_ZZZ_S $Acc, $MulLHS, $MulRHS)>;
+    def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)),
+              (SDOT_ZZZ_S $Acc, $MulLHS, $MulRHS)>;
+    def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
+              (UDOT_ZZZ_D $Acc, $MulLHS, $MulRHS)>;
+    def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
+              (SDOT_ZZZ_D $Acc, $MulLHS, $MulRHS)>;
+  } // End HasSVEorStreamingSVE
+
   defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>;
   defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>;
 
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 3938a57d0152c..40daf8ffb63ea 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -12,15 +12,13 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: udot:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    ushll v3.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    ushll v4.8h, v2.8b, #0
-; CHECK-NODOT-NEXT:    ushll2 v1.8h, v1.16b, #0
-; CHECK-NODOT-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    umlal v0.4s, v4.4h, v3.4h
-; CHECK-NODOT-NEXT:    umull v5.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    umlal2 v5.4s, v4.8h, v3.8h
-; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT:    umull v3.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT:    umull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOT-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = zext <16 x i8> %s to <16 x i32>
@@ -37,19 +35,17 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: udot_narrow:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    umull v1.8h, v2.8b, v1.8b
 ; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT:    umull v3.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    umull2 v4.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v1.4h
 ; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
-; CHECK-NODOT-NEXT:    umlal v3.4s, v6.4h, v5.4h
-; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
 ; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
@@ -66,15 +62,13 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: sdot:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    sshll v3.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    sshll v4.8h, v2.8b, #0
-; CHECK-NODOT-NEXT:    sshll2 v1.8h, v1.16b, #0
-; CHECK-NODOT-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    smlal v0.4s, v4.4h, v3.4h
-; CHECK-NODOT-NEXT:    smull v5.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
-; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT:    smull v3.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT:    smull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOT-NEXT:    saddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOT-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = sext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
@@ -91,19 +85,17 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: sdot_narrow:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    sshll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    smull v1.8h, v2.8b, v1.8b
 ; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT:    smull v3.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    smull2 v4.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    sshll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v1.4h
 ; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
-; CHECK-NODOT-NEXT:    smlal v3.4s, v6.4h, v5.4h
-; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
 ; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
@@ -231,27 +223,19 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
 ;
 ; CHECK-NODOT-LABEL: udot_8to64:
 ; CHECK-NODOT:       // %bb.0: // %entry
-; CHECK-NODOT-NEXT:    ushll v4.8h, v3.8b, #0
-; CHECK-NODOT-NEXT:    ushll v5.8h, v2.8b, #0
-; CHECK-NODOT-NEXT:    ushll2 v3.8h, v3.16b, #0
-; CHECK-NODOT-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    ushll v6.4s, v4.4h, #0
-; CHECK-NODOT-NEXT:    ushll v7.4s, v5.4h, #0
+; CHECK-NODOT-NEXT:    umull v4.8h, v2.8b, v3.8b
+; CHECK-NODOT-NEXT:    umull2 v2.8h, v2.16b, v3.16b
+; CHECK-NODOT-NEXT:    ushll v3.4s, v4.4h, #0
+; CHECK-NODOT-NEXT:    ushll v5.4s, v2.4h, #0
 ; CHECK-NODOT-NEXT:    ushll2 v4.4s, v4.8h, #0
-; CHECK-NODOT-NEXT:    ushll2 v5.4s, v5.8h, #0
-; CHECK-NODOT-NEXT:    ushll2 v16.4s, v3.8h, #0
-; CHECK-NODOT-NEXT:    ushll2 v17.4s, v2.8h, #0
-; CHECK-NODOT-NEXT:    ushll v3.4s, v3.4h, #0
-; CHECK-NODOT-NEXT:    ushll v2.4s, v2.4h, #0
-; CHECK-NODOT-NEXT:    umlal2 v1.2d, v7.4s, v6.4s
-; CHECK-NODOT-NEXT:    umlal v0.2d, v7.2s, v6.2s
-; CHECK-NODOT-NEXT:    umull2 v18.2d, v5.4s, v4.4s
-; CHECK-NODOT-NEXT:    umull v4.2d, v5.2s, v4.2s
-; CHECK-NODOT-NEXT:    umlal2 v1.2d, v17.4s, v16.4s
-; CHECK-NODOT-NEXT:    umlal v0.2d, v17.2s, v16.2s
-; CHECK-NODOT-NEXT:    umlal2 v18.2d, v2.4s, v3.4s
-; CHECK-NODOT-NEXT:    umlal v4.2d, v2.2s, v3.2s
-; CHECK-NODOT-NEXT:    add v1.2d, v18.2d, v1.2d
+; CHECK-NODOT-NEXT:    ushll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NODOT-NEXT:    uaddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOT-NEXT:    uaddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NODOT-NEXT:    uaddl v4.2d, v4.2s, v5.2s
+; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT:    uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT:    add v1.2d, v3.2d, v1.2d
 ; CHECK-NODOT-NEXT:    add v0.2d, v4.2d, v0.2d
 ; CHECK-NODOT-NEXT:    ret
 entry:
@@ -274,27 +258,19 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
 ;
 ; CHECK-NODOT-LABEL: sdot_8to64:
 ; CHECK-NODOT:       // %bb.0: // %entry
-; CHECK-NODOT-NEXT:    sshll v4.8h, v3.8b, #0
-; CHECK-NODOT-NEXT:    sshll v5.8h, v2.8b, #0
-; CHECK-NODOT-NEXT:    sshll2 v3.8h, v3.16b, #0
-; CHECK-NODOT-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    sshll v6.4s, v4.4h, #0
-; CHECK-NODOT-NEXT:    sshll v7.4s, v5.4h, #0
+; CHECK-NODOT-NEXT:    smull v4.8h, v2.8b, v3.8b
+; CHECK-NODOT-NEXT:    smull2 v2.8h, v2.16b, v3.16b
+; CHECK-NODOT-NEXT:    sshll v3.4s, v4.4h, #0
+; CHECK-NODOT-NEXT:    sshll v5.4s, v2.4h, #0
 ; CHECK-NODOT-NEXT:    sshll2 v4.4s, v4.8h, #0
-; CHECK-NODOT-NEXT:    sshll2 v5.4s, v5.8h, #0
-; CHECK-NODOT-NEXT:    sshll2 v16.4s, v3.8h, #0
-; CHECK-NODOT-NEXT:    sshll2 v17.4s, v2.8h, #0
-; CHECK-NODOT-NEXT:    sshll v3.4s, v3.4h, #0
-; CHECK-NODOT-NEXT:    sshll v2.4s, v2.4h, #0
-; CHECK-NODOT-NEXT:    smlal2 v1.2d, v7.4s, v6.4s
-; CHECK-NODOT-NEXT:    smlal v0.2d, v7.2s, v6.2s
-; CHECK-NODOT-NEXT:    smull2 v18.2d, v5.4s, v4.4s
-; CHECK-NODOT-NEXT:    smull v4.2d, v5.2s, v4.2s
-; CHECK-NODOT-NEXT:    smlal2 v1.2d, v17.4s, v16.4s
-; CHECK-NODOT-NEXT:    smlal v0.2d, v17.2s, v16.2s
-; CHECK-NODOT-NEXT:    smlal2 v18.2d, v2.4s, v3.4s
-; CHECK-NODOT-NEXT:    smlal v4.2d, v2.2s, v3.2s
-; CHECK-NODOT-NEXT:    add v1.2d, v18.2d, v1.2d
+; CHECK-NODOT-NEXT:    sshll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    saddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NODOT-NEXT:    saddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOT-NEXT:    saddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NODOT-NEXT:    saddl v4.2d, v4.2s, v5.2s
+; CHECK-NODOT-NEXT:    saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT:    saddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT:    add v1.2d, v3.2d, v1.2d
 ; CHECK-NODOT-NEXT:    add v0.2d, v4.2d, v0.2d
 ; CHECK-NODOT-NEXT:    ret
 entry:
@@ -555,10 +531,9 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
 define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-LABEL: not_udot:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
-; CHECK-NEXT:    umlal v0.4s, v2.4h, v1.4h
-; CHECK-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
 ; CHECK-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index d7bab3297cf29..5974bac348531 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -11,24 +11,7 @@ define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: udot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    mul z3.s, z4.s, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
-; CHECK-NEWLOWERING-NEXT:    mad z1.s, p0/m, z2.s, z3.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    udot z0.s, z1.b, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -46,24 +29,7 @@ define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    udot z0.d, z1.h, z2.h
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -81,24 +47,7 @@ define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    mul z3.s, z4.s, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z6.s, z5.s
-; CHECK-NEWLOWERING-NEXT:    mad z1.s, p0/m, z2.s, z3.s
-; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    sdot z0.s, z1.b, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -116,24 +65,7 @@ define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_wide:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    sdot z0.d, z1.h, z2.h
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -845,11 +777,11 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEXT:    and z2.h, z2.h, #0xff
 ; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    uunpklo z3.s, z2.h
-; CHECK-NEXT:    uunpklo z4.s, z1.h
-; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEXT:    uunpklo z4.s, z2.h
 ; CHECK-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEXT:    mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    mla z0.s, p0/m, z3.s, z4.s
 ; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
 ;
@@ -879,11 +811,11 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEXT:    and z1.s, z1.s, #0xffff
 ; CHECK-NEXT:    and z2.s, z2.s, #0xffff
 ; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    uunpklo z3.d, z2.s
-; CHECK-NEXT:    uunpklo z4.d, z1.s
-; CHECK-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEXT:    uunpklo z3.d, z1.s
+; CHECK-NEXT:    uunpklo z4.d, z2.s
 ; CHECK-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEXT:    mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEXT:    mla z0.d, p0/m, z3.d, z4.d
 ; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
 ;
@@ -1248,48 +1180,24 @@ define <vscale x 2 x i16> @udot_nxv8i8_promote (<vscale x 2 x i16> %acc, <vscale
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    uunpklo z3.s, z2.h
-; CHECK-NEXT:    uunpklo z4.s, z1.h
-; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    mul z1.h, z1.h, z2.h
+; CHECK-NEXT:    uunpklo z2.s, z1.h
 ; CHECK-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEXT:    uunpklo z5.d, z3.s
-; CHECK-NEXT:    uunpklo z6.d, z4.s
-; CHECK-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEXT:    uunpkhi z5.d, z2.s
-; CHECK-NEXT:    uunpkhi z6.d, z1.s
-; CHECK-NEXT:    mul z3.d, z4.d, z3.d
-; CHECK-NEXT:    uunpklo z2.d, z2.s
-; CHECK-NEXT:    uunpklo z1.d, z1.s
-; CHECK-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEXT:    mad z1.d, p0/m, z2.d, z3.d
+; CHECK-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEXT:    uunpklo z4.d, z1.s
+; CHECK-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    add z0.d, z0.d, z3.d
+; CHECK-NEXT:    add z2.d, z2.d, z4.d
 ; CHECK-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEXT:    add z0.d, z2.d, z0.d
 ; CHECK-NEXT:    ret
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_nxv8i8_promote:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEWLOWERING-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-NEXT:    udot z0.d, z1.h, z2.h
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i16>
@@ -1305,49 +1213,25 @@ define <vscale x 2 x i16> @sdot_nxv8i8_promote (<vscale x 2 x i16> %acc, <vscale
 ; CHECK-NEXT:    ptrue p0.h
 ; CHECK-NEXT:    sxtb z1.h, p0/m, z1.h
 ; CHECK-NEXT:    sxtb z2.h, p0/m, z2.h
-; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    sunpklo z3.s, z2.h
-; CHECK-NEXT:    sunpklo z4.s, z1.h
-; CHECK-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEXT:    sunpklo z5.d, z3.s
-; CHECK-NEXT:    sunpklo z6.d, z4.s
-; CHECK-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEXT:    sunpkhi z5.d, z2.s
-; CHECK-NEXT:    sunpkhi z6.d, z1.s
-; CHECK-NEXT:    mul z3.d, z4.d, z3.d
-; CHECK-NEXT:    sunpklo z2.d, z2.s
-; CHECK-NEXT:    sunpklo z1.d, z1.s
-; CHECK-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEXT:    mad z1.d, p0/m, z2.d, z3.d
+; CHECK-NEXT:    mul z1.h, z1.h, z2.h
+; CHECK-NEXT:    uunpklo z2.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEXT:    uunpklo z4.d, z1.s
+; CHECK-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    add z0.d, z0.d, z3.d
+; CHECK-NEXT:    add z2.d, z2.d, z4.d
 ; CHECK-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEXT:    add z0.d, z2.d, z0.d
 ; CHECK-NEXT:    ret
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_nxv8i8_promote:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.h
-; CHECK-NEWLOWERING-NEXT:    sxtb z1.h, p0/m, z1.h
 ; CHECK-NEWLOWERING-NEXT:    sxtb z2.h, p0/m, z2.h
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mul z3.d, z4.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mad z1.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    sxtb z1.h, p0/m, z1.h
+; CHECK-NEWLOWERING-NEXT:    sdot z0.d, z1.h, z2.h
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i8> %a to <vscale x 8 x i16>



More information about the llvm-commits mailing list