[llvm] [WIP][SDAG] Add partial_reduce_sumla node (PR #141267)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri May 30 11:27:31 PDT 2025


https://github.com/preames updated https://github.com/llvm/llvm-project/pull/141267

>From 0015c50d6566f8d11ffc3f300897d5defe544ab6 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 23 May 2025 07:51:50 -0700
Subject: [PATCH 1/2] [WIP][SDAG] Add partial_reduce_sumla node

We have recently added the partial_reduce_smla and partial_reduce_umla
nodes to represent Acc += ext(b) * ext(b) where the two extends have
to have the same source type, and have the same extend kind.

For riscv64 w/zvqdotq, we have the vqdot and vqdotu instructions
which correspond to the existing nodes, but we also have vqdotsu
which represents the case where the two extends are sign and zero
respective (i.e. not the same type of extend).

This patch adds a partial_reduce_sumla node which has sign
extension for A, and zero extension for B.  The addition is somewhat
mechanical, except that it exposes an implementaion challenge
because AArch64 doesn't have an analogous instruction (that I've
found).

The current legalization table assumes that all of the partial_reduce*mla
variants have the same handling for a given type pair.

Questions to the AArch64 folks:
* Does aarch64 have a good implementation for this that I missed?
* If not, are you okay with my somewhat hacky custom legalization
  approach (in this patch)?  It does look like there are some small
  regressions here, but I haven't dug into why.
* If not, any suggestions on how to structure splitting the legalization
  table?  I could add the opcode to the table key; that's probably the
  easiest.
---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  10 +-
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  53 ++-
 .../SelectionDAG/LegalizeIntegerTypes.cpp     |  15 +-
 .../SelectionDAG/LegalizeVectorOps.cpp        |   2 +
 .../SelectionDAG/LegalizeVectorTypes.cpp      |   2 +
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |   3 +-
 .../SelectionDAG/SelectionDAGDumper.cpp       |   2 +
 .../CodeGen/SelectionDAG/TargetLowering.cpp   |  22 +-
 .../Target/AArch64/AArch64ISelLowering.cpp    |  18 +-
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  17 +-
 .../AArch64/sve-partial-reduce-dot-product.ll | 346 ++++++++++++++----
 llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll |  40 +-
 12 files changed, 400 insertions(+), 130 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 9f66402e4c820..848631c7ffb03 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1484,8 +1484,9 @@ enum NodeType {
   VECREDUCE_UMIN,
 
   // PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
-  // The partial reduction nodes sign or zero extend Input1 and Input2 to the
-  // element type of Accumulator before multiplying their results.
+  // The partial reduction nodes sign or zero extend Input1 and Input2
+  // (with the extension kind noted below) to the element type of
+  // Accumulator before multiplying their results.
   // This result is concatenated to the Accumulator, and this is then reduced,
   // using addition, to the result type.
   // The output is only expected to either be given to another partial reduction
@@ -1497,8 +1498,9 @@ enum NodeType {
   // multiple of the number of elements in the Accumulator / output type.
   // Input1 and Input2 must have an element type which is the same as or smaller
   // than the element type of the Accumulator and output.
-  PARTIAL_REDUCE_SMLA,
-  PARTIAL_REDUCE_UMLA,
+  PARTIAL_REDUCE_SMLA,  // sext, sext
+  PARTIAL_REDUCE_UMLA,  // zext, zext
+  PARTIAL_REDUCE_SUMLA, // sext, zext
 
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index efaa8bd4a7950..df6702c390fc7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1991,6 +1991,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
                                 return visitPARTIAL_REDUCE_MLA(N);
   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
@@ -12675,19 +12676,19 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
-  bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
-  unsigned NewOpcode =
-      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-
   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
   // -> partial_reduce_*mla(acc, x, C)
   if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
+    // TODO: Make use of partial_reduce_sumla here
     APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
     unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
     if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
         (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
       return SDValue();
 
+    unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
+                             ? ISD::PARTIAL_REDUCE_SMLA
+                             : ISD::PARTIAL_REDUCE_UMLA;
     return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
                        DAG.getConstant(CTrunc, DL, LHSExtOpVT));
   }
@@ -12697,26 +12698,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
     return SDValue();
 
   SDValue RHSExtOp = RHS->getOperand(0);
-  if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+  if (LHSExtOpVT != RHSExtOp.getValueType())
     return SDValue();
 
-  // For a 2-stage extend the signedness of both of the extends must be the
-  // same. This is so the node can be folded into only a signed or unsigned
-  // node.
-  bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+  unsigned NewOpc = ISD::PARTIAL_REDUCE_SMLA;
+  // For a 2-stage extend the signedness of both of the extends must match
+  // If the mul has the same type, there is no outer extend, and thus we
+  // can simply use the inner extends to pick the result node.
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
-  if (ExtIsSigned != NodeIsSigned &&
-      Op1.getValueType().getVectorElementType() != AccElemVT)
-    return SDValue();
-
-  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
-                     RHSExtOp);
+  if (Op1.getValueType().getVectorElementType() != AccElemVT) {
+    // TODO: Split this into canonicalization rules
+    if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND &&
+        (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ||
+         N->getOpcode() == ISD::PARTIAL_REDUCE_SUMLA))
+      NewOpc = ISD::PARTIAL_REDUCE_SMLA;
+    else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND &&
+             N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA)
+      NewOpc = ISD::PARTIAL_REDUCE_UMLA;
+    else
+      return SDValue();
+  } else {
+    // TODO: Add canonicalization rule
+    if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
+      NewOpc = ISD::PARTIAL_REDUCE_SMLA;
+    else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
+      NewOpc = ISD::PARTIAL_REDUCE_UMLA;
+    else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
+      NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
+    else
+      // TODO: Handle the swapped sumla case here
+      return SDValue();
+  }
+  return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
 }
 
 // partial.reduce.umla(acc, zext(op), splat(1))
 // -> partial.reduce.umla(acc, op, splat(trunc(1)))
 // partial.reduce.smla(acc, sext(op), splat(1))
 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.sumla(acc, sext(op), splat(1))
+// -> partial.reduce.smla(acc, op, splat(trunc(1)))
 SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   SDLoc DL(N);
   SDValue Acc = N->getOperand(0);
@@ -12738,7 +12759,7 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
     return SDValue();
 
   bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
-  bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+  bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
   if (Op1IsSigned != NodeIsSigned &&
       Op1.getValueType().getVectorElementType() != AccElemVT)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 90af5f2cd8e70..5eb2f8c9150e9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -166,6 +166,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
 
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
     break;
 
@@ -2090,6 +2091,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
     break;
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
     break;
   }
@@ -2876,12 +2878,21 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
 
 SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
   SmallVector<SDValue, 1> NewOps(N->ops());
-  if (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA) {
+  switch (N->getOpcode()) {
+  case ISD::PARTIAL_REDUCE_SMLA:
     NewOps[1] = SExtPromotedInteger(N->getOperand(1));
     NewOps[2] = SExtPromotedInteger(N->getOperand(2));
-  } else {
+    break;
+  case ISD::PARTIAL_REDUCE_UMLA:
     NewOps[1] = ZExtPromotedInteger(N->getOperand(1));
     NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
+    break;
+  case ISD::PARTIAL_REDUCE_SUMLA:
+    NewOps[1] = SExtPromotedInteger(N->getOperand(1));
+    NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
+    break;
+  default:
+    llvm_unreachable("unexpected opcode");
   }
   return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
 }
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index affcd78ea61b0..4a12b76851966 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -530,6 +530,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   }
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
                                            Node->getOperand(1).getValueType());
     break;
@@ -1210,6 +1211,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
     return;
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
     return;
   case ISD::VECREDUCE_SEQ_FADD:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index c011a0a61d698..d3200b38c350e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1387,6 +1387,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
     break;
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
     break;
   }
@@ -3454,6 +3455,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
     break;
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
     break;
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 5400f3eaf373d..b8288af53de1e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7967,7 +7967,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
   case ISD::PARTIAL_REDUCE_UMLA:
-  case ISD::PARTIAL_REDUCE_SMLA: {
+  case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA: {
     [[maybe_unused]] EVT AccVT = N1.getValueType();
     [[maybe_unused]] EVT Input1VT = N2.getValueType();
     [[maybe_unused]] EVT Input2VT = N3.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 803894e298dd5..4fd0b7fd873e6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -584,6 +584,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
     return "partial_reduce_umla";
   case ISD::PARTIAL_REDUCE_SMLA:
     return "partial_reduce_smla";
+  case ISD::PARTIAL_REDUCE_SUMLA:
+    return "partial_reduce_sumla";
 
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 75c9bbaec7603..4d627af0e9bca 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11887,13 +11887,23 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   EVT ExtMulOpVT =
       EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
                        MulOpVT.getVectorElementCount());
-  unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
-                        ? ISD::SIGN_EXTEND
-                        : ISD::ZERO_EXTEND;
-
   if (ExtMulOpVT != MulOpVT) {
-    MulLHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulLHS);
-    MulRHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulRHS);
+    switch (N->getOpcode()) {
+    case ISD::PARTIAL_REDUCE_SMLA:
+      MulLHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulLHS);
+      MulRHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulRHS);
+      break;
+    case ISD::PARTIAL_REDUCE_UMLA:
+      MulLHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulLHS);
+      MulRHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulRHS);
+      break;
+    case ISD::PARTIAL_REDUCE_SUMLA:
+      MulLHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulLHS);
+      MulRHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulRHS);
+      break;
+    default:
+      llvm_unreachable("unexpected opcode");
+    }
   }
   SDValue Input = MulLHS;
   APInt ConstantOne;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b7f0bcfd015bc..f3e8a6974c25f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1874,8 +1874,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
     // Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
     // Other pairs will default to 'Expand'.
-    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
-    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Custom);
+    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
 
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
   }
@@ -7745,6 +7745,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerVECTOR_HISTOGRAM(Op, DAG);
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     return LowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }
@@ -29532,13 +29533,24 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
 SDValue
 AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
                                                SelectionDAG &DAG) const {
+  // No support for sumla forms, let generic legalization handle them
+  if (Op->getOpcode() == ISD::PARTIAL_REDUCE_SUMLA)
+    return SDValue();
+
   SDLoc DL(Op);
 
   SDValue Acc = Op.getOperand(0);
   SDValue LHS = Op.getOperand(1);
   SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
-  assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
+  EVT OpVT = LHS.getValueType();
+
+  // These two are legal...
+  if ((ResultVT == MVT::nxv2i64 && OpVT == MVT::nxv8i16) ||
+      (ResultVT == MVT::nxv4i32 && OpVT == MVT::nxv16i8))
+    return Op;
+
+  assert(ResultVT == MVT::nxv2i64 && OpVT == MVT::nxv16i8);
 
   SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
                                 DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 476596e4e0104..5622b68475305 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8240,6 +8240,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return lowerADJUST_TRAMPOLINE(Op, DAG);
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     return lowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }
@@ -8391,8 +8392,20 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
   SDValue B = Op.getOperand(2);
   assert(A.getSimpleValueType() == B.getSimpleValueType() &&
          A.getSimpleValueType().getVectorElementType() == MVT::i8);
-  bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
-  unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+  unsigned Opc;
+  switch (Op.getOpcode()) {
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Opc = RISCVISD::VQDOT_VL;
+    break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+    Opc = RISCVISD::VQDOTU_VL;
+    break;
+  case ISD::PARTIAL_REDUCE_SUMLA:
+    Opc = RISCVISD::VQDOTSU_VL;
+    break;
+  default:
+    llvm_unreachable("Unexpected opcode");
+  }
   auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
   return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
 }
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 5bc9a101b1e44..35b19eee5d983 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -159,26 +159,71 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ; CHECK-NOI8MM-NEXT:    mla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: sudot:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: sudot:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z4.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    mul z3.s, p0/m, z3.s, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mad z1.s, p0/m, z2.s, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sudot:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z4.h, z1.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    mul z3.s, z4.s, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mad z1.s, p0/m, z2.s, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sudot:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z4.h, z1.b
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT:    mul z3.s, z4.s, z3.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-SME-NEXT:    mad z1.s, p0/m, z2.s, z3.s
+; CHECK-NEWLOWERING-SME-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -430,46 +475,142 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ; CHECK-NOI8MM-NEXT:    mla z0.d, p0/m, z2.d, z3.d
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: sudot_8to64:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z4.h, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z5.h, z2.b
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-SVE-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z24.s, z3.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z25.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z28.d, z24.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z29.d, z25.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z30.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z31.d, z5.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z4.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z5.d, z5.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.d, p0/m, z27.d, z26.d
+; CHECK-NEWLOWERING-SVE-NEXT:    movprfx z26, z29
+; CHECK-NEWLOWERING-SVE-NEXT:    mul z26.d, p0/m, z26.d, z28.d
+; CHECK-NEWLOWERING-SVE-NEXT:    movprfx z27, z31
+; CHECK-NEWLOWERING-SVE-NEXT:    mul z27.d, p0/m, z27.d, z30.d
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z28.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mul z6.d, p0/m, z6.d, z7.d
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z7.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.d, p0/m, z7.d, z28.d
+; CHECK-NEWLOWERING-SVE-NEXT:    mad z4.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-SVE-NEXT:    mad z2.d, p0/m, z3.d, z26.d
+; CHECK-NEWLOWERING-SVE-NEXT:    movprfx z3, z27
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z3.d, p0/m, z25.d, z24.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z4.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z4.h, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z5.h, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z24.s, z3.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z25.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z28.d, z24.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z29.d, z25.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z30.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z31.d, z5.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z4.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z5.d, z5.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z27.d, z26.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    mul z26.d, z29.d, z28.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    mul z27.d, z31.d, z30.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z28.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mul z6.d, z7.d, z6.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z7.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z7.d, z28.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    mad z4.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    mad z2.d, p0/m, z3.d, z26.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    movprfx z3, z27
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z3.d, p0/m, z25.d, z24.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    add z2.d, z4.d, z3.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z4.h, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z5.h, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z24.s, z3.h
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z25.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z28.d, z24.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z29.d, z25.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z30.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z31.d, z5.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z4.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z5.d, z5.s
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z27.d, z26.d
+; CHECK-NEWLOWERING-SME-NEXT:    mul z26.d, z29.d, z28.d
+; CHECK-NEWLOWERING-SME-NEXT:    mul z27.d, z31.d, z30.d
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z28.d, z3.s
+; CHECK-NEWLOWERING-SME-NEXT:    mul z6.d, z7.d, z6.d
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z7.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z7.d, z28.d
+; CHECK-NEWLOWERING-SME-NEXT:    mad z4.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-SME-NEXT:    mad z2.d, p0/m, z3.d, z26.d
+; CHECK-NEWLOWERING-SME-NEXT:    movprfx z3, z27
+; CHECK-NEWLOWERING-SME-NEXT:    mla z3.d, p0/m, z25.d, z24.d
+; CHECK-NEWLOWERING-SME-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-SME-NEXT:    add z2.d, z4.d, z3.d
+; CHECK-NEWLOWERING-SME-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -773,26 +914,71 @@ define <vscale x 2 x i64> @not_sudot(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
 ; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: not_sudot:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z5.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_sudot:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mul z3.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-SVE-NEXT:    mad z1.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_sudot:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mul z3.d, z4.d, z3.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    sunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    mad z1.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_sudot:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT:    mul z3.d, z4.d, z3.d
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    sunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-SME-NEXT:    mad z1.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-SME-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index 2bd2ef2878fd5..7e85feefea1e3 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -934,22 +934,30 @@ entry:
 }
 
 define <vscale x 1 x i32> @partial_reduce_vqdotsu(<vscale x 4 x i8> %a, <vscale x 4 x i8> %b) {
-; CHECK-LABEL: partial_reduce_vqdotsu:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    vwmulsu.vv v8, v10, v11
-; CHECK-NEXT:    srli a0, a0, 3
-; CHECK-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vslidedown.vx v10, v9, a0
-; CHECK-NEXT:    vslidedown.vx v11, v8, a0
-; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v8
-; CHECK-NEXT:    vadd.vv v9, v11, v9
-; CHECK-NEXT:    vadd.vv v8, v9, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: partial_reduce_vqdotsu:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
+; NODOT-NEXT:    vsext.vf2 v10, v8
+; NODOT-NEXT:    vzext.vf2 v11, v9
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    vwmulsu.vv v8, v10, v11
+; NODOT-NEXT:    srli a0, a0, 3
+; NODOT-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
+; NODOT-NEXT:    vslidedown.vx v10, v9, a0
+; NODOT-NEXT:    vslidedown.vx v11, v8, a0
+; NODOT-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v10, v8
+; NODOT-NEXT:    vadd.vv v9, v11, v9
+; NODOT-NEXT:    vadd.vv v8, v9, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: partial_reduce_vqdotsu:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotsu.vv v10, v8, v9
+; DOT-NEXT:    vmv1r.v v8, v10
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <vscale x 4 x i8> %a to <vscale x 4 x i32>
   %b.sext = zext <vscale x 4 x i8> %b to <vscale x 4 x i32>

>From 9d624e97c6cbf16eef545257023bd76a441c5bfb Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 30 May 2025 11:25:46 -0700
Subject: [PATCH 2/2] clang-format

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 1bcc82fef00ea..d98312432475a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12726,7 +12726,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
   if (Op1.getValueType().getVectorElementType() != AccElemVT &&
       NewOpc != N->getOpcode())
-      return SDValue();
+    return SDValue();
 
   // Only perform these combines if the target supports folding
   // the extends into the operation.



More information about the llvm-commits mailing list