[llvm] 9396663 - [SDAG] Add partial_reduce_sumla node (#141267)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 9 07:17:48 PDT 2025


Author: Philip Reames
Date: 2025-06-09T07:17:45-07:00
New Revision: 939666380fba5d6db3d224fc358fd3e0f40a9b53

URL: https://github.com/llvm/llvm-project/commit/939666380fba5d6db3d224fc358fd3e0f40a9b53
DIFF: https://github.com/llvm/llvm-project/commit/939666380fba5d6db3d224fc358fd3e0f40a9b53.diff

LOG: [SDAG] Add partial_reduce_sumla node (#141267)

We have recently added the partial_reduce_smla and partial_reduce_umla
nodes to represent Acc += ext(b) * ext(b) where the two extends have to
have the same source type, and have the same extend kind.

For riscv64 w/zvqdotq, we have the vqdot and vqdotu instructions which
correspond to the existing nodes, but we also have vqdotsu which
represents the case where the two extends are sign and zero respective
(i.e. not the same type of extend).

This patch adds a partial_reduce_sumla node which has sign extension for
A, and zero extension for B. The addition is somewhat mechanical.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/ISDOpcodes.h
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
    llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 47a1aec3da06a..465e4a0a9d0d8 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1493,8 +1493,9 @@ enum NodeType {
   VECREDUCE_UMIN,
 
   // PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
-  // The partial reduction nodes sign or zero extend Input1 and Input2 to the
-  // element type of Accumulator before multiplying their results.
+  // The partial reduction nodes sign or zero extend Input1 and Input2
+  // (with the extension kind noted below) to the element type of
+  // Accumulator before multiplying their results.
   // This result is concatenated to the Accumulator, and this is then reduced,
   // using addition, to the result type.
   // The output is only expected to either be given to another partial reduction
@@ -1506,8 +1507,9 @@ enum NodeType {
   // multiple of the number of elements in the Accumulator / output type.
   // Input1 and Input2 must have an element type which is the same as or smaller
   // than the element type of the Accumulator and output.
-  PARTIAL_REDUCE_SMLA,
-  PARTIAL_REDUCE_UMLA,
+  PARTIAL_REDUCE_SMLA,  // sext, sext
+  PARTIAL_REDUCE_UMLA,  // zext, zext
+  PARTIAL_REDUCE_SUMLA, // sext, zext
 
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]

diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9c453f51e129d..04bc0e9353101 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1661,7 +1661,8 @@ class LLVM_ABI TargetLoweringBase {
   /// target has a custom expander for it.
   LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
                                            EVT InputVT) const {
-    assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
+    assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
+           Opc == ISD::PARTIAL_REDUCE_SUMLA);
     PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
                                     InputVT.getSimpleVT().SimpleTy};
     auto It = PartialReduceMLAActions.find(Key);
@@ -2759,7 +2760,8 @@ class LLVM_ABI TargetLoweringBase {
   /// sequence, or the target has a custom expander for it.
   void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
                                  LegalizeAction Action) {
-    assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
+    assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
+           Opc == ISD::PARTIAL_REDUCE_SUMLA);
     assert(AccVT.isValid() && InputVT.isValid() &&
            "setPartialReduceMLAAction types aren't valid");
     PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 414cf22d43471..1712f56f4719d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1992,6 +1992,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
                                 return visitPARTIAL_REDUCE_MLA(N);
   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
@@ -12737,26 +12738,27 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue LHSExtOp = LHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
 
-  bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
-  unsigned NewOpcode =
-      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-
-  // Only perform these combines if the target supports folding
-  // the extends into the operation.
-  if (!TLI.isPartialReduceMLALegalOrCustom(
-          NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
-          TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
-    return SDValue();
-
   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
   // -> partial_reduce_*mla(acc, x, C)
   if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
+    // TODO: Make use of partial_reduce_sumla here
     APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
     unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
     if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
         (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
       return SDValue();
 
+    unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
+                             ? ISD::PARTIAL_REDUCE_SMLA
+                             : ISD::PARTIAL_REDUCE_UMLA;
+
+    // Only perform these combines if the target supports folding
+    // the extends into the operation.
+    if (!TLI.isPartialReduceMLALegalOrCustom(
+            NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+            TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
+      return SDValue();
+
     return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
                        DAG.getConstant(CTrunc, DL, LHSExtOpVT));
   }
@@ -12766,26 +12768,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
     return SDValue();
 
   SDValue RHSExtOp = RHS->getOperand(0);
-  if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+  if (LHSExtOpVT != RHSExtOp.getValueType())
+    return SDValue();
+
+  unsigned NewOpc;
+  if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
+    NewOpc = ISD::PARTIAL_REDUCE_SMLA;
+  else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
+    NewOpc = ISD::PARTIAL_REDUCE_UMLA;
+  else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
+    NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
+  else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
+    NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
+    std::swap(LHSExtOp, RHSExtOp);
+  } else
     return SDValue();
-
-  // For a 2-stage extend the signedness of both of the extends must be the
-  // same. This is so the node can be folded into only a signed or unsigned
-  // node.
-  bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+  // For a 2-stage extend the signedness of both of the extends must match
+  // If the mul has the same type, there is no outer extend, and thus we
+  // can simply use the inner extends to pick the result node.
+  // TODO: extend to handle nonneg zext as sext
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
-  if (ExtIsSigned != NodeIsSigned &&
-      Op1.getValueType().getVectorElementType() != AccElemVT)
+  if (Op1.getValueType().getVectorElementType() != AccElemVT &&
+      NewOpc != N->getOpcode())
     return SDValue();
 
-  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
-                     RHSExtOp);
+  // Only perform these combines if the target supports folding
+  // the extends into the operation.
+  if (!TLI.isPartialReduceMLALegalOrCustom(
+          NewOpc, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+          TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
+    return SDValue();
+
+  return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
 }
 
 // partial.reduce.umla(acc, zext(op), splat(1))
 // -> partial.reduce.umla(acc, op, splat(trunc(1)))
 // partial.reduce.smla(acc, sext(op), splat(1))
 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.sumla(acc, sext(op), splat(1))
+// -> partial.reduce.smla(acc, op, splat(trunc(1)))
 SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   SDLoc DL(N);
   SDValue Acc = N->getOperand(0);
@@ -12802,7 +12824,7 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
     return SDValue();
 
   bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
-  bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+  bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
   if (Op1IsSigned != NodeIsSigned &&
       Op1.getValueType().getVectorElementType() != AccElemVT)

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 2bcca91f6f81a..dd64676222055 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -166,6 +166,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
 
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
     break;
 
@@ -2093,6 +2094,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
     break;
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
     break;
   }
@@ -2886,12 +2888,21 @@ SDValue DAGTypeLegalizer::PromoteIntOp_GET_ACTIVE_LANE_MASK(SDNode *N) {
 
 SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
   SmallVector<SDValue, 1> NewOps(N->ops());
-  if (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA) {
+  switch (N->getOpcode()) {
+  case ISD::PARTIAL_REDUCE_SMLA:
     NewOps[1] = SExtPromotedInteger(N->getOperand(1));
     NewOps[2] = SExtPromotedInteger(N->getOperand(2));
-  } else {
+    break;
+  case ISD::PARTIAL_REDUCE_UMLA:
     NewOps[1] = ZExtPromotedInteger(N->getOperand(1));
     NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
+    break;
+  case ISD::PARTIAL_REDUCE_SUMLA:
+    NewOps[1] = SExtPromotedInteger(N->getOperand(1));
+    NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
+    break;
+  default:
+    llvm_unreachable("unexpected opcode");
   }
   return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
 }

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 910a40e5b5141..4a1cd642233ef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -530,6 +530,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   }
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     Action =
         TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
                                       Node->getOperand(1).getValueType());
@@ -1211,6 +1212,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
     return;
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
     return;
   case ISD::VECREDUCE_SEQ_FADD:

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 5582dc98d35cb..f63fe17da51ff 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1387,6 +1387,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
     break;
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
     break;
   case ISD::GET_ACTIVE_LANE_MASK:
@@ -3473,6 +3474,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
     break;
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
     break;
   }

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 279c7daf71c33..049c24288344d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7981,7 +7981,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
   case ISD::PARTIAL_REDUCE_UMLA:
-  case ISD::PARTIAL_REDUCE_SMLA: {
+  case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA: {
     [[maybe_unused]] EVT AccVT = N1.getValueType();
     [[maybe_unused]] EVT Input1VT = N2.getValueType();
     [[maybe_unused]] EVT Input2VT = N3.getValueType();

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 539f583ea361f..7fc15581c17e4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -585,6 +585,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
     return "partial_reduce_umla";
   case ISD::PARTIAL_REDUCE_SMLA:
     return "partial_reduce_smla";
+  case ISD::PARTIAL_REDUCE_SUMLA:
+    return "partial_reduce_sumla";
 
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index e8e820ac1f695..23304c1ce8cc4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11891,13 +11891,17 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   EVT ExtMulOpVT =
       EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
                        MulOpVT.getVectorElementCount());
-  unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
-                        ? ISD::SIGN_EXTEND
-                        : ISD::ZERO_EXTEND;
+
+  unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
+                      ? ISD::ZERO_EXTEND
+                      : ISD::SIGN_EXTEND;
+  unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
+                      ? ISD::SIGN_EXTEND
+                      : ISD::ZERO_EXTEND;
 
   if (ExtMulOpVT != MulOpVT) {
-    MulLHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulLHS);
-    MulRHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulRHS);
+    MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
+    MulRHS = DAG.getNode(ExtOpcRHS, DL, ExtMulOpVT, MulRHS);
   }
   SDValue Input = MulLHS;
   APInt ConstantOne;

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 1cbd3f4233eee..ab8b36df44d3f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1574,7 +1574,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   // zve32x is broken for partial_reduce_umla, but let's not make it worse.
   if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
     static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
-                                      ISD::PARTIAL_REDUCE_UMLA};
+                                      ISD::PARTIAL_REDUCE_UMLA,
+                                      ISD::PARTIAL_REDUCE_SUMLA};
     setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom);
     setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom);
     setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom);
@@ -8318,6 +8319,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return lowerADJUST_TRAMPOLINE(Op, DAG);
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     return lowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }
@@ -8534,8 +8536,20 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
     B = convertToScalableVector(ContainerVT, B, DAG, Subtarget);
   }
 
-  bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
-  unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+  unsigned Opc;
+  switch (Op.getOpcode()) {
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Opc = RISCVISD::VQDOT_VL;
+    break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+    Opc = RISCVISD::VQDOTU_VL;
+    break;
+  case ISD::PARTIAL_REDUCE_SUMLA:
+    Opc = RISCVISD::VQDOTSU_VL;
+    break;
+  default:
+    llvm_unreachable("Unexpected opcode");
+  }
   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
   SDValue Res = DAG.getNode(Opc, DL, ContainerVT, {A, B, Accum, Mask, VL});
   if (VT.isFixedLengthVector())

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 7ebbe3feaadda..6e8eaa2ab6f74 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -646,23 +646,31 @@ entry:
 }
 
 define <1 x i32> @vqdotsu_vv_partial_reduce_v1i32_v4i8(<4 x i8> %a, <4 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv_partial_reduce_v1i32_v4i8:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v8, v9
-; CHECK-NEXT:    vwmulsu.vv v9, v10, v8
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vslidedown.vi v8, v9, 3
-; CHECK-NEXT:    vslidedown.vi v10, v9, 2
-; CHECK-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v8, v9
-; CHECK-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; CHECK-NEXT:    vslidedown.vi v9, v9, 1
-; CHECK-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
-; CHECK-NEXT:    vadd.vv v9, v9, v10
-; CHECK-NEXT:    vadd.vv v8, v9, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotsu_vv_partial_reduce_v1i32_v4i8:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v10, v8
+; NODOT-NEXT:    vzext.vf2 v8, v9
+; NODOT-NEXT:    vwmulsu.vv v9, v10, v8
+; NODOT-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; NODOT-NEXT:    vslidedown.vi v8, v9, 3
+; NODOT-NEXT:    vslidedown.vi v10, v9, 2
+; NODOT-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v8, v9
+; NODOT-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; NODOT-NEXT:    vslidedown.vi v9, v9, 1
+; NODOT-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
+; NODOT-NEXT:    vadd.vv v9, v9, v10
+; NODOT-NEXT:    vadd.vv v8, v9, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_partial_reduce_v1i32_v4i8:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT:    vmv.s.x v10, zero
+; DOT-NEXT:    vqdotsu.vv v10, v8, v9
+; DOT-NEXT:    vmv1r.v v8, v10
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <4 x i8> %a to <4 x i32>
   %b.sext = zext <4 x i8> %b to <4 x i32>
@@ -672,23 +680,31 @@ entry:
 }
 
 define <1 x i32> @vqdotsu_vv_partial_reduce_swapped(<4 x i8> %a, <4 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv_partial_reduce_swapped:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v8, v9
-; CHECK-NEXT:    vwmulsu.vv v9, v10, v8
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vslidedown.vi v8, v9, 3
-; CHECK-NEXT:    vslidedown.vi v10, v9, 2
-; CHECK-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v8, v9
-; CHECK-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
-; CHECK-NEXT:    vslidedown.vi v9, v9, 1
-; CHECK-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
-; CHECK-NEXT:    vadd.vv v9, v9, v10
-; CHECK-NEXT:    vadd.vv v8, v9, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotsu_vv_partial_reduce_swapped:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v10, v8
+; NODOT-NEXT:    vzext.vf2 v8, v9
+; NODOT-NEXT:    vwmulsu.vv v9, v10, v8
+; NODOT-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; NODOT-NEXT:    vslidedown.vi v8, v9, 3
+; NODOT-NEXT:    vslidedown.vi v10, v9, 2
+; NODOT-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v8, v9
+; NODOT-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; NODOT-NEXT:    vslidedown.vi v9, v9, 1
+; NODOT-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
+; NODOT-NEXT:    vadd.vv v9, v9, v10
+; NODOT-NEXT:    vadd.vv v8, v9, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_partial_reduce_swapped:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT:    vmv.s.x v10, zero
+; DOT-NEXT:    vqdotsu.vv v10, v8, v9
+; DOT-NEXT:    vmv1r.v v8, v10
+; DOT-NEXT:    ret
 entry:
   %a.ext = sext <4 x i8> %a to <4 x i32>
   %b.ext = zext <4 x i8> %b to <4 x i32>
@@ -1065,222 +1081,291 @@ entry:
 
 ; Test legalization - type split
 define <64 x i32> @vqdotsu_vv_partial_v64i32_v256i8(<256 x i8> %a, <256 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv_partial_v64i32_v256i8:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    addi sp, sp, -16
-; CHECK-NEXT:    .cfi_def_cfa_offset 16
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 3
-; CHECK-NEXT:    mv a2, a1
-; CHECK-NEXT:    slli a1, a1, 2
-; CHECK-NEXT:    add a1, a1, a2
-; CHECK-NEXT:    sub sp, sp, a1
-; CHECK-NEXT:    .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x28, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 40 * vlenb
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 4
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 5
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
-; CHECK-NEXT:    addi a1, a0, 128
-; CHECK-NEXT:    li a2, 128
-; CHECK-NEXT:    vsetvli zero, a2, e8, m8, ta, ma
-; CHECK-NEXT:    vle8.v v0, (a0)
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 3
-; CHECK-NEXT:    mv a3, a0
-; CHECK-NEXT:    slli a0, a0, 1
-; CHECK-NEXT:    add a0, a0, a3
-; CHECK-NEXT:    add a0, sp, a0
-; CHECK-NEXT:    addi a0, a0, 16
-; CHECK-NEXT:    vs8r.v v0, (a0) # vscale x 64-byte Folded Spill
-; CHECK-NEXT:    li a0, 32
-; CHECK-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
-; CHECK-NEXT:    vslidedown.vx v24, v8, a0
-; CHECK-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
-; CHECK-NEXT:    vsext.vf2 v8, v24
-; CHECK-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
-; CHECK-NEXT:    vslidedown.vx v12, v0, a0
-; CHECK-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
-; CHECK-NEXT:    vzext.vf2 v4, v12
-; CHECK-NEXT:    vwmulsu.vv v24, v8, v4
-; CHECK-NEXT:    csrr a3, vlenb
-; CHECK-NEXT:    slli a3, a3, 5
-; CHECK-NEXT:    add a3, sp, a3
-; CHECK-NEXT:    addi a3, a3, 16
-; CHECK-NEXT:    vl8r.v v8, (a3) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsext.vf2 v4, v8
-; CHECK-NEXT:    csrr a3, vlenb
-; CHECK-NEXT:    slli a3, a3, 3
-; CHECK-NEXT:    mv a4, a3
-; CHECK-NEXT:    slli a3, a3, 1
-; CHECK-NEXT:    add a3, a3, a4
-; CHECK-NEXT:    add a3, sp, a3
-; CHECK-NEXT:    addi a3, a3, 16
-; CHECK-NEXT:    vl8r.v v8, (a3) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vzext.vf2 v0, v8
-; CHECK-NEXT:    vsetvli zero, a2, e8, m8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a1)
-; CHECK-NEXT:    addi a1, sp, 16
-; CHECK-NEXT:    vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
-; CHECK-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
-; CHECK-NEXT:    vwmaccsu.vv v24, v4, v0
-; CHECK-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
-; CHECK-NEXT:    vslidedown.vx v4, v16, a0
-; CHECK-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v4
-; CHECK-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
-; CHECK-NEXT:    vslidedown.vx v4, v8, a0
-; CHECK-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
-; CHECK-NEXT:    vzext.vf2 v16, v4
-; CHECK-NEXT:    vwmulsu.vv v0, v12, v16
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 4
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vl8r.v v16, (a1) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsext.vf2 v12, v16
-; CHECK-NEXT:    vzext.vf2 v20, v8
-; CHECK-NEXT:    vwmaccsu.vv v0, v12, v20
-; CHECK-NEXT:    li a1, 64
-; CHECK-NEXT:    csrr a2, vlenb
-; CHECK-NEXT:    slli a2, a2, 5
-; CHECK-NEXT:    add a2, sp, a2
-; CHECK-NEXT:    addi a2, a2, 16
-; CHECK-NEXT:    vl8r.v v16, (a2) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsetvli zero, a1, e8, m8, ta, ma
-; CHECK-NEXT:    vslidedown.vx v8, v16, a1
-; CHECK-NEXT:    csrr a2, vlenb
-; CHECK-NEXT:    slli a2, a2, 5
-; CHECK-NEXT:    add a2, sp, a2
-; CHECK-NEXT:    addi a2, a2, 16
-; CHECK-NEXT:    vs8r.v v8, (a2) # vscale x 64-byte Folded Spill
-; CHECK-NEXT:    csrr a2, vlenb
-; CHECK-NEXT:    slli a2, a2, 3
-; CHECK-NEXT:    mv a3, a2
-; CHECK-NEXT:    slli a2, a2, 1
-; CHECK-NEXT:    add a2, a2, a3
-; CHECK-NEXT:    add a2, sp, a2
-; CHECK-NEXT:    addi a2, a2, 16
-; CHECK-NEXT:    vl8r.v v16, (a2) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vslidedown.vx v8, v16, a1
-; CHECK-NEXT:    csrr a2, vlenb
-; CHECK-NEXT:    slli a2, a2, 3
-; CHECK-NEXT:    add a2, sp, a2
-; CHECK-NEXT:    addi a2, a2, 16
-; CHECK-NEXT:    vs8r.v v8, (a2) # vscale x 64-byte Folded Spill
-; CHECK-NEXT:    csrr a2, vlenb
-; CHECK-NEXT:    slli a2, a2, 5
-; CHECK-NEXT:    add a2, sp, a2
-; CHECK-NEXT:    addi a2, a2, 16
-; CHECK-NEXT:    vl8r.v v8, (a2) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
-; CHECK-NEXT:    vsext.vf2 v16, v8
-; CHECK-NEXT:    csrr a2, vlenb
-; CHECK-NEXT:    slli a2, a2, 3
-; CHECK-NEXT:    add a2, sp, a2
-; CHECK-NEXT:    addi a2, a2, 16
-; CHECK-NEXT:    vl8r.v v8, (a2) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vzext.vf2 v20, v8
-; CHECK-NEXT:    vwmaccsu.vv v24, v16, v20
-; CHECK-NEXT:    csrr a2, vlenb
-; CHECK-NEXT:    slli a2, a2, 4
-; CHECK-NEXT:    add a2, sp, a2
-; CHECK-NEXT:    addi a2, a2, 16
-; CHECK-NEXT:    vl8r.v v16, (a2) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsetvli zero, a1, e8, m8, ta, ma
-; CHECK-NEXT:    vslidedown.vx v16, v16, a1
-; CHECK-NEXT:    addi a2, sp, 16
-; CHECK-NEXT:    vl8r.v v8, (a2) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vslidedown.vx v8, v8, a1
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 3
-; CHECK-NEXT:    mv a2, a1
-; CHECK-NEXT:    slli a1, a1, 1
-; CHECK-NEXT:    add a1, a1, a2
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
-; CHECK-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
-; CHECK-NEXT:    vsext.vf2 v8, v16
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 4
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vs4r.v v8, (a1) # vscale x 32-byte Folded Spill
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 3
-; CHECK-NEXT:    mv a2, a1
-; CHECK-NEXT:    slli a1, a1, 1
-; CHECK-NEXT:    add a1, a1, a2
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vl8r.v v8, (a1) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vzext.vf2 v20, v8
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 4
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vl4r.v v8, (a1) # vscale x 32-byte Folded Reload
-; CHECK-NEXT:    vwmaccsu.vv v0, v8, v20
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 5
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vl8r.v v8, (a1) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
-; CHECK-NEXT:    vslidedown.vx v20, v8, a0
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 3
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vl8r.v v8, (a1) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vslidedown.vx v8, v8, a0
-; CHECK-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v20
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 5
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vs4r.v v12, (a1) # vscale x 32-byte Folded Spill
-; CHECK-NEXT:    vzext.vf2 v12, v8
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 5
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vl4r.v v8, (a1) # vscale x 32-byte Folded Reload
-; CHECK-NEXT:    vwmaccsu.vv v24, v8, v12
-; CHECK-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
-; CHECK-NEXT:    vslidedown.vx v12, v16, a0
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 3
-; CHECK-NEXT:    mv a2, a1
-; CHECK-NEXT:    slli a1, a1, 1
-; CHECK-NEXT:    add a1, a1, a2
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vl8r.v v16, (a1) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vslidedown.vx v8, v16, a0
-; CHECK-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
-; CHECK-NEXT:    vsext.vf2 v16, v12
-; CHECK-NEXT:    vzext.vf2 v12, v8
-; CHECK-NEXT:    vwmaccsu.vv v0, v16, v12
-; CHECK-NEXT:    vmv8r.v v8, v24
-; CHECK-NEXT:    vmv8r.v v16, v0
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 3
-; CHECK-NEXT:    mv a1, a0
-; CHECK-NEXT:    slli a0, a0, 2
-; CHECK-NEXT:    add a0, a0, a1
-; CHECK-NEXT:    add sp, sp, a0
-; CHECK-NEXT:    .cfi_def_cfa sp, 16
-; CHECK-NEXT:    addi sp, sp, 16
-; CHECK-NEXT:    .cfi_def_cfa_offset 0
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotsu_vv_partial_v64i32_v256i8:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    addi sp, sp, -16
+; NODOT-NEXT:    .cfi_def_cfa_offset 16
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 3
+; NODOT-NEXT:    mv a2, a1
+; NODOT-NEXT:    slli a1, a1, 2
+; NODOT-NEXT:    add a1, a1, a2
+; NODOT-NEXT:    sub sp, sp, a1
+; NODOT-NEXT:    .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x28, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 40 * vlenb
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 4
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 5
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
+; NODOT-NEXT:    addi a1, a0, 128
+; NODOT-NEXT:    li a2, 128
+; NODOT-NEXT:    vsetvli zero, a2, e8, m8, ta, ma
+; NODOT-NEXT:    vle8.v v0, (a0)
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 3
+; NODOT-NEXT:    mv a3, a0
+; NODOT-NEXT:    slli a0, a0, 1
+; NODOT-NEXT:    add a0, a0, a3
+; NODOT-NEXT:    add a0, sp, a0
+; NODOT-NEXT:    addi a0, a0, 16
+; NODOT-NEXT:    vs8r.v v0, (a0) # vscale x 64-byte Folded Spill
+; NODOT-NEXT:    li a0, 32
+; NODOT-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
+; NODOT-NEXT:    vslidedown.vx v24, v8, a0
+; NODOT-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v8, v24
+; NODOT-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
+; NODOT-NEXT:    vslidedown.vx v12, v0, a0
+; NODOT-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT:    vzext.vf2 v4, v12
+; NODOT-NEXT:    vwmulsu.vv v24, v8, v4
+; NODOT-NEXT:    csrr a3, vlenb
+; NODOT-NEXT:    slli a3, a3, 5
+; NODOT-NEXT:    add a3, sp, a3
+; NODOT-NEXT:    addi a3, a3, 16
+; NODOT-NEXT:    vl8r.v v8, (a3) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsext.vf2 v4, v8
+; NODOT-NEXT:    csrr a3, vlenb
+; NODOT-NEXT:    slli a3, a3, 3
+; NODOT-NEXT:    mv a4, a3
+; NODOT-NEXT:    slli a3, a3, 1
+; NODOT-NEXT:    add a3, a3, a4
+; NODOT-NEXT:    add a3, sp, a3
+; NODOT-NEXT:    addi a3, a3, 16
+; NODOT-NEXT:    vl8r.v v8, (a3) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vzext.vf2 v0, v8
+; NODOT-NEXT:    vsetvli zero, a2, e8, m8, ta, ma
+; NODOT-NEXT:    vle8.v v8, (a1)
+; NODOT-NEXT:    addi a1, sp, 16
+; NODOT-NEXT:    vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
+; NODOT-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT:    vwmaccsu.vv v24, v4, v0
+; NODOT-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
+; NODOT-NEXT:    vslidedown.vx v4, v16, a0
+; NODOT-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v4
+; NODOT-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
+; NODOT-NEXT:    vslidedown.vx v4, v8, a0
+; NODOT-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT:    vzext.vf2 v16, v4
+; NODOT-NEXT:    vwmulsu.vv v0, v12, v16
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 4
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vl8r.v v16, (a1) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsext.vf2 v12, v16
+; NODOT-NEXT:    vzext.vf2 v20, v8
+; NODOT-NEXT:    vwmaccsu.vv v0, v12, v20
+; NODOT-NEXT:    li a1, 64
+; NODOT-NEXT:    csrr a2, vlenb
+; NODOT-NEXT:    slli a2, a2, 5
+; NODOT-NEXT:    add a2, sp, a2
+; NODOT-NEXT:    addi a2, a2, 16
+; NODOT-NEXT:    vl8r.v v16, (a2) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsetvli zero, a1, e8, m8, ta, ma
+; NODOT-NEXT:    vslidedown.vx v8, v16, a1
+; NODOT-NEXT:    csrr a2, vlenb
+; NODOT-NEXT:    slli a2, a2, 5
+; NODOT-NEXT:    add a2, sp, a2
+; NODOT-NEXT:    addi a2, a2, 16
+; NODOT-NEXT:    vs8r.v v8, (a2) # vscale x 64-byte Folded Spill
+; NODOT-NEXT:    csrr a2, vlenb
+; NODOT-NEXT:    slli a2, a2, 3
+; NODOT-NEXT:    mv a3, a2
+; NODOT-NEXT:    slli a2, a2, 1
+; NODOT-NEXT:    add a2, a2, a3
+; NODOT-NEXT:    add a2, sp, a2
+; NODOT-NEXT:    addi a2, a2, 16
+; NODOT-NEXT:    vl8r.v v16, (a2) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vslidedown.vx v8, v16, a1
+; NODOT-NEXT:    csrr a2, vlenb
+; NODOT-NEXT:    slli a2, a2, 3
+; NODOT-NEXT:    add a2, sp, a2
+; NODOT-NEXT:    addi a2, a2, 16
+; NODOT-NEXT:    vs8r.v v8, (a2) # vscale x 64-byte Folded Spill
+; NODOT-NEXT:    csrr a2, vlenb
+; NODOT-NEXT:    slli a2, a2, 5
+; NODOT-NEXT:    add a2, sp, a2
+; NODOT-NEXT:    addi a2, a2, 16
+; NODOT-NEXT:    vl8r.v v8, (a2) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    csrr a2, vlenb
+; NODOT-NEXT:    slli a2, a2, 3
+; NODOT-NEXT:    add a2, sp, a2
+; NODOT-NEXT:    addi a2, a2, 16
+; NODOT-NEXT:    vl8r.v v8, (a2) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vzext.vf2 v20, v8
+; NODOT-NEXT:    vwmaccsu.vv v24, v16, v20
+; NODOT-NEXT:    csrr a2, vlenb
+; NODOT-NEXT:    slli a2, a2, 4
+; NODOT-NEXT:    add a2, sp, a2
+; NODOT-NEXT:    addi a2, a2, 16
+; NODOT-NEXT:    vl8r.v v16, (a2) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsetvli zero, a1, e8, m8, ta, ma
+; NODOT-NEXT:    vslidedown.vx v16, v16, a1
+; NODOT-NEXT:    addi a2, sp, 16
+; NODOT-NEXT:    vl8r.v v8, (a2) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vslidedown.vx v8, v8, a1
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 3
+; NODOT-NEXT:    mv a2, a1
+; NODOT-NEXT:    slli a1, a1, 1
+; NODOT-NEXT:    add a1, a1, a2
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
+; NODOT-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v8, v16
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 4
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vs4r.v v8, (a1) # vscale x 32-byte Folded Spill
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 3
+; NODOT-NEXT:    mv a2, a1
+; NODOT-NEXT:    slli a1, a1, 1
+; NODOT-NEXT:    add a1, a1, a2
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vl8r.v v8, (a1) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vzext.vf2 v20, v8
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 4
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vl4r.v v8, (a1) # vscale x 32-byte Folded Reload
+; NODOT-NEXT:    vwmaccsu.vv v0, v8, v20
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 5
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vl8r.v v8, (a1) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
+; NODOT-NEXT:    vslidedown.vx v20, v8, a0
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 3
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vl8r.v v8, (a1) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vslidedown.vx v8, v8, a0
+; NODOT-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v20
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 5
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vs4r.v v12, (a1) # vscale x 32-byte Folded Spill
+; NODOT-NEXT:    vzext.vf2 v12, v8
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 5
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vl4r.v v8, (a1) # vscale x 32-byte Folded Reload
+; NODOT-NEXT:    vwmaccsu.vv v24, v8, v12
+; NODOT-NEXT:    vsetvli zero, a0, e8, m4, ta, ma
+; NODOT-NEXT:    vslidedown.vx v12, v16, a0
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 3
+; NODOT-NEXT:    mv a2, a1
+; NODOT-NEXT:    slli a1, a1, 1
+; NODOT-NEXT:    add a1, a1, a2
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vl8r.v v16, (a1) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vslidedown.vx v8, v16, a0
+; NODOT-NEXT:    vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v12
+; NODOT-NEXT:    vzext.vf2 v12, v8
+; NODOT-NEXT:    vwmaccsu.vv v0, v16, v12
+; NODOT-NEXT:    vmv8r.v v8, v24
+; NODOT-NEXT:    vmv8r.v v16, v0
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 3
+; NODOT-NEXT:    mv a1, a0
+; NODOT-NEXT:    slli a0, a0, 2
+; NODOT-NEXT:    add a0, a0, a1
+; NODOT-NEXT:    add sp, sp, a0
+; NODOT-NEXT:    .cfi_def_cfa sp, 16
+; NODOT-NEXT:    addi sp, sp, 16
+; NODOT-NEXT:    .cfi_def_cfa_offset 0
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_partial_v64i32_v256i8:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    addi sp, sp, -16
+; DOT-NEXT:    .cfi_def_cfa_offset 16
+; DOT-NEXT:    csrr a1, vlenb
+; DOT-NEXT:    slli a1, a1, 5
+; DOT-NEXT:    sub sp, sp, a1
+; DOT-NEXT:    .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x20, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 32 * vlenb
+; DOT-NEXT:    csrr a1, vlenb
+; DOT-NEXT:    slli a1, a1, 3
+; DOT-NEXT:    mv a2, a1
+; DOT-NEXT:    slli a1, a1, 1
+; DOT-NEXT:    add a1, a1, a2
+; DOT-NEXT:    add a1, sp, a1
+; DOT-NEXT:    addi a1, a1, 16
+; DOT-NEXT:    vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
+; DOT-NEXT:    csrr a1, vlenb
+; DOT-NEXT:    slli a1, a1, 4
+; DOT-NEXT:    add a1, sp, a1
+; DOT-NEXT:    addi a1, a1, 16
+; DOT-NEXT:    vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
+; DOT-NEXT:    addi a1, a0, 128
+; DOT-NEXT:    li a2, 128
+; DOT-NEXT:    vsetvli zero, a2, e8, m8, ta, ma
+; DOT-NEXT:    vle8.v v8, (a0)
+; DOT-NEXT:    csrr a0, vlenb
+; DOT-NEXT:    slli a0, a0, 3
+; DOT-NEXT:    add a0, sp, a0
+; DOT-NEXT:    addi a0, a0, 16
+; DOT-NEXT:    vs8r.v v8, (a0) # vscale x 64-byte Folded Spill
+; DOT-NEXT:    li a0, 32
+; DOT-NEXT:    vle8.v v8, (a1)
+; DOT-NEXT:    addi a1, sp, 16
+; DOT-NEXT:    vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
+; DOT-NEXT:    vsetvli zero, a0, e32, m8, ta, ma
+; DOT-NEXT:    vmv.v.i v24, 0
+; DOT-NEXT:    vmv.v.i v0, 0
+; DOT-NEXT:    csrr a0, vlenb
+; DOT-NEXT:    slli a0, a0, 4
+; DOT-NEXT:    add a0, sp, a0
+; DOT-NEXT:    addi a0, a0, 16
+; DOT-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; DOT-NEXT:    csrr a0, vlenb
+; DOT-NEXT:    slli a0, a0, 3
+; DOT-NEXT:    add a0, sp, a0
+; DOT-NEXT:    addi a0, a0, 16
+; DOT-NEXT:    vl8r.v v8, (a0) # vscale x 64-byte Folded Reload
+; DOT-NEXT:    vqdotsu.vv v0, v16, v8
+; DOT-NEXT:    csrr a0, vlenb
+; DOT-NEXT:    slli a0, a0, 3
+; DOT-NEXT:    mv a1, a0
+; DOT-NEXT:    slli a0, a0, 1
+; DOT-NEXT:    add a0, a0, a1
+; DOT-NEXT:    add a0, sp, a0
+; DOT-NEXT:    addi a0, a0, 16
+; DOT-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; DOT-NEXT:    addi a0, sp, 16
+; DOT-NEXT:    vl8r.v v8, (a0) # vscale x 64-byte Folded Reload
+; DOT-NEXT:    vqdotsu.vv v24, v16, v8
+; DOT-NEXT:    vmv.v.v v8, v0
+; DOT-NEXT:    vmv.v.v v16, v24
+; DOT-NEXT:    csrr a0, vlenb
+; DOT-NEXT:    slli a0, a0, 5
+; DOT-NEXT:    add sp, sp, a0
+; DOT-NEXT:    .cfi_def_cfa sp, 16
+; DOT-NEXT:    addi sp, sp, 16
+; DOT-NEXT:    .cfi_def_cfa_offset 0
+; DOT-NEXT:    ret
 entry:
   %a.ext = sext <256 x i8> %a to <256 x i32>
   %b.ext = zext <256 x i8> %b to <256 x i32>
@@ -1289,6 +1374,56 @@ entry:
   ret <64 x i32> %res
 }
 
+; Test legalization - integer promote
+define <4 x i31> @vqdotsu_vv_partial_v4i31_v16i7(<16 x i7> %a, <16 x i7> %b) {
+; NODOT-LABEL: vqdotsu_vv_partial_v4i31_v16i7:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT:    vzext.vf4 v12, v8
+; NODOT-NEXT:    li a0, 127
+; NODOT-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
+; NODOT-NEXT:    vand.vx v16, v9, a0
+; NODOT-NEXT:    lui a0, 524288
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vsll.vi v8, v12, 25
+; NODOT-NEXT:    addi a0, a0, -1
+; NODOT-NEXT:    vsra.vi v8, v8, 25
+; NODOT-NEXT:    vzext.vf4 v12, v16
+; NODOT-NEXT:    vmul.vv v8, v12, v8
+; NODOT-NEXT:    vand.vx v8, v8, a0
+; NODOT-NEXT:    vsetivli zero, 4, e32, m4, ta, ma
+; NODOT-NEXT:    vslidedown.vi v12, v8, 12
+; NODOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; NODOT-NEXT:    vadd.vv v16, v12, v8
+; NODOT-NEXT:    vsetivli zero, 4, e32, m4, ta, ma
+; NODOT-NEXT:    vslidedown.vi v12, v8, 8
+; NODOT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; NODOT-NEXT:    vslidedown.vi v8, v8, 4
+; NODOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v8, v12
+; NODOT-NEXT:    vadd.vv v8, v8, v16
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_partial_v4i31_v16i7:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    li a0, 127
+; DOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT:    vadd.vv v8, v8, v8
+; DOT-NEXT:    vand.vx v9, v9, a0
+; DOT-NEXT:    vsra.vi v10, v8, 1
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v8, 0
+; DOT-NEXT:    vqdotsu.vv v8, v10, v9
+; DOT-NEXT:    ret
+entry:
+  %a.ext = sext <16 x i7> %a to <16 x i31>
+  %b.ext = zext <16 x i7> %b to <16 x i31>
+  %mul = mul <16 x i31> %b.ext, %a.ext
+  %res = call <4 x i31> @llvm.experimental.vector.partial.reduce.add(<4 x i31> zeroinitializer, <16 x i31> %mul)
+  ret <4 x i31> %res
+}
+
+
 ; Test legalization - expand
 define <1 x i32> @vqdotsu_vv_partial_v1i32_v2i8(<2 x i8> %a, <2 x i8> %b) {
 ; CHECK-LABEL: vqdotsu_vv_partial_v1i32_v2i8:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index 5272f1be50e97..0b6f8a7a838bc 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -910,22 +910,30 @@ entry:
 }
 
 define <vscale x 1 x i32> @partial_reduce_vqdotsu(<vscale x 4 x i8> %a, <vscale x 4 x i8> %b) {
-; CHECK-LABEL: partial_reduce_vqdotsu:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    vwmulsu.vv v8, v10, v11
-; CHECK-NEXT:    srli a0, a0, 3
-; CHECK-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vslidedown.vx v10, v9, a0
-; CHECK-NEXT:    vslidedown.vx v11, v8, a0
-; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v8
-; CHECK-NEXT:    vadd.vv v9, v11, v9
-; CHECK-NEXT:    vadd.vv v8, v9, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: partial_reduce_vqdotsu:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
+; NODOT-NEXT:    vsext.vf2 v10, v8
+; NODOT-NEXT:    vzext.vf2 v11, v9
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    vwmulsu.vv v8, v10, v11
+; NODOT-NEXT:    srli a0, a0, 3
+; NODOT-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
+; NODOT-NEXT:    vslidedown.vx v10, v9, a0
+; NODOT-NEXT:    vslidedown.vx v11, v8, a0
+; NODOT-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v10, v8
+; NODOT-NEXT:    vadd.vv v9, v11, v9
+; NODOT-NEXT:    vadd.vv v8, v9, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: partial_reduce_vqdotsu:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotsu.vv v10, v8, v9
+; DOT-NEXT:    vmv1r.v v8, v10
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <vscale x 4 x i8> %a to <vscale x 4 x i32>
   %b.sext = zext <vscale x 4 x i8> %b to <vscale x 4 x i32>


        


More information about the llvm-commits mailing list