[llvm] [RISCV] Support scalable vectors for the zvqdotq lowering paths (PR #140922)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 21 08:59:26 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

This was an oversight in the original patch series.  Without this change,
the newly added tests fail assertions.


---

Patch is 25.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140922.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+29-15) 
- (added) llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll (+581) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 1158499718737..d69e04a9912a2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18177,17 +18177,20 @@ static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
   assert(VT == Op1.getSimpleValueType() &&
          VT.getVectorElementType() == MVT::i32);
 
-  assert(VT.isFixedLengthVector());
-  MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
-  SDValue Passthru = convertToScalableVector(
-      ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
-  Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
-  Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
-
+  SDValue Passthru = DAG.getConstant(0, DL, VT);
+  MVT ContainerVT = VT;
+  if (VT.isFixedLengthVector()) {
+    ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+    Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget);
+    Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+    Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
+  }
   auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
   SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
                                    {Op0, Op1, Passthru, Mask, VL});
-  return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+  if (VT.isFixedLengthVector())
+    return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+  return LocalAccum;
 }
 
 static MVT getQDOTXResultType(MVT OpVT) {
@@ -18207,7 +18210,7 @@ static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
   EVT AVT = A.getValueType();
   EVT BVT = B.getValueType();
   assert(AVT.getVectorElementType() == BVT.getVectorElementType());
-  if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) {
+  if (AVT.getVectorMinNumElements() > BVT.getVectorMinNumElements()) {
     std::swap(A, B);
     std::swap(AVT, BVT);
   }
@@ -18641,7 +18644,7 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
 static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
 
-  assert(N->getOpcode() == RISCVISD::ADD_VL);
+  assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
 
   if (!N->getValueType(0).isVector())
     return SDValue();
@@ -18649,9 +18652,11 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
   SDValue Addend = N->getOperand(0);
   SDValue DotOp = N->getOperand(1);
 
-  SDValue AddPassthruOp = N->getOperand(2);
-  if (!AddPassthruOp.isUndef())
-    return SDValue();
+  if (N->getOpcode() == RISCVISD::ADD_VL) {
+    SDValue AddPassthruOp = N->getOperand(2);
+    if (!AddPassthruOp.isUndef())
+      return SDValue();
+  }
 
   auto IsVqdotqOpc = [](unsigned Opc) {
     switch (Opc) {
@@ -18670,8 +18675,15 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
   if (!IsVqdotqOpc(DotOp.getOpcode()))
     return SDValue();
 
-  SDValue AddMask = N->getOperand(3);
-  SDValue AddVL = N->getOperand(4);
+  auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
+                             const RISCVSubtarget &Subtarget) {
+    if (N->getOpcode() == ISD::ADD) {
+      SDLoc DL(N);
+      return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
+                                     Subtarget);
+    }
+    return std::make_pair(N->getOperand(3), N->getOperand(4));
+  }(N, DAG, Subtarget);
 
   SDValue MulVL = DotOp.getOperand(4);
   if (AddVL != MulVL)
@@ -19309,6 +19321,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
       return V;
     if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
       return V;
+    if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
+      return V;
     return performADDCombine(N, DCI, Subtarget);
   }
   case ISD::SUB: {
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
new file mode 100644
index 0000000000000..a56ef0cd75d6a
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -0,0 +1,581 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+
+define i32 @vqdot_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; NODOT-LABEL: vqdot_vv:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    vsext.vf2 v20, v10
+; NODOT-NEXT:    vwmul.vv v8, v16, v20
+; NODOT-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vmv.s.x v16, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v16
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vqdot.vv v12, v8, v10
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.sext
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+  ret i32 %res
+}
+
+define i32 @vqdot_vx_constant(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: vqdot_vx_constant:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT:    vsext.vf2 v16, v8
+; CHECK-NEXT:    li a0, 23
+; CHECK-NEXT:    vwmul.vx v8, v16, a0
+; CHECK-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; CHECK-NEXT:    vmv.s.x v16, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v16
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, splat (i32 23)
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+  ret i32 %res
+}
+
+define i32 @vqdot_vx_constant_swapped(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: vqdot_vx_constant_swapped:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT:    vsext.vf2 v16, v8
+; CHECK-NEXT:    li a0, 23
+; CHECK-NEXT:    vwmul.vx v8, v16, a0
+; CHECK-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; CHECK-NEXT:    vmv.s.x v16, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v16
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> splat (i32 23), %a.sext
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+  ret i32 %res
+}
+
+define i32 @vqdotu_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; NODOT-LABEL: vqdotu_vv:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; NODOT-NEXT:    vwmulu.vv v12, v8, v10
+; NODOT-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vsetvli zero, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vwredsumu.vs v8, v12, v8
+; NODOT-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotu_vv:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vqdotu.vv v12, v8, v10
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.zext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.zext, %b.zext
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+  ret i32 %res
+}
+
+define i32 @vqdotu_vx_constant(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: vqdotu_vx_constant:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT:    vzext.vf2 v16, v8
+; CHECK-NEXT:    li a0, 123
+; CHECK-NEXT:    vwmulu.vx v8, v16, a0
+; CHECK-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; CHECK-NEXT:    vmv.s.x v16, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v16
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %a.zext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.zext, splat (i32 123)
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+  ret i32 %res
+}
+
+define i32 @vqdotsu_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; NODOT-LABEL: vqdotsu_vv:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    vzext.vf2 v20, v10
+; NODOT-NEXT:    vwmulsu.vv v8, v16, v20
+; NODOT-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vmv.s.x v16, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v16
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vqdotsu.vv v12, v8, v10
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.zext
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+  ret i32 %res
+}
+
+define i32 @vqdotsu_vv_swapped(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; NODOT-LABEL: vqdotsu_vv_swapped:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    vzext.vf2 v20, v10
+; NODOT-NEXT:    vwmulsu.vv v8, v16, v20
+; NODOT-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vmv.s.x v16, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v16
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_swapped:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vqdotsu.vv v12, v8, v10
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %b.zext, %a.sext
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+  ret i32 %res
+}
+
+define i32 @vdotqsu_vx_constant(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: vdotqsu_vx_constant:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT:    vsext.vf2 v16, v8
+; CHECK-NEXT:    li a0, 123
+; CHECK-NEXT:    vwmul.vx v8, v16, a0
+; CHECK-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; CHECK-NEXT:    vmv.s.x v16, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v16
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, splat (i32 123)
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+  ret i32 %res
+}
+
+define i32 @vdotqus_vx_constant(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: vdotqus_vx_constant:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT:    vzext.vf2 v16, v8
+; CHECK-NEXT:    li a0, -23
+; CHECK-NEXT:    vmv.v.x v20, a0
+; CHECK-NEXT:    vwmulsu.vv v8, v20, v16
+; CHECK-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; CHECK-NEXT:    vmv.s.x v16, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v16
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %a.zext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.zext, splat (i32 -23)
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+  ret i32 %res
+}
+
+define i32 @reduce_of_sext(<vscale x 16 x i8> %a) {
+; NODOT-LABEL: reduce_of_sext:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vsext.vf4 v16, v8
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v16, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT32-LABEL: reduce_of_sext:
+; DOT32:       # %bb.0: # %entry
+; DOT32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT32-NEXT:    vmv.v.i v10, 0
+; DOT32-NEXT:    lui a0, 4112
+; DOT32-NEXT:    addi a0, a0, 257
+; DOT32-NEXT:    vqdot.vx v10, v8, a0
+; DOT32-NEXT:    vmv.s.x v8, zero
+; DOT32-NEXT:    vredsum.vs v8, v10, v8
+; DOT32-NEXT:    vmv.x.s a0, v8
+; DOT32-NEXT:    ret
+;
+; DOT64-LABEL: reduce_of_sext:
+; DOT64:       # %bb.0: # %entry
+; DOT64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT64-NEXT:    vmv.v.i v10, 0
+; DOT64-NEXT:    lui a0, 4112
+; DOT64-NEXT:    addiw a0, a0, 257
+; DOT64-NEXT:    vqdot.vx v10, v8, a0
+; DOT64-NEXT:    vmv.s.x v8, zero
+; DOT64-NEXT:    vredsum.vs v8, v10, v8
+; DOT64-NEXT:    vmv.x.s a0, v8
+; DOT64-NEXT:    ret
+entry:
+  %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
+  ret i32 %res
+}
+
+define i32 @reduce_of_zext(<vscale x 16 x i8> %a) {
+; NODOT-LABEL: reduce_of_zext:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vzext.vf4 v16, v8
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v16, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT32-LABEL: reduce_of_zext:
+; DOT32:       # %bb.0: # %entry
+; DOT32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT32-NEXT:    vmv.v.i v10, 0
+; DOT32-NEXT:    lui a0, 4112
+; DOT32-NEXT:    addi a0, a0, 257
+; DOT32-NEXT:    vqdotu.vx v10, v8, a0
+; DOT32-NEXT:    vmv.s.x v8, zero
+; DOT32-NEXT:    vredsum.vs v8, v10, v8
+; DOT32-NEXT:    vmv.x.s a0, v8
+; DOT32-NEXT:    ret
+;
+; DOT64-LABEL: reduce_of_zext:
+; DOT64:       # %bb.0: # %entry
+; DOT64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT64-NEXT:    vmv.v.i v10, 0
+; DOT64-NEXT:    lui a0, 4112
+; DOT64-NEXT:    addiw a0, a0, 257
+; DOT64-NEXT:    vqdotu.vx v10, v8, a0
+; DOT64-NEXT:    vmv.s.x v8, zero
+; DOT64-NEXT:    vredsum.vs v8, v10, v8
+; DOT64-NEXT:    vmv.x.s a0, v8
+; DOT64-NEXT:    ret
+entry:
+  %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
+  ret i32 %res
+}
+
+define i32 @vqdot_vv_accum(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i32> %x) {
+; NODOT-LABEL: vqdot_vv_accum:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vsext.vf2 v24, v10
+; NODOT-NEXT:    vwmacc.vv v16, v12, v24
+; NODOT-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v16, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv_accum:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.s.x v12, zero
+; DOT-NEXT:    vqdot.vv v16, v8, v10
+; DOT-NEXT:    vsetvli a0, zero, e32, m8, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v16, v12
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.sext
+  %add = add <vscale x 16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %add)
+  ret i32 %sum
+}
+
+define i32 @vqdotu_vv_accum(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i32> %x) {
+; NODOT-LABEL: vqdotu_vv_accum:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; NODOT-NEXT:    vwmulu.vv v12, v8, v10
+; NODOT-NEXT:    vsetvli zero, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vwaddu.wv v16, v16, v12
+; NODOT-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v16, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotu_vv_accum:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.s.x v12, zero
+; DOT-NEXT:    vqdotu.vv v16, v8, v10
+; DOT-NEXT:    vsetvli a0, zero, e32, m8, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v16, v12
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.zext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.zext, %b.zext
+  %add = add <vscale x 16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %add)
+  ret i32 %sum
+}
+
+define i32 @vqdotsu_vv_accum(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i32> %x) {
+; NODOT-LABEL: vqdotsu_vv_accum:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vzext.vf2 v24, v10
+; NODOT-NEXT:    vwmaccsu.vv v16, v12, v24
+; NODOT-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v16, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_accum:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.s.x v12, zero
+; DOT-NEXT:    vqdotsu.vv v16, v8, v10
+; DOT-NEXT:    vsetvli a0, zero, e32, m8, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v16, v12
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.zext
+  %add = add <vscale x 16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %add)
+  ret i32 %sum
+}
+
+define i32 @vqdot_vv_scalar_add(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, i32 %x) {
+; NODOT-LABEL: vqdot_vv_scalar_add:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a1, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    vsext.vf2 v20, v10
+; NODOT-NEXT:    vwmul.vv v8, v16, v20
+; NODOT-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vmv.s.x v16, a0
+; NODOT-NEXT:    vredsum.vs v8, v8, v16
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv_scalar_add:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a1, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vqdot.vv v12, v8, v10
+; DOT-NEXT:    vmv.s.x v8, a0
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.sext
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+  %add = add i32 %sum, %x
+  ret i32 %add
+}
+
+define i32 @vqdotu_vv_scalar_add(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, i32 %x) {
+; NODOT-LABEL: vqdotu_vv_scalar_add:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a1, zero, e8, m2, ta, ma
+; NODOT-NEXT:    vwmulu.vv v12, v8, v10
+; NODOT-NEXT:    vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, a0
+; NODOT-NEXT:    vsetvli zero, zero...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/140922


More information about the llvm-commits mailing list