[llvm] 1ac489c - [RISCV] Initial codegen support for zvqdotq extension (#137039)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 7 08:15:48 PDT 2025


Author: Philip Reames
Date: 2025-05-07T08:15:44-07:00
New Revision: 1ac489c8e38ecaeccba7d8826273395eaba2db6c

URL: https://github.com/llvm/llvm-project/commit/1ac489c8e38ecaeccba7d8826273395eaba2db6c
DIFF: https://github.com/llvm/llvm-project/commit/1ac489c8e38ecaeccba7d8826273395eaba2db6c.diff

LOG: [RISCV] Initial codegen support for zvqdotq extension (#137039)

This patch adds pattern matching for the basic usages of the dot product
instructions introduced by the experimental zvqdotq extension. It
specifically only handles the case where the pattern is feeding a i32
sum reduction as we need to reassociate the reduction tree to use these
instructions.

The vecreduce_add (sext) and vecreduce_add (zext) cases are included
mostly to exercise the VX matchers. For the generic matching, we fail to
match due to an order of combine issue which results in the bitcast
being separated from the splat.

I chose to do this lowering as an early combine so as to avoid having to
integrate the entire logic into the reduction lowering flow. In
particular, that would get a lot more complicated as we extend this to
handle add-trees feeding the reductions.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 86f8873c135ef..698b951ad4928 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -6971,7 +6971,7 @@ static bool hasPassthruOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(
-      RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
+      RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
       RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
       "adding target specific op should update this function");
   if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
@@ -6995,7 +6995,7 @@ static bool hasMaskOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(
-      RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
+      RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
       RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
       "adding target specific op should update this function");
   if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
@@ -18101,6 +18101,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
                      DAG.getBuildVector(VT, DL, RHSOps));
 }
 
+static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
+                          const SDLoc &DL, SelectionDAG &DAG,
+                          const RISCVSubtarget &Subtarget) {
+  assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
+         RISCVISD::VQDOTSU_VL == Opc);
+  MVT VT = Op0.getSimpleValueType();
+  assert(VT == Op1.getSimpleValueType() &&
+         VT.getVectorElementType() == MVT::i32);
+
+  assert(VT.isFixedLengthVector());
+  MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+  SDValue Passthru = convertToScalableVector(
+      ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
+  Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+  Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
+
+  auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+  const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
+  SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
+  SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
+                                   {Op0, Op1, Passthru, Mask, VL, PolicyOp});
+  return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+}
+
+static MVT getQDOTXResultType(MVT OpVT) {
+  ElementCount OpEC = OpVT.getVectorElementCount();
+  assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
+  return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
+}
+
+static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
+                                         SelectionDAG &DAG,
+                                         const RISCVSubtarget &Subtarget,
+                                         const RISCVTargetLowering &TLI) {
+  // Note: We intentionally do not check the legality of the reduction type.
+  // We want to handle the m4/m8 *src*  types, and thus need to let illegal
+  // intermediate types flow through here.
+  if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
+      !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
+    return SDValue();
+
+  // reduce (zext a) <--> reduce (mul zext a. zext 1)
+  // reduce (sext a) <--> reduce (mul sext a. sext 1)
+  if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
+      InVec.getOpcode() == ISD::SIGN_EXTEND) {
+    SDValue A = InVec.getOperand(0);
+    if (A.getValueType().getVectorElementType() != MVT::i8 ||
+        !TLI.isTypeLegal(A.getValueType()))
+      return SDValue();
+
+    MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
+    A = DAG.getBitcast(ResVT, A);
+    SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
+
+    bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
+    unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+    return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+  }
+
+  // mul (sext, sext) -> vqdot
+  // mul (zext, zext) -> vqdotu
+  // mul (sext, zext) -> vqdotsu
+  // mul (zext, sext) -> vqdotsu (swapped)
+  // TODO: Improve .vx handling - we end up with a sub-vector insert
+  // which confuses the splat pattern matching.  Also, match vqdotus.vx
+  if (InVec.getOpcode() != ISD::MUL)
+    return SDValue();
+
+  SDValue A = InVec.getOperand(0);
+  SDValue B = InVec.getOperand(1);
+  unsigned Opc = 0;
+  if (A.getOpcode() == B.getOpcode()) {
+    if (A.getOpcode() == ISD::SIGN_EXTEND)
+      Opc = RISCVISD::VQDOT_VL;
+    else if (A.getOpcode() == ISD::ZERO_EXTEND)
+      Opc = RISCVISD::VQDOTU_VL;
+    else
+      return SDValue();
+  } else {
+    if (B.getOpcode() != ISD::ZERO_EXTEND)
+      std::swap(A, B);
+    if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+      return SDValue();
+    Opc = RISCVISD::VQDOTSU_VL;
+  }
+  assert(Opc);
+
+  if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
+      A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
+      !TLI.isTypeLegal(A.getValueType()))
+    return SDValue();
+
+  MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
+  A = DAG.getBitcast(ResVT, A.getOperand(0));
+  B = DAG.getBitcast(ResVT, B.getOperand(0));
+  return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+}
+
+static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
+                                       const RISCVSubtarget &Subtarget,
+                                       const RISCVTargetLowering &TLI) {
+  if (!Subtarget.hasStdExtZvqdotq())
+    return SDValue();
+
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue InVec = N->getOperand(0);
+  if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
+    return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
+  return SDValue();
+}
+
 static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
                                                const RISCVSubtarget &Subtarget,
                                                const RISCVTargetLowering &TLI) {
@@ -19878,8 +19990,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
 
     return SDValue();
   }
-  case ISD::CTPOP:
   case ISD::VECREDUCE_ADD:
+    if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this))
+      return V;
+    [[fallthrough]];
+  case ISD::CTPOP:
     if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
       return V;
     break;
@@ -22401,6 +22516,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(RI_VUNZIP2A_VL)
   NODE_NAME_CASE(RI_VUNZIP2B_VL)
   NODE_NAME_CASE(RI_VEXTRACT)
+  NODE_NAME_CASE(VQDOT_VL)
+  NODE_NAME_CASE(VQDOTU_VL)
+  NODE_NAME_CASE(VQDOTSU_VL)
   NODE_NAME_CASE(READ_CSR)
   NODE_NAME_CASE(WRITE_CSR)
   NODE_NAME_CASE(SWAP_CSR)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index ba24a0c324f51..3f1fce5d9f7e5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -416,7 +416,12 @@ enum NodeType : unsigned {
   RI_VUNZIP2A_VL,
   RI_VUNZIP2B_VL,
 
-  LAST_VL_VECTOR_OP = RI_VUNZIP2B_VL,
+  // zvqdot instructions with additional passthru, mask and VL operands
+  VQDOT_VL,
+  VQDOTU_VL,
+  VQDOTSU_VL,
+
+  LAST_VL_VECTOR_OP = VQDOTSU_VL,
 
   // XRivosVisni
   // VEXTRACT matches the semantics of ri.vextract.x.v. The result is always

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
index 205fffd5115ee..6018958f6eb27 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
@@ -26,3 +26,34 @@ let Predicates = [HasStdExtZvqdotq] in {
   def VQDOTSU_VX : VALUVX<0b101010, OPMVX, "vqdotsu.vx">;
   def VQDOTUS_VX : VALUVX<0b101110, OPMVX, "vqdotus.vx">;
 } // Predicates = [HasStdExtZvqdotq]
+
+
+def riscv_vqdot_vl : SDNode<"RISCVISD::VQDOT_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vqdotu_vl : SDNode<"RISCVISD::VQDOTU_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vqdotsu_vl : SDNode<"RISCVISD::VQDOTSU_VL", SDT_RISCVIntBinOp_VL>;
+
+multiclass VPseudoVQDOT_VV_VX {
+  foreach m = MxSet<32>.m in {
+    defm "" : VPseudoBinaryV_VV<m>,
+            SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", m.MX,
+                        forcePassthruRead=true>;
+    defm "" : VPseudoBinaryV_VX<m>,
+            SchedBinary<"WriteVIALUX", "ReadVIALUV", "ReadVIALUX", m.MX,
+                        forcePassthruRead=true>;
+  }
+}
+
+// TODO: Add pseudo and patterns for vqdotus.vx
+// TODO: Add isCommutable for VQDOT and VQDOTU
+let Predicates = [HasStdExtZvqdotq], mayLoad = 0, mayStore = 0,
+    hasSideEffects = 0 in {
+  defm PseudoVQDOT : VPseudoVQDOT_VV_VX;
+  defm PseudoVQDOTU : VPseudoVQDOT_VV_VX;
+  defm PseudoVQDOTSU : VPseudoVQDOT_VV_VX;
+}
+
+defvar AllE32Vectors = [VI32MF2, VI32M1, VI32M2, VI32M4, VI32M8];
+defm : VPatBinaryVL_VV_VX<riscv_vqdot_vl, "PseudoVQDOT", AllE32Vectors>;
+defm : VPatBinaryVL_VV_VX<riscv_vqdotu_vl, "PseudoVQDOTU", AllE32Vectors>;
+defm : VPatBinaryVL_VV_VX<riscv_vqdotsu_vl, "PseudoVQDOTSU", AllE32Vectors>;
+

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 25192ea19aab3..e48bc9cdfea4e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,21 +1,31 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
 
 define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdot_vv:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v8
-; CHECK-NEXT:    vsext.vf2 v14, v9
-; CHECK-NEXT:    vwmul.vv v8, v12, v14
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v12, zero
-; CHECK-NEXT:    vredsum.vs v8, v8, v12
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdot_vv:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vsext.vf2 v14, v9
+; NODOT-NEXT:    vwmul.vv v8, v12, v14
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdot.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.sext = sext <16 x i8> %b to <16 x i32>
@@ -63,17 +73,27 @@ entry:
 }
 
 define i32 @vqdotu_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotu_vv:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT:    vwmulu.vv v10, v8, v9
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vwredsumu.vs v8, v10, v8
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotu_vv:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT:    vwmulu.vv v10, v8, v9
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT:    vwredsumu.vs v8, v10, v8
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotu_vv:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotu.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.zext = zext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -102,17 +122,27 @@ entry:
 }
 
 define i32 @vqdotsu_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v8
-; CHECK-NEXT:    vzext.vf2 v14, v9
-; CHECK-NEXT:    vwmulsu.vv v8, v12, v14
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v12, zero
-; CHECK-NEXT:    vredsum.vs v8, v8, v12
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotsu_vv:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vzext.vf2 v14, v9
+; NODOT-NEXT:    vwmulsu.vv v8, v12, v14
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotsu.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -122,17 +152,27 @@ entry:
 }
 
 define i32 @vqdotsu_vv_swapped(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv_swapped:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v8
-; CHECK-NEXT:    vzext.vf2 v14, v9
-; CHECK-NEXT:    vwmulsu.vv v8, v12, v14
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v12, zero
-; CHECK-NEXT:    vredsum.vs v8, v8, v12
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotsu_vv_swapped:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vzext.vf2 v14, v9
+; NODOT-NEXT:    vwmulsu.vv v8, v12, v14
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_swapped:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotsu.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -181,14 +221,38 @@ entry:
 }
 
 define i32 @reduce_of_sext(<16 x i8> %a) {
-; CHECK-LABEL: reduce_of_sext:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT:    vsext.vf4 v12, v8
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: reduce_of_sext:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT:    vsext.vf4 v12, v8
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT32-LABEL: reduce_of_sext:
+; DOT32:       # %bb.0: # %entry
+; DOT32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT32-NEXT:    vmv.v.i v9, 0
+; DOT32-NEXT:    lui a0, 4112
+; DOT32-NEXT:    addi a0, a0, 257
+; DOT32-NEXT:    vqdot.vx v9, v8, a0
+; DOT32-NEXT:    vmv.s.x v8, zero
+; DOT32-NEXT:    vredsum.vs v8, v9, v8
+; DOT32-NEXT:    vmv.x.s a0, v8
+; DOT32-NEXT:    ret
+;
+; DOT64-LABEL: reduce_of_sext:
+; DOT64:       # %bb.0: # %entry
+; DOT64-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT64-NEXT:    vmv.v.i v9, 0
+; DOT64-NEXT:    lui a0, 4112
+; DOT64-NEXT:    addiw a0, a0, 257
+; DOT64-NEXT:    vqdot.vx v9, v8, a0
+; DOT64-NEXT:    vmv.s.x v8, zero
+; DOT64-NEXT:    vredsum.vs v8, v9, v8
+; DOT64-NEXT:    vmv.x.s a0, v8
+; DOT64-NEXT:    ret
 entry:
   %a.ext = sext <16 x i8> %a to <16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -196,14 +260,38 @@ entry:
 }
 
 define i32 @reduce_of_zext(<16 x i8> %a) {
-; CHECK-LABEL: reduce_of_zext:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT:    vzext.vf4 v12, v8
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: reduce_of_zext:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT:    vzext.vf4 v12, v8
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT32-LABEL: reduce_of_zext:
+; DOT32:       # %bb.0: # %entry
+; DOT32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT32-NEXT:    vmv.v.i v9, 0
+; DOT32-NEXT:    lui a0, 4112
+; DOT32-NEXT:    addi a0, a0, 257
+; DOT32-NEXT:    vqdotu.vx v9, v8, a0
+; DOT32-NEXT:    vmv.s.x v8, zero
+; DOT32-NEXT:    vredsum.vs v8, v9, v8
+; DOT32-NEXT:    vmv.x.s a0, v8
+; DOT32-NEXT:    ret
+;
+; DOT64-LABEL: reduce_of_zext:
+; DOT64:       # %bb.0: # %entry
+; DOT64-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT64-NEXT:    vmv.v.i v9, 0
+; DOT64-NEXT:    lui a0, 4112
+; DOT64-NEXT:    addiw a0, a0, 257
+; DOT64-NEXT:    vqdotu.vx v9, v8, a0
+; DOT64-NEXT:    vmv.s.x v8, zero
+; DOT64-NEXT:    vredsum.vs v8, v9, v8
+; DOT64-NEXT:    vmv.x.s a0, v8
+; DOT64-NEXT:    ret
 entry:
   %a.ext = zext <16 x i8> %a to <16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)


        


More information about the llvm-commits mailing list