[llvm] [SelectionDAG] Add PARTIAL_REDUCE_U/SMLA ISD Nodes (PR #125207)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 4 05:57:38 PST 2025


https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/125207

>From c264dc2ed989da10717ed0529d1af5ee9815e72b Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 31 Jan 2025 11:44:55 +0000
Subject: [PATCH 1/3] [AArch64] Add PARTIAL_REDUCE_U/SMLA ISD Nodes

Add signed and unsigned PARTIAL_REDUCE_MLA ISD nodes.
Add command line argument (new-partial-reduce-lowering) that
indicates whether the intrinsic experimental_vector_partial_
reduce_add will be transformed into the new ISD node.
Lowering with the new ISD nodes will, for now, always be done as
an expand.
---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  14 +
 llvm/include/llvm/CodeGen/SelectionDAG.h      |   7 +
 .../SelectionDAG/LegalizeIntegerTypes.cpp     |  21 +
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |   4 +
 .../SelectionDAG/LegalizeVectorTypes.cpp      |  17 +
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  17 +
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  27 +
 .../SelectionDAG/SelectionDAGDumper.cpp       |   5 +
 .../AArch64/sve-partial-reduce-dot-product.ll | 709 ++++++++++++++++--
 .../AArch64/sve-partial-reduce-wide-add.ll    |  49 ++
 10 files changed, 796 insertions(+), 74 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index fd8784a4c10034..3f235ee358e0ed 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,6 +1451,20 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
+  // PARTIAL_REDUCE_*MLA (Accumulator, Input1, Input2)
+  // Partial reduction nodes. Input1 and Input2 are multiplied together before
+  // being reduced, by addition to the number of elements that Accumulator's
+  // type has.
+  // Input1 and Input2 must be the same type. Accumulator and the output must be
+  // the same type.
+  // The number of elements in Input1 and Input2 must be a positive integer
+  // multiple of the number of elements in the Accumulator / output type.
+  // All operands, as well as the output, must have the same element type.
+  // Operands: Accumulator, Input1, Input2
+  // Outputs: Output
+  PARTIAL_REDUCE_SMLA,
+  PARTIAL_REDUCE_UMLA,
+
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
   // Outputs: output chain, glue
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 461c0c1ead16d2..0fc6f6ccf85bd9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1607,6 +1607,13 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
+  // Expands PARTIAL_REDUCE_S/UMLA nodes
+  // \p Acc Accumulator for where the result is stored for the partial reduction
+  // operation.
+  // \p Input1 First input for the partial reduction operation
+  // \p Input2 Second input for the partial reduction operation
+  SDValue expandPartialReduceMLA(SDNode *N);
+
   /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
   /// its operands and ReducedTY is the intrinsic's return type.
   SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 625052be657ca0..3a9518ea569ebc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -159,6 +159,11 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
     Res = PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(N);
     break;
 
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
+    break;
+
   case ISD::SIGN_EXTEND:
   case ISD::VP_SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
@@ -2076,6 +2081,10 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::VECTOR_FIND_LAST_ACTIVE:
     Res = PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(N, OpNo);
     break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
+    break;
   }
 
   // If the result is null, the sub-method took care of registering results etc.
@@ -2824,6 +2833,12 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
   return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+  SDValue Res = DAG.expandPartialReduceMLA(N);
+  ReplaceValueWith(SDValue(N, 0), Res);
+  return SDValue();
+}
+
 //===----------------------------------------------------------------------===//
 //  Integer Result Expansion
 //===----------------------------------------------------------------------===//
@@ -6139,6 +6154,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N) {
   return DAG.getNode(ISD::VECTOR_FIND_LAST_ACTIVE, SDLoc(N), NVT, N->ops());
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+  SDValue Res = DAG.expandPartialReduceMLA(N);
+  ReplaceValueWith(SDValue(N, 0), Res);
+  return SDValue();
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
   EVT OutVT = N->getValueType(0);
   EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index f13f70e66cfaa6..cb9c1b239c0fa9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -379,6 +379,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
   SDValue PromoteIntRes_PATCHPOINT(SDNode *N);
   SDValue PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N);
+  SDValue PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N);
 
   // Integer Operand Promotion.
   bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
@@ -430,6 +431,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);
+  SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N);
 
   void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
   void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
@@ -968,6 +970,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
 
   // Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
   bool SplitVectorOperand(SDNode *N, unsigned OpNo);
@@ -999,6 +1002,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
   SDValue SplitVecOp_VP_CttzElements(SDNode *N);
   SDValue SplitVecOp_VECTOR_HISTOGRAM(SDNode *N);
+  SDValue SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N);
 
   //===--------------------------------------------------------------------===//
   // Vector Widening Support: LegalizeVectorTypes.cpp
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1000235ab4061f..b01470028981e7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1373,6 +1373,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::EXPERIMENTAL_VP_REVERSE:
     SplitVecRes_VP_REVERSE(N, Lo, Hi);
     break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    SplitVecRes_PARTIAL_REDUCE_MLA(N);
   }
 
   // If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -3182,6 +3185,11 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
   std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
 }
 
+void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+  SDValue Res = DAG.expandPartialReduceMLA(N);
+  ReplaceValueWith(SDValue(N, 0), Res);
+}
+
 void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
 
   SDValue Op0Lo, Op0Hi, Op1Lo, Op1Hi;
@@ -3381,6 +3389,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     Res = SplitVecOp_VECTOR_HISTOGRAM(N);
     break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
   }
 
   // If the result is null, the sub-method took care of registering results etc.
@@ -4435,6 +4446,12 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
                                 MMO, IndexType);
 }
 
+SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+  SDValue Res = DAG.expandPartialReduceMLA(N);
+  ReplaceValueWith(SDValue(N, 0), Res);
+  return SDValue();
+}
+
 //===----------------------------------------------------------------------===//
 //  Result Vector Widening
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b416c0efbbc4fc..7240e4e00dfa07 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2473,6 +2473,23 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
+SDValue SelectionDAG::expandPartialReduceMLA(SDNode *N) {
+  SDLoc DL(N);
+  SDValue Acc = N->getOperand(0);
+  SDValue Input1 = N->getOperand(1);
+  SDValue Input2 = N->getOperand(2);
+
+  EVT FullTy = Input1.getValueType();
+
+  SDValue Input = Input1;
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(Input2.getNode(), ConstantOne) ||
+      !ConstantOne.isOne())
+    Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
+
+  return getPartialReduceAdd(DL, Acc.getValueType(), Acc, Input);
+}
+
 SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
                                           SDValue Op2) {
   EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 428e7a316d247b..144439f136ff16 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -135,6 +135,10 @@ static cl::opt<unsigned> SwitchPeelThreshold(
              "switch statement. A value greater than 100 will void this "
              "optimization"));
 
+static cl::opt<bool> NewPartialReduceLowering(
+    "new-partial-reduce-lowering", cl::init(false), cl::ReallyHidden,
+    cl::desc("Use the new method of lowering partial reductions."));
+
 // Limit the width of DAG chains. This is important in general to prevent
 // DAG-based analysis from blowing up. For example, alias analysis and
 // load clustering may not complete in reasonable time. It is difficult to
@@ -8118,6 +8122,29 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
+    if (NewPartialReduceLowering) {
+      SDValue Acc = getValue(I.getOperand(0));
+      EVT AccVT = Acc.getValueType();
+      SDValue Input = getValue(I.getOperand(1));
+      EVT InputVT = Input.getValueType();
+
+      assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
+             "Expected operands to have the same vector element type!");
+      assert(
+          InputVT.getVectorElementCount().getKnownMinValue() %
+                  AccVT.getVectorElementCount().getKnownMinValue() ==
+              0 &&
+          "Expected the element count of the Input operand to be a positive "
+          "integer multiple of the element count of the Accumulator operand!");
+
+      // ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
+      // same if ISD::PARTIAL_REDUCE_SMLA was chosen instead. It should be
+      // changed to its correct signedness when combining or expanding,
+      // according to extends being performed on Input.
+      setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, AccVT, Acc, Input,
+                               DAG.getConstant(1, sdl, InputVT)));
+      return;
+    }
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
       visitTargetIntrinsic(I, Intrinsic);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index f63c8dd3df1c83..a387c10679261b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -570,6 +570,11 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::VECTOR_FIND_LAST_ACTIVE:
     return "find_last_active";
 
+  case ISD::PARTIAL_REDUCE_UMLA:
+    return "partial_reduce_umla";
+  case ISD::PARTIAL_REDUCE_SMLA:
+    return "partial_reduce_smla";
+
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \
   case ISD::SDID:                                                              \
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 66f83c658ff4f2..16c0001dbdb838 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,12 +1,41 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -new-partial-reduce-lowering %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
 
 define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: udot:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    udot z0.s, z1.b, z2.b
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: udot:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: udot:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -16,10 +45,38 @@ entry:
 }
 
 define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: udot_wide:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    udot z0.d, z1.h, z2.h
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: udot_wide:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    udot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: udot_wide:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    udot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_wide:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -29,10 +86,38 @@ entry:
 }
 
 define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: sdot:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: sdot:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: sdot:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -42,10 +127,38 @@ entry:
 }
 
 define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: sdot_wide:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sdot z0.d, z1.h, z2.h
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: sdot_wide:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    sdot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: sdot_wide:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    sdot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_wide:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -82,6 +195,29 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ; CHECK-NOI8MM-NEXT:    mla z1.s, p0/m, z7.s, z24.s
 ; CHECK-NOI8MM-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -118,6 +254,29 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ; CHECK-NOI8MM-NEXT:    mla z1.s, p0/m, z7.s, z24.s
 ; CHECK-NOI8MM-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -127,15 +286,82 @@ entry:
 }
 
 define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: udot_8to64:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z4.s, #0 // =0x0
-; CHECK-NEXT:    udot z4.s, z2.b, z3.b
-; CHECK-NEXT:    sunpklo z2.d, z4.s
-; CHECK-NEXT:    sunpkhi z3.d, z4.s
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: udot_8to64:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT:    sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: udot_8to64:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NOI8MM-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NOI8MM-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NOI8MM-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NOI8MM-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NOI8MM-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_8to64:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #-2
+; CHECK-NEWLOWERING-NEXT:    str z9, [sp] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NEWLOWERING-NEXT:    .cfi_offset w29, -16
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z30.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z8.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z9.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mul z27.d, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mul z4.d, z4.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z2, z27
+; CHECK-NEWLOWERING-NEXT:    mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT:    ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT:    mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NEWLOWERING-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #2
+; CHECK-NEWLOWERING-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -146,15 +372,82 @@ entry:
 }
 
 define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b){
-; CHECK-LABEL: sdot_8to64:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z4.s, #0 // =0x0
-; CHECK-NEXT:    sdot z4.s, z2.b, z3.b
-; CHECK-NEXT:    sunpklo z2.d, z4.s
-; CHECK-NEXT:    sunpkhi z3.d, z4.s
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: sdot_8to64:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT:    sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: sdot_8to64:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NOI8MM-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NOI8MM-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NOI8MM-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NOI8MM-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NOI8MM-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #-2
+; CHECK-NEWLOWERING-NEXT:    str z9, [sp] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NEWLOWERING-NEXT:    .cfi_offset w29, -16
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z30.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z8.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z9.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mul z27.d, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mul z4.d, z4.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z2, z27
+; CHECK-NEWLOWERING-NEXT:    mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT:    ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT:    mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NEWLOWERING-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #2
+; CHECK-NEWLOWERING-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -231,6 +524,63 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ; CHECK-NOI8MM-NEXT:    addvl sp, sp, #2
 ; CHECK-NOI8MM-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #-2
+; CHECK-NEWLOWERING-NEXT:    str z9, [sp] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NEWLOWERING-NEXT:    .cfi_offset w29, -16
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z30.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z8.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z9.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mul z27.d, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mul z4.d, z4.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z2, z27
+; CHECK-NEWLOWERING-NEXT:    mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT:    ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT:    mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NEWLOWERING-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #2
+; CHECK-NEWLOWERING-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -307,6 +657,63 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ; CHECK-NOI8MM-NEXT:    addvl sp, sp, #2
 ; CHECK-NOI8MM-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #-2
+; CHECK-NEWLOWERING-NEXT:    str z9, [sp] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NEWLOWERING-NEXT:    .cfi_offset w29, -16
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEWLOWERING-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z30.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z8.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z9.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mul z27.d, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mul z4.d, z4.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z2, z27
+; CHECK-NEWLOWERING-NEXT:    mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT:    ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT:    mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NEWLOWERING-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #2
+; CHECK-NEWLOWERING-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -317,33 +724,93 @@ entry:
 }
 
 define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
-; CHECK-LABEL: udot_no_bin_op:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z2.b, #1 // =0x1
-; CHECK-NEXT:    udot z0.s, z1.b, z2.b
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: udot_no_bin_op:
+; CHECK-I8MM:       // %bb.0:
+; CHECK-I8MM-NEXT:    mov z2.b, #1 // =0x1
+; CHECK-I8MM-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: udot_no_bin_op:
+; CHECK-NOI8MM:       // %bb.0:
+; CHECK-NOI8MM-NEXT:    mov z2.b, #1 // =0x1
+; CHECK-NOI8MM-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    add z1.s, z2.s, z1.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z4.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    ret
   %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
   ret <vscale x 4 x i32> %partial.reduce
 }
 
 define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
-; CHECK-LABEL: sdot_no_bin_op:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z2.b, #1 // =0x1
-; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: sdot_no_bin_op:
+; CHECK-I8MM:       // %bb.0:
+; CHECK-I8MM-NEXT:    mov z2.b, #1 // =0x1
+; CHECK-I8MM-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: sdot_no_bin_op:
+; CHECK-NOI8MM:       // %bb.0:
+; CHECK-NOI8MM-NEXT:    mov z2.b, #1 // =0x1
+; CHECK-NOI8MM-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT:    add z1.s, z2.s, z1.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z4.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    ret
   %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
   ret <vscale x 4 x i32> %partial.reduce
 }
 
 define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
-; CHECK-LABEL: udot_no_bin_op_wide:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z2.h, #1 // =0x1
-; CHECK-NEXT:    udot z0.d, z1.h, z2.h
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: udot_no_bin_op_wide:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    mov z2.h, #1 // =0x1
+; CHECK-I8MM-NEXT:    udot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: udot_no_bin_op_wide:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    mov z2.h, #1 // =0x1
+; CHECK-NOI8MM-NEXT:    udot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z4.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
@@ -351,11 +818,31 @@ entry:
 }
 
 define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
-; CHECK-LABEL: sdot_no_bin_op_wide:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z2.h, #1 // =0x1
-; CHECK-NEXT:    sdot z0.d, z1.h, z2.h
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: sdot_no_bin_op_wide:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    mov z2.h, #1 // =0x1
+; CHECK-I8MM-NEXT:    sdot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: sdot_no_bin_op_wide:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    mov z2.h, #1 // =0x1
+; CHECK-NOI8MM-NEXT:    sdot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z4.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
@@ -363,32 +850,106 @@ entry:
 }
 
 define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
-; CHECK-LABEL: udot_no_bin_op_8to64:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z3.b, #1 // =0x1
-; CHECK-NEXT:    mov z4.s, #0 // =0x0
-; CHECK-NEXT:    udot z4.s, z2.b, z3.b
-; CHECK-NEXT:    sunpklo z2.d, z4.s
-; CHECK-NEXT:    sunpkhi z3.d, z4.s
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: udot_no_bin_op_8to64:
+; CHECK-I8MM:       // %bb.0:
+; CHECK-I8MM-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-I8MM-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT:    sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: udot_no_bin_op_8to64:
+; CHECK-NOI8MM:       // %bb.0:
+; CHECK-NOI8MM-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NOI8MM-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NOI8MM-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NOI8MM-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NOI8MM-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NOI8MM-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NOI8MM-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z25.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z6.d
+; CHECK-NEWLOWERING-NEXT:    add z4.d, z25.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z5.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z7.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z4.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    ret
   %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
   ret <vscale x 4 x i64> %partial.reduce
 }
 
 define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
-; CHECK-LABEL: sdot_no_bin_op_8to64:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z3.b, #1 // =0x1
-; CHECK-NEXT:    mov z4.s, #0 // =0x0
-; CHECK-NEXT:    sdot z4.s, z2.b, z3.b
-; CHECK-NEXT:    sunpklo z2.d, z4.s
-; CHECK-NEXT:    sunpkhi z3.d, z4.s
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEXT:    ret
+; CHECK-I8MM-LABEL: sdot_no_bin_op_8to64:
+; CHECK-I8MM:       // %bb.0:
+; CHECK-I8MM-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-I8MM-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT:    sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NOI8MM:       // %bb.0:
+; CHECK-NOI8MM-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NOI8MM-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NOI8MM-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NOI8MM-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NOI8MM-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NOI8MM-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NOI8MM-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z25.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z4.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z6.d
+; CHECK-NEWLOWERING-NEXT:    add z4.d, z25.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z5.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z7.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z4.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT:    ret
   %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
   ret <vscale x 4 x i64> %partial.reduce
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index b4b946c68566ed..62b5039259392c 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -1,6 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE2
 ; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 -new-partial-reduce-lowering %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
 
 define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
 ; CHECK-SVE2-LABEL: signed_wide_add_nxv4i32:
@@ -16,6 +17,14 @@ define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vsc
 ; CHECK-SVE-NEXT:    add z0.d, z0.d, z2.d
 ; CHECK-SVE-NEXT:    add z0.d, z1.d, z0.d
 ; CHECK-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
     %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
@@ -36,6 +45,14 @@ define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <v
 ; CHECK-SVE-NEXT:    add z0.d, z0.d, z2.d
 ; CHECK-SVE-NEXT:    add z0.d, z1.d, z0.d
 ; CHECK-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv4i32:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
     %input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
     %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
@@ -56,6 +73,14 @@ define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vsc
 ; CHECK-SVE-NEXT:    add z0.s, z0.s, z2.s
 ; CHECK-SVE-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv8i16:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
     %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
@@ -76,6 +101,14 @@ define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <v
 ; CHECK-SVE-NEXT:    add z0.s, z0.s, z2.s
 ; CHECK-SVE-NEXT:    add z0.s, z1.s, z0.s
 ; CHECK-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv8i16:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
     %input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
     %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
@@ -96,6 +129,14 @@ define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vsc
 ; CHECK-SVE-NEXT:    add z0.h, z0.h, z2.h
 ; CHECK-SVE-NEXT:    add z0.h, z1.h, z0.h
 ; CHECK-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv16i8:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    add z0.h, z0.h, z2.h
+; CHECK-NEWLOWERING-NEXT:    add z0.h, z1.h, z0.h
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
     %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
@@ -116,6 +157,14 @@ define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <v
 ; CHECK-SVE-NEXT:    add z0.h, z0.h, z2.h
 ; CHECK-SVE-NEXT:    add z0.h, z1.h, z0.h
 ; CHECK-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv16i8:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT:    add z0.h, z0.h, z2.h
+; CHECK-NEWLOWERING-NEXT:    add z0.h, z1.h, z0.h
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
     %input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
     %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)

>From df8f9121c13d21d073e5780c12cdacabd3f0ea6f Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 3 Feb 2025 15:17:31 +0000
Subject: [PATCH 2/3] Address comments on PR.

---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        | 11 +++---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  7 ----
 llvm/include/llvm/CodeGen/TargetLowering.h    |  4 ++
 llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp |  5 +++
 .../SelectionDAG/LegalizeIntegerTypes.cpp     | 21 ++++++----
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |  2 +-
 .../SelectionDAG/LegalizeVectorOps.cpp        |  6 +++
 .../SelectionDAG/LegalizeVectorTypes.cpp      |  4 +-
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 17 ---------
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  3 ++
 .../CodeGen/SelectionDAG/TargetLowering.cpp   | 38 +++++++++++++++++++
 .../Target/AArch64/AArch64ISelLowering.cpp    |  5 +++
 12 files changed, 82 insertions(+), 41 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 3f235ee358e0ed..422a70bb6641bd 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,17 +1451,16 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
-  // PARTIAL_REDUCE_*MLA (Accumulator, Input1, Input2)
-  // Partial reduction nodes. Input1 and Input2 are multiplied together before
-  // being reduced, by addition to the number of elements that Accumulator's
-  // type has.
+  // PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
+  // The partial reduction nodes sign-or zero extend Input1 and Input2 to
+  // the element type of Accumulator before multiplying their results.
+  // The multiplied result is then reduced using addition to the result
+  // type of Accumulator. The result is added to Accumulator and returned.
   // Input1 and Input2 must be the same type. Accumulator and the output must be
   // the same type.
   // The number of elements in Input1 and Input2 must be a positive integer
   // multiple of the number of elements in the Accumulator / output type.
   // All operands, as well as the output, must have the same element type.
-  // Operands: Accumulator, Input1, Input2
-  // Outputs: Output
   PARTIAL_REDUCE_SMLA,
   PARTIAL_REDUCE_UMLA,
 
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 0fc6f6ccf85bd9..461c0c1ead16d2 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1607,13 +1607,6 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
-  // Expands PARTIAL_REDUCE_S/UMLA nodes
-  // \p Acc Accumulator for where the result is stored for the partial reduction
-  // operation.
-  // \p Input1 First input for the partial reduction operation
-  // \p Input2 Second input for the partial reduction operation
-  SDValue expandPartialReduceMLA(SDNode *N);
-
   /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
   /// its operands and ReducedTY is the intrinsic's return type.
   SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9fcd2ac9514e56..b9fc21ae21139a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5622,6 +5622,10 @@ class TargetLowering : public TargetLoweringBase {
   // joining their results. SDValue() is returned when expansion did not happen.
   SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;
 
+  // Expands PARTIAL_REDUCE_S/UMLA nodes to a series of simpler operations,
+  // consisting of zext/sext, extract_subvector, mul and add operations.
+  SDValue expandPartialReduceMLA(SDNode *N, SelectionDAG &DAG) const;
+
 private:
   SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
                            const SDLoc &DL, DAGCombinerInfo &DCI) const;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index c6475f02199033..77fc529ab34db1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1245,6 +1245,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
         Node->getOpcode(),
         cast<MaskedHistogramSDNode>(Node)->getIndex().getValueType());
     break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Action = TLI.getOperationAction(Node->getOpcode(),
+                                    Node->getOperand(0).getValueType());
+    break;
   default:
     if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
       Action = TLI.getCustomOperationAction(*Node);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 3a9518ea569ebc..312f35f8106e52 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2083,7 +2083,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
     break;
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
-    Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
+    Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N, OpNo);
     break;
   }
 
@@ -2833,10 +2833,11 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
   return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
 }
 
-SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
-  SDValue Res = DAG.expandPartialReduceMLA(N);
-  ReplaceValueWith(SDValue(N, 0), Res);
-  return SDValue();
+SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N,
+                                                          unsigned OpNo) {
+  SmallVector<SDValue, 1> NewOps(N->ops());
+  NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
+  return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
 }
 
 //===----------------------------------------------------------------------===//
@@ -6155,9 +6156,13 @@ SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N) {
 }
 
 SDValue DAGTypeLegalizer::PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N) {
-  SDValue Res = DAG.expandPartialReduceMLA(N);
-  ReplaceValueWith(SDValue(N, 0), Res);
-  return SDValue();
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+  SDValue ExtAcc = GetPromotedInteger(N->getOperand(0));
+  SDValue V = DAG.getNode(N->getOpcode(), DL, NVT, ExtAcc, N->getOperand(1),
+                          N->getOperand(2));
+  return V;
 }
 
 SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index cb9c1b239c0fa9..e48aa357737802 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -431,7 +431,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);
-  SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N);
+  SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N, unsigned OpNo);
 
   void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
   void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 6ad08bce44b0a4..ab74e8e1bac001 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -503,6 +503,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::VECREDUCE_FMIN:
   case ISD::VECREDUCE_FMAXIMUM:
   case ISD::VECREDUCE_FMINIMUM:
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::VECTOR_FIND_LAST_ACTIVE:
     Action = TLI.getOperationAction(Node->getOpcode(),
                                     Node->getOperand(0).getValueType());
@@ -1195,6 +1197,10 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::VECREDUCE_FMINIMUM:
     Results.push_back(TLI.expandVecReduce(Node, DAG));
     return;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
+    return;
   case ISD::VECREDUCE_SEQ_FADD:
   case ISD::VECREDUCE_SEQ_FMUL:
     Results.push_back(TLI.expandVecReduceSeq(Node, DAG));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index b01470028981e7..97396627888f4f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3186,7 +3186,7 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
 }
 
 void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
-  SDValue Res = DAG.expandPartialReduceMLA(N);
+  SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
   ReplaceValueWith(SDValue(N, 0), Res);
 }
 
@@ -4447,7 +4447,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
 }
 
 SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
-  SDValue Res = DAG.expandPartialReduceMLA(N);
+  SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
   ReplaceValueWith(SDValue(N, 0), Res);
   return SDValue();
 }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 7240e4e00dfa07..b416c0efbbc4fc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2473,23 +2473,6 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
-SDValue SelectionDAG::expandPartialReduceMLA(SDNode *N) {
-  SDLoc DL(N);
-  SDValue Acc = N->getOperand(0);
-  SDValue Input1 = N->getOperand(1);
-  SDValue Input2 = N->getOperand(2);
-
-  EVT FullTy = Input1.getValueType();
-
-  SDValue Input = Input1;
-  APInt ConstantOne;
-  if (!ISD::isConstantSplatVector(Input2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
-    Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
-
-  return getPartialReduceAdd(DL, Acc.getValueType(), Acc, Input);
-}
-
 SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
                                           SDValue Op2) {
   EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 144439f136ff16..530e1ff3d0af07 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -135,6 +135,9 @@ static cl::opt<unsigned> SwitchPeelThreshold(
              "switch statement. A value greater than 100 will void this "
              "optimization"));
 
+// FIXME : This is a temporary flag, and is used to help transition to
+// performing lowering the proper way using the new PARTIAL_REDUCE_MLA ISD
+// nodes.
 static cl::opt<bool> NewPartialReduceLowering(
     "new-partial-reduce-lowering", cl::init(false), cl::ReallyHidden,
     cl::desc("Use the new method of lowering partial reductions."));
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 49ec47f4e8a702..f0743113df6864 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12114,3 +12114,41 @@ SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
   SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
 }
+
+SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
+                                               SelectionDAG &DAG) const {
+  SDLoc DL(N);
+  SDValue Acc = N->getOperand(0);
+  SDValue Input1 = N->getOperand(1);
+  SDValue Input2 = N->getOperand(2);
+
+  EVT ReducedTy = Acc.getValueType();
+  EVT FullTy = Input1.getValueType();
+
+  auto ExtendToAccEltVT = [&](SDValue V) {
+    unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
+                          ? ISD::ZERO_EXTEND
+                          : ISD::SIGN_EXTEND;
+    EVT ExtVT = V.getValueType().changeVectorElementType(
+        Acc.getValueType().getVectorElementType());
+    if (ExtVT != FullTy)
+      return DAG.getNode(ExtOpc, DL, ExtVT, V);
+    return V;
+  };
+
+  SDValue Input;
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(Input2.getNode(), ConstantOne) ||
+      !ConstantOne.isOne()) {
+    EVT NewVT =
+        EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
+                         FullTy.getVectorElementCount());
+    Input1 = ExtendToAccEltVT(Input1);
+    Input2 = ExtendToAccEltVT(Input2);
+    Input = DAG.getNode(ISD::MUL, DL, NewVT, Input1, Input2);
+  } else {
+    Input = ExtendToAccEltVT(Input1);
+  }
+
+  return DAG.getPartialReduceAdd(DL, ReducedTy, Acc, Input);
+}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index bd9994bcb669ca..8eabddf0b9a010 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1568,6 +1568,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::MSTORE, VT, Custom);
     }
 
+    for (MVT VT : MVT::scalable_vector_valuetypes()) {
+      setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Expand);
+      setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Expand);
+    }
+
     // Firstly, exclude all scalable vector extending loads/truncating stores,
     // include both integer and floating scalable vector.
     for (MVT VT : MVT::scalable_vector_valuetypes()) {

>From a4fb454a827fc7ebc596f8177c018f601e10709a Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Tue, 4 Feb 2025 13:55:24 +0000
Subject: [PATCH 3/3] Address comments

Move the getPartialReduceAdd function around.
Make the new codepath work for fixed length NEON vectors too.
---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |   5 -
 llvm/include/llvm/CodeGen/TargetLowering.h    |   9 +-
 .../SelectionDAG/LegalizeVectorOps.cpp        |   8 +-
 .../SelectionDAG/LegalizeVectorTypes.cpp      |   8 +-
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  29 --
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  31 +-
 .../CodeGen/SelectionDAG/TargetLowering.cpp   |  45 ++-
 .../Target/AArch64/AArch64ISelLowering.cpp    |   9 +-
 .../neon-partial-reduce-dot-product.ll        | 289 ++++++++++++++++++
 9 files changed, 365 insertions(+), 68 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 461c0c1ead16d2..cf8e4a3d2513b7 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1607,11 +1607,6 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
-  /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
-  /// its operands and ReducedTY is the intrinsic's return type.
-  SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
-                              SDValue Op2);
-
   /// Expands a node with multiple results to an FP or vector libcall. The
   /// libcall is expected to take all the operands of the \p Node followed by
   /// output pointers for each of the results. \p CallRetResNo can be optionally
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index b9fc21ae21139a..4c98d9927e5fff 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5624,7 +5624,14 @@ class TargetLowering : public TargetLoweringBase {
 
   // Expands PARTIAL_REDUCE_S/UMLA nodes to a series of simpler operations,
   // consisting of zext/sext, extract_subvector, mul and add operations.
-  SDValue expandPartialReduceMLA(SDNode *N, SelectionDAG &DAG) const;
+  SDValue expandPartialReduceMLA(SDLoc DL, SDValue Acc, SDValue Input1,
+                                 SDValue Input2, SelectionDAG &DAG) const;
+
+  // Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
+  // its operands and ReducedTY is the return type.
+  static SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, EVT FullTy,
+                                     SDValue Op1, SDValue Op2,
+                                     SelectionDAG &DAG);
 
 private:
   SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index ab74e8e1bac001..da86996e88ed88 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -1198,9 +1198,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
     Results.push_back(TLI.expandVecReduce(Node, DAG));
     return;
   case ISD::PARTIAL_REDUCE_UMLA:
-  case ISD::PARTIAL_REDUCE_SMLA:
-    Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
+  case ISD::PARTIAL_REDUCE_SMLA: {
+    SDLoc DL(Node);
+    Results.push_back(TLI.expandPartialReduceMLA(DL, Node->getOperand(0),
+                                                 Node->getOperand(1),
+                                                 Node->getOperand(2), DAG));
     return;
+  }
   case ISD::VECREDUCE_SEQ_FADD:
   case ISD::VECREDUCE_SEQ_FMUL:
     Results.push_back(TLI.expandVecReduceSeq(Node, DAG));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 97396627888f4f..ed6e1b57542afd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3186,7 +3186,9 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
 }
 
 void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
-  SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
+  SDLoc DL(N);
+  SDValue Res = TLI.expandPartialReduceMLA(
+      DL, N->getOperand(0), N->getOperand(1), N->getOperand(2), DAG);
   ReplaceValueWith(SDValue(N, 0), Res);
 }
 
@@ -4447,7 +4449,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
 }
 
 SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
-  SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
+  SDLoc DL(N);
+  SDValue Res = TLI.expandPartialReduceMLA(
+      DL, N->getOperand(0), N->getOperand(1), N->getOperand(2), DAG);
   ReplaceValueWith(SDValue(N, 0), Res);
   return SDValue();
 }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b416c0efbbc4fc..af7bc3ff69bab2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2473,35 +2473,6 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
-SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
-                                          SDValue Op2) {
-  EVT FullTy = Op2.getValueType();
-
-  unsigned Stride = ReducedTy.getVectorMinNumElements();
-  unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
-
-  // Collect all of the subvectors
-  std::deque<SDValue> Subvectors = {Op1};
-  for (unsigned I = 0; I < ScaleFactor; I++) {
-    auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
-    Subvectors.push_back(
-        getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
-  }
-
-  // Flatten the subvector tree
-  while (Subvectors.size() > 1) {
-    Subvectors.push_back(
-        getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
-    Subvectors.pop_front();
-    Subvectors.pop_front();
-  }
-
-  assert(Subvectors.size() == 1 &&
-         "There should only be one subvector after tree flattening");
-
-  return Subvectors[0];
-}
-
 /// Given a store node \p StoreNode, return true if it is safe to fold that node
 /// into \p FPNode, which expands to a library call with output pointers.
 static bool canFoldStoreIntoLibCallOutputPointers(StoreSDNode *StoreNode,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 530e1ff3d0af07..bc08b6c5580f77 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8125,21 +8125,19 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
+    SDValue Acc = getValue(I.getOperand(0));
+    EVT AccVT = Acc.getValueType();
+    SDValue Input = getValue(I.getOperand(1));
+    EVT InputVT = Input.getValueType();
+
+    assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
+           "Expected operands to have the same vector element type!");
+    assert(InputVT.getVectorElementCount().getKnownMinValue() %
+                   AccVT.getVectorElementCount().getKnownMinValue() ==
+               0 &&
+           "Expected the element count of the Input operand to be a positive "
+           "integer multiple of the element count of the Accumulator operand!");
     if (NewPartialReduceLowering) {
-      SDValue Acc = getValue(I.getOperand(0));
-      EVT AccVT = Acc.getValueType();
-      SDValue Input = getValue(I.getOperand(1));
-      EVT InputVT = Input.getValueType();
-
-      assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
-             "Expected operands to have the same vector element type!");
-      assert(
-          InputVT.getVectorElementCount().getKnownMinValue() %
-                  AccVT.getVectorElementCount().getKnownMinValue() ==
-              0 &&
-          "Expected the element count of the Input operand to be a positive "
-          "integer multiple of the element count of the Accumulator operand!");
-
       // ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
       // same if ISD::PARTIAL_REDUCE_SMLA was chosen instead. It should be
       // changed to its correct signedness when combining or expanding,
@@ -8154,9 +8152,8 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
       return;
     }
 
-    setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
-                                         getValue(I.getOperand(0)),
-                                         getValue(I.getOperand(1))));
+    setValue(&I, TLI.expandPartialReduceMLA(
+                     sdl, Acc, Input, DAG.getConstant(1, sdl, InputVT), DAG));
     return;
   }
   case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index f0743113df6864..c56779f151ac9c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -34,6 +34,7 @@
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Target/TargetMachine.h"
 #include <cctype>
+#include <deque>
 using namespace llvm;
 
 /// NOTE: The TargetMachine owns TLOF.
@@ -12115,20 +12116,15 @@ SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
 }
 
-SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
+SDValue TargetLowering::expandPartialReduceMLA(SDLoc DL, SDValue Acc,
+                                               SDValue Input1, SDValue Input2,
                                                SelectionDAG &DAG) const {
-  SDLoc DL(N);
-  SDValue Acc = N->getOperand(0);
-  SDValue Input1 = N->getOperand(1);
-  SDValue Input2 = N->getOperand(2);
-
   EVT ReducedTy = Acc.getValueType();
   EVT FullTy = Input1.getValueType();
 
   auto ExtendToAccEltVT = [&](SDValue V) {
-    unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
-                          ? ISD::ZERO_EXTEND
-                          : ISD::SIGN_EXTEND;
+    unsigned ExtOpc = V->getOpcode() == ISD::SIGN_EXTEND ? ISD::SIGN_EXTEND
+                                                         : ISD::ZERO_EXTEND;
     EVT ExtVT = V.getValueType().changeVectorElementType(
         Acc.getValueType().getVectorElementType());
     if (ExtVT != FullTy)
@@ -12150,5 +12146,34 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
     Input = ExtendToAccEltVT(Input1);
   }
 
-  return DAG.getPartialReduceAdd(DL, ReducedTy, Acc, Input);
+  return TargetLowering::getPartialReduceAdd(DL, ReducedTy, FullTy, Acc, Input,
+                                             DAG);
+}
+
+SDValue TargetLowering::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, EVT FullTy,
+                                            SDValue Op1, SDValue Op2,
+                                            SelectionDAG &DAG) {
+  unsigned Stride = ReducedTy.getVectorMinNumElements();
+  unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
+
+  // Collect all of the subvectors
+  std::deque<SDValue> Subvectors = {Op1};
+  for (unsigned I = 0; I < ScaleFactor; I++) {
+    auto SourceIndex = DAG.getVectorIdxConstant(I * Stride, DL);
+    Subvectors.push_back(
+        DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
+  }
+
+  // Flatten the subvector tree
+  while (Subvectors.size() > 1) {
+    Subvectors.push_back(
+        DAG.getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
+    Subvectors.pop_front();
+    Subvectors.pop_front();
+  }
+
+  assert(Subvectors.size() == 1 &&
+         "There should only be one subvector after tree flattening");
+
+  return Subvectors[0];
 }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8eabddf0b9a010..1fdea6b021c66c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1355,6 +1355,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::BSWAP, VT, Expand);
       setOperationAction(ISD::CTTZ, VT, Expand);
 
+      setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Expand);
+      setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Expand);
+
       for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) {
         setTruncStoreAction(VT, InnerVT, Expand);
         setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
@@ -21981,8 +21984,10 @@ static SDValue performIntrinsicCombine(SDNode *N,
       return Dot;
     if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
       return WideAdd;
-    return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
-                                   N->getOperand(1), N->getOperand(2));
+    SDValue Input = N->getOperand(2);
+    return TargetLowering::getPartialReduceAdd(SDLoc(N), N->getValueType(0),
+                                               Input.getValueType(),
+                                               N->getOperand(1), Input, DAG);
   }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 9ece9edb843439..4fa10d0e9c0f98 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -2,6 +2,7 @@
 ; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
 ; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
 ; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -new-partial-reduce-lowering < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
 
 define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-DOT-LABEL: udot:
@@ -19,6 +20,17 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-NODOT-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
 ; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    umull v3.8h, v2.8b, v1.8b
+; CHECK-NEWLOWERING-NEXT:    umull2 v1.8h, v2.16b, v1.16b
+; CHECK-NEWLOWERING-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT:    uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NEWLOWERING-NEXT:    uaddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NEWLOWERING-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEWLOWERING-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = zext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -46,6 +58,21 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ; CHECK-NODOT-NEXT:    uaddw v1.4s, v2.4s, v4.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_narrow:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NEWLOWERING-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-NEWLOWERING-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEWLOWERING-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -69,6 +96,17 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-NODOT-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
 ; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    smull v3.8h, v2.8b, v1.8b
+; CHECK-NEWLOWERING-NEXT:    smull2 v1.8h, v2.16b, v1.16b
+; CHECK-NEWLOWERING-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT:    saddw v0.4s, v0.4s, v3.4h
+; CHECK-NEWLOWERING-NEXT:    saddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NEWLOWERING-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEWLOWERING-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT:    ret
   %u.wide = sext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -96,6 +134,21 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ; CHECK-NODOT-NEXT:    saddw v1.4s, v2.4s, v4.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_narrow:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    smull v1.8h, v2.8b, v1.8b
+; CHECK-NEWLOWERING-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v3.4s, v1.8h, #0
+; CHECK-NEWLOWERING-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT:    saddw v0.4s, v0.4s, v1.4h
+; CHECK-NEWLOWERING-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    saddw v1.4s, v2.4s, v4.4h
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -121,6 +174,19 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-I8MM:       // %bb.0:
 ; CHECK-I8MM-NEXT:    usdot v0.4s, v1.16b, v2.16b
 ; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NEWLOWERING-NEXT:    smull v5.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NEWLOWERING-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -150,6 +216,23 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-I8MM:       // %bb.0:
 ; CHECK-I8MM-NEXT:    usdot v0.2s, v1.8b, v2.8b
 ; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot_narrow:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v2.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT:    smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NEWLOWERING-NEXT:    smlal v3.4s, v6.4h, v5.4h
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -175,6 +258,19 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
 ; CHECK-I8MM:       // %bb.0:
 ; CHECK-I8MM-NEXT:    usdot v0.4s, v2.16b, v1.16b
 ; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NEWLOWERING-NEXT:    smull v5.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NEWLOWERING-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT:    ret
   %u.wide = sext <16 x i8> %u to <16 x i32>
   %s.wide = zext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -204,6 +300,23 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-I8MM:       // %bb.0:
 ; CHECK-I8MM-NEXT:    usdot v0.2s, v2.8b, v1.8b
 ; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot_narrow:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT:    smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NEWLOWERING-NEXT:    smlal v3.4s, v6.4h, v5.4h
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -237,6 +350,24 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
 ; CHECK-NODOT-NEXT:    add v1.2d, v3.2d, v1.2d
 ; CHECK-NODOT-NEXT:    add v0.2d, v4.2d, v0.2d
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_8to64:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    umull v4.8h, v2.8b, v3.8b
+; CHECK-NEWLOWERING-NEXT:    umull2 v2.8h, v2.16b, v3.16b
+; CHECK-NEWLOWERING-NEXT:    ushll v3.4s, v4.4h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v5.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v4.4s, v4.8h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v2.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT:    uaddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NEWLOWERING-NEXT:    uaddw v0.2d, v0.2d, v3.2s
+; CHECK-NEWLOWERING-NEXT:    uaddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT:    uaddl v4.2d, v4.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT:    uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NEWLOWERING-NEXT:    uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NEWLOWERING-NEXT:    add v1.2d, v3.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT:    add v0.2d, v4.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <16 x i8> %a to <16 x i64>
   %b.wide = zext <16 x i8> %b to <16 x i64>
@@ -272,6 +403,24 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
 ; CHECK-NODOT-NEXT:    add v1.2d, v3.2d, v1.2d
 ; CHECK-NODOT-NEXT:    add v0.2d, v4.2d, v0.2d
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    smull v4.8h, v2.8b, v3.8b
+; CHECK-NEWLOWERING-NEXT:    smull2 v2.8h, v2.16b, v3.16b
+; CHECK-NEWLOWERING-NEXT:    sshll v3.4s, v4.4h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v5.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v4.4s, v4.8h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v2.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT:    saddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NEWLOWERING-NEXT:    saddw v0.2d, v0.2d, v3.2s
+; CHECK-NEWLOWERING-NEXT:    saddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT:    saddl v4.2d, v4.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT:    saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NEWLOWERING-NEXT:    saddw v0.2d, v0.2d, v2.2s
+; CHECK-NEWLOWERING-NEXT:    add v1.2d, v3.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT:    add v0.2d, v4.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <16 x i8> %a to <16 x i64>
   %b.wide = sext <16 x i8> %b to <16 x i64>
@@ -315,6 +464,32 @@ define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
 ; CHECK-I8MM-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
 ; CHECK-I8MM-NEXT:    saddw v0.2d, v0.2d, v4.2s
 ; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v5.8h, v3.8b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v6.4s, v4.4h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v7.4s, v5.4h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v4.4s, v4.8h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v5.4s, v5.8h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v16.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v17.4s, v3.8h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v2.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v3.4s, v3.4h, #0
+; CHECK-NEWLOWERING-NEXT:    smlal2 v1.2d, v6.4s, v7.4s
+; CHECK-NEWLOWERING-NEXT:    smlal v0.2d, v6.2s, v7.2s
+; CHECK-NEWLOWERING-NEXT:    smull v18.2d, v4.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT:    smull2 v4.2d, v4.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT:    smlal2 v1.2d, v16.4s, v17.4s
+; CHECK-NEWLOWERING-NEXT:    smlal v0.2d, v16.2s, v17.2s
+; CHECK-NEWLOWERING-NEXT:    smlal2 v4.2d, v2.4s, v3.4s
+; CHECK-NEWLOWERING-NEXT:    smlal v18.2d, v2.2s, v3.2s
+; CHECK-NEWLOWERING-NEXT:    add v1.2d, v4.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT:    add v0.2d, v18.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <16 x i8> %a to <16 x i64>
   %b.wide = sext <16 x i8> %b to <16 x i64>
@@ -358,6 +533,32 @@ define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
 ; CHECK-I8MM-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
 ; CHECK-I8MM-NEXT:    saddw v0.2d, v0.2d, v4.2s
 ; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT:    sshll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v5.8h, v3.8b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v6.4s, v4.4h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v7.4s, v5.4h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v4.4s, v4.8h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v5.4s, v5.8h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v16.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v17.4s, v3.8h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v2.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v3.4s, v3.4h, #0
+; CHECK-NEWLOWERING-NEXT:    smlal2 v1.2d, v6.4s, v7.4s
+; CHECK-NEWLOWERING-NEXT:    smlal v0.2d, v6.2s, v7.2s
+; CHECK-NEWLOWERING-NEXT:    smull v18.2d, v4.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT:    smull2 v4.2d, v4.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT:    smlal2 v1.2d, v16.4s, v17.4s
+; CHECK-NEWLOWERING-NEXT:    smlal v0.2d, v16.2s, v17.2s
+; CHECK-NEWLOWERING-NEXT:    smlal2 v4.2d, v2.4s, v3.4s
+; CHECK-NEWLOWERING-NEXT:    smlal v18.2d, v2.2s, v3.2s
+; CHECK-NEWLOWERING-NEXT:    add v1.2d, v4.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT:    add v0.2d, v18.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <16 x i8> %a to <16 x i64>
   %b.wide = zext <16 x i8> %b to <16 x i64>
@@ -384,6 +585,17 @@ define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
 ; CHECK-NODOT-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
 ; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    ushll v2.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v3.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT:    uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NEWLOWERING-NEXT:    uaddw2 v2.4s, v3.4s, v2.8h
+; CHECK-NEWLOWERING-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEWLOWERING-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT:    ret
   %a.wide = zext <16 x i8> %a to <16 x i32>
   %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
   ret <4 x i32> %partial.reduce
@@ -406,6 +618,17 @@ define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
 ; CHECK-NODOT-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
 ; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    sshll v2.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v3.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT:    saddw v0.4s, v0.4s, v2.4h
+; CHECK-NEWLOWERING-NEXT:    saddw2 v2.4s, v3.4s, v2.8h
+; CHECK-NEWLOWERING-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEWLOWERING-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT:    ret
   %a.wide = sext <16 x i8> %a to <16 x i32>
   %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
   ret <4 x i32> %partial.reduce
@@ -432,6 +655,21 @@ define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
 ; CHECK-NODOT-NEXT:    uaddw v1.4s, v2.4s, v4.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_narrow:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-NEWLOWERING-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEWLOWERING-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    ret
   %a.wide = zext <8 x i8> %a to <8 x i32>
   %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
   ret <2 x i32> %partial.reduce
@@ -458,6 +696,21 @@ define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
 ; CHECK-NODOT-NEXT:    saddw v1.4s, v2.4s, v4.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_narrow:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v3.4s, v1.8h, #0
+; CHECK-NEWLOWERING-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT:    saddw v0.4s, v0.4s, v1.4h
+; CHECK-NEWLOWERING-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    saddw v1.4s, v2.4s, v4.4h
+; CHECK-NEWLOWERING-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT:    ret
   %a.wide = sext <8 x i8> %a to <8 x i32>
   %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
   ret <2 x i32> %partial.reduce
@@ -490,6 +743,24 @@ define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
 ; CHECK-NODOT-NEXT:    add v1.2d, v4.2d, v1.2d
 ; CHECK-NODOT-NEXT:    add v0.2d, v3.2d, v0.2d
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    ushll v3.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v4.4s, v3.4h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll v5.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v3.4s, v3.8h, #0
+; CHECK-NEWLOWERING-NEXT:    ushll2 v2.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT:    uaddw2 v1.2d, v1.2d, v4.4s
+; CHECK-NEWLOWERING-NEXT:    uaddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-NEXT:    uaddl2 v4.2d, v3.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT:    uaddl v3.2d, v3.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT:    uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NEWLOWERING-NEXT:    uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NEWLOWERING-NEXT:    add v1.2d, v4.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT:    add v0.2d, v3.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT:    ret
   %a.wide = zext <16 x i8> %a to <16 x i64>
   %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
   ret <4 x i64> %partial.reduce
@@ -522,6 +793,24 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
 ; CHECK-NODOT-NEXT:    add v1.2d, v4.2d, v1.2d
 ; CHECK-NODOT-NEXT:    add v0.2d, v3.2d, v0.2d
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    sshll v3.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v4.4s, v3.4h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll v5.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v3.4s, v3.8h, #0
+; CHECK-NEWLOWERING-NEXT:    sshll2 v2.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-NEWLOWERING-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-NEXT:    saddl2 v4.2d, v3.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT:    saddl v3.2d, v3.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT:    saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NEWLOWERING-NEXT:    saddw v0.2d, v0.2d, v2.2s
+; CHECK-NEWLOWERING-NEXT:    add v1.2d, v4.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT:    add v0.2d, v3.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT:    ret
   %a.wide = sext <16 x i8> %a to <16 x i64>
   %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
   ret <4 x i64> %partial.reduce



More information about the llvm-commits mailing list